From 92c8ac444ffad7a6ffc8b270e71c0227009c4c25 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Magnus=20=C3=85hall?= Date: Wed, 14 Feb 2024 14:34:36 +0100 Subject: [PATCH] Added MFA to sessions and better websocket management --- database/pkg.go | 39 +++++++++++++++++++++++++++++++-------- pkg.go | 4 ++-- session.go | 22 ++++++++++++++++------ session/pkg.go | 1 + static/js/websocket.mjs | 5 ++++- ws_conn_manager/pkg.go | 22 ++++++++++++++++++++-- 6 files changed, 74 insertions(+), 19 deletions(-) diff --git a/database/pkg.go b/database/pkg.go index 6cab492..efa3cc8 100644 --- a/database/pkg.go +++ b/database/pkg.go @@ -12,6 +12,7 @@ import ( // Standard "database/sql" + "encoding/json" "fmt" "log/slog" ) @@ -93,6 +94,10 @@ func webserviceSQLProvider(dbname string, version int) ([]byte, bool) { // {{{ END; $$; `, + + 2: ` + ALTER TABLE "_webservice"."session" ADD mfa jsonb DEFAULT '{}' NOT NULL; + `, } statement, found := sql[version] @@ -179,7 +184,7 @@ func (db *T) NewSession(uuid string) (err error) { // {{{ _, err = db.Conn.Exec("INSERT INTO _webservice.session(uuid) VALUES($1)", uuid) return } // }}} -func (db *T) RetrieveSession(uuid string) (sess *session.T, err error) {// {{{ +func (db *T) RetrieveSession(uuid string) (sess *session.T, err error) { // {{{ var rows *sqlx.Rows rows, err = db.Conn.Queryx(` WITH session_data AS ( @@ -189,13 +194,14 @@ func (db *T) RetrieveSession(uuid string) (sess *session.T, err error) {// {{{ WHERE uuid=$1 RETURNING - uuid, created, last_used, user_id + uuid, created, last_used, user_id, mfa ) SELECT sd.uuid, sd.created, sd.last_used, COALESCE(u.username, '') AS username, COALESCE(u.name, '') AS name, - COALESCE(u.id, 0) AS user_id + COALESCE(u.id, 0) AS user_id, + mfa FROM session_data sd LEFT JOIN _webservice.user u ON sd.user_id = u.id `, @@ -212,7 +218,7 @@ func (db *T) RetrieveSession(uuid string) (sess *session.T, err error) {// {{{ sess.Authenticated = sess.UserID > 0 } return -}// }}} +} // }}} func (db *T) SetSessionUser(uuid string, userID int) (err error) { // {{{ _, err = db.Conn.Exec(` UPDATE _webservice.session @@ -231,12 +237,29 @@ func (db *T) SetSessionUser(uuid string, userID int) (err error) { // {{{ } return } // }}} -func (db *T) UpdateUserTime(userID int) (err error) {// {{{ +func (db *T) SetSessionMFA(uuid string, mfa any) (err error) { // {{{ + mfaByte, _ := json.Marshal(mfa) + _, err = db.Conn.Exec(` + UPDATE _webservice.session + SET + mfa = $2 + WHERE + uuid = $1 + `, + uuid, + mfaByte, + ) + if err != nil { + return + } + return +} // }}} +func (db *T) UpdateUserTime(userID int) (err error) { // {{{ _, err = db.Conn.Exec(`UPDATE _webservice.user SET last_login=NOW() WHERE id=$1`, userID) return -}// }}} +} // }}} -func (db *T) CreateUser(username, password, name string) (userID int64, err error) {// {{{ +func (db *T) CreateUser(username, password, name string) (userID int64, err error) { // {{{ var row *sql.Row row = db.Conn.QueryRow(` INSERT INTO _webservice.user(username, password, name) @@ -260,6 +283,6 @@ func (db *T) CreateUser(username, password, name string) (userID int64, err erro err = row.Scan(&userID) return -}// }}} +} // }}} // vim: foldmethod=marker diff --git a/pkg.go b/pkg.go index 3ea745c..e51f557 100644 --- a/pkg.go +++ b/pkg.go @@ -113,7 +113,7 @@ func New(configFilename, version string, logger *slog.Logger) (service *Service, return } // }}} -func (service *Service) defaultAuthenticationHandler(req AuthenticationRequest, alreadyAuthenticated bool) (resp AuthenticationResponse, err error) { // {{{ +func (service *Service) defaultAuthenticationHandler(req AuthenticationRequest, sess *session.T, alreadyAuthenticated bool) (resp AuthenticationResponse, err error) { // {{{ resp.Authenticated = alreadyAuthenticated service.logger.Info("webservice", "op", "authentication", "username", req.Username, "authenticated", resp.Authenticated) return @@ -183,7 +183,7 @@ func (service *Service) Register(path string, requireSession, requireAuthenticat return } - sess, found = service.retrieveSession(headerSessionUUID) + sess, found = service.RetrieveSession(headerSessionUUID) if !found { service.errorHandler(fmt.Errorf("Session '%s' not found", headerSessionUUID), "001-0001", w) return diff --git a/session.go b/session.go index a2d2251..5608ab7 100644 --- a/session.go +++ b/session.go @@ -25,9 +25,10 @@ type AuthenticationResponse struct { OK bool Authenticated bool UserID int + MFA any } -type AuthenticationHandler func(AuthenticationRequest, bool) (AuthenticationResponse, error) +type AuthenticationHandler func(AuthenticationRequest, *session.T, bool) (AuthenticationResponse, error) type AuthorizationHandler func(*session.T, *http.Request) (bool, error) func (service *Service) sessionNew(w http.ResponseWriter, r *http.Request, foo *session.T) { // {{{ @@ -48,7 +49,7 @@ func (service *Service) sessionNew(w http.ResponseWriter, r *http.Request, foo * break } else { - if _, found = service.retrieveSession(sess.UUID); found { + if _, found = service.RetrieveSession(sess.UUID); found { continue } @@ -91,8 +92,8 @@ func (service *Service) sessionAuthenticate(w http.ResponseWriter, r *http.Reque } // Authenticate against webservice user table if using a database. - var userID int - if service.Db != nil { + var userID int = sess.UserID + if service.Db != nil && userID == 0 { authenticatedByFramework, userID, err = service.Db.Authenticate(authRequest.Username, authRequest.Password) if err != nil { service.errorHandler(err, "001-A002", w) @@ -103,15 +104,24 @@ func (service *Service) sessionAuthenticate(w http.ResponseWriter, r *http.Reque // The authentication handler is provided with the authenticated response of the possible database authentication, // and given a chance to override it. - authResponse, err = service.authenticationHandler(authRequest, authenticatedByFramework) + authResponse, err = service.authenticationHandler(authRequest, sess, authenticatedByFramework) if err != nil { service.errorHandler(err, "001-F002", w) return } + authResponse.UserID = userID authResponse.OK = true sess.Authenticated = authResponse.Authenticated + if authResponse.MFA != nil { + err = service.Db.SetSessionMFA(sess.UUID, authResponse.MFA) + if err != nil { + service.errorHandler(err, "001-A003", w) + return + } + } + if authResponse.Authenticated && userID > 0 { err = service.Db.SetSessionUser(sess.UUID, userID) if err != nil { @@ -136,7 +146,7 @@ func (service *Service) sessionRetrieve(w http.ResponseWriter, r *http.Request, w.Write(out) } // }}} -func (service *Service) retrieveSession(uuid string) (session *session.T, found bool) { // {{{ +func (service *Service) RetrieveSession(uuid string) (session *session.T, found bool) { // {{{ var err error if service.Db == nil { diff --git a/session/pkg.go b/session/pkg.go index c193d60..ea8a894 100644 --- a/session/pkg.go +++ b/session/pkg.go @@ -10,6 +10,7 @@ type T struct { Created time.Time LastUsed time.Time `db:"last_used"` Authenticated bool + MFA any UserID int `db:"user_id"` Username string diff --git a/static/js/websocket.mjs b/static/js/websocket.mjs index 459bae1..01276b8 100644 --- a/static/js/websocket.mjs +++ b/static/js/websocket.mjs @@ -17,7 +17,7 @@ export class Websocket { start() {//{{{ this.connect() - //this.loop() + this.loop() }//}}} loop() {//{{{ setInterval(() => { @@ -27,6 +27,9 @@ export class Websocket { } }, 1000) }//}}} + send(data) {//{{{ + this.websocket.send(data) + }//}}} connect() {//{{{ const protocol = location.protocol; diff --git a/ws_conn_manager/pkg.go b/ws_conn_manager/pkg.go index f41aab6..386482e 100644 --- a/ws_conn_manager/pkg.go +++ b/ws_conn_manager/pkg.go @@ -8,15 +8,18 @@ import ( // Standard "log/slog" "net/http" - "strings" "slices" + "strings" ) +type ReadHandler func(*ConnectionManager, *WsConnection, []byte) + type WsConnection struct { ConnectionManager *ConnectionManager UUID string Conn *websocket.Conn Pruned bool + SessionUUID string } type ConnectionManager struct { connections map[string]*WsConnection @@ -24,6 +27,7 @@ type ConnectionManager struct { sendQueue chan SendRequest logger *slog.Logger domains []string + readHandlers []ReadHandler } type SendRequest struct { WsConn *WsConnection @@ -72,6 +76,9 @@ func (cm *ConnectionManager) NewConnection(w http.ResponseWriter, r *http.Reques return &wsConn, nil } // }}} +func (cm *ConnectionManager) AddMsgHandler(handler ReadHandler) { // {{{ + cm.readHandlers = append(cm.readHandlers, handler) +} // }}} // validateOrigin matches host from X-Forwarded-Host or request host against a list of configured domains. func (cm *ConnectionManager) validateOrigin(r *http.Request) bool { // {{{ @@ -102,7 +109,10 @@ func (cm *ConnectionManager) ReadLoop(wsConn *WsConnection) { // {{{ } cm.logger.Debug("websocket", "op", "read", "data", data) - //cm.Send(wsConn, response) + + for _, handler := range cm.readHandlers { + handler(cm, wsConn, data) + } } } // }}} func (cm *ConnectionManager) Read(wsConn *WsConnection) ([]byte, bool) { // {{{ @@ -118,6 +128,14 @@ func (cm *ConnectionManager) Read(wsConn *WsConnection) ([]byte, bool) { // {{{ func (cm *ConnectionManager) Send(wsConn *WsConnection, msg interface{}) { // {{{ wsConn.Conn.WriteJSON(msg) } // }}} +func (cm *ConnectionManager) SendToSessionUUID(uuid string, msg interface{}) { // {{{ + for _, wsConn := range cm.connections { + if wsConn.SessionUUID == uuid { + wsConn.Conn.WriteJSON(msg) + break + } + } +} // }}} func (cm *ConnectionManager) Broadcast(msg interface{}) { // {{{ cm.broadcastQueue <- msg } // }}}