package dbschema import ( // Standard "database/sql" "fmt" ) func defaultCallback(topic, msg string) {// {{{ fmt.Printf("[%s] %s\n", topic, msg) }// }}} func NewUpgrader(host string, port int, dbName, user, pass string) (upgrader Upgrader, err error) {// {{{ upgrader.logCallback = defaultCallback upgrader.databases = map[string]Database{} upgrader.schemaDb, err = newDatabase( host, port, dbName, user, pass, ) err = upgrader.verifySchemaTable() return }// }}} func (upgrader *Upgrader) SetLogCallback(callback func(string, string)) {// {{{ upgrader.logCallback = callback }// }}} func (upgrader *Upgrader) SetSqlCallback(callback func(string, int) ([]byte, bool)) {// {{{ upgrader.sqlCallback = callback }// }}} func (upgrader Upgrader) verifySchemaTable() (err error) {// {{{ var rows *sql.Rows if rows, err = upgrader.schemaDb.db.Query( `SELECT EXISTS ( SELECT FROM pg_catalog.pg_class c JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace WHERE n.nspname = '_db' AND c.relname = 'schema' )`, ); err != nil { return } defer rows.Close() var exists bool rows.Next() if err = rows.Scan(&exists); err != nil { return } if !exists { upgrader.logCallback("create", "_db.schema") upgrader.schemaDb.db.Exec(`CREATE SCHEMA "_db"`) if _, err = upgrader.schemaDb.db.Exec(` CREATE TABLE "_db"."schema" ( database varchar NOT NULL, version int4 NOT NULL, updated timestamp NOT NULL DEFAULT NOW(), CONSTRAINT schema_pk PRIMARY KEY (database) ); `, ); err != nil { return } } return }// }}} func (upgrader Upgrader) verifySchemaEntry(dbase Database) (err error) {// {{{ var rows *sql.Rows rows, err = upgrader.schemaDb.db.Query(`SELECT version FROM _db.schema WHERE database=$1`, dbase.DbName) if err != nil { return } defer rows.Close() if !rows.Next() { upgrader.logCallback("initiate version", dbase.DbName) _, err = upgrader.schemaDb.db.Exec(`INSERT INTO _db.schema(database, version) VALUES($1, 0)`, dbase.DbName) if err != nil { return } } return }// }}} func (upgrader Upgrader) version(dbName string) (version int, err error) {// {{{ var rows *sql.Rows rows, err = upgrader.schemaDb.db.Query( `SELECT version FROM _db.schema WHERE database=$1`, dbName, ) if err != nil { return } defer rows.Close() if rows.Next() { err = rows.Scan(&version) } else { err = fmt.Errorf(`Database "%s" is missing an entry in _db.schema`, dbName) } return }// }}} func (upgrader Upgrader) AddDatabase(host string, port int, dbName, user, pass string) (err error) {// {{{ var db Database if db, err = newDatabase(host, port, dbName, user, pass); err != nil { return } upgrader.databases[dbName] = db err = upgrader.verifySchemaEntry(db) return }// }}} func (upgrader Upgrader) Run() (err error) {// {{{ var version int for dbName, db := range upgrader.databases { version, err = upgrader.version(dbName) if err != nil { return } upgrader.logCallback("version", fmt.Sprintf("%s: %d", dbName, version)) for { version++ sql, found := upgrader.sqlCallback(dbName, version) if !found { break } upgrader.logCallback("exec", fmt.Sprintf("%s: %d", dbName, version)) if _, err = db.db.Exec(string(sql)); err != nil { return } _, err = upgrader.schemaDb.db.Exec(` UPDATE _db.schema SET version=$1, updated=NOW() WHERE database=$2 `, version, dbName) if err != nil { return } } } return }// }}} // vim: foldmethod=marker