322 lines
6.8 KiB
Go
322 lines
6.8 KiB
Go
package database
|
|
|
|
import (
|
|
// External
|
|
"git.gibonuddevalla.se/go/dbschema"
|
|
"github.com/jmoiron/sqlx"
|
|
_ "github.com/lib/pq"
|
|
|
|
// Internal
|
|
"git.gibonuddevalla.se/go/webservice/config"
|
|
"git.gibonuddevalla.se/go/webservice/session"
|
|
|
|
// Standard
|
|
"database/sql"
|
|
"encoding/json"
|
|
"fmt"
|
|
"log/slog"
|
|
)
|
|
|
|
type SqlProvider func(string, int) ([]byte, bool)
|
|
|
|
type T struct {
|
|
cfg config.DatabaseDetails
|
|
Conn *sqlx.DB
|
|
logger *slog.Logger
|
|
|
|
sqlProvider SqlProvider
|
|
logProvider func(string, string)
|
|
}
|
|
|
|
func New(cfg config.DatabaseDetails) (db *T) { // {{{
|
|
db = new(T)
|
|
db.cfg = cfg
|
|
db.logProvider = db.defaultLogProvider
|
|
return
|
|
} // }}}
|
|
|
|
func (db *T) SetLogger(l *slog.Logger) { // {{{
|
|
db.logger = l
|
|
} // }}}
|
|
func (db *T) SetSQLProvider(fn func(string, int) ([]byte, bool)) { // {{{
|
|
db.sqlProvider = fn
|
|
} // }}}
|
|
func (db *T) SetLogProvider(fn func(string, string)) { // {{{
|
|
db.logProvider = fn
|
|
} // }}}
|
|
|
|
func (db *T) defaultLogProvider(category, msg string) { // {{{
|
|
db.logger.Info("database", category, msg)
|
|
} // }}}
|
|
func webserviceSQLProvider(dbname string, version int) ([]byte, bool) { // {{{
|
|
sql := map[int]string{
|
|
1: `
|
|
CREATE TABLE _webservice.user (
|
|
id serial NOT NULL,
|
|
"name" varchar NOT NULL,
|
|
"username" varchar NOT NULL,
|
|
"password" char(96) NOT NULL,
|
|
last_login timestamp with time zone NOT NULL DEFAULT '1970-01-01 00:00:00',
|
|
CONSTRAINT user_pk PRIMARY KEY (id),
|
|
CONSTRAINT user_un UNIQUE (username)
|
|
);
|
|
|
|
CREATE TABLE "_webservice"."session" (
|
|
id serial NOT NULL,
|
|
user_id int4 NULL,
|
|
"uuid" char(36) NOT NULL,
|
|
created timestamp with time zone NOT NULL DEFAULT NOW(),
|
|
last_used timestamp with time zone NOT NULL DEFAULT NOW(),
|
|
CONSTRAINT session_pk PRIMARY KEY (id),
|
|
CONSTRAINT session_un UNIQUE ("uuid"),
|
|
CONSTRAINT session_user_fk FOREIGN KEY (user_id) REFERENCES "_webservice"."user"(id) ON DELETE CASCADE ON UPDATE CASCADE
|
|
);
|
|
|
|
CREATE EXTENSION IF NOT EXISTS pgcrypto SCHEMA _webservice;
|
|
|
|
CREATE FUNCTION _webservice.password_hash(salt_hex char(32), pass bytea)
|
|
RETURNS char(96)
|
|
LANGUAGE plpgsql
|
|
AS
|
|
$$
|
|
BEGIN
|
|
RETURN (
|
|
SELECT
|
|
salt_hex ||
|
|
encode(
|
|
sha256(
|
|
decode(salt_hex, 'hex') || /* salt in binary */
|
|
pass /* password */
|
|
),
|
|
'hex'
|
|
)
|
|
);
|
|
END;
|
|
$$;
|
|
`,
|
|
|
|
2: `
|
|
ALTER TABLE "_webservice"."session" ADD mfa jsonb DEFAULT '{}' NOT NULL;
|
|
`,
|
|
}
|
|
|
|
statement, found := sql[version]
|
|
return []byte(statement), found
|
|
} // }}}
|
|
|
|
func (db *T) Upgrade() (err error) { // {{{
|
|
upgrader := dbschema.NewUpgrader("_webservice")
|
|
upgrader.SetSqlCallback(webserviceSQLProvider)
|
|
upgrader.SetLogCallback(db.logProvider)
|
|
if err = upgrader.AddDatabase(
|
|
db.cfg.Host,
|
|
db.cfg.Port,
|
|
db.cfg.Name,
|
|
db.cfg.Username,
|
|
db.cfg.Password,
|
|
); err != nil {
|
|
return
|
|
}
|
|
|
|
err = upgrader.Run()
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
upgrader = dbschema.NewUpgrader("_db")
|
|
upgrader.SetSqlCallback(db.sqlProvider)
|
|
upgrader.SetLogCallback(db.logProvider)
|
|
if err = upgrader.AddDatabase(
|
|
db.cfg.Host,
|
|
db.cfg.Port,
|
|
db.cfg.Name,
|
|
db.cfg.Username,
|
|
db.cfg.Password,
|
|
); err != nil {
|
|
return
|
|
}
|
|
|
|
err = upgrader.Run()
|
|
return
|
|
} // }}}
|
|
func (db *T) Connect() (err error) { // {{{
|
|
db.logger.Info("database", "host", db.cfg.Host, "port", db.cfg.Port, "name", db.cfg.Name)
|
|
|
|
dbConn := fmt.Sprintf(
|
|
"host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
|
|
db.cfg.Host,
|
|
db.cfg.Port,
|
|
db.cfg.Username,
|
|
db.cfg.Password,
|
|
db.cfg.Name,
|
|
)
|
|
|
|
if db.Conn, err = sqlx.Connect("postgres", dbConn); err != nil {
|
|
return
|
|
}
|
|
|
|
return
|
|
} // }}}
|
|
|
|
func (db *T) Authenticate(username, password string) (authenticated bool, userID int, err error) { // {{{
|
|
var rows *sql.Rows
|
|
if rows, err = db.Conn.Query(`
|
|
SELECT id
|
|
FROM _webservice.user
|
|
WHERE
|
|
LOWER(username) = LOWER($1) AND
|
|
password = _webservice.password_hash(SUBSTRING(password FROM 1 FOR 32), $2::bytea)
|
|
`,
|
|
username,
|
|
password,
|
|
); err != nil {
|
|
return
|
|
}
|
|
defer rows.Close()
|
|
|
|
if rows.Next() {
|
|
rows.Scan(&userID)
|
|
authenticated = userID > 0
|
|
}
|
|
return
|
|
} // }}}
|
|
func (db *T) NewSession(uuid string) (err error) { // {{{
|
|
_, err = db.Conn.Exec("INSERT INTO _webservice.session(uuid) VALUES($1)", uuid)
|
|
return
|
|
} // }}}
|
|
func (db *T) RetrieveSession(uuid string) (sess *session.T, err error) { // {{{
|
|
var rows *sqlx.Rows
|
|
rows, err = db.Conn.Queryx(`
|
|
WITH session_data AS (
|
|
UPDATE _webservice.session
|
|
SET
|
|
last_used=NOW()
|
|
WHERE
|
|
uuid=$1
|
|
RETURNING
|
|
uuid, created, last_used, user_id, mfa
|
|
)
|
|
SELECT
|
|
sd.uuid, sd.created, sd.last_used,
|
|
COALESCE(u.username, '') AS username,
|
|
COALESCE(u.name, '') AS name,
|
|
COALESCE(u.id, 0) AS user_id,
|
|
mfa
|
|
FROM session_data sd
|
|
LEFT JOIN _webservice.user u ON sd.user_id = u.id
|
|
`,
|
|
uuid,
|
|
)
|
|
if err != nil {
|
|
return
|
|
}
|
|
defer rows.Close()
|
|
|
|
for rows.Next() {
|
|
sess = new(session.T)
|
|
err = rows.StructScan(sess)
|
|
sess.Authenticated = sess.UserID > 0
|
|
}
|
|
return
|
|
} // }}}
|
|
func (db *T) SetSessionUser(uuid string, userID int) (err error) { // {{{
|
|
_, err = db.Conn.Exec(`
|
|
UPDATE _webservice.session
|
|
SET
|
|
user_id = CASE
|
|
WHEN $1 <= 0 THEN NULL
|
|
ELSE $1
|
|
END
|
|
WHERE uuid=$2
|
|
`,
|
|
userID,
|
|
uuid,
|
|
)
|
|
if err != nil {
|
|
return
|
|
}
|
|
return
|
|
} // }}}
|
|
func (db *T) SetSessionMFA(uuid string, mfa any) (err error) { // {{{
|
|
mfaByte, _ := json.Marshal(mfa)
|
|
_, err = db.Conn.Exec(`
|
|
UPDATE _webservice.session
|
|
SET
|
|
mfa = $2
|
|
WHERE
|
|
uuid = $1
|
|
`,
|
|
uuid,
|
|
mfaByte,
|
|
)
|
|
if err != nil {
|
|
return
|
|
}
|
|
return
|
|
} // }}}
|
|
func (db *T) UpdateUserTime(userID int) (err error) { // {{{
|
|
_, err = db.Conn.Exec(`UPDATE _webservice.user SET last_login=NOW() WHERE id=$1`, userID)
|
|
return
|
|
} // }}}
|
|
|
|
func (db *T) CreateUser(username, password, name string) (userID int64, err error) { // {{{
|
|
var row *sql.Row
|
|
row = db.Conn.QueryRow(`
|
|
INSERT INTO _webservice.user(username, password, name)
|
|
VALUES(
|
|
$1,
|
|
_webservice.password_hash(
|
|
/* salt in hex */
|
|
ENCODE(_webservice.gen_random_bytes(16), 'hex'),
|
|
|
|
/* password */
|
|
$2::bytea
|
|
),
|
|
$3
|
|
)
|
|
ON CONFLICT (username) DO UPDATE
|
|
SET username = EXCLUDED.username
|
|
RETURNING id
|
|
`,
|
|
username,
|
|
password,
|
|
name,
|
|
)
|
|
|
|
err = row.Scan(&userID)
|
|
return
|
|
} // }}}
|
|
func (db *T) ChangePassword(userID int, currentPassword, newPassword string) (changed bool, err error) { // {{{
|
|
var res sql.Result
|
|
res, err = db.Conn.Exec(`
|
|
UPDATE _webservice.user
|
|
SET
|
|
"password" = _webservice.password_hash(
|
|
/* salt in hex */
|
|
ENCODE(_webservice.gen_random_bytes(16), 'hex'),
|
|
|
|
/* password */
|
|
$3::bytea
|
|
)
|
|
WHERE
|
|
id = $1 AND
|
|
"password" = _webservice.password_hash(SUBSTRING(password FROM 1 FOR 32), $2::bytea)
|
|
|
|
`,
|
|
userID,
|
|
currentPassword,
|
|
newPassword,
|
|
)
|
|
|
|
var rowsAffected int64
|
|
rowsAffected, err = res.RowsAffected()
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
changed = (rowsAffected == 1)
|
|
return
|
|
} // }}}
|
|
|
|
// vim: foldmethod=marker
|