Added MFA to sessions and better websocket management

This commit is contained in:
Magnus Åhall 2024-02-14 14:34:36 +01:00
parent c848ae60b5
commit 92c8ac444f
6 changed files with 74 additions and 19 deletions

View File

@ -12,6 +12,7 @@ import (
// Standard // Standard
"database/sql" "database/sql"
"encoding/json"
"fmt" "fmt"
"log/slog" "log/slog"
) )
@ -93,6 +94,10 @@ func webserviceSQLProvider(dbname string, version int) ([]byte, bool) { // {{{
END; END;
$$; $$;
`, `,
2: `
ALTER TABLE "_webservice"."session" ADD mfa jsonb DEFAULT '{}' NOT NULL;
`,
} }
statement, found := sql[version] statement, found := sql[version]
@ -189,13 +194,14 @@ func (db *T) RetrieveSession(uuid string) (sess *session.T, err error) {// {{{
WHERE WHERE
uuid=$1 uuid=$1
RETURNING RETURNING
uuid, created, last_used, user_id uuid, created, last_used, user_id, mfa
) )
SELECT SELECT
sd.uuid, sd.created, sd.last_used, sd.uuid, sd.created, sd.last_used,
COALESCE(u.username, '') AS username, COALESCE(u.username, '') AS username,
COALESCE(u.name, '') AS name, COALESCE(u.name, '') AS name,
COALESCE(u.id, 0) AS user_id COALESCE(u.id, 0) AS user_id,
mfa
FROM session_data sd FROM session_data sd
LEFT JOIN _webservice.user u ON sd.user_id = u.id LEFT JOIN _webservice.user u ON sd.user_id = u.id
`, `,
@ -231,6 +237,23 @@ func (db *T) SetSessionUser(uuid string, userID int) (err error) { // {{{
} }
return return
} // }}} } // }}}
func (db *T) SetSessionMFA(uuid string, mfa any) (err error) { // {{{
mfaByte, _ := json.Marshal(mfa)
_, err = db.Conn.Exec(`
UPDATE _webservice.session
SET
mfa = $2
WHERE
uuid = $1
`,
uuid,
mfaByte,
)
if err != nil {
return
}
return
} // }}}
func (db *T) UpdateUserTime(userID int) (err error) { // {{{ func (db *T) UpdateUserTime(userID int) (err error) { // {{{
_, err = db.Conn.Exec(`UPDATE _webservice.user SET last_login=NOW() WHERE id=$1`, userID) _, err = db.Conn.Exec(`UPDATE _webservice.user SET last_login=NOW() WHERE id=$1`, userID)
return return

4
pkg.go
View File

@ -113,7 +113,7 @@ func New(configFilename, version string, logger *slog.Logger) (service *Service,
return return
} // }}} } // }}}
func (service *Service) defaultAuthenticationHandler(req AuthenticationRequest, alreadyAuthenticated bool) (resp AuthenticationResponse, err error) { // {{{ func (service *Service) defaultAuthenticationHandler(req AuthenticationRequest, sess *session.T, alreadyAuthenticated bool) (resp AuthenticationResponse, err error) { // {{{
resp.Authenticated = alreadyAuthenticated resp.Authenticated = alreadyAuthenticated
service.logger.Info("webservice", "op", "authentication", "username", req.Username, "authenticated", resp.Authenticated) service.logger.Info("webservice", "op", "authentication", "username", req.Username, "authenticated", resp.Authenticated)
return return
@ -183,7 +183,7 @@ func (service *Service) Register(path string, requireSession, requireAuthenticat
return return
} }
sess, found = service.retrieveSession(headerSessionUUID) sess, found = service.RetrieveSession(headerSessionUUID)
if !found { if !found {
service.errorHandler(fmt.Errorf("Session '%s' not found", headerSessionUUID), "001-0001", w) service.errorHandler(fmt.Errorf("Session '%s' not found", headerSessionUUID), "001-0001", w)
return return

View File

@ -25,9 +25,10 @@ type AuthenticationResponse struct {
OK bool OK bool
Authenticated bool Authenticated bool
UserID int UserID int
MFA any
} }
type AuthenticationHandler func(AuthenticationRequest, bool) (AuthenticationResponse, error) type AuthenticationHandler func(AuthenticationRequest, *session.T, bool) (AuthenticationResponse, error)
type AuthorizationHandler func(*session.T, *http.Request) (bool, error) type AuthorizationHandler func(*session.T, *http.Request) (bool, error)
func (service *Service) sessionNew(w http.ResponseWriter, r *http.Request, foo *session.T) { // {{{ func (service *Service) sessionNew(w http.ResponseWriter, r *http.Request, foo *session.T) { // {{{
@ -48,7 +49,7 @@ func (service *Service) sessionNew(w http.ResponseWriter, r *http.Request, foo *
break break
} else { } else {
if _, found = service.retrieveSession(sess.UUID); found { if _, found = service.RetrieveSession(sess.UUID); found {
continue continue
} }
@ -91,8 +92,8 @@ func (service *Service) sessionAuthenticate(w http.ResponseWriter, r *http.Reque
} }
// Authenticate against webservice user table if using a database. // Authenticate against webservice user table if using a database.
var userID int var userID int = sess.UserID
if service.Db != nil { if service.Db != nil && userID == 0 {
authenticatedByFramework, userID, err = service.Db.Authenticate(authRequest.Username, authRequest.Password) authenticatedByFramework, userID, err = service.Db.Authenticate(authRequest.Username, authRequest.Password)
if err != nil { if err != nil {
service.errorHandler(err, "001-A002", w) service.errorHandler(err, "001-A002", w)
@ -103,15 +104,24 @@ func (service *Service) sessionAuthenticate(w http.ResponseWriter, r *http.Reque
// The authentication handler is provided with the authenticated response of the possible database authentication, // The authentication handler is provided with the authenticated response of the possible database authentication,
// and given a chance to override it. // and given a chance to override it.
authResponse, err = service.authenticationHandler(authRequest, authenticatedByFramework) authResponse, err = service.authenticationHandler(authRequest, sess, authenticatedByFramework)
if err != nil { if err != nil {
service.errorHandler(err, "001-F002", w) service.errorHandler(err, "001-F002", w)
return return
} }
authResponse.UserID = userID authResponse.UserID = userID
authResponse.OK = true authResponse.OK = true
sess.Authenticated = authResponse.Authenticated sess.Authenticated = authResponse.Authenticated
if authResponse.MFA != nil {
err = service.Db.SetSessionMFA(sess.UUID, authResponse.MFA)
if err != nil {
service.errorHandler(err, "001-A003", w)
return
}
}
if authResponse.Authenticated && userID > 0 { if authResponse.Authenticated && userID > 0 {
err = service.Db.SetSessionUser(sess.UUID, userID) err = service.Db.SetSessionUser(sess.UUID, userID)
if err != nil { if err != nil {
@ -136,7 +146,7 @@ func (service *Service) sessionRetrieve(w http.ResponseWriter, r *http.Request,
w.Write(out) w.Write(out)
} // }}} } // }}}
func (service *Service) retrieveSession(uuid string) (session *session.T, found bool) { // {{{ func (service *Service) RetrieveSession(uuid string) (session *session.T, found bool) { // {{{
var err error var err error
if service.Db == nil { if service.Db == nil {

View File

@ -10,6 +10,7 @@ type T struct {
Created time.Time Created time.Time
LastUsed time.Time `db:"last_used"` LastUsed time.Time `db:"last_used"`
Authenticated bool Authenticated bool
MFA any
UserID int `db:"user_id"` UserID int `db:"user_id"`
Username string Username string

View File

@ -17,7 +17,7 @@ export class Websocket {
start() {//{{{ start() {//{{{
this.connect() this.connect()
//this.loop() this.loop()
}//}}} }//}}}
loop() {//{{{ loop() {//{{{
setInterval(() => { setInterval(() => {
@ -27,6 +27,9 @@ export class Websocket {
} }
}, 1000) }, 1000)
}//}}} }//}}}
send(data) {//{{{
this.websocket.send(data)
}//}}}
connect() {//{{{ connect() {//{{{
const protocol = location.protocol; const protocol = location.protocol;

View File

@ -8,15 +8,18 @@ import (
// Standard // Standard
"log/slog" "log/slog"
"net/http" "net/http"
"strings"
"slices" "slices"
"strings"
) )
type ReadHandler func(*ConnectionManager, *WsConnection, []byte)
type WsConnection struct { type WsConnection struct {
ConnectionManager *ConnectionManager ConnectionManager *ConnectionManager
UUID string UUID string
Conn *websocket.Conn Conn *websocket.Conn
Pruned bool Pruned bool
SessionUUID string
} }
type ConnectionManager struct { type ConnectionManager struct {
connections map[string]*WsConnection connections map[string]*WsConnection
@ -24,6 +27,7 @@ type ConnectionManager struct {
sendQueue chan SendRequest sendQueue chan SendRequest
logger *slog.Logger logger *slog.Logger
domains []string domains []string
readHandlers []ReadHandler
} }
type SendRequest struct { type SendRequest struct {
WsConn *WsConnection WsConn *WsConnection
@ -72,6 +76,9 @@ func (cm *ConnectionManager) NewConnection(w http.ResponseWriter, r *http.Reques
return &wsConn, nil return &wsConn, nil
} // }}} } // }}}
func (cm *ConnectionManager) AddMsgHandler(handler ReadHandler) { // {{{
cm.readHandlers = append(cm.readHandlers, handler)
} // }}}
// validateOrigin matches host from X-Forwarded-Host or request host against a list of configured domains. // validateOrigin matches host from X-Forwarded-Host or request host against a list of configured domains.
func (cm *ConnectionManager) validateOrigin(r *http.Request) bool { // {{{ func (cm *ConnectionManager) validateOrigin(r *http.Request) bool { // {{{
@ -102,7 +109,10 @@ func (cm *ConnectionManager) ReadLoop(wsConn *WsConnection) { // {{{
} }
cm.logger.Debug("websocket", "op", "read", "data", data) cm.logger.Debug("websocket", "op", "read", "data", data)
//cm.Send(wsConn, response)
for _, handler := range cm.readHandlers {
handler(cm, wsConn, data)
}
} }
} // }}} } // }}}
func (cm *ConnectionManager) Read(wsConn *WsConnection) ([]byte, bool) { // {{{ func (cm *ConnectionManager) Read(wsConn *WsConnection) ([]byte, bool) { // {{{
@ -118,6 +128,14 @@ func (cm *ConnectionManager) Read(wsConn *WsConnection) ([]byte, bool) { // {{{
func (cm *ConnectionManager) Send(wsConn *WsConnection, msg interface{}) { // {{{ func (cm *ConnectionManager) Send(wsConn *WsConnection, msg interface{}) { // {{{
wsConn.Conn.WriteJSON(msg) wsConn.Conn.WriteJSON(msg)
} // }}} } // }}}
func (cm *ConnectionManager) SendToSessionUUID(uuid string, msg interface{}) { // {{{
for _, wsConn := range cm.connections {
if wsConn.SessionUUID == uuid {
wsConn.Conn.WriteJSON(msg)
break
}
}
} // }}}
func (cm *ConnectionManager) Broadcast(msg interface{}) { // {{{ func (cm *ConnectionManager) Broadcast(msg interface{}) { // {{{
cm.broadcastQueue <- msg cm.broadcastQueue <- msg
} // }}} } // }}}