Compare commits

...

2 commits
v1.7.0 ... main

Author SHA1 Message Date
ca41ea4c50 Fixed no row-error 2025-09-29 09:34:14 +02:00
ad601219ec Upgraded to pgx 2025-09-23 21:47:39 +02:00
3 changed files with 67 additions and 51 deletions

View file

@ -1,8 +1,11 @@
package dbschema package dbschema
import ( import (
// External
"github.com/jackc/pgx/v5/pgxpool"
// Standard // Standard
"database/sql" "context"
"fmt" "fmt"
) )
@ -13,10 +16,10 @@ func newDatabase(host string, port int, dbName, user, pass string) (dbase Databa
dbase.Username = user dbase.Username = user
dbase.Password = pass dbase.Password = pass
dbase.db, err = sql.Open("postgres", dbase.sqlConnString()) dbase.db, err = pgxpool.New(context.Background(), dbase.sqlConnString())
return return
}// }}} }// }}}
func databaseFromInstance(db *sql.DB) (dbase Database, err error) { func databaseFromInstance(db *pgxpool.Pool) (dbase Database, err error) {
dbase.db = db dbase.db = db
return return
} }

View file

@ -21,10 +21,7 @@ package dbschema
import ( import (
// External // External
_ "github.com/lib/pq" "github.com/jackc/pgx/v5/pgxpool"
// Standard
"database/sql"
) )
// An upgrader verifies the schema for one or more databases and upgrades them if possible. // An upgrader verifies the schema for one or more databases and upgrades them if possible.
@ -42,7 +39,7 @@ type Database struct {
Username string Username string
Password string Password string
db *sql.DB db *pgxpool.Pool
upgrader *Upgrader upgrader *Upgrader
} }

View file

