From ad601219ec1c3700531daf906a32fa098253ab58 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Magnus=20=C3=85hall?= Date: Tue, 23 Sep 2025 21:47:39 +0200 Subject: [PATCH] Upgraded to pgx --- database.go | 9 +++-- schema.go | 7 ++-- upgrader.go | 99 +++++++++++++++++++++++++++++++---------------------- 3 files changed, 66 insertions(+), 49 deletions(-) diff --git a/database.go b/database.go index 5f309b8..b756e9d 100644 --- a/database.go +++ b/database.go @@ -1,8 +1,11 @@ package dbschema import ( + // External + "github.com/jackc/pgx/v5/pgxpool" + // Standard - "database/sql" + "context" "fmt" ) @@ -13,10 +16,10 @@ func newDatabase(host string, port int, dbName, user, pass string) (dbase Databa dbase.Username = user dbase.Password = pass - dbase.db, err = sql.Open("postgres", dbase.sqlConnString()) + dbase.db, err = pgxpool.New(context.Background(), dbase.sqlConnString()) return }// }}} -func databaseFromInstance(db *sql.DB) (dbase Database, err error) { +func databaseFromInstance(db *pgxpool.Pool) (dbase Database, err error) { dbase.db = db return } diff --git a/schema.go b/schema.go index 576929c..8ee3646 100644 --- a/schema.go +++ b/schema.go @@ -21,10 +21,7 @@ package dbschema import ( // External - _ "github.com/lib/pq" - - // Standard - "database/sql" + "github.com/jackc/pgx/v5/pgxpool" ) // An upgrader verifies the schema for one or more databases and upgrades them if possible. @@ -42,7 +39,7 @@ type Database struct { Username string Password string - db *sql.DB + db *pgxpool.Pool upgrader *Upgrader } diff --git a/upgrader.go b/upgrader.go index 6b662c3..3c0e0b0 100644 --- a/upgrader.go +++ b/upgrader.go @@ -2,19 +2,22 @@ package dbschema import ( // External - "github.com/lib/pq" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgxpool" // Standard + "context" "database/sql" "fmt" ) -func defaultCallback(topic, msg string) {// {{{ +func defaultCallback(topic, msg string) { // {{{ fmt.Printf("[%s] %s\n", topic, msg) -}// }}} +} // }}} // NewUpgrader creates an upgrader with an empty list of databases. -func NewUpgrader(schema ...string) (upgrader Upgrader) {// {{{ +func NewUpgrader(schema ...string) (upgrader Upgrader) { // {{{ // Using a variadic function for backward compatibility. if len(schema) > 0 { upgrader.schema = schema[0] @@ -25,18 +28,18 @@ func NewUpgrader(schema ...string) (upgrader Upgrader) {// {{{ upgrader.logCallback = defaultCallback upgrader.databases = map[string]Database{} return -}// }}} +} // }}} // SetLogCallback allows to set a callback for custom logging. -func (upgrader *Upgrader) SetLogCallback(callback func(string, string)) {// {{{ +func (upgrader *Upgrader) SetLogCallback(callback func(string, string)) { // {{{ upgrader.logCallback = callback -}// }}} +} // }}} // SetSqlCallback is required for providing the SQL schema updates. -func (upgrader *Upgrader) SetSqlCallback(callback func(string, int) ([]byte, bool)) {// {{{ +func (upgrader *Upgrader) SetSqlCallback(callback func(string, int) ([]byte, bool)) { // {{{ upgrader.sqlCallback = callback -}// }}} +} // }}} // Version returns the current dbschema version for the given database name. -func (upgrader *Upgrader) Version(dbName string) (version int, err error) {// {{{ +func (upgrader *Upgrader) Version(dbName string) (version int, err error) { // {{{ dbase, found := upgrader.databases[dbName] if !found { err = fmt.Errorf("Database %s not previously added to the upgrader", dbName) @@ -45,21 +48,22 @@ func (upgrader *Upgrader) Version(dbName string) (version int, err error) {// {{ version, err = dbase.Version() return -}// }}} +} // }}} -func (dbase Database) createSchemaTable() (err error) {// {{{ +func (dbase Database) createSchemaTable() (err error) { // {{{ dbase.upgrader.logCallback("create", fmt.Sprintf("%s, %s.schema", dbase.DbName, dbase.upgrader.schema)) - _, err = dbase.db.Exec(`CREATE SCHEMA "`+dbase.upgrader.schema+`"`) + _, err = dbase.db.Exec(context.Background(), `CREATE SCHEMA "` + dbase.upgrader.schema + `"`) // Error code 42P06 "duplicate_schema" is an OK error, // table can still be missing and created. - pqErr, _ := err.(*pq.Error) + pqErr, _ := err.(*pgconn.PgError) if pqErr != nil && pqErr.Code != "42P06" { return } - _, err = dbase.db.Exec(` - CREATE TABLE "`+dbase.upgrader.schema+`"."schema" ( + _, err = dbase.db.Exec( + context.Background(), + `CREATE TABLE "` + dbase.upgrader.schema + `"."schema" ( version int4 NOT NULL, updated timestamp NOT NULL DEFAULT NOW(), @@ -67,19 +71,20 @@ func (dbase Database) createSchemaTable() (err error) {// {{{ )`, ) return -}// }}} -func (dbase Database) appendSchemaVersion(version int) (err error) {// {{{ - _, err = dbase.db.Exec(`INSERT INTO `+dbase.upgrader.schema+`.schema(version) VALUES($1)`, version) +} // }}} +func (dbase Database) appendSchemaVersion(version int) (err error) { // {{{ + _, err = dbase.db.Exec(context.Background(), `INSERT INTO `+dbase.upgrader.schema+`.schema(version) VALUES($1)`, version) return -}// }}} +} // }}} -func (dbase Database) verifySchemaTable() (err error) {// {{{ - var rows *sql.Rows +func (dbase Database) verifySchemaTable() (err error) { // {{{ + var rows pgx.Rows if rows, err = dbase.db.Query( + context.Background(), `SELECT EXISTS ( SELECT FROM pg_catalog.pg_class c JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace - WHERE n.nspname = '`+dbase.upgrader.schema+`' + WHERE n.nspname = '` + dbase.upgrader.schema + `' AND c.relname = 'schema' )`, ); err != nil { @@ -97,11 +102,11 @@ func (dbase Database) verifySchemaTable() (err error) {// {{{ } err = dbase.createSchemaTable() return -}// }}} -func (dbase Database) verifySchemaEntry() (err error) {// {{{ +} // }}} +func (dbase Database) verifySchemaEntry() (err error) { // {{{ var version int - var row *sql.Row - row = dbase.db.QueryRow(`SELECT version FROM `+dbase.upgrader.schema+`.schema LIMIT 1`) + var row pgx.Row + row = dbase.db.QueryRow(context.Background(), `SELECT version FROM `+dbase.upgrader.schema+`.schema LIMIT 1`) err = row.Scan(&version) if err == sql.ErrNoRows { @@ -110,11 +115,12 @@ func (dbase Database) verifySchemaEntry() (err error) {// {{{ } return -}// }}} -func (dbase Database) Version() (version int, err error) {// {{{ - var rows *sql.Rows +} // }}} +func (dbase Database) Version() (version int, err error) { // {{{ + var rows pgx.Rows rows, err = dbase.db.Query( - `SELECT version FROM `+dbase.upgrader.schema+`.schema ORDER BY version DESC LIMIT 1`, + context.Background(), + `SELECT version FROM ` + dbase.upgrader.schema + `.schema ORDER BY version DESC LIMIT 1`, ) if err != nil { return @@ -127,15 +133,15 @@ func (dbase Database) Version() (version int, err error) {// {{{ err = fmt.Errorf(`Database "%s" is missing an entry in `+dbase.upgrader.schema+`.schema`, dbase.DbName) } return -}// }}} +} // }}} // AddDatabase sets a database up for the Run() function with verifying/creating the _db.schema table. -func (upgrader Upgrader) AddDatabase(host string, port int, dbName, user, pass string) (db Database, err error) {// {{{ +func (upgrader Upgrader) AddDatabase(host string, port int, dbName, user, pass string) (db Database, err error) { // {{{ if db, err = newDatabase(host, port, dbName, user, pass); err != nil { return } db.upgrader = &upgrader - + upgrader.databases[dbName] = db if err = db.verifySchemaTable(); err != nil { @@ -144,13 +150,24 @@ func (upgrader Upgrader) AddDatabase(host string, port int, dbName, user, pass s err = db.verifySchemaEntry() return -}// }}} -func (upgrader Upgrader) AddDatabaseInstance(sqlDB *sql.DB) (db Database, err error) {// {{{ - return databaseFromInstance(sqlDB) -}// }}} +} // }}} +func (upgrader Upgrader) AddDatabaseInstance(sqlDB *pgxpool.Pool, dbName string) (db Database, err error) { // {{{ + db, err = databaseFromInstance(sqlDB) + + db.upgrader = &upgrader + + upgrader.databases[dbName] = db + + if err = db.verifySchemaTable(); err != nil { + return + } + + err = db.verifySchemaEntry() + return +} // }}} // Run executes the actual schema updates until there are no more available. -func (upgrader Upgrader) Run() (err error) {// {{{ +func (upgrader Upgrader) Run() (err error) { // {{{ var version int for dbName, dbase := range upgrader.databases { @@ -168,7 +185,7 @@ func (upgrader Upgrader) Run() (err error) {// {{{ } upgrader.logCallback("exec", fmt.Sprintf("%s.%s: %d", dbName, upgrader.schema, version)) - if _, err = dbase.db.Exec(string(sql)); err != nil { + if _, err = dbase.db.Exec(context.Background(), string(sql)); err != nil { return } if err = dbase.appendSchemaVersion(version); err != nil { @@ -177,6 +194,6 @@ func (upgrader Upgrader) Run() (err error) {// {{{ } } return -}// }}} +} // }}} // vim: foldmethod=marker