Compare commits

...

21 commits
v0.2.5 ... main

Author SHA1 Message Date
58050a44a4 Fixed double pointer for ParseApplicationConfig 2025-03-06 07:51:26 +01:00
63f484e269 Honor the useStaticDirectory variable in the static handler to minimize unnecessary os.Stat 2025-03-06 07:01:18 +01:00
1cf9318bc8 Added html_template package for easier HTML page generation 2024-12-19 13:46:22 +01:00
ac7887b6c7 Add README.md 2024-09-17 07:28:04 +02:00
53efcaedc9 Execute websocket handlers as goroutines. Bumped to 0.2.17 2024-07-30 07:45:41 +02:00
256dda9a48 Removed validation on session since a lot of other software is depending on this package 2024-06-30 11:41:03 +02:00
6623db9574 Added caching headers for static content 2024-05-22 07:58:46 +02:00
cff9082aac Added application config parsing 2024-05-22 07:58:46 +02:00
452a109204 Read session UUID from cookie if header is missing. Bump to v0.2.14 2024-05-08 18:10:06 +02:00
aecafc2986 Bumped to v0.2.13 2024-05-08 17:52:15 +02:00
a62d2287ae Added X-Session-ID cookie 2024-05-08 17:51:41 +02:00
9ce19b94b4 Updated doc, exposed Config and bumped to v0.2.12 2024-04-27 11:54:33 +02:00
57c1cfb412 Bumped to v0.2.11 2024-04-22 09:19:38 +02:00
d650542b72 Authentication with lowercased username 2024-04-22 09:19:21 +02:00
548cddb773 Bumped version string to 0.2.10 2024-03-29 08:45:07 +01:00
ad1cb56a06 Merge branch 'main' of ssh://git.gibonuddevalla.se:2222/go/webservice 2024-02-20 08:05:42 +01:00
825bc4d8b2 Added a function to change password 2024-02-20 08:05:33 +01:00
f28b5188a6 Create user without confilct 2024-02-15 13:25:58 +01:00
77f15a197a Added mutex locking around modifications to websocket connections 2024-02-15 10:35:57 +01:00
92c8ac444f Added MFA to sessions and better websocket management 2024-02-14 14:34:36 +01:00
c848ae60b5 Return ID when creating a user 2024-02-13 13:51:01 +01:00
10 changed files with 383 additions and 52 deletions

10
README.md Normal file
View file

@ -0,0 +1,10 @@
# godoc
Installera verktyget godoc:
```
go install golang.org/x/tools/cmd/godoc@latest
```
Kör godoc i katalogen för webservice-repot.
Gå till http://localhost:6060/

View file

@ -5,7 +5,6 @@ import (
"gopkg.in/yaml.v3" "gopkg.in/yaml.v3"
// Standard // Standard
"errors"
"os" "os"
) )
@ -33,6 +32,8 @@ type Config struct {
Session struct { Session struct {
DaysValid int DaysValid int
} }
Application any
} }
func New(filename string) (config Config, err error) { func New(filename string) (config Config, err error) {
@ -47,9 +48,10 @@ func New(filename string) (config Config, err error) {
return return
} }
if config.Session.DaysValid == 0 {
err = errors.New("Configuration: session.daysvalid needs to be higher than 0.")
}
return return
} }
func (config *Config) ParseApplicationConfig(v any) {
yStr, _ := yaml.Marshal(config.Application)
yaml.Unmarshal(yStr, v)
}

View file

