package database import ( // External "git.gibonuddevalla.se/go/dbschema" "github.com/jmoiron/sqlx" _ "github.com/lib/pq" // Internal "git.gibonuddevalla.se/go/webservice/config" "git.gibonuddevalla.se/go/webservice/session" // Standard "database/sql" "encoding/json" "fmt" "log/slog" ) type SqlProvider func(string, int) ([]byte, bool) type T struct { cfg config.DatabaseDetails Conn *sqlx.DB logger *slog.Logger sqlProvider SqlProvider logProvider func(string, string) } func New(cfg config.DatabaseDetails) (db *T) { // {{{ db = new(T) db.cfg = cfg db.logProvider = db.defaultLogProvider return } // }}} func (db *T) SetLogger(l *slog.Logger) { // {{{ db.logger = l } // }}} func (db *T) SetSQLProvider(fn func(string, int) ([]byte, bool)) { // {{{ db.sqlProvider = fn } // }}} func (db *T) SetLogProvider(fn func(string, string)) { // {{{ db.logProvider = fn } // }}} func (db *T) defaultLogProvider(category, msg string) { // {{{ db.logger.Info("database", category, msg) } // }}} func webserviceSQLProvider(dbname string, version int) ([]byte, bool) { // {{{ sql := map[int]string{ 1: ` CREATE TABLE _webservice.user ( id serial NOT NULL, "name" varchar NOT NULL, "username" varchar NOT NULL, "password" char(96) NOT NULL, last_login timestamp with time zone NOT NULL DEFAULT '1970-01-01 00:00:00', CONSTRAINT user_pk PRIMARY KEY (id), CONSTRAINT user_un UNIQUE (username) ); CREATE TABLE "_webservice"."session" ( id serial NOT NULL, user_id int4 NULL, "uuid" char(36) NOT NULL, created timestamp with time zone NOT NULL DEFAULT NOW(), last_used timestamp with time zone NOT NULL DEFAULT NOW(), CONSTRAINT session_pk PRIMARY KEY (id), CONSTRAINT session_un UNIQUE ("uuid"), CONSTRAINT session_user_fk FOREIGN KEY (user_id) REFERENCES "_webservice"."user"(id) ON DELETE CASCADE ON UPDATE CASCADE ); CREATE EXTENSION IF NOT EXISTS pgcrypto SCHEMA _webservice; CREATE FUNCTION _webservice.password_hash(salt_hex char(32), pass bytea) RETURNS char(96) LANGUAGE plpgsql AS $$ BEGIN RETURN ( SELECT salt_hex || encode( sha256( decode(salt_hex, 'hex') || /* salt in binary */ pass /* password */ ), 'hex' ) ); END; $$; `, 2: ` ALTER TABLE "_webservice"."session" ADD mfa jsonb DEFAULT '{}' NOT NULL; `, } statement, found := sql[version] return []byte(statement), found } // }}} func (db *T) Upgrade() (err error) { // {{{ upgrader := dbschema.NewUpgrader("_webservice") upgrader.SetSqlCallback(webserviceSQLProvider) upgrader.SetLogCallback(db.logProvider) if err = upgrader.AddDatabase( db.cfg.Host, db.cfg.Port, db.cfg.Name, db.cfg.Username, db.cfg.Password, ); err != nil { return } err = upgrader.Run() if err != nil { return } upgrader = dbschema.NewUpgrader("_db") upgrader.SetSqlCallback(db.sqlProvider) upgrader.SetLogCallback(db.logProvider) if err = upgrader.AddDatabase( db.cfg.Host, db.cfg.Port, db.cfg.Name, db.cfg.Username, db.cfg.Password, ); err != nil { return } err = upgrader.Run() return } // }}} func (db *T) Connect() (err error) { // {{{ db.logger.Info("database", "host", db.cfg.Host, "port", db.cfg.Port, "name", db.cfg.Name) dbConn := fmt.Sprintf( "host=%s port=%d user=%s password=%s dbname=%s sslmode=disable", db.cfg.Host, db.cfg.Port, db.cfg.Username, db.cfg.Password, db.cfg.Name, ) if db.Conn, err = sqlx.Connect("postgres", dbConn); err != nil { return } return } // }}} func (db *T) Authenticate(username, password string) (authenticated bool, userID int, err error) { // {{{ var rows *sql.Rows if rows, err = db.Conn.Query(` SELECT id FROM _webservice.user WHERE LOWER(username) = LOWER($1) AND password = _webservice.password_hash(SUBSTRING(password FROM 1 FOR 32), $2::bytea) `, username, password, ); err != nil { return } defer rows.Close() if rows.Next() { rows.Scan(&userID) authenticated = userID > 0 } return } // }}} func (db *T) NewSession(uuid string) (err error) { // {{{ _, err = db.Conn.Exec("INSERT INTO _webservice.session(uuid) VALUES($1)", uuid) return } // }}} func (db *T) RetrieveSession(uuid string) (sess *session.T, err error) { // {{{ var rows *sqlx.Rows rows, err = db.Conn.Queryx(` WITH session_data AS ( UPDATE _webservice.session SET last_used=NOW() WHERE uuid=$1 RETURNING uuid, created, last_used, user_id, mfa ) SELECT sd.uuid, sd.created, sd.last_used, COALESCE(u.username, '') AS username, COALESCE(u.name, '') AS name, COALESCE(u.id, 0) AS user_id, mfa FROM session_data sd LEFT JOIN _webservice.user u ON sd.user_id = u.id `, uuid, ) if err != nil { return } defer rows.Close() for rows.Next() { sess = new(session.T) err = rows.StructScan(sess) sess.Authenticated = sess.UserID > 0 } return } // }}} func (db *T) SetSessionUser(uuid string, userID int) (err error) { // {{{ _, err = db.Conn.Exec(` UPDATE _webservice.session SET user_id = CASE WHEN $1 <= 0 THEN NULL ELSE $1 END WHERE uuid=$2 `, userID, uuid, ) if err != nil { return } return } // }}} func (db *T) SetSessionMFA(uuid string, mfa any) (err error) { // {{{ mfaByte, _ := json.Marshal(mfa) _, err = db.Conn.Exec(` UPDATE _webservice.session SET mfa = $2 WHERE uuid = $1 `, uuid, mfaByte, ) if err != nil { return } return } // }}} func (db *T) UpdateUserTime(userID int) (err error) { // {{{ _, err = db.Conn.Exec(`UPDATE _webservice.user SET last_login=NOW() WHERE id=$1`, userID) return } // }}} func (db *T) CreateUser(username, password, name string) (userID int64, err error) { // {{{ var row *sql.Row row = db.Conn.QueryRow(` INSERT INTO _webservice.user(username, password, name) VALUES( $1, _webservice.password_hash( /* salt in hex */ ENCODE(_webservice.gen_random_bytes(16), 'hex'), /* password */ $2::bytea ), $3 ) ON CONFLICT (username) DO UPDATE SET username = EXCLUDED.username RETURNING id `, username, password, name, ) err = row.Scan(&userID) return } // }}} func (db *T) ChangePassword(userID int, currentPassword, newPassword string) (changed bool, err error) { // {{{ var res sql.Result res, err = db.Conn.Exec(` UPDATE _webservice.user SET "password" = _webservice.password_hash( /* salt in hex */ ENCODE(_webservice.gen_random_bytes(16), 'hex'), /* password */ $3::bytea ) WHERE id = $1 AND "password" = _webservice.password_hash(SUBSTRING(password FROM 1 FOR 32), $2::bytea) `, userID, currentPassword, newPassword, ) var rowsAffected int64 rowsAffected, err = res.RowsAffected() if err != nil { return } changed = (rowsAffected == 1) return } // }}} // vim: foldmethod=marker