package dbschema import ( // External "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgxpool" // Standard "context" "fmt" ) func defaultCallback(topic, msg string) { // {{{ fmt.Printf("[%s] %s\n", topic, msg) } // }}} // NewUpgrader creates an upgrader with an empty list of databases. func NewUpgrader(schema ...string) (upgrader Upgrader) { // {{{ // Using a variadic function for backward compatibility. if len(schema) > 0 { upgrader.schema = schema[0] } else { upgrader.schema = "_db" } upgrader.logCallback = defaultCallback upgrader.databases = map[string]Database{} return } // }}} // SetLogCallback allows to set a callback for custom logging. func (upgrader *Upgrader) SetLogCallback(callback func(string, string)) { // {{{ upgrader.logCallback = callback } // }}} // SetSqlCallback is required for providing the SQL schema updates. func (upgrader *Upgrader) SetSqlCallback(callback func(string, int) ([]byte, bool)) { // {{{ upgrader.sqlCallback = callback } // }}} // Version returns the current dbschema version for the given database name. func (upgrader *Upgrader) Version(dbName string) (version int, err error) { // {{{ dbase, found := upgrader.databases[dbName] if !found { err = fmt.Errorf("Database %s not previously added to the upgrader", dbName) return } version, err = dbase.Version() return } // }}} func (dbase Database) createSchemaTable() (err error) { // {{{ dbase.upgrader.logCallback("create", fmt.Sprintf("%s, %s.schema", dbase.DbName, dbase.upgrader.schema)) _, err = dbase.db.Exec(context.Background(), `CREATE SCHEMA "` + dbase.upgrader.schema + `"`) // Error code 42P06 "duplicate_schema" is an OK error, // table can still be missing and created. pqErr, _ := err.(*pgconn.PgError) if pqErr != nil && pqErr.Code != "42P06" { return } _, err = dbase.db.Exec( context.Background(), `CREATE TABLE "` + dbase.upgrader.schema + `"."schema" ( version int4 NOT NULL, updated timestamp NOT NULL DEFAULT NOW(), CONSTRAINT schema_pk PRIMARY KEY (version) )`, ) return } // }}} func (dbase Database) appendSchemaVersion(version int) (err error) { // {{{ _, err = dbase.db.Exec(context.Background(), `INSERT INTO `+dbase.upgrader.schema+`.schema(version) VALUES($1)`, version) return } // }}} func (dbase Database) verifySchemaTable() (err error) { // {{{ var rows pgx.Rows if rows, err = dbase.db.Query( context.Background(), `SELECT EXISTS ( SELECT FROM pg_catalog.pg_class c JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace WHERE n.nspname = '` + dbase.upgrader.schema + `' AND c.relname = 'schema' )`, ); err != nil { return } defer rows.Close() var exists bool rows.Next() if err = rows.Scan(&exists); err != nil { return } if exists { return } err = dbase.createSchemaTable() return } // }}} func (dbase Database) verifySchemaEntry() (err error) { // {{{ var version int var row pgx.Row row = dbase.db.QueryRow(context.Background(), `SELECT version FROM `+dbase.upgrader.schema+`.schema LIMIT 1`) err = row.Scan(&version) if err == pgx.ErrNoRows { dbase.upgrader.logCallback("initiate version", dbase.DbName) err = dbase.appendSchemaVersion(0) } return } // }}} func (dbase Database) Version() (version int, err error) { // {{{ var rows pgx.Rows rows, err = dbase.db.Query( context.Background(), `SELECT version FROM ` + dbase.upgrader.schema + `.schema ORDER BY version DESC LIMIT 1`, ) if err != nil { return } defer rows.Close() if rows.Next() { err = rows.Scan(&version) } else { err = fmt.Errorf(`Database "%s" is missing an entry in `+dbase.upgrader.schema+`.schema`, dbase.DbName) } return } // }}} // AddDatabase sets a database up for the Run() function with verifying/creating the _db.schema table. func (upgrader Upgrader) AddDatabase(host string, port int, dbName, user, pass string) (db Database, err error) { // {{{ if db, err = newDatabase(host, port, dbName, user, pass); err != nil { return } db.upgrader = &upgrader upgrader.databases[dbName] = db if err = db.verifySchemaTable(); err != nil { return } err = db.verifySchemaEntry() return } // }}} func (upgrader Upgrader) AddDatabaseInstance(sqlDB *pgxpool.Pool, dbName string) (db Database, err error) { // {{{ db, err = databaseFromInstance(sqlDB) db.upgrader = &upgrader upgrader.databases[dbName] = db if err = db.verifySchemaTable(); err != nil { return } err = db.verifySchemaEntry() return } // }}} // Run executes the actual schema updates until there are no more available. func (upgrader Upgrader) Run() (err error) { // {{{ var version int for dbName, dbase := range upgrader.databases { version, err = dbase.Version() if err != nil { return } upgrader.logCallback("version", fmt.Sprintf("%s.%s: %d", dbName, upgrader.schema, version)) for { version++ sql, found := upgrader.sqlCallback(dbName, version) if !found { break } upgrader.logCallback("exec", fmt.Sprintf("%s.%s: %d", dbName, upgrader.schema, version)) if _, err = dbase.db.Exec(context.Background(), string(sql)); err != nil { return } if err = dbase.appendSchemaVersion(version); err != nil { return } } } return } // }}} // vim: foldmethod=marker