Added MFA to sessions and better websocket management
This commit is contained in:
parent
c848ae60b5
commit
92c8ac444f
@ -12,6 +12,7 @@ import (
|
|||||||
|
|
||||||
// Standard
|
// Standard
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
)
|
)
|
||||||
@ -93,6 +94,10 @@ func webserviceSQLProvider(dbname string, version int) ([]byte, bool) { // {{{
|
|||||||
END;
|
END;
|
||||||
$$;
|
$$;
|
||||||
`,
|
`,
|
||||||
|
|
||||||
|
2: `
|
||||||
|
ALTER TABLE "_webservice"."session" ADD mfa jsonb DEFAULT '{}' NOT NULL;
|
||||||
|
`,
|
||||||
}
|
}
|
||||||
|
|
||||||
statement, found := sql[version]
|
statement, found := sql[version]
|
||||||
@ -179,7 +184,7 @@ func (db *T) NewSession(uuid string) (err error) { // {{{
|
|||||||
_, err = db.Conn.Exec("INSERT INTO _webservice.session(uuid) VALUES($1)", uuid)
|
_, err = db.Conn.Exec("INSERT INTO _webservice.session(uuid) VALUES($1)", uuid)
|
||||||
return
|
return
|
||||||
} // }}}
|
} // }}}
|
||||||
func (db *T) RetrieveSession(uuid string) (sess *session.T, err error) {// {{{
|
func (db *T) RetrieveSession(uuid string) (sess *session.T, err error) { // {{{
|
||||||
var rows *sqlx.Rows
|
var rows *sqlx.Rows
|
||||||
rows, err = db.Conn.Queryx(`
|
rows, err = db.Conn.Queryx(`
|
||||||
WITH session_data AS (
|
WITH session_data AS (
|
||||||
@ -189,13 +194,14 @@ func (db *T) RetrieveSession(uuid string) (sess *session.T, err error) {// {{{
|
|||||||
WHERE
|
WHERE
|
||||||
uuid=$1
|
uuid=$1
|
||||||
RETURNING
|
RETURNING
|
||||||
uuid, created, last_used, user_id
|
uuid, created, last_used, user_id, mfa
|
||||||
)
|
)
|
||||||
SELECT
|
SELECT
|
||||||
sd.uuid, sd.created, sd.last_used,
|
sd.uuid, sd.created, sd.last_used,
|
||||||
COALESCE(u.username, '') AS username,
|
COALESCE(u.username, '') AS username,
|
||||||
COALESCE(u.name, '') AS name,
|
COALESCE(u.name, '') AS name,
|
||||||
COALESCE(u.id, 0) AS user_id
|
COALESCE(u.id, 0) AS user_id,
|
||||||
|
mfa
|
||||||
FROM session_data sd
|
FROM session_data sd
|
||||||
LEFT JOIN _webservice.user u ON sd.user_id = u.id
|
LEFT JOIN _webservice.user u ON sd.user_id = u.id
|
||||||
`,
|
`,
|
||||||
@ -212,7 +218,7 @@ func (db *T) RetrieveSession(uuid string) (sess *session.T, err error) {// {{{
|
|||||||
sess.Authenticated = sess.UserID > 0
|
sess.Authenticated = sess.UserID > 0
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}// }}}
|
} // }}}
|
||||||
func (db *T) SetSessionUser(uuid string, userID int) (err error) { // {{{
|
func (db *T) SetSessionUser(uuid string, userID int) (err error) { // {{{
|
||||||
_, err = db.Conn.Exec(`
|
_, err = db.Conn.Exec(`
|
||||||
UPDATE _webservice.session
|
UPDATE _webservice.session
|
||||||
@ -231,12 +237,29 @@ func (db *T) SetSessionUser(uuid string, userID int) (err error) { // {{{
|
|||||||
}
|
}
|
||||||
return
|
return
|
||||||
} // }}}
|
} // }}}
|
||||||
func (db *T) UpdateUserTime(userID int) (err error) {// {{{
|
func (db *T) SetSessionMFA(uuid string, mfa any) (err error) { // {{{
|
||||||
|
mfaByte, _ := json.Marshal(mfa)
|
||||||
|
_, err = db.Conn.Exec(`
|
||||||
|
UPDATE _webservice.session
|
||||||
|
SET
|
||||||
|
mfa = $2
|
||||||
|
WHERE
|
||||||
|
uuid = $1
|
||||||
|
`,
|
||||||
|
uuid,
|
||||||
|
mfaByte,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return
|
||||||
|
} // }}}
|
||||||
|
func (db *T) UpdateUserTime(userID int) (err error) { // {{{
|
||||||
_, err = db.Conn.Exec(`UPDATE _webservice.user SET last_login=NOW() WHERE id=$1`, userID)
|
_, err = db.Conn.Exec(`UPDATE _webservice.user SET last_login=NOW() WHERE id=$1`, userID)
|
||||||
return
|
return
|
||||||
}// }}}
|
} // }}}
|
||||||
|
|
||||||
func (db *T) CreateUser(username, password, name string) (userID int64, err error) {// {{{
|
func (db *T) CreateUser(username, password, name string) (userID int64, err error) { // {{{
|
||||||
var row *sql.Row
|
var row *sql.Row
|
||||||
row = db.Conn.QueryRow(`
|
row = db.Conn.QueryRow(`
|
||||||
INSERT INTO _webservice.user(username, password, name)
|
INSERT INTO _webservice.user(username, password, name)
|
||||||
@ -260,6 +283,6 @@ func (db *T) CreateUser(username, password, name string) (userID int64, err erro
|
|||||||
|
|
||||||
err = row.Scan(&userID)
|
err = row.Scan(&userID)
|
||||||
return
|
return
|
||||||
}// }}}
|
} // }}}
|
||||||
|
|
||||||
// vim: foldmethod=marker
|
// vim: foldmethod=marker
|
||||||
|
4
pkg.go
4
pkg.go
@ -113,7 +113,7 @@ func New(configFilename, version string, logger *slog.Logger) (service *Service,
|
|||||||
return
|
return
|
||||||
} // }}}
|
} // }}}
|
||||||
|
|
||||||
func (service *Service) defaultAuthenticationHandler(req AuthenticationRequest, alreadyAuthenticated bool) (resp AuthenticationResponse, err error) { // {{{
|
func (service *Service) defaultAuthenticationHandler(req AuthenticationRequest, sess *session.T, alreadyAuthenticated bool) (resp AuthenticationResponse, err error) { // {{{
|
||||||
resp.Authenticated = alreadyAuthenticated
|
resp.Authenticated = alreadyAuthenticated
|
||||||
service.logger.Info("webservice", "op", "authentication", "username", req.Username, "authenticated", resp.Authenticated)
|
service.logger.Info("webservice", "op", "authentication", "username", req.Username, "authenticated", resp.Authenticated)
|
||||||
return
|
return
|
||||||
@ -183,7 +183,7 @@ func (service *Service) Register(path string, requireSession, requireAuthenticat
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
sess, found = service.retrieveSession(headerSessionUUID)
|
sess, found = service.RetrieveSession(headerSessionUUID)
|
||||||
if !found {
|
if !found {
|
||||||
service.errorHandler(fmt.Errorf("Session '%s' not found", headerSessionUUID), "001-0001", w)
|
service.errorHandler(fmt.Errorf("Session '%s' not found", headerSessionUUID), "001-0001", w)
|
||||||
return
|
return
|
||||||
|
22
session.go
22
session.go
@ -25,9 +25,10 @@ type AuthenticationResponse struct {
|
|||||||
OK bool
|
OK bool
|
||||||
Authenticated bool
|
Authenticated bool
|
||||||
UserID int
|
UserID int
|
||||||
|
MFA any
|
||||||
}
|
}
|
||||||
|
|
||||||
type AuthenticationHandler func(AuthenticationRequest, bool) (AuthenticationResponse, error)
|
type AuthenticationHandler func(AuthenticationRequest, *session.T, bool) (AuthenticationResponse, error)
|
||||||
type AuthorizationHandler func(*session.T, *http.Request) (bool, error)
|
type AuthorizationHandler func(*session.T, *http.Request) (bool, error)
|
||||||
|
|
||||||
func (service *Service) sessionNew(w http.ResponseWriter, r *http.Request, foo *session.T) { // {{{
|
func (service *Service) sessionNew(w http.ResponseWriter, r *http.Request, foo *session.T) { // {{{
|
||||||
@ -48,7 +49,7 @@ func (service *Service) sessionNew(w http.ResponseWriter, r *http.Request, foo *
|
|||||||
break
|
break
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
if _, found = service.retrieveSession(sess.UUID); found {
|
if _, found = service.RetrieveSession(sess.UUID); found {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -91,8 +92,8 @@ func (service *Service) sessionAuthenticate(w http.ResponseWriter, r *http.Reque
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Authenticate against webservice user table if using a database.
|
// Authenticate against webservice user table if using a database.
|
||||||
var userID int
|
var userID int = sess.UserID
|
||||||
if service.Db != nil {
|
if service.Db != nil && userID == 0 {
|
||||||
authenticatedByFramework, userID, err = service.Db.Authenticate(authRequest.Username, authRequest.Password)
|
authenticatedByFramework, userID, err = service.Db.Authenticate(authRequest.Username, authRequest.Password)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
service.errorHandler(err, "001-A002", w)
|
service.errorHandler(err, "001-A002", w)
|
||||||
@ -103,15 +104,24 @@ func (service *Service) sessionAuthenticate(w http.ResponseWriter, r *http.Reque
|
|||||||
|
|
||||||
// The authentication handler is provided with the authenticated response of the possible database authentication,
|
// The authentication handler is provided with the authenticated response of the possible database authentication,
|
||||||
// and given a chance to override it.
|
// and given a chance to override it.
|
||||||
authResponse, err = service.authenticationHandler(authRequest, authenticatedByFramework)
|
authResponse, err = service.authenticationHandler(authRequest, sess, authenticatedByFramework)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
service.errorHandler(err, "001-F002", w)
|
service.errorHandler(err, "001-F002", w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
authResponse.UserID = userID
|
authResponse.UserID = userID
|
||||||
authResponse.OK = true
|
authResponse.OK = true
|
||||||
sess.Authenticated = authResponse.Authenticated
|
sess.Authenticated = authResponse.Authenticated
|
||||||
|
|
||||||
|
if authResponse.MFA != nil {
|
||||||
|
err = service.Db.SetSessionMFA(sess.UUID, authResponse.MFA)
|
||||||
|
if err != nil {
|
||||||
|
service.errorHandler(err, "001-A003", w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if authResponse.Authenticated && userID > 0 {
|
if authResponse.Authenticated && userID > 0 {
|
||||||
err = service.Db.SetSessionUser(sess.UUID, userID)
|
err = service.Db.SetSessionUser(sess.UUID, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -136,7 +146,7 @@ func (service *Service) sessionRetrieve(w http.ResponseWriter, r *http.Request,
|
|||||||
w.Write(out)
|
w.Write(out)
|
||||||
} // }}}
|
} // }}}
|
||||||
|
|
||||||
func (service *Service) retrieveSession(uuid string) (session *session.T, found bool) { // {{{
|
func (service *Service) RetrieveSession(uuid string) (session *session.T, found bool) { // {{{
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
if service.Db == nil {
|
if service.Db == nil {
|
||||||
|
@ -10,6 +10,7 @@ type T struct {
|
|||||||
Created time.Time
|
Created time.Time
|
||||||
LastUsed time.Time `db:"last_used"`
|
LastUsed time.Time `db:"last_used"`
|
||||||
Authenticated bool
|
Authenticated bool
|
||||||
|
MFA any
|
||||||
|
|
||||||
UserID int `db:"user_id"`
|
UserID int `db:"user_id"`
|
||||||
Username string
|
Username string
|
||||||
|
@ -17,7 +17,7 @@ export class Websocket {
|
|||||||
|
|
||||||
start() {//{{{
|
start() {//{{{
|
||||||
this.connect()
|
this.connect()
|
||||||
//this.loop()
|
this.loop()
|
||||||
}//}}}
|
}//}}}
|
||||||
loop() {//{{{
|
loop() {//{{{
|
||||||
setInterval(() => {
|
setInterval(() => {
|
||||||
@ -27,6 +27,9 @@ export class Websocket {
|
|||||||
}
|
}
|
||||||
}, 1000)
|
}, 1000)
|
||||||
}//}}}
|
}//}}}
|
||||||
|
send(data) {//{{{
|
||||||
|
this.websocket.send(data)
|
||||||
|
}//}}}
|
||||||
|
|
||||||
connect() {//{{{
|
connect() {//{{{
|
||||||
const protocol = location.protocol;
|
const protocol = location.protocol;
|
||||||
|
@ -8,15 +8,18 @@ import (
|
|||||||
// Standard
|
// Standard
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
|
||||||
"slices"
|
"slices"
|
||||||
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type ReadHandler func(*ConnectionManager, *WsConnection, []byte)
|
||||||
|
|
||||||
type WsConnection struct {
|
type WsConnection struct {
|
||||||
ConnectionManager *ConnectionManager
|
ConnectionManager *ConnectionManager
|
||||||
UUID string
|
UUID string
|
||||||
Conn *websocket.Conn
|
Conn *websocket.Conn
|
||||||
Pruned bool
|
Pruned bool
|
||||||
|
SessionUUID string
|
||||||
}
|
}
|
||||||
type ConnectionManager struct {
|
type ConnectionManager struct {
|
||||||
connections map[string]*WsConnection
|
connections map[string]*WsConnection
|
||||||
@ -24,6 +27,7 @@ type ConnectionManager struct {
|
|||||||
sendQueue chan SendRequest
|
sendQueue chan SendRequest
|
||||||
logger *slog.Logger
|
logger *slog.Logger
|
||||||
domains []string
|
domains []string
|
||||||
|
readHandlers []ReadHandler
|
||||||
}
|
}
|
||||||
type SendRequest struct {
|
type SendRequest struct {
|
||||||
WsConn *WsConnection
|
WsConn *WsConnection
|
||||||
@ -72,6 +76,9 @@ func (cm *ConnectionManager) NewConnection(w http.ResponseWriter, r *http.Reques
|
|||||||
|
|
||||||
return &wsConn, nil
|
return &wsConn, nil
|
||||||
} // }}}
|
} // }}}
|
||||||
|
func (cm *ConnectionManager) AddMsgHandler(handler ReadHandler) { // {{{
|
||||||
|
cm.readHandlers = append(cm.readHandlers, handler)
|
||||||
|
} // }}}
|
||||||
|
|
||||||
// validateOrigin matches host from X-Forwarded-Host or request host against a list of configured domains.
|
// validateOrigin matches host from X-Forwarded-Host or request host against a list of configured domains.
|
||||||
func (cm *ConnectionManager) validateOrigin(r *http.Request) bool { // {{{
|
func (cm *ConnectionManager) validateOrigin(r *http.Request) bool { // {{{
|
||||||
@ -102,7 +109,10 @@ func (cm *ConnectionManager) ReadLoop(wsConn *WsConnection) { // {{{
|
|||||||
}
|
}
|
||||||
|
|
||||||
cm.logger.Debug("websocket", "op", "read", "data", data)
|
cm.logger.Debug("websocket", "op", "read", "data", data)
|
||||||
//cm.Send(wsConn, response)
|
|
||||||
|
for _, handler := range cm.readHandlers {
|
||||||
|
handler(cm, wsConn, data)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} // }}}
|
} // }}}
|
||||||
func (cm *ConnectionManager) Read(wsConn *WsConnection) ([]byte, bool) { // {{{
|
func (cm *ConnectionManager) Read(wsConn *WsConnection) ([]byte, bool) { // {{{
|
||||||
@ -118,6 +128,14 @@ func (cm *ConnectionManager) Read(wsConn *WsConnection) ([]byte, bool) { // {{{
|
|||||||
func (cm *ConnectionManager) Send(wsConn *WsConnection, msg interface{}) { // {{{
|
func (cm *ConnectionManager) Send(wsConn *WsConnection, msg interface{}) { // {{{
|
||||||
wsConn.Conn.WriteJSON(msg)
|
wsConn.Conn.WriteJSON(msg)
|
||||||
} // }}}
|
} // }}}
|
||||||
|
func (cm *ConnectionManager) SendToSessionUUID(uuid string, msg interface{}) { // {{{
|
||||||
|
for _, wsConn := range cm.connections {
|
||||||
|
if wsConn.SessionUUID == uuid {
|
||||||
|
wsConn.Conn.WriteJSON(msg)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // }}}
|
||||||
func (cm *ConnectionManager) Broadcast(msg interface{}) { // {{{
|
func (cm *ConnectionManager) Broadcast(msg interface{}) { // {{{
|
||||||
cm.broadcastQueue <- msg
|
cm.broadcastQueue <- msg
|
||||||
} // }}}
|
} // }}}
|
||||||
|
Loading…
Reference in New Issue
Block a user