package ws_conn_manager import ( // External "github.com/google/uuid" "github.com/gorilla/websocket" // Standard "log/slog" "net/http" "strings" "slices" ) type WsConnection struct { ConnectionManager *ConnectionManager UUID string Conn *websocket.Conn Pruned bool } type ConnectionManager struct { connections map[string]*WsConnection broadcastQueue chan interface{} sendQueue chan SendRequest logger *slog.Logger domains []string } type SendRequest struct { WsConn *WsConnection Msg interface{} } func NewConnectionManager(logger *slog.Logger, domains []string) (cm ConnectionManager) { // {{{ cm.connections = make(map[string]*WsConnection) cm.sendQueue = make(chan SendRequest, 65536) cm.broadcastQueue = make(chan interface{}, 65536) cm.logger = logger cm.domains = domains return } // }}} // NewConnection creates a new connection, which is assigned a UUIDv4 for // identification. This is then put into the connection collection. func (cm *ConnectionManager) NewConnection(w http.ResponseWriter, r *http.Request) (*WsConnection, error) { // {{{ var err error wsConn := WsConnection{ UUID: uuid.NewString(), ConnectionManager: cm, } upgrader := websocket.Upgrader{ ReadBufferSize: 1024, WriteBufferSize: 1024, // CheckOrigin is to match DOMAIN constant. // Use X-Forwarded-Server if behind proxy. CheckOrigin: cm.validateOrigin, } wsConn.Conn, err = upgrader.Upgrade(w, r, nil) if err != nil { return nil, err } // Keep track of all connections. cm.connections[wsConn.UUID] = &wsConn // Successfully upgraded to a websocket connection. cm.logger.Info("websocket", "uuid", wsConn.UUID, "remote_addr", r.RemoteAddr) go cm.ReadLoop(&wsConn) return &wsConn, nil } // }}} // validateOrigin matches host from X-Forwarded-Host or request host against a list of configured domains. func (cm *ConnectionManager) validateOrigin(r *http.Request) bool { // {{{ host := r.Header.Get("X-Forwarded-Host") if host == "" { components := strings.Split(r.Host, ":") host = components[0] } cm.logger.Debug("websocket", "op", "new connection", "allowed", cm.domains, "host", host) return slices.Contains(cm.domains, host) } // }}} // Prune closes an deletes connections. If this happened to be non-fatal, the // user will just have to reconnect. func (cm *ConnectionManager) Prune(wsConn *WsConnection, err error) { // {{{ cm.logger.Info("websocket", "op", "prune", "uuid", wsConn.UUID) wsConn.Conn.Close() wsConn.Pruned = true delete(cm.connections, wsConn.UUID) } // }}} func (cm *ConnectionManager) ReadLoop(wsConn *WsConnection) { // {{{ var data []byte var ok bool for { if data, ok = cm.Read(wsConn); !ok { break } cm.logger.Debug("websocket", "op", "read", "data", data) //cm.Send(wsConn, response) } } // }}} func (cm *ConnectionManager) Read(wsConn *WsConnection) ([]byte, bool) { // {{{ var err error var requestData []byte _, requestData, err = wsConn.Conn.ReadMessage() if err != nil { cm.Prune(wsConn, err) return nil, false } return requestData, true } // }}} func (cm *ConnectionManager) Send(wsConn *WsConnection, msg interface{}) { // {{{ wsConn.Conn.WriteJSON(msg) } // }}} func (cm *ConnectionManager) Broadcast(msg interface{}) { // {{{ cm.broadcastQueue <- msg } // }}} func (cm *ConnectionManager) BroadcastLoop() { // {{{ for { msg := <-cm.broadcastQueue for _, wsConn := range cm.connections { cm.Send(wsConn, msg) } } } // }}} // vim: foldmethod=marker