package dbschema import ( // External "github.com/lib/pq" // Standard "database/sql" "fmt" ) func defaultCallback(topic, msg string) {// {{{ fmt.Printf("[%s] %s\n", topic, msg) }// }}} // NewUpgrader creates an upgrader with an empty list of databases. func NewUpgrader(schema ...string) (upgrader Upgrader) {// {{{ // Using a variadic function for backward compatibility. if len(schema) > 0 { upgrader.schema = schema[0] } else { upgrader.schema = "_db" } upgrader.logCallback = defaultCallback upgrader.databases = map[string]Database{} return }// }}} // SetLogCallback allows to set a callback for custom logging. func (upgrader *Upgrader) SetLogCallback(callback func(string, string)) {// {{{ upgrader.logCallback = callback }// }}} // SetSqlCallback is required for providing the SQL schema updates. func (upgrader *Upgrader) SetSqlCallback(callback func(string, int) ([]byte, bool)) {// {{{ upgrader.sqlCallback = callback }// }}} // Version returns the current dbschema version for the given database name. func (upgrader *Upgrader) Version(dbName string) (version int, err error) {// {{{ dbase, found := upgrader.databases[dbName] if !found { err = fmt.Errorf("Database %s not previously added to the upgrader", dbName) return } version, err = dbase.version() return }// }}} func (dbase Database) createSchemaTable() (err error) {// {{{ dbase.upgrader.logCallback("create", fmt.Sprintf("%s, %s.schema", dbase.DbName, dbase.upgrader.schema)) _, err = dbase.db.Exec(`CREATE SCHEMA "`+dbase.upgrader.schema+`"`) // Error code 42P06 "duplicate_schema" is an OK error, // table can still be missing and created. pqErr, _ := err.(*pq.Error) if pqErr != nil && pqErr.Code != "42P06" { return } _, err = dbase.db.Exec(` CREATE TABLE "`+dbase.upgrader.schema+`"."schema" ( version int4 NOT NULL, updated timestamp NOT NULL DEFAULT NOW(), CONSTRAINT schema_pk PRIMARY KEY (version) )`, ) return }// }}} func (dbase Database) appendSchemaVersion(version int) (err error) {// {{{ _, err = dbase.db.Exec(`INSERT INTO `+dbase.upgrader.schema+`.schema(version) VALUES($1)`, version) return }// }}} func (dbase Database) verifySchemaTable() (err error) {// {{{ var rows *sql.Rows if rows, err = dbase.db.Query( `SELECT EXISTS ( SELECT FROM pg_catalog.pg_class c JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace WHERE n.nspname = '`+dbase.upgrader.schema+`' AND c.relname = 'schema' )`, ); err != nil { return } defer rows.Close() var exists bool rows.Next() if err = rows.Scan(&exists); err != nil { return } if exists { return } err = dbase.createSchemaTable() return }// }}} func (dbase Database) verifySchemaEntry() (err error) {// {{{ var version int var row *sql.Row row = dbase.db.QueryRow(`SELECT version FROM `+dbase.upgrader.schema+`.schema LIMIT 1`) err = row.Scan(&version) if err == sql.ErrNoRows { dbase.upgrader.logCallback("initiate version", dbase.DbName) err = dbase.appendSchemaVersion(0) } return }// }}} func (dbase Database) version() (version int, err error) {// {{{ var rows *sql.Rows rows, err = dbase.db.Query( `SELECT version FROM `+dbase.upgrader.schema+`.schema ORDER BY version DESC LIMIT 1`, ) if err != nil { return } defer rows.Close() if rows.Next() { err = rows.Scan(&version) } else { err = fmt.Errorf(`Database "%s" is missing an entry in `+dbase.upgrader.schema+`.schema`, dbase.DbName) } return }// }}} // AddDatabase sets a database up for the Run() function with verifying/creating the _db.schema table. func (upgrader Upgrader) AddDatabase(host string, port int, dbName, user, pass string) (err error) {// {{{ var db Database if db, err = newDatabase(host, port, dbName, user, pass); err != nil { return } db.upgrader = &upgrader upgrader.databases[dbName] = db if err = db.verifySchemaTable(); err != nil { return } err = db.verifySchemaEntry() return }// }}} // Run executes the actual schema updates until there are no more available. func (upgrader Upgrader) Run() (err error) {// {{{ var version int for dbName, dbase := range upgrader.databases { version, err = dbase.version() if err != nil { return } upgrader.logCallback("version", fmt.Sprintf("%s: %d", dbName, version)) for { version++ sql, found := upgrader.sqlCallback(dbName, version) if !found { break } upgrader.logCallback("exec", fmt.Sprintf("%s: %d", dbName, version)) if _, err = dbase.db.Exec(string(sql)); err != nil { return } if err = dbase.appendSchemaVersion(version); err != nil { return } } } return }// }}} // vim: foldmethod=marker