diff --git a/js_library.go b/js_library.go new file mode 100644 index 0000000..a02bda2 --- /dev/null +++ b/js_library.go @@ -0,0 +1,22 @@ +package webservice + +import ( + // Standard + "embed" + "fmt" + "net/http" +) + +var ( + //go:embed static/js + embedded embed.FS +) + +func staticJSWebsocket(w http.ResponseWriter, r *http.Request) { + contents, err := embedded.ReadFile("static/js/websocket.js") + if err != nil { + fmt.Println(err) + return + } + w.Write(contents) +} diff --git a/pkg.go b/pkg.go index 4eebce1..d87d33f 100644 --- a/pkg.go +++ b/pkg.go @@ -33,9 +33,9 @@ package webservice import ( // Internal "git.gibonuddevalla.se/go/webservice/config" - "git.gibonuddevalla.se/go/webservice/ws_conn_manager" "git.gibonuddevalla.se/go/webservice/database" "git.gibonuddevalla.se/go/webservice/session" + "git.gibonuddevalla.se/go/webservice/ws_conn_manager" // Standard "embed" @@ -62,11 +62,11 @@ type ServiceError struct { } type Service struct { - logger *slog.Logger - sessions map[string]*session.T - config config.Config - Db *database.T - Version string + logger *slog.Logger + sessions map[string]*session.T + config config.Config + Db *database.T + Version string WsConnectionManager ws_conn_manager.ConnectionManager errorHandler ErrorHandler @@ -82,27 +82,32 @@ type Service struct { type ServiceHandler func(http.ResponseWriter, *http.Request, *session.T) -func New(configFilename, version string) (service *Service, err error) { // {{{ +func New(configFilename, version string, logger *slog.Logger) (service *Service, err error) { // {{{ service = new(Service) service.config, err = config.New(configFilename) if err != nil { return } + logger.Debug("config", "config", service.config) - opts := slog.HandlerOptions{} service.Version = version - service.logger = slog.New(slog.NewJSONHandler(os.Stdout, &opts)) + service.logger = logger service.sessions = make(map[string]*session.T, 128) service.errorHandler = service.defaultErrorHandler service.authenticationHandler = service.defaultAuthenticationHandler service.authorizationHandler = service.defaultAuthorizationHandler - service.WsConnectionManager = ws_conn_manager.NewConnectionManager(service.logger) + service.WsConnectionManager = ws_conn_manager.NewConnectionManager(service.logger, service.config.Websocket.Domains) service.Register("/_session/new", false, false, service.sessionNew) service.Register("/_session/authenticate", true, false, service.sessionAuthenticate) service.Register("/_session/retrieve", true, false, service.sessionRetrieve) + http.HandleFunc("/_ws", service.websocketHandler) + http.HandleFunc("/_ws/css_update", service.cssUpdateHandler) + + http.HandleFunc("/_js/websocket.js", staticJSWebsocket) + return } // }}} @@ -251,6 +256,8 @@ func (service *Service) Start() (err error) { // {{{ } } + go service.WsConnectionManager.BroadcastLoop() + listen := fmt.Sprintf("%s:%d", service.config.Network.Address, service.config.Network.Port) service.logger.Info("webserver", "listen", listen) err = http.ListenAndServe(listen, nil) @@ -299,7 +306,8 @@ func (service *Service) StaticHandler(w http.ResponseWriter, r *http.Request, se w.Write([]byte(err.Error())) } } // }}} -func (service *Service) WebsocketHandler(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{ +func (service *Service) websocketHandler(w http.ResponseWriter, r *http.Request) { // {{{ + service.logger.Debug("websocket", "op", "connect") var err error _, err = service.WsConnectionManager.NewConnection(w, r) @@ -308,6 +316,17 @@ func (service *Service) WebsocketHandler(w http.ResponseWriter, r *http.Request, return } } // }}} +func (service *Service) cssUpdateHandler(w http.ResponseWriter, r *http.Request) { // {{{ + service.logger.Debug("websocket", "css", "updated") + service.WsConnectionManager.Broadcast(struct { + OK bool + ID string + Op string + }{ + OK: true, + Op: "css_reload", + }) +} // }}} func (service *Service) newTemplate(requestPath string) (tmpl *template.Template, err error) { // {{{ // Append index.html if needed for further reading of the file diff --git a/static/js/websocket.js b/static/js/websocket.js new file mode 100644 index 0000000..b3e92d7 --- /dev/null +++ b/static/js/websocket.js @@ -0,0 +1,22 @@ +package webservice + +import ( + // Standard + "embed" + "fmt" + "net/http" +) + +var ( + //go:embed foo.txt + embedded embed.FS +) + +func staticJSWebsocket(w http.ResponseWriter, r *http.Request) { + contents, err := embedded.ReadFile("foo.txt") + if err != nil { + fmt.Println(err) + return + } + w.Write(contents) +} diff --git a/ws_conn_manager/pkg.go b/ws_conn_manager/pkg.go index e366d52..f41aab6 100644 --- a/ws_conn_manager/pkg.go +++ b/ws_conn_manager/pkg.go @@ -8,6 +8,8 @@ import ( // Standard "log/slog" "net/http" + "strings" + "slices" ) type WsConnection struct { @@ -21,39 +23,19 @@ type ConnectionManager struct { broadcastQueue chan interface{} sendQueue chan SendRequest logger *slog.Logger + domains []string } type SendRequest struct { WsConn *WsConnection Msg interface{} } -func validateOrigin(r *http.Request) bool { // {{{ - /* - host := r.Header.Get("X-Forwarded-Host") - if host == "" { - components := strings.Split(r.Host, ":") - host = components[0] - } - */ - return true -} // }}} - -var ( - upgrader = websocket.Upgrader{ - ReadBufferSize: 1024, - WriteBufferSize: 1024, - - // CheckOrigin is to match DOMAIN constant. - // Use X-Forwarded-Server if behind proxy. - CheckOrigin: validateOrigin, - } -) - -func NewConnectionManager(logger *slog.Logger) (cm ConnectionManager) { // {{{ +func NewConnectionManager(logger *slog.Logger, domains []string) (cm ConnectionManager) { // {{{ cm.connections = make(map[string]*WsConnection) cm.sendQueue = make(chan SendRequest, 65536) cm.broadcastQueue = make(chan interface{}, 65536) cm.logger = logger + cm.domains = domains return } // }}} @@ -65,6 +47,16 @@ func (cm *ConnectionManager) NewConnection(w http.ResponseWriter, r *http.Reques UUID: uuid.NewString(), ConnectionManager: cm, } + + upgrader := websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + + // CheckOrigin is to match DOMAIN constant. + // Use X-Forwarded-Server if behind proxy. + CheckOrigin: cm.validateOrigin, + } + wsConn.Conn, err = upgrader.Upgrade(w, r, nil) if err != nil { return nil, err @@ -81,6 +73,17 @@ func (cm *ConnectionManager) NewConnection(w http.ResponseWriter, r *http.Reques return &wsConn, nil } // }}} +// validateOrigin matches host from X-Forwarded-Host or request host against a list of configured domains. +func (cm *ConnectionManager) validateOrigin(r *http.Request) bool { // {{{ + host := r.Header.Get("X-Forwarded-Host") + if host == "" { + components := strings.Split(r.Host, ":") + host = components[0] + } + cm.logger.Debug("websocket", "op", "new connection", "allowed", cm.domains, "host", host) + return slices.Contains(cm.domains, host) +} // }}} + // Prune closes an deletes connections. If this happened to be non-fatal, the // user will just have to reconnect. func (cm *ConnectionManager) Prune(wsConn *WsConnection, err error) { // {{{