diff --git a/schema.go b/schema.go index c7b5894..576929c 100644 --- a/schema.go +++ b/schema.go @@ -29,7 +29,8 @@ import ( // An upgrader verifies the schema for one or more databases and upgrades them if possible. type Upgrader struct { - databases map[string]Database + schema string + databases map[string]Database logCallback func(string, string) sqlCallback func(string, int) ([]byte, bool) } @@ -41,7 +42,7 @@ type Database struct { Username string Password string - db *sql.DB + db *sql.DB upgrader *Upgrader } diff --git a/upgrader.go b/upgrader.go index 7cb65f2..f06205d 100644 --- a/upgrader.go +++ b/upgrader.go @@ -14,7 +14,14 @@ func defaultCallback(topic, msg string) {// {{{ }// }}} // NewUpgrader creates an upgrader with an empty list of databases. -func NewUpgrader() (upgrader Upgrader) {// {{{ +func NewUpgrader(schema ...string) (upgrader Upgrader) {// {{{ + // Using a variadic function for backward compatibility. + if len(schema) > 0 { + upgrader.schema = schema[0] + } else { + upgrader.schema = "_db" + } + upgrader.logCallback = defaultCallback upgrader.databases = map[string]Database{} return @@ -41,8 +48,8 @@ func (upgrader *Upgrader) Version(dbName string) (version int, err error) {// {{ }// }}} func (dbase Database) createSchemaTable() (err error) {// {{{ - dbase.upgrader.logCallback("create", fmt.Sprintf("%s, _db.schema", dbase.DbName)) - _, err = dbase.db.Exec(`CREATE SCHEMA "_db"`) + dbase.upgrader.logCallback("create", fmt.Sprintf("%s, %s.schema", dbase.DbName, dbase.upgrader.schema)) + _, err = dbase.db.Exec(`CREATE SCHEMA "`+dbase.upgrader.schema+`"`) // Error code 42P06 "duplicate_schema" is an OK error, // table can still be missing and created. @@ -52,7 +59,7 @@ func (dbase Database) createSchemaTable() (err error) {// {{{ } _, err = dbase.db.Exec(` - CREATE TABLE "_db"."schema" ( + CREATE TABLE "`+dbase.upgrader.schema+`"."schema" ( version int4 NOT NULL, updated timestamp NOT NULL DEFAULT NOW(), @@ -62,7 +69,7 @@ func (dbase Database) createSchemaTable() (err error) {// {{{ return }// }}} func (dbase Database) appendSchemaVersion(version int) (err error) {// {{{ - _, err = dbase.db.Exec(`INSERT INTO _db.schema(version) VALUES($1)`, version) + _, err = dbase.db.Exec(`INSERT INTO `+dbase.upgrader.schema+`.schema(version) VALUES($1)`, version) return }// }}} @@ -72,7 +79,7 @@ func (dbase Database) verifySchemaTable() (err error) {// {{{ `SELECT EXISTS ( SELECT FROM pg_catalog.pg_class c JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace - WHERE n.nspname = '_db' + WHERE n.nspname = '`+dbase.upgrader.schema+`' AND c.relname = 'schema' )`, ); err != nil { @@ -94,7 +101,7 @@ func (dbase Database) verifySchemaTable() (err error) {// {{{ func (dbase Database) verifySchemaEntry() (err error) {// {{{ var version int var row *sql.Row - row = dbase.db.QueryRow(`SELECT version FROM _db.schema LIMIT 1`) + row = dbase.db.QueryRow(`SELECT version FROM `+dbase.upgrader.schema+`.schema LIMIT 1`) err = row.Scan(&version) if err == sql.ErrNoRows { @@ -107,7 +114,7 @@ func (dbase Database) verifySchemaEntry() (err error) {// {{{ func (dbase Database) version() (version int, err error) {// {{{ var rows *sql.Rows rows, err = dbase.db.Query( - `SELECT version FROM _db.schema ORDER BY version DESC LIMIT 1`, + `SELECT version FROM `+dbase.upgrader.schema+`.schema ORDER BY version DESC LIMIT 1`, ) if err != nil { return @@ -117,7 +124,7 @@ func (dbase Database) version() (version int, err error) {// {{{ if rows.Next() { err = rows.Scan(&version) } else { - err = fmt.Errorf(`Database "%s" is missing an entry in _db.schema`, dbase.DbName) + err = fmt.Errorf(`Database "%s" is missing an entry in `+dbase.upgrader.schema+`.schema`, dbase.DbName) } return }// }}}