@ -12,6 +12,7 @@ import (
// Standard // Standard
"database/sql" "database/sql"
"encoding/json"
"fmt" "fmt"
"log/slog" "log/slog"
) )
@ -93,6 +94,10 @@ func webserviceSQLProvider(dbname string, version int) ([]byte, bool) { // {{{
END; END;
$$; $$;
`, `,
2: `
ALTER TABLE "_webservice"."session" ADD mfa jsonb DEFAULT '{}' NOT NULL;
`,
} }
statement, found := sql[version] statement, found := sql[version]
@ -159,7 +164,7 @@ func (db *T) Authenticate(username, password string) (authenticated bool, userID
SELECT id SELECT id
FROM _webservice.user FROM _webservice.user
WHERE WHERE
username = $1 AND LOWER(username) = LOWER($1) AND
password = _webservice.password_hash(SUBSTRING(password FROM 1 FOR 32), $2::bytea) password = _webservice.password_hash(SUBSTRING(password FROM 1 FOR 32), $2::bytea)
`, `,
username, username,
@ -179,7 +184,7 @@ func (db *T) NewSession(uuid string) (err error) { // {{{
_, err = db.Conn.Exec("INSERT INTO _webservice.session(uuid) VALUES($1)", uuid) _, err = db.Conn.Exec("INSERT INTO _webservice.session(uuid) VALUES($1)", uuid)
return return
} // }}} } // }}}
func (db *T) RetrieveSession(uuid string) (sess *session.T, err error) {// {{{ func (db *T) RetrieveSession(uuid string) (sess *session.T, err error) { // {{{
var rows *sqlx.Rows var rows *sqlx.Rows
rows, err = db.Conn.Queryx(` rows, err = db.Conn.Queryx(`
WITH session_data AS ( WITH session_data AS (
@ -189,13 +194,14 @@ func (db *T) RetrieveSession(uuid string) (sess *session.T, err error) {// {{{
WHERE WHERE
uuid=$1 uuid=$1
RETURNING RETURNING
uuid, created, last_used, user_id uuid, created, last_used, user_id, mfa
) )
SELECT SELECT
sd.uuid, sd.created, sd.last_used, sd.uuid, sd.created, sd.last_used,
COALESCE(u.username, '') AS username, COALESCE(u.username, '') AS username,
COALESCE(u.name, '') AS name, COALESCE(u.name, '') AS name,
COALESCE(u.id, 0) AS user_id COALESCE(u.id, 0) AS user_id,
mfa
FROM session_data sd FROM session_data sd
LEFT JOIN _webservice.user u ON sd.user_id = u.id LEFT JOIN _webservice.user u ON sd.user_id = u.id
`, `,
@ -212,7 +218,7 @@ func (db *T) RetrieveSession(uuid string) (sess *session.T, err error) {// {{{
sess.Authenticated = sess.UserID > 0 sess.Authenticated = sess.UserID > 0
} }
return return
}// }}} } // }}}
func (db *T) SetSessionUser(uuid string, userID int) (err error) { // {{{ func (db *T) SetSessionUser(uuid string, userID int) (err error) { // {{{
_, err = db.Conn.Exec(` _, err = db.Conn.Exec(`
UPDATE _webservice.session UPDATE _webservice.session
@ -231,13 +237,31 @@ func (db *T) SetSessionUser(uuid string, userID int) (err error) { // {{{
} }
return return
} // }}} } // }}}
func (db *T) UpdateUserTime(userID int) (err error) {// {{{ func (db *T) SetSessionMFA(uuid string, mfa any) (err error) { // {{{
mfaByte, _ := json.Marshal(mfa)
_, err = db.Conn.Exec(`
UPDATE _webservice.session
SET
mfa = $2
WHERE
uuid = $1
`,
uuid,
mfaByte,
)
if err != nil {
return
}
return
} // }}}
func (db *T) UpdateUserTime(userID int) (err error) { // {{{
_, err = db.Conn.Exec(`UPDATE _webservice.user SET last_login=NOW() WHERE id=$1`, userID) _, err = db.Conn.Exec(`UPDATE _webservice.user SET last_login=NOW() WHERE id=$1`, userID)
return return
}// }}} } // }}}
func (db *T) CreateUser(username, password, name string) (err error) {// {{{ func (db *T) CreateUser(username, password, name string) (userID int64, err error) { // {{{
_, err = db.Conn.Exec(` var row *sql.Row
row = db.Conn.QueryRow(`
INSERT INTO _webservice.user(username, password, name) INSERT INTO _webservice.user(username, password, name)
VALUES( VALUES(
$1, $1,
@ -250,12 +274,48 @@ func (db *T) CreateUser(username, password, name string) (err error) {// {{{
), ),
$3 $3
) )
ON CONFLICT (username) DO UPDATE
SET username = EXCLUDED.username
RETURNING id
`, `,
username, username,
password, password,
name, name,
) )
err = row.Scan(&userID)
return return
}// }}} } // }}}
func (db *T) ChangePassword(userID int, currentPassword, newPassword string) (changed bool, err error) { // {{{
var res sql.Result
res, err = db.Conn.Exec(`
UPDATE _webservice.user
SET
"password" = _webservice.password_hash(
/* salt in hex */
ENCODE(_webservice.gen_random_bytes(16), 'hex'),
/* password */
$3::bytea
)
WHERE
id = $1 AND
"password" = _webservice.password_hash(SUBSTRING(password FROM 1 FOR 32), $2::bytea)
`,
userID,
currentPassword,
newPassword,
)
var rowsAffected int64
rowsAffected, err = res.RowsAffected()
if err != nil {
return
}
changed = (rowsAffected == 1)
return
} // }}}
// vim: foldmethod=marker // vim: foldmethod=marker

