Added websocket code

This commit is contained in:
Magnus Åhall 2024-01-07 11:58:06 +01:00
parent c3571a7fdc
commit 2938da6d5a
4 changed files with 100 additions and 34 deletions

22
js_library.go Normal file
View File

@ -0,0 +1,22 @@
package webservice
import (
// Standard
"embed"
"fmt"
"net/http"
)
var (
//go:embed static/js
embedded embed.FS
)
func staticJSWebsocket(w http.ResponseWriter, r *http.Request) {
contents, err := embedded.ReadFile("static/js/websocket.js")
if err != nil {
fmt.Println(err)
return
}
w.Write(contents)
}

31
pkg.go
View File

@ -33,9 +33,9 @@ package webservice
import ( import (
// Internal // Internal
"git.gibonuddevalla.se/go/webservice/config" "git.gibonuddevalla.se/go/webservice/config"
"git.gibonuddevalla.se/go/webservice/ws_conn_manager"
"git.gibonuddevalla.se/go/webservice/database" "git.gibonuddevalla.se/go/webservice/database"
"git.gibonuddevalla.se/go/webservice/session" "git.gibonuddevalla.se/go/webservice/session"
"git.gibonuddevalla.se/go/webservice/ws_conn_manager"
// Standard // Standard
"embed" "embed"
@ -82,27 +82,32 @@ type Service struct {
type ServiceHandler func(http.ResponseWriter, *http.Request, *session.T) type ServiceHandler func(http.ResponseWriter, *http.Request, *session.T)
func New(configFilename, version string) (service *Service, err error) { // {{{ func New(configFilename, version string, logger *slog.Logger) (service *Service, err error) { // {{{
service = new(Service) service = new(Service)
service.config, err = config.New(configFilename) service.config, err = config.New(configFilename)
if err != nil { if err != nil {
return return
} }
logger.Debug("config", "config", service.config)
opts := slog.HandlerOptions{}
service.Version = version service.Version = version
service.logger = slog.New(slog.NewJSONHandler(os.Stdout, &opts)) service.logger = logger
service.sessions = make(map[string]*session.T, 128) service.sessions = make(map[string]*session.T, 128)
service.errorHandler = service.defaultErrorHandler service.errorHandler = service.defaultErrorHandler
service.authenticationHandler = service.defaultAuthenticationHandler service.authenticationHandler = service.defaultAuthenticationHandler
service.authorizationHandler = service.defaultAuthorizationHandler service.authorizationHandler = service.defaultAuthorizationHandler
service.WsConnectionManager = ws_conn_manager.NewConnectionManager(service.logger) service.WsConnectionManager = ws_conn_manager.NewConnectionManager(service.logger, service.config.Websocket.Domains)
service.Register("/_session/new", false, false, service.sessionNew) service.Register("/_session/new", false, false, service.sessionNew)
service.Register("/_session/authenticate", true, false, service.sessionAuthenticate) service.Register("/_session/authenticate", true, false, service.sessionAuthenticate)
service.Register("/_session/retrieve", true, false, service.sessionRetrieve) service.Register("/_session/retrieve", true, false, service.sessionRetrieve)
http.HandleFunc("/_ws", service.websocketHandler)
http.HandleFunc("/_ws/css_update", service.cssUpdateHandler)
http.HandleFunc("/_js/websocket.js", staticJSWebsocket)
return return
} // }}} } // }}}
@ -251,6 +256,8 @@ func (service *Service) Start() (err error) { // {{{
} }
} }
go service.WsConnectionManager.BroadcastLoop()
listen := fmt.Sprintf("%s:%d", service.config.Network.Address, service.config.Network.Port) listen := fmt.Sprintf("%s:%d", service.config.Network.Address, service.config.Network.Port)
service.logger.Info("webserver", "listen", listen) service.logger.Info("webserver", "listen", listen)
err = http.ListenAndServe(listen, nil) err = http.ListenAndServe(listen, nil)
@ -299,7 +306,8 @@ func (service *Service) StaticHandler(w http.ResponseWriter, r *http.Request, se
w.Write([]byte(err.Error())) w.Write([]byte(err.Error()))
} }
} // }}} } // }}}
func (service *Service) WebsocketHandler(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{ func (service *Service) websocketHandler(w http.ResponseWriter, r *http.Request) { // {{{
service.logger.Debug("websocket", "op", "connect")
var err error var err error
_, err = service.WsConnectionManager.NewConnection(w, r) _, err = service.WsConnectionManager.NewConnection(w, r)
@ -308,6 +316,17 @@ func (service *Service) WebsocketHandler(w http.ResponseWriter, r *http.Request,
return return
} }
} // }}} } // }}}
func (service *Service) cssUpdateHandler(w http.ResponseWriter, r *http.Request) { // {{{
service.logger.Debug("websocket", "css", "updated")
service.WsConnectionManager.Broadcast(struct {
OK bool
ID string
Op string
}{
OK: true,
Op: "css_reload",
})
} // }}}
func (service *Service) newTemplate(requestPath string) (tmpl *template.Template, err error) { // {{{ func (service *Service) newTemplate(requestPath string) (tmpl *template.Template, err error) { // {{{
// Append index.html if needed for further reading of the file // Append index.html if needed for further reading of the file

22
static/js/websocket.js Normal file
View File

@ -0,0 +1,22 @@
package webservice
import (
// Standard
"embed"
"fmt"
"net/http"
)
var (
//go:embed foo.txt
embedded embed.FS
)
func staticJSWebsocket(w http.ResponseWriter, r *http.Request) {
contents, err := embedded.ReadFile("foo.txt")
if err != nil {
fmt.Println(err)
return
}
w.Write(contents)
}

View File

@ -8,6 +8,8 @@ import (
// Standard // Standard
"log/slog" "log/slog"
"net/http" "net/http"
"strings"
"slices"
) )
type WsConnection struct { type WsConnection struct {
@ -21,39 +23,19 @@ type ConnectionManager struct {
broadcastQueue chan interface{} broadcastQueue chan interface{}
sendQueue chan SendRequest sendQueue chan SendRequest
logger *slog.Logger logger *slog.Logger
domains []string
} }
type SendRequest struct { type SendRequest struct {
WsConn *WsConnection WsConn *WsConnection
Msg interface{} Msg interface{}
} }
func validateOrigin(r *http.Request) bool { // {{{ func NewConnectionManager(logger *slog.Logger, domains []string) (cm ConnectionManager) { // {{{
/*
host := r.Header.Get("X-Forwarded-Host")
if host == "" {
components := strings.Split(r.Host, ":")
host = components[0]
}
*/
return true
} // }}}
var (
upgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
// CheckOrigin is to match DOMAIN constant.
// Use X-Forwarded-Server if behind proxy.
CheckOrigin: validateOrigin,
}
)
func NewConnectionManager(logger *slog.Logger) (cm ConnectionManager) { // {{{
cm.connections = make(map[string]*WsConnection) cm.connections = make(map[string]*WsConnection)
cm.sendQueue = make(chan SendRequest, 65536) cm.sendQueue = make(chan SendRequest, 65536)
cm.broadcastQueue = make(chan interface{}, 65536) cm.broadcastQueue = make(chan interface{}, 65536)
cm.logger = logger cm.logger = logger
cm.domains = domains
return return
} // }}} } // }}}
@ -65,6 +47,16 @@ func (cm *ConnectionManager) NewConnection(w http.ResponseWriter, r *http.Reques
UUID: uuid.NewString(), UUID: uuid.NewString(),
ConnectionManager: cm, ConnectionManager: cm,
} }
upgrader := websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
// CheckOrigin is to match DOMAIN constant.
// Use X-Forwarded-Server if behind proxy.
CheckOrigin: cm.validateOrigin,
}
wsConn.Conn, err = upgrader.Upgrade(w, r, nil) wsConn.Conn, err = upgrader.Upgrade(w, r, nil)
if err != nil { if err != nil {
return nil, err return nil, err
@ -81,6 +73,17 @@ func (cm *ConnectionManager) NewConnection(w http.ResponseWriter, r *http.Reques
return &wsConn, nil return &wsConn, nil
} // }}} } // }}}
// validateOrigin matches host from X-Forwarded-Host or request host against a list of configured domains.
func (cm *ConnectionManager) validateOrigin(r *http.Request) bool { // {{{
host := r.Header.Get("X-Forwarded-Host")
if host == "" {
components := strings.Split(r.Host, ":")
host = components[0]
}
cm.logger.Debug("websocket", "op", "new connection", "allowed", cm.domains, "host", host)
return slices.Contains(cm.domains, host)
} // }}}
// Prune closes an deletes connections. If this happened to be non-fatal, the // Prune closes an deletes connections. If this happened to be non-fatal, the
// user will just have to reconnect. // user will just have to reconnect.
func (cm *ConnectionManager) Prune(wsConn *WsConnection, err error) { // {{{ func (cm *ConnectionManager) Prune(wsConn *WsConnection, err error) { // {{{