dbschema/upgrader.go
2023-08-01 17:10:35 +02:00

155 lines
3.5 KiB
Go

package dbschema
import (
// Standard
"database/sql"
"fmt"
)
func defaultCallback(topic, msg string) {// {{{
fmt.Printf("[%s] %s\n", topic, msg)
}// }}}
func NewUpgrader(host string, port int, dbName, user, pass string) (upgrader Upgrader, err error) {// {{{
upgrader.logCallback = defaultCallback
upgrader.databases = map[string]Database{}
upgrader.schemaDb, err = newDatabase(
host,
port,
dbName,
user,
pass,
)
err = upgrader.verifySchemaTable()
return
}// }}}
func (upgrader *Upgrader) SetLogCallback(callback func(string, string)) {// {{{
upgrader.logCallback = callback
}// }}}
func (upgrader *Upgrader) SetSqlCallback(callback func(string, int) ([]byte, bool)) {// {{{
upgrader.sqlCallback = callback
}// }}}
func (upgrader Upgrader) verifySchemaTable() (err error) {// {{{
var rows *sql.Rows
if rows, err = upgrader.schemaDb.db.Query(
`SELECT EXISTS (
SELECT FROM pg_catalog.pg_class c
JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
WHERE n.nspname = '_db'
AND c.relname = 'schema'
)`,
); err != nil {
return
}
defer rows.Close()
var exists bool
rows.Next()
if err = rows.Scan(&exists); err != nil {
return
}
if !exists {
upgrader.logCallback("create", "_db.schema")
upgrader.schemaDb.db.Exec(`CREATE SCHEMA "_db"`)
if _, err = upgrader.schemaDb.db.Exec(`
CREATE TABLE "_db"."schema" (
database varchar NOT NULL,
version int4 NOT NULL,
updated timestamp NOT NULL DEFAULT NOW(),
CONSTRAINT schema_pk PRIMARY KEY (database)
);
`,
); err != nil {
return
}
}
return
}// }}}
func (upgrader Upgrader) verifySchemaEntry(dbase Database) (err error) {// {{{
var rows *sql.Rows
rows, err = upgrader.schemaDb.db.Query(`SELECT version FROM _db.schema WHERE database=$1`, dbase.DbName)
if err != nil {
return
}
defer rows.Close()
if !rows.Next() {
upgrader.logCallback("initiate version", dbase.DbName)
_, err = upgrader.schemaDb.db.Exec(`INSERT INTO _db.schema(database, version) VALUES($1, 0)`, dbase.DbName)
if err != nil {
return
}
}
return
}// }}}
func (upgrader Upgrader) version(dbName string) (version int, err error) {// {{{
var rows *sql.Rows
rows, err = upgrader.schemaDb.db.Query(
`SELECT version FROM _db.schema WHERE database=$1`,
dbName,
)
if err != nil {
return
}
defer rows.Close()
if rows.Next() {
err = rows.Scan(&version)
} else {
err = fmt.Errorf(`Database "%s" is missing an entry in _db.schema`, dbName)
}
return
}// }}}
func (upgrader Upgrader) AddDatabase(host string, port int, dbName, user, pass string) (err error) {// {{{
var db Database
if db, err = newDatabase(host, port, dbName, user, pass); err != nil {
return
}
upgrader.databases[dbName] = db
err = upgrader.verifySchemaEntry(db)
return
}// }}}
func (upgrader Upgrader) Run() (err error) {// {{{
var version int
for dbName, db := range upgrader.databases {
version, err = upgrader.version(dbName)
if err != nil {
return
}
upgrader.logCallback("version", fmt.Sprintf("%s: %d", dbName, version))
for {
version++
sql, found := upgrader.sqlCallback(dbName, version)
if !found {
break
}
upgrader.logCallback("exec", fmt.Sprintf("%s: %d", dbName, version))
if _, err = db.db.Exec(string(sql)); err != nil {
return
}
_, err = upgrader.schemaDb.db.Exec(`
UPDATE _db.schema
SET
version=$1,
updated=NOW()
WHERE database=$2
`, version, dbName)
if err != nil {
return
}
}
}
return
}// }}}
// vim: foldmethod=marker