2024-01-05 19:59:18 +01:00
|
|
|
package ws_conn_manager
|
|
|
|
|
|
|
|
import (
|
|
|
|
// External
|
|
|
|
"github.com/google/uuid"
|
|
|
|
"github.com/gorilla/websocket"
|
|
|
|
|
|
|
|
// Standard
|
|
|
|
"log/slog"
|
|
|
|
"net/http"
|
2024-01-07 11:58:06 +01:00
|
|
|
"slices"
|
2024-02-14 14:34:36 +01:00
|
|
|
"strings"
|
2024-02-15 10:35:57 +01:00
|
|
|
"sync"
|
2024-01-05 19:59:18 +01:00
|
|
|
)
|
|
|
|
|
2024-02-14 14:34:36 +01:00
|
|
|
type ReadHandler func(*ConnectionManager, *WsConnection, []byte)
|
|
|
|
|
2024-01-05 19:59:18 +01:00
|
|
|
type WsConnection struct {
|
|
|
|
ConnectionManager *ConnectionManager
|
|
|
|
UUID string
|
|
|
|
Conn *websocket.Conn
|
|
|
|
Pruned bool
|
2024-02-14 14:34:36 +01:00
|
|
|
SessionUUID string
|
2024-01-05 19:59:18 +01:00
|
|
|
}
|
|
|
|
type ConnectionManager struct {
|
|
|
|
connections map[string]*WsConnection
|
|
|
|
broadcastQueue chan interface{}
|
|
|
|
sendQueue chan SendRequest
|
|
|
|
logger *slog.Logger
|
2024-01-07 11:58:06 +01:00
|
|
|
domains []string
|
2024-02-14 14:34:36 +01:00
|
|
|
readHandlers []ReadHandler
|
2024-02-15 10:35:57 +01:00
|
|
|
connSync sync.Mutex
|
2024-01-05 19:59:18 +01:00
|
|
|
}
|
|
|
|
type SendRequest struct {
|
|
|
|
WsConn *WsConnection
|
|
|
|
Msg interface{}
|
|
|
|
}
|
|
|
|
|
2024-01-07 11:58:06 +01:00
|
|
|
func NewConnectionManager(logger *slog.Logger, domains []string) (cm ConnectionManager) { // {{{
|
2024-01-05 19:59:18 +01:00
|
|
|
cm.connections = make(map[string]*WsConnection)
|
|
|
|
cm.sendQueue = make(chan SendRequest, 65536)
|
|
|
|
cm.broadcastQueue = make(chan interface{}, 65536)
|
|
|
|
cm.logger = logger
|
2024-01-07 11:58:06 +01:00
|
|
|
cm.domains = domains
|
2024-01-05 19:59:18 +01:00
|
|
|
return
|
|
|
|
} // }}}
|
|
|
|
|
|
|
|
// NewConnection creates a new connection, which is assigned a UUIDv4 for
|
|
|
|
// identification. This is then put into the connection collection.
|
|
|
|
func (cm *ConnectionManager) NewConnection(w http.ResponseWriter, r *http.Request) (*WsConnection, error) { // {{{
|
|
|
|
var err error
|
|
|
|
wsConn := WsConnection{
|
|
|
|
UUID: uuid.NewString(),
|
|
|
|
ConnectionManager: cm,
|
|
|
|
}
|
2024-01-07 11:58:06 +01:00
|
|
|
|
|
|
|
upgrader := websocket.Upgrader{
|
|
|
|
ReadBufferSize: 1024,
|
|
|
|
WriteBufferSize: 1024,
|
|
|
|
|
|
|
|
// CheckOrigin is to match DOMAIN constant.
|
|
|
|
// Use X-Forwarded-Server if behind proxy.
|
|
|
|
CheckOrigin: cm.validateOrigin,
|
|
|
|
}
|
|
|
|
|
2024-01-05 19:59:18 +01:00
|
|
|
wsConn.Conn, err = upgrader.Upgrade(w, r, nil)
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
|
|
|
// Keep track of all connections.
|
2024-02-15 10:35:57 +01:00
|
|
|
cm.connSync.Lock()
|
2024-01-05 19:59:18 +01:00
|
|
|
cm.connections[wsConn.UUID] = &wsConn
|
2024-02-15 10:35:57 +01:00
|
|
|
cm.connSync.Unlock()
|
2024-01-05 19:59:18 +01:00
|
|
|
|
|
|
|
// Successfully upgraded to a websocket connection.
|
|
|
|
cm.logger.Info("websocket", "uuid", wsConn.UUID, "remote_addr", r.RemoteAddr)
|
|
|
|
|
|
|
|
go cm.ReadLoop(&wsConn)
|
|
|
|
|
|
|
|
return &wsConn, nil
|
|
|
|
} // }}}
|
2024-02-14 14:34:36 +01:00
|
|
|
func (cm *ConnectionManager) AddMsgHandler(handler ReadHandler) { // {{{
|
|
|
|
cm.readHandlers = append(cm.readHandlers, handler)
|
|
|
|
} // }}}
|
2024-01-05 19:59:18 +01:00
|
|
|
|
2024-01-07 11:58:06 +01:00
|
|
|
// validateOrigin matches host from X-Forwarded-Host or request host against a list of configured domains.
|
|
|
|
func (cm *ConnectionManager) validateOrigin(r *http.Request) bool { // {{{
|
|
|
|
host := r.Header.Get("X-Forwarded-Host")
|
|
|
|
if host == "" {
|
|
|
|
components := strings.Split(r.Host, ":")
|
|
|
|
host = components[0]
|
|
|
|
}
|
|
|
|
cm.logger.Debug("websocket", "op", "new connection", "allowed", cm.domains, "host", host)
|
|
|
|
return slices.Contains(cm.domains, host)
|
|
|
|
} // }}}
|
|
|
|
|
2024-01-05 19:59:18 +01:00
|
|
|
// Prune closes an deletes connections. If this happened to be non-fatal, the
|
|
|
|
// user will just have to reconnect.
|
|
|
|
func (cm *ConnectionManager) Prune(wsConn *WsConnection, err error) { // {{{
|
|
|
|
cm.logger.Info("websocket", "op", "prune", "uuid", wsConn.UUID)
|
|
|
|
wsConn.Conn.Close()
|
|
|
|
wsConn.Pruned = true
|
2024-02-15 10:35:57 +01:00
|
|
|
cm.connSync.Lock()
|
2024-01-05 19:59:18 +01:00
|
|
|
delete(cm.connections, wsConn.UUID)
|
2024-02-15 10:35:57 +01:00
|
|
|
cm.connSync.Unlock()
|
2024-01-05 19:59:18 +01:00
|
|
|
} // }}}
|
|
|
|
func (cm *ConnectionManager) ReadLoop(wsConn *WsConnection) { // {{{
|
|
|
|
var data []byte
|
|
|
|
var ok bool
|
|
|
|
|
|
|
|
for {
|
|
|
|
if data, ok = cm.Read(wsConn); !ok {
|
|
|
|
break
|
|
|
|
}
|
|
|
|
|
|
|
|
cm.logger.Debug("websocket", "op", "read", "data", data)
|
2024-02-14 14:34:36 +01:00
|
|
|
|
|
|
|
for _, handler := range cm.readHandlers {
|
|
|
|
handler(cm, wsConn, data)
|
|
|
|
}
|
2024-01-05 19:59:18 +01:00
|
|
|
}
|
|
|
|
} // }}}
|
|
|
|
func (cm *ConnectionManager) Read(wsConn *WsConnection) ([]byte, bool) { // {{{
|
|
|
|
var err error
|
|
|
|
var requestData []byte
|
|
|
|
_, requestData, err = wsConn.Conn.ReadMessage()
|
|
|
|
if err != nil {
|
|
|
|
cm.Prune(wsConn, err)
|
|
|
|
return nil, false
|
|
|
|
}
|
|
|
|
return requestData, true
|
|
|
|
} // }}}
|
|
|
|
func (cm *ConnectionManager) Send(wsConn *WsConnection, msg interface{}) { // {{{
|
|
|
|
wsConn.Conn.WriteJSON(msg)
|
|
|
|
} // }}}
|
2024-02-14 14:34:36 +01:00
|
|
|
func (cm *ConnectionManager) SendToSessionUUID(uuid string, msg interface{}) { // {{{
|
|
|
|
for _, wsConn := range cm.connections {
|
|
|
|
if wsConn.SessionUUID == uuid {
|
|
|
|
wsConn.Conn.WriteJSON(msg)
|
|
|
|
break
|
|
|
|
}
|
|
|
|
}
|
|
|
|
} // }}}
|
2024-01-05 19:59:18 +01:00
|
|
|
func (cm *ConnectionManager) Broadcast(msg interface{}) { // {{{
|
|
|
|
cm.broadcastQueue <- msg
|
|
|
|
} // }}}
|
|
|
|
|
|
|
|
func (cm *ConnectionManager) BroadcastLoop() { // {{{
|
|
|
|
for {
|
|
|
|
msg := <-cm.broadcastQueue
|
|
|
|
for _, wsConn := range cm.connections {
|
|
|
|
cm.Send(wsConn, msg)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
} // }}}
|
|
|
|
|
|
|
|
// vim: foldmethod=marker
|