@ -2,19 +2,21 @@ package dbschema
import ( import (
// External // External
"github.com/lib/pq" "github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/pgxpool"
// Standard // Standard
"database/sql" "context"
"fmt" "fmt"
) )
func defaultCallback(topic, msg string) {// {{{ func defaultCallback(topic, msg string) { // {{{
fmt.Printf("[%s] %s\n", topic, msg) fmt.Printf("[%s] %s\n", topic, msg)
}// }}} } // }}}
// NewUpgrader creates an upgrader with an empty list of databases. // NewUpgrader creates an upgrader with an empty list of databases.
func NewUpgrader(schema ...string) (upgrader Upgrader) {// {{{ func NewUpgrader(schema ...string) (upgrader Upgrader) { // {{{
// Using a variadic function for backward compatibility. // Using a variadic function for backward compatibility.
if len(schema) > 0 { if len(schema) > 0 {
upgrader.schema = schema[0] upgrader.schema = schema[0]
@ -25,18 +27,18 @@ func NewUpgrader(schema ...string) (upgrader Upgrader) {// {{{
upgrader.logCallback = defaultCallback upgrader.logCallback = defaultCallback
upgrader.databases = map[string]Database{} upgrader.databases = map[string]Database{}
return return
}// }}} } // }}}
// SetLogCallback allows to set a callback for custom logging. // SetLogCallback allows to set a callback for custom logging.
func (upgrader *Upgrader) SetLogCallback(callback func(string, string)) {// {{{ func (upgrader *Upgrader) SetLogCallback(callback func(string, string)) { // {{{
upgrader.logCallback = callback upgrader.logCallback = callback
}// }}} } // }}}
// SetSqlCallback is required for providing the SQL schema updates. // SetSqlCallback is required for providing the SQL schema updates.
func (upgrader *Upgrader) SetSqlCallback(callback func(string, int) ([]byte, bool)) {// {{{ func (upgrader *Upgrader) SetSqlCallback(callback func(string, int) ([]byte, bool)) { // {{{
upgrader.sqlCallback = callback upgrader.sqlCallback = callback
}// }}} } // }}}
// Version returns the current dbschema version for the given database name. // Version returns the current dbschema version for the given database name.
func (upgrader *Upgrader) Version(dbName string) (version int, err error) {// {{{ func (upgrader *Upgrader) Version(dbName string) (version int, err error) { // {{{
dbase, found := upgrader.databases[dbName] dbase, found := upgrader.databases[dbName]
if !found { if !found {
err = fmt.Errorf("Database %s not previously added to the upgrader", dbName) err = fmt.Errorf("Database %s not previously added to the upgrader", dbName)
@ -45,21 +47,22 @@ func (upgrader *Upgrader) Version(dbName string) (version int, err error) {// {{
version, err = dbase.Version() version, err = dbase.Version()
return return
}// }}} } // }}}
func (dbase Database) createSchemaTable() (err error) {// {{{ func (dbase Database) createSchemaTable() (err error) { // {{{
dbase.upgrader.logCallback("create", fmt.Sprintf("%s, %s.schema", dbase.DbName, dbase.upgrader.schema)) dbase.upgrader.logCallback("create", fmt.Sprintf("%s, %s.schema", dbase.DbName, dbase.upgrader.schema))
_, err = dbase.db.Exec(`CREATE SCHEMA "`+dbase.upgrader.schema+`"`) _, err = dbase.db.Exec(context.Background(), `CREATE SCHEMA "` + dbase.upgrader.schema + `"`)
// Error code 42P06 "duplicate_schema" is an OK error, // Error code 42P06 "duplicate_schema" is an OK error,
// table can still be missing and created. // table can still be missing and created.
pqErr, _ := err.(*pq.Error) pqErr, _ := err.(*pgconn.PgError)
if pqErr != nil && pqErr.Code != "42P06" { if pqErr != nil && pqErr.Code != "42P06" {
return return
} }
_, err = dbase.db.Exec(` _, err = dbase.db.Exec(
CREATE TABLE "`+dbase.upgrader.schema+`"."schema" ( context.Background(),
`CREATE TABLE "` + dbase.upgrader.schema + `"."schema" (
version int4 NOT NULL, version int4 NOT NULL,
updated timestamp NOT NULL DEFAULT NOW(), updated timestamp NOT NULL DEFAULT NOW(),
@ -67,19 +70,20 @@ func (dbase Database) createSchemaTable() (err error) {// {{{
)`, )`,
) )
return return
}// }}} } // }}}
func (dbase Database) appendSchemaVersion(version int) (err error) {// {{{ func (dbase Database) appendSchemaVersion(version int) (err error) { // {{{
_, err = dbase.db.Exec(`INSERT INTO `+dbase.upgrader.schema+`.schema(version) VALUES($1)`, version) _, err = dbase.db.Exec(context.Background(), `INSERT INTO `+dbase.upgrader.schema+`.schema(version) VALUES($1)`, version)
return return
}// }}} } // }}}
func (dbase Database) verifySchemaTable() (err error) {// {{{ func (dbase Database) verifySchemaTable() (err error) { // {{{
var rows *sql.Rows var rows pgx.Rows
if rows, err = dbase.db.Query( if rows, err = dbase.db.Query(
context.Background(),
`SELECT EXISTS ( `SELECT EXISTS (
SELECT FROM pg_catalog.pg_class c SELECT FROM pg_catalog.pg_class c
JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
WHERE n.nspname = '`+dbase.upgrader.schema+`' WHERE n.nspname = '` + dbase.upgrader.schema + `'
AND c.relname = 'schema' AND c.relname = 'schema'
)`, )`,
); err != nil { ); err != nil {
@ -97,24 +101,25 @@ func (dbase Database) verifySchemaTable() (err error) {// {{{
} }
err = dbase.createSchemaTable() err = dbase.createSchemaTable()
return return
}// }}} } // }}}
func (dbase Database) verifySchemaEntry() (err error) {// {{{ func (dbase Database) verifySchemaEntry() (err error) { // {{{
var version int var version int
var row *sql.Row var row pgx.Row
row = dbase.db.QueryRow(`SELECT version FROM `+dbase.upgrader.schema+`.schema LIMIT 1`) row = dbase.db.QueryRow(context.Background(), `SELECT version FROM `+dbase.upgrader.schema+`.schema LIMIT 1`)
err = row.Scan(&version) err = row.Scan(&version)
if err == sql.ErrNoRows { if err == pgx.ErrNoRows {
dbase.upgrader.logCallback("initiate version", dbase.DbName) dbase.upgrader.logCallback("initiate version", dbase.DbName)
err = dbase.appendSchemaVersion(0) err = dbase.appendSchemaVersion(0)
} }
return return
}// }}} } // }}}
func (dbase Database) Version() (version int, err error) {// {{{ func (dbase Database) Version() (version int, err error) { // {{{
var rows *sql.Rows var rows pgx.Rows
rows, err = dbase.db.Query( rows, err = dbase.db.Query(
`SELECT version FROM `+dbase.upgrader.schema+`.schema ORDER BY version DESC LIMIT 1`, context.Background(),
`SELECT version FROM ` + dbase.upgrader.schema + `.schema ORDER BY version DESC LIMIT 1`,
) )
if err != nil { if err != nil {
return return
@ -127,15 +132,15 @@ func (dbase Database) Version() (version int, err error) {// {{{
err = fmt.Errorf(`Database "%s" is missing an entry in `+dbase.upgrader.schema+`.schema`, dbase.DbName) err = fmt.Errorf(`Database "%s" is missing an entry in `+dbase.upgrader.schema+`.schema`, dbase.DbName)
} }
return return
}// }}} } // }}}
// AddDatabase sets a database up for the Run() function with verifying/creating the _db.schema table. // AddDatabase sets a database up for the Run() function with verifying/creating the _db.schema table.
func (upgrader Upgrader) AddDatabase(host string, port int, dbName, user, pass string) (db Database, err error) {// {{{ func (upgrader Upgrader) AddDatabase(host string, port int, dbName, user, pass string) (db Database, err error) { // {{{
if db, err = newDatabase(host, port, dbName, user, pass); err != nil { if db, err = newDatabase(host, port, dbName, user, pass); err != nil {
return return
} }
db.upgrader = &upgrader db.upgrader = &upgrader
upgrader.databases[dbName] = db upgrader.databases[dbName] = db
if err = db.verifySchemaTable(); err != nil { if err = db.verifySchemaTable(); err != nil {
@ -144,13 +149,24 @@ func (upgrader Upgrader) AddDatabase(host string, port int, dbName, user, pass s
err = db.verifySchemaEntry() err = db.verifySchemaEntry()
return return
}// }}} } // }}}
func (upgrader Upgrader) AddDatabaseInstance(sqlDB *sql.DB) (db Database, err error) {// {{{ func (upgrader Upgrader) AddDatabaseInstance(sqlDB *pgxpool.Pool, dbName string) (db Database, err error) { // {{{
return databaseFromInstance(sqlDB) db, err = databaseFromInstance(sqlDB)
}// }}}
db.upgrader = &upgrader
upgrader.databases[dbName] = db
if err = db.verifySchemaTable(); err != nil {
return
}
err = db.verifySchemaEntry()
return
} // }}}
// Run executes the actual schema updates until there are no more available. // Run executes the actual schema updates until there are no more available.
func (upgrader Upgrader) Run() (err error) {// {{{ func (upgrader Upgrader) Run() (err error) { // {{{
var version int var version int
for dbName, dbase := range upgrader.databases { for dbName, dbase := range upgrader.databases {
@ -168,7 +184,7 @@ func (upgrader Upgrader) Run() (err error) {// {{{
} }
upgrader.logCallback("exec", fmt.Sprintf("%s.%s: %d", dbName, upgrader.schema, version)) upgrader.logCallback("exec", fmt.Sprintf("%s.%s: %d", dbName, upgrader.schema, version))
if _, err = dbase.db.Exec(string(sql)); err != nil { if _, err = dbase.db.Exec(context.Background(), string(sql)); err != nil {
return return
} }
if err = dbase.appendSchemaVersion(version); err != nil { if err = dbase.appendSchemaVersion(version); err != nil {
@ -177,6 +193,6 @@ func (upgrader Upgrader) Run() (err error) {// {{{
} }
} }
return return
}// }}} } // }}}
// vim: foldmethod=marker // vim: foldmethod=marker