package ws_conn_manager import ( // External "github.com/google/uuid" "github.com/gorilla/websocket" // Standard "log/slog" "net/http" "slices" "strings" "sync" ) type ReadHandler func(*ConnectionManager, *WsConnection, []byte) type WsConnection struct { ConnectionManager *ConnectionManager UUID string Conn *websocket.Conn Pruned bool SessionUUID string } type ConnectionManager struct { connections map[string]*WsConnection broadcastQueue chan interface{} sendQueue chan SendRequest logger *slog.Logger domains []string readHandlers []ReadHandler connSync sync.Mutex } type SendRequest struct { WsConn *WsConnection Msg interface{} } func NewConnectionManager(logger *slog.Logger, domains []string) (cm ConnectionManager) { // {{{ cm.connections = make(map[string]*WsConnection) cm.sendQueue = make(chan SendRequest, 65536) cm.broadcastQueue = make(chan interface{}, 65536) cm.logger = logger cm.domains = domains return } // }}} // NewConnection creates a new connection, which is assigned a UUIDv4 for // identification. This is then put into the connection collection. func (cm *ConnectionManager) NewConnection(w http.ResponseWriter, r *http.Request) (*WsConnection, error) { // {{{ var err error wsConn := WsConnection{ UUID: uuid.NewString(), ConnectionManager: cm, } upgrader := websocket.Upgrader{ ReadBufferSize: 1024, WriteBufferSize: 1024, // CheckOrigin is to match DOMAIN constant. // Use X-Forwarded-Server if behind proxy. CheckOrigin: cm.validateOrigin, } wsConn.Conn, err = upgrader.Upgrade(w, r, nil) if err != nil { return nil, err } // Keep track of all connections. cm.connSync.Lock() cm.connections[wsConn.UUID] = &wsConn cm.connSync.Unlock() // Successfully upgraded to a websocket connection. cm.logger.Info("websocket", "uuid", wsConn.UUID, "remote_addr", r.RemoteAddr) go cm.ReadLoop(&wsConn) return &wsConn, nil } // }}} func (cm *ConnectionManager) AddMsgHandler(handler ReadHandler) { // {{{ cm.readHandlers = append(cm.readHandlers, handler) } // }}} // validateOrigin matches host from X-Forwarded-Host or request host against a list of configured domains. func (cm *ConnectionManager) validateOrigin(r *http.Request) bool { // {{{ host := r.Header.Get("X-Forwarded-Host") if host == "" { components := strings.Split(r.Host, ":") host = components[0] } cm.logger.Debug("websocket", "op", "new connection", "allowed", cm.domains, "host", host) return slices.Contains(cm.domains, host) } // }}} // Prune closes an deletes connections. If this happened to be non-fatal, the // user will just have to reconnect. func (cm *ConnectionManager) Prune(wsConn *WsConnection, err error) { // {{{ cm.logger.Info("websocket", "op", "prune", "uuid", wsConn.UUID) wsConn.Conn.Close() wsConn.Pruned = true cm.connSync.Lock() delete(cm.connections, wsConn.UUID) cm.connSync.Unlock() } // }}} func (cm *ConnectionManager) ReadLoop(wsConn *WsConnection) { // {{{ var data []byte var ok bool for { if data, ok = cm.Read(wsConn); !ok { break } cm.logger.Debug("websocket", "op", "read", "data", data) for _, handler := range cm.readHandlers { handler(cm, wsConn, data) } } } // }}} func (cm *ConnectionManager) Read(wsConn *WsConnection) ([]byte, bool) { // {{{ var err error var requestData []byte _, requestData, err = wsConn.Conn.ReadMessage() if err != nil { cm.Prune(wsConn, err) return nil, false } return requestData, true } // }}} func (cm *ConnectionManager) Send(wsConn *WsConnection, msg interface{}) { // {{{ wsConn.Conn.WriteJSON(msg) } // }}} func (cm *ConnectionManager) SendToSessionUUID(uuid string, msg interface{}) { // {{{ for _, wsConn := range cm.connections { if wsConn.SessionUUID == uuid { wsConn.Conn.WriteJSON(msg) break } } } // }}} func (cm *ConnectionManager) Broadcast(msg interface{}) { // {{{ cm.broadcastQueue <- msg } // }}} func (cm *ConnectionManager) BroadcastLoop() { // {{{ for { msg := <-cm.broadcastQueue for _, wsConn := range cm.connections { cm.Send(wsConn, msg) } } } // }}} // vim: foldmethod=marker