31
html_template/page.go Normal file
View file

@ -0,0 +1,31 @@
package HTMLTemplate
type Page interface {
GetVersion() string
GetLayout() string
GetPage() string
GetData() any
}
type SimplePage struct {
Version string
Layout string
Page string
Data any
}
func (s SimplePage) GetVersion() string {
return s.Version
}
func (s SimplePage) GetLayout() string {
return s.Layout
}
func (s SimplePage) GetPage() string {
return s.Page
}
func (s SimplePage) GetData() any {
return s.Data
}

158
html_template/pkg.go Normal file
View file

@ -0,0 +1,158 @@
package HTMLTemplate
import (
// External
werr "git.gibonuddevalla.se/go/wrappederror"
// Standard
"fmt"
"html/template"
"io/fs"
"net/http"
"os"
"regexp"
)
type Engine struct {
parsedTemplates map[string]*template.Template
viewFS fs.FS
staticEmbeddedFS http.Handler
staticLocalFS http.Handler
componentFilenames []string
DevMode bool
}
func NewEngine(viewFS, staticFS fs.FS, devmode bool) (e Engine, err error) { // {{{
e.parsedTemplates = make(map[string]*template.Template)
e.viewFS = viewFS
e.DevMode = devmode
e.componentFilenames, err = e.getComponentFilenames()
// Set up fileservers for static resources.
// The embedded FS is using the embedded files intented for production use.
// The local FS is for development of Javascript to avoid server rebuild (devmode).
var staticSubFS fs.FS
staticSubFS, err = fs.Sub(staticFS, "static")
if err != nil {
return
}
e.staticEmbeddedFS = http.FileServer(http.FS(staticSubFS))
e.staticLocalFS = http.FileServer(http.Dir("static"))
return
} // }}}
func (e *Engine) getComponentFilenames() (files []string, err error) { // {{{
files = []string{}
if err := fs.WalkDir(e.viewFS, "views/components", func(path string, d fs.DirEntry, err error) error {
if d == nil {
return nil
}
if d.IsDir() {
return nil
}
files = append(files, path)
return nil
}); err != nil {
return nil, err
}
return files, nil
} // }}}
func (e *Engine) ReloadTemplates() { // {{{
e.parsedTemplates = make(map[string]*template.Template)
} // }}}
func (e *Engine) StaticResource(w http.ResponseWriter, r *http.Request) { // {{{
var err error
// URLs with pattern /(css|images)/v1.0.0/foobar are stripped of the version.
// To get rid of problems with cached content in browser on a new version release,
// while also not disabling cache altogether.
if r.URL.Path == "/favicon.ico" {
e.staticEmbeddedFS.ServeHTTP(w, r)
return
}
rxp := regexp.MustCompile("^/(css|images|js|fonts)/v[0-9]+/(.*)$")
if comp := rxp.FindStringSubmatch(r.URL.Path); comp != nil {
w.Header().Add("Pragma", "public")
w.Header().Add("Cache-Control", "max-age=604800")
r.URL.Path = fmt.Sprintf("/%s/%s", comp[1], comp[2])
if e.DevMode {
p := fmt.Sprintf("static/%s/%s", comp[1], comp[2])
_, err = os.Stat(p)
if err == nil {
e.staticLocalFS.ServeHTTP(w, r)
}
return
}
}
e.staticEmbeddedFS.ServeHTTP(w, r)
} // }}}
func (e *Engine) getPage(layout, page string) (tmpl *template.Template, err error) { // {{{
layoutFilename := fmt.Sprintf("views/layouts/%s.gotmpl", layout)
pageFilename := fmt.Sprintf("views/pages/%s.gotmpl", page)
if tmpl, found := e.parsedTemplates[page]; found {
return tmpl, nil
}
funcMap := template.FuncMap{
/*
"format_time": func(t time.Time) template.HTML {
return template.HTML(
t.In(smonConfig.Timezone()).Format(`<span class="date">2006-01-02</span> <span class="time">15:04:05<span class="seconds">:05</span></span>`),
)
},
*/
}
filenames := []string{layoutFilename, pageFilename}
filenames = append(filenames, e.componentFilenames...)
if e.DevMode {
tmpl, err = template.New(layout+".gotmpl").Funcs(funcMap).ParseFS(os.DirFS("."), filenames...)
} else {
tmpl, err = template.New(layout+".gotmpl").Funcs(funcMap).ParseFS(e.viewFS, filenames...)
}
if err != nil {
err = werr.Wrap(err).Log()
return
}
e.parsedTemplates[page] = tmpl
return
} // }}}
func (e *Engine) Render(p Page, w http.ResponseWriter, r *http.Request) (err error) { // {{{
if e.DevMode {
e.ReloadTemplates()
}
var tmpl *template.Template
tmpl, err = e.getPage(p.GetLayout(), p.GetPage())
if err != nil {
err = werr.Wrap(err)
return
}
data := map[string]any{
"VERSION": p.GetVersion(),
"LAYOUT": p.GetLayout(),
"PAGE": p.GetPage(),
"ERROR": r.URL.Query().Get("_err"),
"Data": p.GetData(),
}
err = tmpl.Execute(w, data)
if err != nil {
err = werr.Wrap(err)
}
return
} // }}}
// vim: foldmethod=marker

