dbschema/upgrader.go

173 lines
4.2 KiB
Go
Raw Permalink Normal View History

2023-08-01 17:10:35 +02:00
package dbschema
import (
// External
"github.com/lib/pq"
2023-08-01 17:10:35 +02:00
// Standard
"database/sql"
"fmt"
)
func defaultCallback(topic, msg string) {// {{{
fmt.Printf("[%s] %s\n", topic, msg)
}// }}}
// NewUpgrader creates an upgrader with an empty list of databases.
func NewUpgrader() (upgrader Upgrader) {// {{{
2023-08-01 17:10:35 +02:00
upgrader.logCallback = defaultCallback
upgrader.databases = map[string]Database{}
return
}// }}}
// SetLogCallback allows to set a callback for custom logging.
2023-08-01 17:10:35 +02:00
func (upgrader *Upgrader) SetLogCallback(callback func(string, string)) {// {{{
upgrader.logCallback = callback
}// }}}
// SetSqlCallback is required for providing the SQL schema updates.
2023-08-01 17:10:35 +02:00
func (upgrader *Upgrader) SetSqlCallback(callback func(string, int) ([]byte, bool)) {// {{{
upgrader.sqlCallback = callback
}// }}}
// Version returns the current dbschema version for the given database name.
func (upgrader *Upgrader) Version(dbName string) (version int, err error) {// {{{
dbase, found := upgrader.databases[dbName]
if !found {
err = fmt.Errorf("Database %s not previously added to the upgrader", dbName)
return
}
version, err = dbase.version()
return
}// }}}
func (dbase Database) createSchemaTable() (err error) {// {{{
dbase.upgrader.logCallback("create", fmt.Sprintf("%s, _db.schema", dbase.DbName))
_, err = dbase.db.Exec(`CREATE SCHEMA "_db"`)
// Error code 42P06 "duplicate_schema" is an OK error,
// table can still be missing and created.
pqErr, _ := err.(*pq.Error)
if pqErr != nil && pqErr.Code != "42P06" {
return
}
_, err = dbase.db.Exec(`
CREATE TABLE "_db"."schema" (
version int4 NOT NULL,
updated timestamp NOT NULL DEFAULT NOW(),
CONSTRAINT schema_pk PRIMARY KEY (version)
)`,
)
return
}// }}}
func (dbase Database) appendSchemaVersion(version int) (err error) {// {{{
_, err = dbase.db.Exec(`INSERT INTO _db.schema(version) VALUES($1)`, version)
return
}// }}}
func (dbase Database) verifySchemaTable() (err error) {// {{{
2023-08-01 17:10:35 +02:00
var rows *sql.Rows
if rows, err = dbase.db.Query(
2023-08-01 17:10:35 +02:00
`SELECT EXISTS (
SELECT FROM pg_catalog.pg_class c
JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
WHERE n.nspname = '_db'
AND c.relname = 'schema'
)`,
); err != nil {
return
}
defer rows.Close()
var exists bool
rows.Next()
if err = rows.Scan(&exists); err != nil {
return
}
if exists {
return
2023-08-01 17:10:35 +02:00
}
err = dbase.createSchemaTable()
2023-08-01 17:10:35 +02:00
return
}// }}}
func (dbase Database) verifySchemaEntry() (err error) {// {{{
var version int
var row *sql.Row
row = dbase.db.QueryRow(`SELECT version FROM _db.schema LIMIT 1`)
2023-08-01 17:10:35 +02:00
err = row.Scan(&version)
if err == sql.ErrNoRows {
dbase.upgrader.logCallback("initiate version", dbase.DbName)
err = dbase.appendSchemaVersion(0)
2023-08-01 17:10:35 +02:00
}
return
}// }}}
func (dbase Database) version() (version int, err error) {// {{{
2023-08-01 17:10:35 +02:00
var rows *sql.Rows
rows, err = dbase.db.Query(
`SELECT version FROM _db.schema ORDER BY version DESC LIMIT 1`,
2023-08-01 17:10:35 +02:00
)
if err != nil {
return
}
defer rows.Close()
if rows.Next() {
err = rows.Scan(&version)
} else {
err = fmt.Errorf(`Database "%s" is missing an entry in _db.schema`, dbase.DbName)
2023-08-01 17:10:35 +02:00
}
return
}// }}}
// AddDatabase sets a database up for the Run() function with verifying/creating the _db.schema table.
2023-08-01 17:10:35 +02:00
func (upgrader Upgrader) AddDatabase(host string, port int, dbName, user, pass string) (err error) {// {{{
var db Database
if db, err = newDatabase(host, port, dbName, user, pass); err != nil {
return
}
db.upgrader = &upgrader
2023-08-01 17:10:35 +02:00
upgrader.databases[dbName] = db
if err = db.verifySchemaTable(); err != nil {
return
}
err = db.verifySchemaEntry()
2023-08-01 17:10:35 +02:00
return
}// }}}
// Run executes the actual schema updates until there are no more available.
2023-08-01 17:10:35 +02:00
func (upgrader Upgrader) Run() (err error) {// {{{
var version int
for dbName, dbase := range upgrader.databases {
version, err = dbase.version()
2023-08-01 17:10:35 +02:00
if err != nil {
return
}
upgrader.logCallback("version", fmt.Sprintf("%s: %d", dbName, version))
for {
version++
sql, found := upgrader.sqlCallback(dbName, version)
if !found {
break
}
upgrader.logCallback("exec", fmt.Sprintf("%s: %d", dbName, version))
if _, err = dbase.db.Exec(string(sql)); err != nil {
2023-08-01 17:10:35 +02:00
return
}
if err = dbase.appendSchemaVersion(version); err != nil {
2023-08-01 17:10:35 +02:00
return
}
}
}
return
}// }}}
// vim: foldmethod=marker