Upgraded to pgx
This commit is contained in:
		
							parent
							
								
									825cf0fc9a
								
							
						
					
					
						commit
						ad601219ec
					
				
					 3 changed files with 66 additions and 49 deletions
				
			
		| 
						 | 
				
			
			@ -1,8 +1,11 @@
 | 
			
		|||
package dbschema
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	// External
 | 
			
		||||
	"github.com/jackc/pgx/v5/pgxpool"
 | 
			
		||||
 | 
			
		||||
	// Standard
 | 
			
		||||
	"database/sql"
 | 
			
		||||
	"context"
 | 
			
		||||
	"fmt"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -13,10 +16,10 @@ func newDatabase(host string, port int, dbName, user, pass string) (dbase Databa
 | 
			
		|||
	dbase.Username = user
 | 
			
		||||
	dbase.Password = pass
 | 
			
		||||
 | 
			
		||||
	dbase.db, err = sql.Open("postgres", dbase.sqlConnString())
 | 
			
		||||
	dbase.db, err = pgxpool.New(context.Background(), dbase.sqlConnString())
 | 
			
		||||
	return
 | 
			
		||||
}// }}}
 | 
			
		||||
func databaseFromInstance(db *sql.DB) (dbase Database, err error) {
 | 
			
		||||
func databaseFromInstance(db *pgxpool.Pool) (dbase Database, err error) {
 | 
			
		||||
	dbase.db = db
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -21,10 +21,7 @@ package dbschema
 | 
			
		|||
 | 
			
		||||
import (
 | 
			
		||||
	// External
 | 
			
		||||
	_ "github.com/lib/pq"
 | 
			
		||||
 | 
			
		||||
	// Standard
 | 
			
		||||
	"database/sql"
 | 
			
		||||
	"github.com/jackc/pgx/v5/pgxpool"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// An upgrader verifies the schema for one or more databases and upgrades them if possible.
 | 
			
		||||
| 
						 | 
				
			
			@ -42,7 +39,7 @@ type Database struct {
 | 
			
		|||
	Username string
 | 
			
		||||
	Password string
 | 
			
		||||
 | 
			
		||||
	db       *sql.DB
 | 
			
		||||
	db       *pgxpool.Pool
 | 
			
		||||
	upgrader *Upgrader
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										99
									
								
								upgrader.go
									
										
									
									
									
								
							
							
						
						
									
										99
									
								
								upgrader.go
									
										
									
									
									
								
							| 
						 | 
				
			
			@ -2,19 +2,22 @@ package dbschema
 | 
			
		|||
 | 
			
		||||
import (
 | 
			
		||||
	// External
 | 
			
		||||
	"github.com/lib/pq"
 | 
			
		||||
	"github.com/jackc/pgx/v5"
 | 
			
		||||
	"github.com/jackc/pgx/v5/pgconn"
 | 
			
		||||
	"github.com/jackc/pgx/v5/pgxpool"
 | 
			
		||||
 | 
			
		||||
	// Standard
 | 
			
		||||
	"context"
 | 
			
		||||
	"database/sql"
 | 
			
		||||
	"fmt"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func defaultCallback(topic, msg string) {// {{{
 | 
			
		||||
func defaultCallback(topic, msg string) { // {{{
 | 
			
		||||
	fmt.Printf("[%s] %s\n", topic, msg)
 | 
			
		||||
}// }}}
 | 
			
		||||
} // }}}
 | 
			
		||||
 | 
			
		||||
// NewUpgrader creates an upgrader with an empty list of databases.
 | 
			
		||||
func NewUpgrader(schema ...string) (upgrader Upgrader) {// {{{
 | 
			
		||||
func NewUpgrader(schema ...string) (upgrader Upgrader) { // {{{
 | 
			
		||||
	// Using a variadic function for backward compatibility.
 | 
			
		||||
	if len(schema) > 0 {
 | 
			
		||||
		upgrader.schema = schema[0]
 | 
			
		||||
| 
						 | 
				
			
			@ -25,18 +28,18 @@ func NewUpgrader(schema ...string) (upgrader Upgrader) {// {{{
 | 
			
		|||
	upgrader.logCallback = defaultCallback
 | 
			
		||||
	upgrader.databases = map[string]Database{}
 | 
			
		||||
	return
 | 
			
		||||
}// }}}
 | 
			
		||||
} // }}}
 | 
			
		||||
 | 
			
		||||
// SetLogCallback allows to set a callback for custom logging.
 | 
			
		||||
func (upgrader *Upgrader) SetLogCallback(callback func(string, string)) {// {{{
 | 
			
		||||
func (upgrader *Upgrader) SetLogCallback(callback func(string, string)) { // {{{
 | 
			
		||||
	upgrader.logCallback = callback
 | 
			
		||||
}// }}}
 | 
			
		||||
} // }}}
 | 
			
		||||
// SetSqlCallback is required for providing the SQL schema updates.
 | 
			
		||||
func (upgrader *Upgrader) SetSqlCallback(callback func(string, int) ([]byte, bool)) {// {{{
 | 
			
		||||
func (upgrader *Upgrader) SetSqlCallback(callback func(string, int) ([]byte, bool)) { // {{{
 | 
			
		||||
	upgrader.sqlCallback = callback
 | 
			
		||||
}// }}}
 | 
			
		||||
} // }}}
 | 
			
		||||
// Version returns the current dbschema version for the given database name.
 | 
			
		||||
func (upgrader *Upgrader) Version(dbName string) (version int, err error) {// {{{
 | 
			
		||||
func (upgrader *Upgrader) Version(dbName string) (version int, err error) { // {{{
 | 
			
		||||
	dbase, found := upgrader.databases[dbName]
 | 
			
		||||
	if !found {
 | 
			
		||||
		err = fmt.Errorf("Database %s not previously added to the upgrader", dbName)
 | 
			
		||||
| 
						 | 
				
			
			@ -45,21 +48,22 @@ func (upgrader *Upgrader) Version(dbName string) (version int, err error) {// {{
 | 
			
		|||
 | 
			
		||||
	version, err = dbase.Version()
 | 
			
		||||
	return
 | 
			
		||||
}// }}}
 | 
			
		||||
} // }}}
 | 
			
		||||
 | 
			
		||||
func (dbase Database) createSchemaTable() (err error) {// {{{
 | 
			
		||||
func (dbase Database) createSchemaTable() (err error) { // {{{
 | 
			
		||||
	dbase.upgrader.logCallback("create", fmt.Sprintf("%s, %s.schema", dbase.DbName, dbase.upgrader.schema))
 | 
			
		||||
	_, err = dbase.db.Exec(`CREATE SCHEMA "`+dbase.upgrader.schema+`"`)
 | 
			
		||||
	_, err = dbase.db.Exec(context.Background(), `CREATE SCHEMA "` + dbase.upgrader.schema + `"`)
 | 
			
		||||
 | 
			
		||||
	// Error code 42P06 "duplicate_schema" is an OK error,
 | 
			
		||||
	// table can still be missing and created.
 | 
			
		||||
	pqErr, _ := err.(*pq.Error)
 | 
			
		||||
	pqErr, _ := err.(*pgconn.PgError)
 | 
			
		||||
	if pqErr != nil && pqErr.Code != "42P06" {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	_, err = dbase.db.Exec(`
 | 
			
		||||
		CREATE TABLE "`+dbase.upgrader.schema+`"."schema" (
 | 
			
		||||
	_, err = dbase.db.Exec(
 | 
			
		||||
		context.Background(),
 | 
			
		||||
		`CREATE TABLE "` + dbase.upgrader.schema + `"."schema" (
 | 
			
		||||
			version  int4      NOT NULL,
 | 
			
		||||
			updated  timestamp NOT NULL DEFAULT NOW(),
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -67,19 +71,20 @@ func (dbase Database) createSchemaTable() (err error) {// {{{
 | 
			
		|||
		)`,
 | 
			
		||||
	)
 | 
			
		||||
	return
 | 
			
		||||
}// }}}
 | 
			
		||||
func (dbase Database) appendSchemaVersion(version int) (err error) {// {{{
 | 
			
		||||
	_, err = dbase.db.Exec(`INSERT INTO `+dbase.upgrader.schema+`.schema(version) VALUES($1)`, version)
 | 
			
		||||
} // }}}
 | 
			
		||||
func (dbase Database) appendSchemaVersion(version int) (err error) { // {{{
 | 
			
		||||
	_, err = dbase.db.Exec(context.Background(), `INSERT INTO `+dbase.upgrader.schema+`.schema(version) VALUES($1)`, version)
 | 
			
		||||
	return
 | 
			
		||||
}// }}}
 | 
			
		||||
} // }}}
 | 
			
		||||
 | 
			
		||||
func (dbase Database) verifySchemaTable() (err error) {// {{{
 | 
			
		||||
	var rows *sql.Rows
 | 
			
		||||
func (dbase Database) verifySchemaTable() (err error) { // {{{
 | 
			
		||||
	var rows pgx.Rows
 | 
			
		||||
	if rows, err = dbase.db.Query(
 | 
			
		||||
		context.Background(),
 | 
			
		||||
		`SELECT EXISTS (
 | 
			
		||||
			SELECT FROM pg_catalog.pg_class c
 | 
			
		||||
			JOIN   pg_catalog.pg_namespace n ON n.oid = c.relnamespace
 | 
			
		||||
			WHERE  n.nspname = '`+dbase.upgrader.schema+`'
 | 
			
		||||
			WHERE  n.nspname = '` + dbase.upgrader.schema + `'
 | 
			
		||||
			AND    c.relname = 'schema'
 | 
			
		||||
		)`,
 | 
			
		||||
	); err != nil {
 | 
			
		||||
| 
						 | 
				
			
			@ -97,11 +102,11 @@ func (dbase Database) verifySchemaTable() (err error) {// {{{
 | 
			
		|||
	}
 | 
			
		||||
	err = dbase.createSchemaTable()
 | 
			
		||||
	return
 | 
			
		||||
}// }}}
 | 
			
		||||
func (dbase Database) verifySchemaEntry() (err error) {// {{{
 | 
			
		||||
} // }}}
 | 
			
		||||
func (dbase Database) verifySchemaEntry() (err error) { // {{{
 | 
			
		||||
	var version int
 | 
			
		||||
	var row *sql.Row
 | 
			
		||||
	row = dbase.db.QueryRow(`SELECT version FROM `+dbase.upgrader.schema+`.schema LIMIT 1`)
 | 
			
		||||
	var row pgx.Row
 | 
			
		||||
	row = dbase.db.QueryRow(context.Background(), `SELECT version FROM `+dbase.upgrader.schema+`.schema LIMIT 1`)
 | 
			
		||||
 | 
			
		||||
	err = row.Scan(&version)
 | 
			
		||||
	if err == sql.ErrNoRows {
 | 
			
		||||
| 
						 | 
				
			
			@ -110,11 +115,12 @@ func (dbase Database) verifySchemaEntry() (err error) {// {{{
 | 
			
		|||
	}
 | 
			
		||||
 | 
			
		||||
	return
 | 
			
		||||
}// }}}
 | 
			
		||||
func (dbase Database) Version() (version int, err error) {// {{{
 | 
			
		||||
	var rows *sql.Rows
 | 
			
		||||
} // }}}
 | 
			
		||||
func (dbase Database) Version() (version int, err error) { // {{{
 | 
			
		||||
	var rows pgx.Rows
 | 
			
		||||
	rows, err = dbase.db.Query(
 | 
			
		||||
		`SELECT version FROM `+dbase.upgrader.schema+`.schema ORDER BY version DESC LIMIT 1`,
 | 
			
		||||
		context.Background(),
 | 
			
		||||
		`SELECT version FROM ` + dbase.upgrader.schema + `.schema ORDER BY version DESC LIMIT 1`,
 | 
			
		||||
	)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return
 | 
			
		||||
| 
						 | 
				
			
			@ -127,15 +133,15 @@ func (dbase Database) Version() (version int, err error) {// {{{
 | 
			
		|||
		err = fmt.Errorf(`Database "%s" is missing an entry in `+dbase.upgrader.schema+`.schema`, dbase.DbName)
 | 
			
		||||
	}
 | 
			
		||||
	return
 | 
			
		||||
}// }}}
 | 
			
		||||
} // }}}
 | 
			
		||||
 | 
			
		||||
// AddDatabase sets a database up for the Run() function with verifying/creating the _db.schema table.
 | 
			
		||||
func (upgrader Upgrader) AddDatabase(host string, port int, dbName, user, pass string) (db Database, err error) {// {{{
 | 
			
		||||
func (upgrader Upgrader) AddDatabase(host string, port int, dbName, user, pass string) (db Database, err error) { // {{{
 | 
			
		||||
	if db, err = newDatabase(host, port, dbName, user, pass); err != nil {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	db.upgrader = &upgrader
 | 
			
		||||
	
 | 
			
		||||
 | 
			
		||||
	upgrader.databases[dbName] = db
 | 
			
		||||
 | 
			
		||||
	if err = db.verifySchemaTable(); err != nil {
 | 
			
		||||
| 
						 | 
				
			
			@ -144,13 +150,24 @@ func (upgrader Upgrader) AddDatabase(host string, port int, dbName, user, pass s
 | 
			
		|||
 | 
			
		||||
	err = db.verifySchemaEntry()
 | 
			
		||||
	return
 | 
			
		||||
}// }}}
 | 
			
		||||
func (upgrader Upgrader) AddDatabaseInstance(sqlDB *sql.DB) (db Database, err error) {// {{{
 | 
			
		||||
	return databaseFromInstance(sqlDB)
 | 
			
		||||
}// }}}
 | 
			
		||||
} // }}}
 | 
			
		||||
func (upgrader Upgrader) AddDatabaseInstance(sqlDB *pgxpool.Pool, dbName string) (db Database, err error) { // {{{
 | 
			
		||||
	db, err = databaseFromInstance(sqlDB)
 | 
			
		||||
 | 
			
		||||
	db.upgrader = &upgrader
 | 
			
		||||
 | 
			
		||||
	upgrader.databases[dbName] = db
 | 
			
		||||
 | 
			
		||||
	if err = db.verifySchemaTable(); err != nil {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err = db.verifySchemaEntry()
 | 
			
		||||
	return
 | 
			
		||||
} // }}}
 | 
			
		||||
 | 
			
		||||
// Run executes the actual schema updates until there are no more available.
 | 
			
		||||
func (upgrader Upgrader) Run() (err error) {// {{{
 | 
			
		||||
func (upgrader Upgrader) Run() (err error) { // {{{
 | 
			
		||||
	var version int
 | 
			
		||||
 | 
			
		||||
	for dbName, dbase := range upgrader.databases {
 | 
			
		||||
| 
						 | 
				
			
			@ -168,7 +185,7 @@ func (upgrader Upgrader) Run() (err error) {// {{{
 | 
			
		|||
			}
 | 
			
		||||
 | 
			
		||||
			upgrader.logCallback("exec", fmt.Sprintf("%s.%s: %d", dbName, upgrader.schema, version))
 | 
			
		||||
			if _, err = dbase.db.Exec(string(sql)); err != nil {
 | 
			
		||||
			if _, err = dbase.db.Exec(context.Background(), string(sql)); err != nil {
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
			if err = dbase.appendSchemaVersion(version); err != nil {
 | 
			
		||||
| 
						 | 
				
			
			@ -177,6 +194,6 @@ func (upgrader Upgrader) Run() (err error) {// {{{
 | 
			
		|||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return
 | 
			
		||||
}// }}}
 | 
			
		||||
} // }}}
 | 
			
		||||
 | 
			
		||||
// vim: foldmethod=marker
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
	Add table
		
		Reference in a new issue