webservice/database/pkg.go

322 lines
6.8 KiB
Go
Raw Normal View History

2024-01-04 20:19:47 +01:00
package database
import (
// External
"git.gibonuddevalla.se/go/dbschema"
"github.com/jmoiron/sqlx"
_ "github.com/lib/pq"
// Internal
"git.gibonuddevalla.se/go/webservice/config"
"git.gibonuddevalla.se/go/webservice/session"
// Standard
"database/sql"
"encoding/json"
2024-01-04 20:19:47 +01:00
"fmt"
"log/slog"
)
type SqlProvider func(string, int) ([]byte, bool)
type T struct {
cfg config.DatabaseDetails
Conn *sqlx.DB
logger *slog.Logger
sqlProvider SqlProvider
logProvider func(string, string)
}
func New(cfg config.DatabaseDetails) (db *T) { // {{{
db = new(T)
db.cfg = cfg
db.logProvider = db.defaultLogProvider
return
} // }}}
func (db *T) SetLogger(l *slog.Logger) { // {{{
db.logger = l
} // }}}
func (db *T) SetSQLProvider(fn func(string, int) ([]byte, bool)) { // {{{
db.sqlProvider = fn
} // }}}
func (db *T) SetLogProvider(fn func(string, string)) { // {{{
db.logProvider = fn
} // }}}
func (db *T) defaultLogProvider(category, msg string) { // {{{
db.logger.Info("database", category, msg)
} // }}}
func webserviceSQLProvider(dbname string, version int) ([]byte, bool) { // {{{
sql := map[int]string{
1: `
CREATE TABLE _webservice.user (
id serial NOT NULL,
"name" varchar NOT NULL,
"username" varchar NOT NULL,
"password" char(96) NOT NULL,
2024-01-05 09:00:09 +01:00
last_login timestamp with time zone NOT NULL DEFAULT '1970-01-01 00:00:00',
2024-01-04 20:19:47 +01:00
CONSTRAINT user_pk PRIMARY KEY (id),
CONSTRAINT user_un UNIQUE (username)
);
CREATE TABLE "_webservice"."session" (
id serial NOT NULL,
user_id int4 NULL,
"uuid" char(36) NOT NULL,
2024-01-05 09:00:09 +01:00
created timestamp with time zone NOT NULL DEFAULT NOW(),
last_used timestamp with time zone NOT NULL DEFAULT NOW(),
2024-01-04 20:19:47 +01:00
CONSTRAINT session_pk PRIMARY KEY (id),
CONSTRAINT session_un UNIQUE ("uuid"),
CONSTRAINT session_user_fk FOREIGN KEY (user_id) REFERENCES "_webservice"."user"(id) ON DELETE CASCADE ON UPDATE CASCADE
);
CREATE EXTENSION IF NOT EXISTS pgcrypto SCHEMA _webservice;
CREATE FUNCTION _webservice.password_hash(salt_hex char(32), pass bytea)
RETURNS char(96)
LANGUAGE plpgsql
AS
$$
BEGIN
RETURN (
SELECT
salt_hex ||
encode(
sha256(
decode(salt_hex, 'hex') || /* salt in binary */
pass /* password */
),
'hex'
)
);
END;
$$;
`,
2: `
ALTER TABLE "_webservice"."session" ADD mfa jsonb DEFAULT '{}' NOT NULL;
`,
2024-01-04 20:19:47 +01:00
}
statement, found := sql[version]
return []byte(statement), found
} // }}}
func (db *T) Upgrade() (err error) { // {{{
upgrader := dbschema.NewUpgrader("_webservice")
upgrader.SetSqlCallback(webserviceSQLProvider)
upgrader.SetLogCallback(db.logProvider)
if err = upgrader.AddDatabase(
db.cfg.Host,
db.cfg.Port,
db.cfg.Name,
db.cfg.Username,
db.cfg.Password,
); err != nil {
return
}
err = upgrader.Run()
if err != nil {
return
}
upgrader = dbschema.NewUpgrader("_db")
upgrader.SetSqlCallback(db.sqlProvider)
upgrader.SetLogCallback(db.logProvider)
if err = upgrader.AddDatabase(
db.cfg.Host,
db.cfg.Port,
db.cfg.Name,
db.cfg.Username,
db.cfg.Password,
); err != nil {
return
}
err = upgrader.Run()
return
} // }}}
func (db *T) Connect() (err error) { // {{{
db.logger.Info("database", "host", db.cfg.Host, "port", db.cfg.Port, "name", db.cfg.Name)
dbConn := fmt.Sprintf(
"host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
db.cfg.Host,
db.cfg.Port,
db.cfg.Username,
db.cfg.Password,
db.cfg.Name,
)
if db.Conn, err = sqlx.Connect("postgres", dbConn); err != nil {
return
}
return
} // }}}
2024-01-05 09:00:09 +01:00
2024-01-04 20:19:47 +01:00
func (db *T) Authenticate(username, password string) (authenticated bool, userID int, err error) { // {{{
var rows *sql.Rows
if rows, err = db.Conn.Query(`
SELECT id
FROM _webservice.user
WHERE
LOWER(username) = LOWER($1) AND
2024-01-04 20:19:47 +01:00
password = _webservice.password_hash(SUBSTRING(password FROM 1 FOR 32), $2::bytea)
`,
username,
password,
); err != nil {
return
}
defer rows.Close()
if rows.Next() {
rows.Scan(&userID)
authenticated = userID > 0
}
return
} // }}}
func (db *T) NewSession(uuid string) (err error) { // {{{
_, err = db.Conn.Exec("INSERT INTO _webservice.session(uuid) VALUES($1)", uuid)
return
} // }}}
func (db *T) RetrieveSession(uuid string) (sess *session.T, err error) { // {{{
2024-01-04 20:19:47 +01:00
var rows *sqlx.Rows
rows, err = db.Conn.Queryx(`
WITH session_data AS (
UPDATE _webservice.session
SET
last_used=NOW()
WHERE
uuid=$1
RETURNING
uuid, created, last_used, user_id, mfa
2024-01-04 20:19:47 +01:00
)
SELECT
sd.uuid, sd.created, sd.last_used,
COALESCE(u.username, '') AS username,
2024-01-05 09:00:09 +01:00
COALESCE(u.name, '') AS name,
COALESCE(u.id, 0) AS user_id,
mfa
2024-01-04 20:19:47 +01:00
FROM session_data sd
LEFT JOIN _webservice.user u ON sd.user_id = u.id
`,
uuid,
)
if err != nil {
return
}
defer rows.Close()
for rows.Next() {
sess = new(session.T)
err = rows.StructScan(sess)
2024-01-05 09:00:09 +01:00
sess.Authenticated = sess.UserID > 0
2024-01-04 20:19:47 +01:00
}
return
} // }}}
2024-01-04 20:19:47 +01:00
func (db *T) SetSessionUser(uuid string, userID int) (err error) { // {{{
_, err = db.Conn.Exec(`
UPDATE _webservice.session
SET
user_id = CASE
WHEN $1 <= 0 THEN NULL
ELSE $1
END
WHERE uuid=$2
`,
userID,
uuid,
)
2024-01-04 20:19:47 +01:00
if err != nil {
return
}
return
} // }}}
func (db *T) SetSessionMFA(uuid string, mfa any) (err error) { // {{{
mfaByte, _ := json.Marshal(mfa)
_, err = db.Conn.Exec(`
UPDATE _webservice.session
SET
mfa = $2
WHERE
uuid = $1
`,
uuid,
mfaByte,
)
if err != nil {
return
}
return
} // }}}
func (db *T) UpdateUserTime(userID int) (err error) { // {{{
2024-02-03 10:53:12 +01:00
_, err = db.Conn.Exec(`UPDATE _webservice.user SET last_login=NOW() WHERE id=$1`, userID)
return
} // }}}
2024-01-05 09:00:09 +01:00
func (db *T) CreateUser(username, password, name string) (userID int64, err error) { // {{{
2024-02-13 13:51:01 +01:00
var row *sql.Row
row = db.Conn.QueryRow(`
2024-01-04 20:19:47 +01:00
INSERT INTO _webservice.user(username, password, name)
VALUES(
$1,
_webservice.password_hash(
/* salt in hex */
ENCODE(_webservice.gen_random_bytes(16), 'hex'),
/* password */
$2::bytea
),
$3
)
2024-02-15 13:25:58 +01:00
ON CONFLICT (username) DO UPDATE
SET username = EXCLUDED.username
2024-02-13 13:51:01 +01:00
RETURNING id
2024-01-04 20:19:47 +01:00
`,
username,
password,
name,
)
2024-02-13 13:51:01 +01:00
err = row.Scan(&userID)
2024-01-04 20:19:47 +01:00
return
} // }}}
2024-02-20 08:05:33 +01:00
func (db *T) ChangePassword(userID int, currentPassword, newPassword string) (changed bool, err error) { // {{{
var res sql.Result
res, err = db.Conn.Exec(`
UPDATE _webservice.user
SET
"password" = _webservice.password_hash(
/* salt in hex */
ENCODE(_webservice.gen_random_bytes(16), 'hex'),
/* password */
$3::bytea
)
WHERE
id = $1 AND
"password" = _webservice.password_hash(SUBSTRING(password FROM 1 FOR 32), $2::bytea)
`,
userID,
currentPassword,
newPassword,
)
var rowsAffected int64
rowsAffected, err = res.RowsAffected()
if err != nil {
return
}
changed = (rowsAffected == 1)
return
} // }}}
2024-01-04 20:19:47 +01:00
// vim: foldmethod=marker