81
pkg.go
View file

@ -1,6 +1,10 @@
/* /*
The webservice package is used to provide a webservice with sessions: The webservice package is used to provide a webservice with sessions:
const VERSION = "v1"
var logger *slog.Logger
func sqlProvider(dbname string, version int) (sql []byte, found bool) { func sqlProvider(dbname string, version int) (sql []byte, found bool) {
var err error var err error
sql, err = embeddedSQL.ReadFile(fmt.Sprintf("sql/%05d.sql", version)) sql, err = embeddedSQL.ReadFile(fmt.Sprintf("sql/%05d.sql", version))
@ -11,21 +15,29 @@ The webservice package is used to provide a webservice with sessions:
return return
} }
service, err := webservice.New("/etc/some/webservice.yaml")
if err != nil { func init() {
logger.Error("application", "error", err) opts := slog.HandlerOptions{}
os.Exit(1) logger = slog.New(slog.NewJSONHandler(os.Stdout, &opts))
} }
service.SetDatabase(sqlProvider) func main() {
service.SetAuthenticationHandler(authenticate) service, err := webservice.New("/etc/some/webservice.yaml", VERSION, logger)
service.SetAuthorizationHandler(authorize) if err != nil {
service.Register("/foo", true, true, foo) logger.Error("application", "error", err)
service.Register("/bar", true, false, bar) os.Exit(1)
err = service.Start() }
if err != nil {
logger.Error("webserver", "error", err) service.SetDatabase(sqlProvider)
os.Exit(1) service.SetAuthenticationHandler(authenticate)
service.SetAuthorizationHandler(authorize)
service.Register("/foo", true, true, foo)
service.Register("/bar", true, false, bar)
err = service.Start()
if err != nil {
logger.Error("webserver", "error", err)
os.Exit(1)
}
} }
*/ */
package webservice package webservice
@ -52,7 +64,7 @@ import (
"strings" "strings"
) )
const VERSION = "v0.1.0" const VERSION = "v0.2.17"
type HttpHandler func(http.ResponseWriter, *http.Request) type HttpHandler func(http.ResponseWriter, *http.Request)
@ -66,7 +78,7 @@ type ServiceError struct {
type Service struct { type Service struct {
logger *slog.Logger logger *slog.Logger
sessions map[string]*session.T sessions map[string]*session.T
config config.Config Config config.Config
Db *database.T Db *database.T
Version string Version string
WsConnectionManager ws_conn_manager.ConnectionManager WsConnectionManager ws_conn_manager.ConnectionManager
@ -87,11 +99,11 @@ type ServiceHandler func(http.ResponseWriter, *http.Request, *session.T)
func New(configFilename, version string, logger *slog.Logger) (service *Service, err error) { // {{{ func New(configFilename, version string, logger *slog.Logger) (service *Service, err error) { // {{{
service = new(Service) service = new(Service)
service.config, err = config.New(configFilename) service.Config, err = config.New(configFilename)
if err != nil { if err != nil {
return return
} }
logger.Debug("config", "config", service.config) logger.Debug("config", "config", service.Config)
service.Version = version service.Version = version
service.logger = logger service.logger = logger
@ -99,7 +111,7 @@ func New(configFilename, version string, logger *slog.Logger) (service *Service,
service.errorHandler = service.defaultErrorHandler service.errorHandler = service.defaultErrorHandler
service.authenticationHandler = service.defaultAuthenticationHandler service.authenticationHandler = service.defaultAuthenticationHandler
service.authorizationHandler = service.defaultAuthorizationHandler service.authorizationHandler = service.defaultAuthorizationHandler
service.WsConnectionManager = ws_conn_manager.NewConnectionManager(service.logger, service.config.Websocket.Domains) service.WsConnectionManager = ws_conn_manager.NewConnectionManager(service.logger, service.Config.Websocket.Domains)
service.Register("/_session/new", false, false, service.sessionNew) service.Register("/_session/new", false, false, service.sessionNew)
service.Register("/_session/authenticate", true, false, service.sessionAuthenticate) service.Register("/_session/authenticate", true, false, service.sessionAuthenticate)
@ -113,7 +125,7 @@ func New(configFilename, version string, logger *slog.Logger) (service *Service,
return return
} // }}} } // }}}
func (service *Service) defaultAuthenticationHandler(req AuthenticationRequest, alreadyAuthenticated bool) (resp AuthenticationResponse, err error) { // {{{ func (service *Service) defaultAuthenticationHandler(req AuthenticationRequest, sess *session.T, alreadyAuthenticated bool) (resp AuthenticationResponse, err error) { // {{{
resp.Authenticated = alreadyAuthenticated resp.Authenticated = alreadyAuthenticated
service.logger.Info("webservice", "op", "authentication", "username", req.Username, "authenticated", resp.Authenticated) service.logger.Info("webservice", "op", "authentication", "username", req.Username, "authenticated", resp.Authenticated)
return return
@ -160,7 +172,7 @@ func (service *Service) SetStaticDirectory(directory string, useDirectory bool)
} // }}} } // }}}
func (service *Service) SetDatabase(sqlProv database.SqlProvider) { // {{{ func (service *Service) SetDatabase(sqlProv database.SqlProvider) { // {{{
service.Db = database.New(service.config.Database) service.Db = database.New(service.Config.Database)
service.Db.SetLogger(service.logger) service.Db.SetLogger(service.logger)
service.Db.SetSQLProvider(sqlProv) service.Db.SetSQLProvider(sqlProv)
return return
@ -183,7 +195,7 @@ func (service *Service) Register(path string, requireSession, requireAuthenticat
return return
} }
sess, found = service.retrieveSession(headerSessionUUID) sess, found = service.RetrieveSession(headerSessionUUID)
if !found { if !found {
service.errorHandler(fmt.Errorf("Session '%s' not found", headerSessionUUID), "001-0001", w) service.errorHandler(fmt.Errorf("Session '%s' not found", headerSessionUUID), "001-0001", w)
return return
@ -223,7 +235,7 @@ func (service *Service) InitDatabaseConnection() (err error) { // {{{
} }
return return
} // }}} } // }}}
func (service *Service) CreateUser(username, password, name string) (err error) { // {{{ func (service *Service) CreateUser(username, password, name string) (userID int64, err error) { // {{{
if service.Db != nil { if service.Db != nil {
err = service.InitDatabaseConnection() err = service.InitDatabaseConnection()
if err != nil { if err != nil {
@ -231,7 +243,7 @@ func (service *Service) CreateUser(username, password, name string) (err error)
} }
} }
err = service.Db.CreateUser(username, password, name) userID, err = service.Db.CreateUser(username, password, name)
return return
} // }}} } // }}}
func (service *Service) CreateUserPrompt() { // {{{ func (service *Service) CreateUserPrompt() { // {{{
@ -251,7 +263,7 @@ func (service *Service) CreateUserPrompt() { // {{{
password, _ = reader.ReadString('\n') password, _ = reader.ReadString('\n')
password = strings.TrimSpace(password) password = strings.TrimSpace(password)
err = service.CreateUser(username, password, name) _, err = service.CreateUser(username, password, name)
if err != nil { if err != nil {
service.logger.Error("application", "error", err) service.logger.Error("application", "error", err)
os.Exit(1) os.Exit(1)
@ -267,7 +279,7 @@ func (service *Service) Start() (err error) { // {{{
go service.WsConnectionManager.BroadcastLoop() go service.WsConnectionManager.BroadcastLoop()
listen := fmt.Sprintf("%s:%d", service.config.Network.Address, service.config.Network.Port) listen := fmt.Sprintf("%s:%d", service.Config.Network.Address, service.Config.Network.Port)
service.logger.Info("webserver", "listen", listen) service.logger.Info("webserver", "listen", listen)
err = http.ListenAndServe(listen, nil) err = http.ListenAndServe(listen, nil)
return return
@ -288,11 +300,19 @@ func (service *Service) StaticHandler(w http.ResponseWriter, r *http.Request, se
rxp := regexp.MustCompile("^/(css|images|js|fonts)/v[0-9]+/(.*)$") rxp := regexp.MustCompile("^/(css|images|js|fonts)/v[0-9]+/(.*)$")
if comp := rxp.FindStringSubmatch(r.URL.Path); comp != nil { if comp := rxp.FindStringSubmatch(r.URL.Path); comp != nil {
w.Header().Add("Pragma", "public")
w.Header().Add("Cache-Control", "max-age=604800")
r.URL.Path = fmt.Sprintf("/%s/%s", comp[1], comp[2]) r.URL.Path = fmt.Sprintf("/%s/%s", comp[1], comp[2])
p := fmt.Sprintf(service.staticDirectory+"/%s/%s", comp[1], comp[2]) p := fmt.Sprintf(service.staticDirectory+"/%s/%s", comp[1], comp[2])
_, err = os.Stat(p)
if err == nil { if service.useStaticDirectory {
service.staticLocalFileserver.ServeHTTP(w, r) _, err = os.Stat(p)
if err == nil {
service.staticLocalFileserver.ServeHTTP(w, r)
} else {
service.staticEmbeddedFileserver.ServeHTTP(w, r)
}
} else { } else {
service.staticEmbeddedFileserver.ServeHTTP(w, r) service.staticEmbeddedFileserver.ServeHTTP(w, r)
} }
@ -364,6 +384,11 @@ func sessionUUID(r *http.Request) (string, error) { // {{{
headers := r.Header["X-Session-Id"] headers := r.Header["X-Session-Id"]
if len(headers) > 0 { if len(headers) > 0 {
return headers[0], nil return headers[0], nil
} else {
cookie, err := r.Cookie("X-Session-ID")
if err == nil && cookie.Value != "" {
return cookie.Value, nil
}
} }
return "", errors.New("Invalid session") return "", errors.New("Invalid session")
} // }}} } // }}}

View file

@ -25,9 +25,10 @@ type AuthenticationResponse struct {
OK bool OK bool
Authenticated bool Authenticated bool
UserID int UserID int
MFA any
} }
type AuthenticationHandler func(AuthenticationRequest, bool) (AuthenticationResponse, error) type AuthenticationHandler func(AuthenticationRequest, *session.T, bool) (AuthenticationResponse, error)
type AuthorizationHandler func(*session.T, *http.Request) (bool, error) type AuthorizationHandler func(*session.T, *http.Request) (bool, error)
func (service *Service) sessionNew(w http.ResponseWriter, r *http.Request, foo *session.T) { // {{{ func (service *Service) sessionNew(w http.ResponseWriter, r *http.Request, foo *session.T) { // {{{
@ -48,7 +49,7 @@ func (service *Service) sessionNew(w http.ResponseWriter, r *http.Request, foo *
break break
} else { } else {
if _, found = service.retrieveSession(sess.UUID); found { if _, found = service.RetrieveSession(sess.UUID); found {
continue continue
} }
@ -74,6 +75,13 @@ func (service *Service) sessionNew(w http.ResponseWriter, r *http.Request, foo *
}, },
) )
cookie := http.Cookie{}
cookie.Name = "X-Session-ID"
cookie.Value = sess.UUID
cookie.MaxAge = 86400 * 365
cookie.Path = "/"
http.SetCookie(w, &cookie)
w.Write(respJSON) w.Write(respJSON)
} // }}} } // }}}
func (service *Service) sessionAuthenticate(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{ func (service *Service) sessionAuthenticate(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{
@ -91,8 +99,8 @@ func (service *Service) sessionAuthenticate(w http.ResponseWriter, r *http.Reque
} }
// Authenticate against webservice user table if using a database. // Authenticate against webservice user table if using a database.
var userID int var userID int = sess.UserID
if service.Db != nil { if service.Db != nil && userID == 0 {
authenticatedByFramework, userID, err = service.Db.Authenticate(authRequest.Username, authRequest.Password) authenticatedByFramework, userID, err = service.Db.Authenticate(authRequest.Username, authRequest.Password)
if err != nil { if err != nil {
service.errorHandler(err, "001-A002", w) service.errorHandler(err, "001-A002", w)
@ -103,15 +111,24 @@ func (service *Service) sessionAuthenticate(w http.ResponseWriter, r *http.Reque
// The authentication handler is provided with the authenticated response of the possible database authentication, // The authentication handler is provided with the authenticated response of the possible database authentication,
// and given a chance to override it. // and given a chance to override it.
authResponse, err = service.authenticationHandler(authRequest, authenticatedByFramework) authResponse, err = service.authenticationHandler(authRequest, sess, authenticatedByFramework)
if err != nil { if err != nil {
service.errorHandler(err, "001-F002", w) service.errorHandler(err, "001-F002", w)
return return
} }
authResponse.UserID = userID authResponse.UserID = userID
authResponse.OK = true authResponse.OK = true
sess.Authenticated = authResponse.Authenticated sess.Authenticated = authResponse.Authenticated
if authResponse.MFA != nil {
err = service.Db.SetSessionMFA(sess.UUID, authResponse.MFA)
if err != nil {
service.errorHandler(err, "001-A003", w)
return
}
}
if authResponse.Authenticated && userID > 0 { if authResponse.Authenticated && userID > 0 {
err = service.Db.SetSessionUser(sess.UUID, userID) err = service.Db.SetSessionUser(sess.UUID, userID)
if err != nil { if err != nil {
@ -136,7 +153,7 @@ func (service *Service) sessionRetrieve(w http.ResponseWriter, r *http.Request,
w.Write(out) w.Write(out)
} // }}} } // }}}
func (service *Service) retrieveSession(uuid string) (session *session.T, found bool) { // {{{ func (service *Service) RetrieveSession(uuid string) (session *session.T, found bool) { // {{{
var err error var err error
if service.Db == nil { if service.Db == nil {

View file

@ -10,6 +10,7 @@ type T struct {
Created time.Time Created time.Time
LastUsed time.Time `db:"last_used"` LastUsed time.Time `db:"last_used"`
Authenticated bool Authenticated bool
MFA any
UserID int `db:"user_id"` UserID int `db:"user_id"`
Username string Username string

View file

@ -17,7 +17,7 @@ export class Websocket {
start() {//{{{ start() {//{{{
this.connect() this.connect()
//this.loop() this.loop()
}//}}} }//}}}
loop() {//{{{ loop() {//{{{
setInterval(() => { setInterval(() => {
@ -27,6 +27,9 @@ export class Websocket {
} }
}, 1000) }, 1000)
}//}}} }//}}}
send(data) {//{{{
this.websocket.send(data)
}//}}}
connect() {//{{{ connect() {//{{{
const protocol = location.protocol; const protocol = location.protocol;

View file

@ -8,15 +8,19 @@ import (
// Standard // Standard
"log/slog" "log/slog"
"net/http" "net/http"
"strings"
"slices" "slices"
"strings"
"sync"
) )
type ReadHandler func(*ConnectionManager, *WsConnection, []byte)
type WsConnection struct { type WsConnection struct {
ConnectionManager *ConnectionManager ConnectionManager *ConnectionManager
UUID string UUID string
Conn *websocket.Conn Conn *websocket.Conn
Pruned bool Pruned bool
SessionUUID string
} }
type ConnectionManager struct { type ConnectionManager struct {
connections map[string]*WsConnection connections map[string]*WsConnection
@ -24,6 +28,8 @@ type ConnectionManager struct {
sendQueue chan SendRequest sendQueue chan SendRequest
logger *slog.Logger logger *slog.Logger
domains []string domains []string
readHandlers []ReadHandler
connSync sync.Mutex
} }
type SendRequest struct { type SendRequest struct {
WsConn *WsConnection WsConn *WsConnection
@ -63,7 +69,9 @@ func (cm *ConnectionManager) NewConnection(w http.ResponseWriter, r *http.Reques
} }
// Keep track of all connections. // Keep track of all connections.
cm.connSync.Lock()
cm.connections[wsConn.UUID] = &wsConn cm.connections[wsConn.UUID] = &wsConn
cm.connSync.Unlock()
// Successfully upgraded to a websocket connection. // Successfully upgraded to a websocket connection.
cm.logger.Info("websocket", "uuid", wsConn.UUID, "remote_addr", r.RemoteAddr) cm.logger.Info("websocket", "uuid", wsConn.UUID, "remote_addr", r.RemoteAddr)
@ -72,6 +80,9 @@ func (cm *ConnectionManager) NewConnection(w http.ResponseWriter, r *http.Reques
return &wsConn, nil return &wsConn, nil
} // }}} } // }}}
func (cm *ConnectionManager) AddMsgHandler(handler ReadHandler) { // {{{
cm.readHandlers = append(cm.readHandlers, handler)
} // }}}
// validateOrigin matches host from X-Forwarded-Host or request host against a list of configured domains. // validateOrigin matches host from X-Forwarded-Host or request host against a list of configured domains.
func (cm *ConnectionManager) validateOrigin(r *http.Request) bool { // {{{ func (cm *ConnectionManager) validateOrigin(r *http.Request) bool { // {{{
@ -90,7 +101,9 @@ func (cm *ConnectionManager) Prune(wsConn *WsConnection, err error) { // {{{
cm.logger.Info("websocket", "op", "prune", "uuid", wsConn.UUID) cm.logger.Info("websocket", "op", "prune", "uuid", wsConn.UUID)
wsConn.Conn.Close() wsConn.Conn.Close()
wsConn.Pruned = true wsConn.Pruned = true
cm.connSync.Lock()
delete(cm.connections, wsConn.UUID) delete(cm.connections, wsConn.UUID)
cm.connSync.Unlock()
} // }}} } // }}}
func (cm *ConnectionManager) ReadLoop(wsConn *WsConnection) { // {{{ func (cm *ConnectionManager) ReadLoop(wsConn *WsConnection) { // {{{
var data []byte var data []byte
@ -102,7 +115,10 @@ func (cm *ConnectionManager) ReadLoop(wsConn *WsConnection) { // {{{
} }
cm.logger.Debug("websocket", "op", "read", "data", data) cm.logger.Debug("websocket", "op", "read", "data", data)
//cm.Send(wsConn, response)
for _, handler := range cm.readHandlers {
go handler(cm, wsConn, data)
}
} }
} // }}} } // }}}
func (cm *ConnectionManager) Read(wsConn *WsConnection) ([]byte, bool) { // {{{ func (cm *ConnectionManager) Read(wsConn *WsConnection) ([]byte, bool) { // {{{
@ -118,6 +134,14 @@ func (cm *ConnectionManager) Read(wsConn *WsConnection) ([]byte, bool) { // {{{
func (cm *ConnectionManager) Send(wsConn *WsConnection, msg interface{}) { // {{{ func (cm *ConnectionManager) Send(wsConn *WsConnection, msg interface{}) { // {{{
wsConn.Conn.WriteJSON(msg) wsConn.Conn.WriteJSON(msg)
} // }}} } // }}}
func (cm *ConnectionManager) SendToSessionUUID(uuid string, msg interface{}) { // {{{
for _, wsConn := range cm.connections {
if wsConn.SessionUUID == uuid {
wsConn.Conn.WriteJSON(msg)
break
}
}
} // }}}
func (cm *ConnectionManager) Broadcast(msg interface{}) { // {{{ func (cm *ConnectionManager) Broadcast(msg interface{}) { // {{{
cm.broadcastQueue <- msg cm.broadcastQueue <- msg
} // }}} } // }}}