Added configurable schema

This commit is contained in:
Magnus Åhall 2023-12-30 17:40:34 +01:00
parent 2cb694f534
commit 95335125d3
2 changed files with 19 additions and 11 deletions

View File

@ -29,6 +29,7 @@ import (
// An upgrader verifies the schema for one or more databases and upgrades them if possible. // An upgrader verifies the schema for one or more databases and upgrades them if possible.
type Upgrader struct { type Upgrader struct {
schema string
databases map[string]Database databases map[string]Database
logCallback func(string, string) logCallback func(string, string)
sqlCallback func(string, int) ([]byte, bool) sqlCallback func(string, int) ([]byte, bool)

View File

@ -14,7 +14,14 @@ func defaultCallback(topic, msg string) {// {{{
}// }}} }// }}}
// NewUpgrader creates an upgrader with an empty list of databases. // NewUpgrader creates an upgrader with an empty list of databases.
func NewUpgrader() (upgrader Upgrader) {// {{{ func NewUpgrader(schema ...string) (upgrader Upgrader) {// {{{
// Using a variadic function for backward compatibility.
if len(schema) > 0 {
upgrader.schema = schema[0]
} else {
upgrader.schema = "_db"
}
upgrader.logCallback = defaultCallback upgrader.logCallback = defaultCallback
upgrader.databases = map[string]Database{} upgrader.databases = map[string]Database{}
return return
@ -41,8 +48,8 @@ func (upgrader *Upgrader) Version(dbName string) (version int, err error) {// {{
}// }}} }// }}}
func (dbase Database) createSchemaTable() (err error) {// {{{ func (dbase Database) createSchemaTable() (err error) {// {{{
dbase.upgrader.logCallback("create", fmt.Sprintf("%s, _db.schema", dbase.DbName)) dbase.upgrader.logCallback("create", fmt.Sprintf("%s, %s.schema", dbase.DbName, dbase.upgrader.schema))
_, err = dbase.db.Exec(`CREATE SCHEMA "_db"`) _, err = dbase.db.Exec(`CREATE SCHEMA "`+dbase.upgrader.schema+`"`)
// Error code 42P06 "duplicate_schema" is an OK error, // Error code 42P06 "duplicate_schema" is an OK error,
// table can still be missing and created. // table can still be missing and created.
@ -52,7 +59,7 @@ func (dbase Database) createSchemaTable() (err error) {// {{{
} }
_, err = dbase.db.Exec(` _, err = dbase.db.Exec(`
CREATE TABLE "_db"."schema" ( CREATE TABLE "`+dbase.upgrader.schema+`"."schema" (
version int4 NOT NULL, version int4 NOT NULL,
updated timestamp NOT NULL DEFAULT NOW(), updated timestamp NOT NULL DEFAULT NOW(),
@ -62,7 +69,7 @@ func (dbase Database) createSchemaTable() (err error) {// {{{
return return
}// }}} }// }}}
func (dbase Database) appendSchemaVersion(version int) (err error) {// {{{ func (dbase Database) appendSchemaVersion(version int) (err error) {// {{{
_, err = dbase.db.Exec(`INSERT INTO _db.schema(version) VALUES($1)`, version) _, err = dbase.db.Exec(`INSERT INTO `+dbase.upgrader.schema+`.schema(version) VALUES($1)`, version)
return return
}// }}} }// }}}
@ -72,7 +79,7 @@ func (dbase Database) verifySchemaTable() (err error) {// {{{
`SELECT EXISTS ( `SELECT EXISTS (
SELECT FROM pg_catalog.pg_class c SELECT FROM pg_catalog.pg_class c
JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
WHERE n.nspname = '_db' WHERE n.nspname = '`+dbase.upgrader.schema+`'
AND c.relname = 'schema' AND c.relname = 'schema'
)`, )`,
); err != nil { ); err != nil {
@ -94,7 +101,7 @@ func (dbase Database) verifySchemaTable() (err error) {// {{{
func (dbase Database) verifySchemaEntry() (err error) {// {{{ func (dbase Database) verifySchemaEntry() (err error) {// {{{
var version int var version int
var row *sql.Row var row *sql.Row
row = dbase.db.QueryRow(`SELECT version FROM _db.schema LIMIT 1`) row = dbase.db.QueryRow(`SELECT version FROM `+dbase.upgrader.schema+`.schema LIMIT 1`)
err = row.Scan(&version) err = row.Scan(&version)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
@ -107,7 +114,7 @@ func (dbase Database) verifySchemaEntry() (err error) {// {{{
func (dbase Database) version() (version int, err error) {// {{{ func (dbase Database) version() (version int, err error) {// {{{
var rows *sql.Rows var rows *sql.Rows
rows, err = dbase.db.Query( rows, err = dbase.db.Query(
`SELECT version FROM _db.schema ORDER BY version DESC LIMIT 1`, `SELECT version FROM `+dbase.upgrader.schema+`.schema ORDER BY version DESC LIMIT 1`,
) )
if err != nil { if err != nil {
return return
@ -117,7 +124,7 @@ func (dbase Database) version() (version int, err error) {// {{{
if rows.Next() { if rows.Next() {
err = rows.Scan(&version) err = rows.Scan(&version)
} else { } else {
err = fmt.Errorf(`Database "%s" is missing an entry in _db.schema`, dbase.DbName) err = fmt.Errorf(`Database "%s" is missing an entry in `+dbase.upgrader.schema+`.schema`, dbase.DbName)
} }
return return
}// }}} }// }}}