From 447aec742c963bcb7c3e9945dca58ddb28f4dc7b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Magnus=20=C3=85hall?= Date: Fri, 5 Jan 2024 19:59:18 +0100 Subject: [PATCH] Added static handler and websocket connection manager --- pkg.go | 128 +++++++++++++++++++++++++++++++++++++++- session.go | 18 +++++- ws_conn_manager/pkg.go | 131 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 272 insertions(+), 5 deletions(-) create mode 100644 ws_conn_manager/pkg.go diff --git a/pkg.go b/pkg.go index 6941b64..ce321d7 100644 --- a/pkg.go +++ b/pkg.go @@ -33,16 +33,21 @@ package webservice import ( // Internal "git.gibonuddevalla.se/go/webservice/config" + "git.gibonuddevalla.se/go/webservice/ws_conn_manager" "git.gibonuddevalla.se/go/webservice/database" "git.gibonuddevalla.se/go/webservice/session" // Standard + "embed" "encoding/json" "errors" "fmt" + "html/template" + "io/fs" "log/slog" "net/http" "os" + "regexp" ) const VERSION = "v0.1.0" @@ -60,15 +65,23 @@ type Service struct { sessions map[string]*session.T config config.Config Db *database.T + Version string + WsConnectionManager ws_conn_manager.ConnectionManager errorHandler ErrorHandler authenticationHandler AuthenticationHandler authorizationHandler AuthorizationHandler + + staticSubFs fs.FS + useStaticDirectory bool + staticDirectory string + staticEmbeddedFileserver http.Handler + staticLocalFileserver http.Handler } type ServiceHandler func(http.ResponseWriter, *http.Request, *session.T) -func New(configFilename string) (service *Service, err error) { // {{{ +func New(configFilename, version string) (service *Service, err error) { // {{{ service = new(Service) service.config, err = config.New(configFilename) @@ -77,14 +90,17 @@ func New(configFilename string) (service *Service, err error) { // {{{ } opts := slog.HandlerOptions{} + service.Version = version service.logger = slog.New(slog.NewJSONHandler(os.Stdout, &opts)) service.sessions = make(map[string]*session.T, 128) service.errorHandler = service.defaultErrorHandler service.authenticationHandler = service.defaultAuthenticationHandler service.authorizationHandler = service.defaultAuthorizationHandler + service.WsConnectionManager = ws_conn_manager.NewConnectionManager(service.logger) service.Register("/_session/new", false, false, service.sessionNew) service.Register("/_session/authenticate", true, false, service.sessionAuthenticate) + service.Register("/_session/retrieve", true, false, service.sessionRetrieve) return } // }}} @@ -95,7 +111,8 @@ func (service *Service) defaultAuthenticationHandler(req AuthenticationRequest, return } // }}} func (service *Service) defaultAuthorizationHandler(sess *session.T, r *http.Request) (resp bool, err error) { // {{{ - service.logger.Error("webservice", "op", "authorization", "session", sess.UUID, "request", r, "authorized", false) + resp = true + service.logger.Error("webservice", "op", "authorization", "session", sess.UUID, "request", r, "authorized", resp) return } // }}} func (service *Service) defaultErrorHandler(err error, w http.ResponseWriter) { // {{{ @@ -119,6 +136,19 @@ func (service *Service) SetAuthenticationHandler(h AuthenticationHandler) { // { func (service *Service) SetAuthorizationHandler(h AuthorizationHandler) { // {{{ service.authorizationHandler = h } // }}} +func (service *Service) SetStaticFS(staticFS embed.FS, directory string) (err error) { // {{{ + service.staticSubFs, err = fs.Sub(staticFS, directory) + if err != nil { + return + } + service.staticEmbeddedFileserver = http.FileServer(http.FS(service.staticSubFs)) + return +} // }}} +func (service *Service) SetStaticDirectory(directory string, useDirectory bool) { // {{{ + service.useStaticDirectory = useDirectory + service.staticDirectory = directory + service.staticLocalFileserver = http.FileServer(http.Dir(directory)) +} // }}} func (service *Service) SetDatabase(sqlProv database.SqlProvider) { // {{{ service.Db = database.New(service.config.Database) @@ -144,7 +174,7 @@ func (service *Service) Register(path string, requireSession, requireAuthenticat return } - session, found = service.sessionRetrieve(headerSessionUUID) + session, found = service.retrieveSession(headerSessionUUID) if !found { service.errorHandler(fmt.Errorf("Session '%s' not found", headerSessionUUID), w) return @@ -194,6 +224,23 @@ func (service *Service) CreateUser(username, password, name string) (err error) err = service.Db.CreateUser(username, password, name) return } // }}} +func (service *Service) CreateUserPrompt() { // {{{ + var err error + var username, name, password string + + fmt.Printf("Username: ") + fmt.Scanln(&username) + fmt.Printf("Name: ") + fmt.Scanln(&name) + fmt.Printf("Password: ") + fmt.Scanln(&password) + + err = service.CreateUser(username, password, name) + if err != nil { + service.logger.Error("application", "error", err) + os.Exit(1) + } +} // }}} func (service *Service) Start() (err error) { // {{{ if service.Db != nil { err = service.InitDatabaseConnection() @@ -208,6 +255,81 @@ func (service *Service) Start() (err error) { // {{{ return } // }}} +func (service *Service) StaticHandler(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{ + var err error + + data := struct{ VERSION string }{service.Version} + + // URLs with pattern /(css|images)/v1.0.0/foobar are stripped of the version. + // To get rid of problems with cached content in browser on a new version release, + // while also not disabling cache altogether. + if r.URL.Path == "/favicon.ico" { + service.staticEmbeddedFileserver.ServeHTTP(w, r) + return + } + + rxp := regexp.MustCompile("^/(css|images|js|fonts)/v[0-9]+/(.*)$") + if comp := rxp.FindStringSubmatch(r.URL.Path); comp != nil { + r.URL.Path = fmt.Sprintf("/%s/%s", comp[1], comp[2]) + p := fmt.Sprintf(service.staticDirectory+"/%s/%s", comp[1], comp[2]) + _, err = os.Stat(p) + if err == nil { + service.staticLocalFileserver.ServeHTTP(w, r) + } else { + service.staticEmbeddedFileserver.ServeHTTP(w, r) + } + return + } + + // Everything else is run through the template system. + // For now to get VERSION into files to fix caching. + //log.Printf("template: %s", r.URL.Path) + tmpl, err := service.newTemplate(r.URL.Path) + if err != nil { + if os.IsNotExist(err) { + w.WriteHeader(404) + } + w.Write([]byte(err.Error())) + return + } + + if err = tmpl.Execute(w, data); err != nil { + w.Write([]byte(err.Error())) + } +} // }}} +func (service *Service) WebsocketHandler(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{ + var err error + + _, err = service.WsConnectionManager.NewConnection(w, r) + if err != nil { + service.logger.Error("websocket", "error", err) + return + } +} // }}} + +func (service *Service) newTemplate(requestPath string) (tmpl *template.Template, err error) { // {{{ + // Append index.html if needed for further reading of the file + p := requestPath + if p[len(p)-1] == '/' { + p += "index.html" + } + + if p[0:1] == "/" { + p = p[1:] + } + + // Try local disk files for faster testing + if service.useStaticDirectory { + _, err = os.Stat(service.staticDirectory + "/" + p) + if err == nil { + tmpl, err = template.ParseFiles(service.staticDirectory + "/" + p) + return + } + } + + tmpl, err = template.ParseFS(service.staticSubFs, p) + return +} // }}} func sessionUUID(r *http.Request) (string, error) { // {{{ headers := r.Header["X-Session-Id"] if len(headers) > 0 { diff --git a/session.go b/session.go index 49ce467..08365bf 100644 --- a/session.go +++ b/session.go @@ -21,6 +21,7 @@ type AuthenticationRequest struct { type AuthenticationResponse struct { Authenticated bool + UserID int } type AuthenticationHandler func(AuthenticationRequest, bool) (AuthenticationResponse, error) @@ -44,7 +45,7 @@ func (service *Service) sessionNew(w http.ResponseWriter, r *http.Request, sess break } else { - if _, found = service.sessionRetrieve(session.UUID); found { + if _, found = service.retrieveSession(session.UUID); found { continue } @@ -112,13 +113,26 @@ func (service *Service) sessionAuthenticate(w http.ResponseWriter, r *http.Reque service.errorHandler(err, w) return } + authResponse.UserID = sess.UserID sess.Authenticated = authResponse.Authenticated authResp, _ := json.Marshal(authResponse) w.Write(authResp) } // }}} -func (service *Service) sessionRetrieve(uuid string) (session *session.T, found bool) { // {{{ +func (service *Service) sessionRetrieve(w http.ResponseWriter, r *http.Request, sess *session.T) {// {{{ + response := struct { + OK bool + Session *session.T + }{ + true, + sess, + } + out, _ := json.Marshal(response) + w.Write(out) +}// }}} + +func (service *Service) retrieveSession(uuid string) (session *session.T, found bool) { // {{{ var err error if service.Db == nil { diff --git a/ws_conn_manager/pkg.go b/ws_conn_manager/pkg.go new file mode 100644 index 0000000..e366d52 --- /dev/null +++ b/ws_conn_manager/pkg.go @@ -0,0 +1,131 @@ +package ws_conn_manager + +import ( + // External + "github.com/google/uuid" + "github.com/gorilla/websocket" + + // Standard + "log/slog" + "net/http" +) + +type WsConnection struct { + ConnectionManager *ConnectionManager + UUID string + Conn *websocket.Conn + Pruned bool +} +type ConnectionManager struct { + connections map[string]*WsConnection + broadcastQueue chan interface{} + sendQueue chan SendRequest + logger *slog.Logger +} +type SendRequest struct { + WsConn *WsConnection + Msg interface{} +} + +func validateOrigin(r *http.Request) bool { // {{{ + /* + host := r.Header.Get("X-Forwarded-Host") + if host == "" { + components := strings.Split(r.Host, ":") + host = components[0] + } + */ + return true +} // }}} + +var ( + upgrader = websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + + // CheckOrigin is to match DOMAIN constant. + // Use X-Forwarded-Server if behind proxy. + CheckOrigin: validateOrigin, + } +) + +func NewConnectionManager(logger *slog.Logger) (cm ConnectionManager) { // {{{ + cm.connections = make(map[string]*WsConnection) + cm.sendQueue = make(chan SendRequest, 65536) + cm.broadcastQueue = make(chan interface{}, 65536) + cm.logger = logger + return +} // }}} + +// NewConnection creates a new connection, which is assigned a UUIDv4 for +// identification. This is then put into the connection collection. +func (cm *ConnectionManager) NewConnection(w http.ResponseWriter, r *http.Request) (*WsConnection, error) { // {{{ + var err error + wsConn := WsConnection{ + UUID: uuid.NewString(), + ConnectionManager: cm, + } + wsConn.Conn, err = upgrader.Upgrade(w, r, nil) + if err != nil { + return nil, err + } + + // Keep track of all connections. + cm.connections[wsConn.UUID] = &wsConn + + // Successfully upgraded to a websocket connection. + cm.logger.Info("websocket", "uuid", wsConn.UUID, "remote_addr", r.RemoteAddr) + + go cm.ReadLoop(&wsConn) + + return &wsConn, nil +} // }}} + +// Prune closes an deletes connections. If this happened to be non-fatal, the +// user will just have to reconnect. +func (cm *ConnectionManager) Prune(wsConn *WsConnection, err error) { // {{{ + cm.logger.Info("websocket", "op", "prune", "uuid", wsConn.UUID) + wsConn.Conn.Close() + wsConn.Pruned = true + delete(cm.connections, wsConn.UUID) +} // }}} +func (cm *ConnectionManager) ReadLoop(wsConn *WsConnection) { // {{{ + var data []byte + var ok bool + + for { + if data, ok = cm.Read(wsConn); !ok { + break + } + + cm.logger.Debug("websocket", "op", "read", "data", data) + //cm.Send(wsConn, response) + } +} // }}} +func (cm *ConnectionManager) Read(wsConn *WsConnection) ([]byte, bool) { // {{{ + var err error + var requestData []byte + _, requestData, err = wsConn.Conn.ReadMessage() + if err != nil { + cm.Prune(wsConn, err) + return nil, false + } + return requestData, true +} // }}} +func (cm *ConnectionManager) Send(wsConn *WsConnection, msg interface{}) { // {{{ + wsConn.Conn.WriteJSON(msg) +} // }}} +func (cm *ConnectionManager) Broadcast(msg interface{}) { // {{{ + cm.broadcastQueue <- msg +} // }}} + +func (cm *ConnectionManager) BroadcastLoop() { // {{{ + for { + msg := <-cm.broadcastQueue + for _, wsConn := range cm.connections { + cm.Send(wsConn, msg) + } + } +} // }}} + +// vim: foldmethod=marker