Added static handler and websocket connection manager
This commit is contained in:
		
							parent
							
								
									61a36b87bb
								
							
						
					
					
						commit
						447aec742c
					
				
					 3 changed files with 272 additions and 5 deletions
				
			
		
							
								
								
									
										128
									
								
								pkg.go
									
										
									
									
									
								
							
							
						
						
									
										128
									
								
								pkg.go
									
										
									
									
									
								
							| 
						 | 
				
			
			@ -33,16 +33,21 @@ package webservice
 | 
			
		|||
import (
 | 
			
		||||
	// Internal
 | 
			
		||||
	"git.gibonuddevalla.se/go/webservice/config"
 | 
			
		||||
	"git.gibonuddevalla.se/go/webservice/ws_conn_manager"
 | 
			
		||||
	"git.gibonuddevalla.se/go/webservice/database"
 | 
			
		||||
	"git.gibonuddevalla.se/go/webservice/session"
 | 
			
		||||
 | 
			
		||||
	// Standard
 | 
			
		||||
	"embed"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"html/template"
 | 
			
		||||
	"io/fs"
 | 
			
		||||
	"log/slog"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"os"
 | 
			
		||||
	"regexp"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const VERSION = "v0.1.0"
 | 
			
		||||
| 
						 | 
				
			
			@ -60,15 +65,23 @@ type Service struct {
 | 
			
		|||
	sessions map[string]*session.T
 | 
			
		||||
	config   config.Config
 | 
			
		||||
	Db       *database.T
 | 
			
		||||
	Version  string
 | 
			
		||||
	WsConnectionManager ws_conn_manager.ConnectionManager
 | 
			
		||||
 | 
			
		||||
	errorHandler          ErrorHandler
 | 
			
		||||
	authenticationHandler AuthenticationHandler
 | 
			
		||||
	authorizationHandler  AuthorizationHandler
 | 
			
		||||
 | 
			
		||||
	staticSubFs              fs.FS
 | 
			
		||||
	useStaticDirectory       bool
 | 
			
		||||
	staticDirectory          string
 | 
			
		||||
	staticEmbeddedFileserver http.Handler
 | 
			
		||||
	staticLocalFileserver    http.Handler
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ServiceHandler func(http.ResponseWriter, *http.Request, *session.T)
 | 
			
		||||
 | 
			
		||||
func New(configFilename string) (service *Service, err error) { // {{{
 | 
			
		||||
func New(configFilename, version string) (service *Service, err error) { // {{{
 | 
			
		||||
	service = new(Service)
 | 
			
		||||
 | 
			
		||||
	service.config, err = config.New(configFilename)
 | 
			
		||||
| 
						 | 
				
			
			@ -77,14 +90,17 @@ func New(configFilename string) (service *Service, err error) { // {{{
 | 
			
		|||
	}
 | 
			
		||||
 | 
			
		||||
	opts := slog.HandlerOptions{}
 | 
			
		||||
	service.Version = version
 | 
			
		||||
	service.logger = slog.New(slog.NewJSONHandler(os.Stdout, &opts))
 | 
			
		||||
	service.sessions = make(map[string]*session.T, 128)
 | 
			
		||||
	service.errorHandler = service.defaultErrorHandler
 | 
			
		||||
	service.authenticationHandler = service.defaultAuthenticationHandler
 | 
			
		||||
	service.authorizationHandler = service.defaultAuthorizationHandler
 | 
			
		||||
	service.WsConnectionManager = ws_conn_manager.NewConnectionManager(service.logger)
 | 
			
		||||
 | 
			
		||||
	service.Register("/_session/new", false, false, service.sessionNew)
 | 
			
		||||
	service.Register("/_session/authenticate", true, false, service.sessionAuthenticate)
 | 
			
		||||
	service.Register("/_session/retrieve", true, false, service.sessionRetrieve)
 | 
			
		||||
 | 
			
		||||
	return
 | 
			
		||||
} // }}}
 | 
			
		||||
| 
						 | 
				
			
			@ -95,7 +111,8 @@ func (service *Service) defaultAuthenticationHandler(req AuthenticationRequest,
 | 
			
		|||
	return
 | 
			
		||||
} // }}}
 | 
			
		||||
func (service *Service) defaultAuthorizationHandler(sess *session.T, r *http.Request) (resp bool, err error) { // {{{
 | 
			
		||||
	service.logger.Error("webservice", "op", "authorization", "session", sess.UUID, "request", r, "authorized", false)
 | 
			
		||||
	resp = true
 | 
			
		||||
	service.logger.Error("webservice", "op", "authorization", "session", sess.UUID, "request", r, "authorized", resp)
 | 
			
		||||
	return
 | 
			
		||||
} // }}}
 | 
			
		||||
func (service *Service) defaultErrorHandler(err error, w http.ResponseWriter) { // {{{
 | 
			
		||||
| 
						 | 
				
			
			@ -119,6 +136,19 @@ func (service *Service) SetAuthenticationHandler(h AuthenticationHandler) { // {
 | 
			
		|||
func (service *Service) SetAuthorizationHandler(h AuthorizationHandler) { // {{{
 | 
			
		||||
	service.authorizationHandler = h
 | 
			
		||||
} // }}}
 | 
			
		||||
func (service *Service) SetStaticFS(staticFS embed.FS, directory string) (err error) { // {{{
 | 
			
		||||
	service.staticSubFs, err = fs.Sub(staticFS, directory)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	service.staticEmbeddedFileserver = http.FileServer(http.FS(service.staticSubFs))
 | 
			
		||||
	return
 | 
			
		||||
} // }}}
 | 
			
		||||
func (service *Service) SetStaticDirectory(directory string, useDirectory bool) { // {{{
 | 
			
		||||
	service.useStaticDirectory = useDirectory
 | 
			
		||||
	service.staticDirectory = directory
 | 
			
		||||
	service.staticLocalFileserver = http.FileServer(http.Dir(directory))
 | 
			
		||||
} // }}}
 | 
			
		||||
 | 
			
		||||
func (service *Service) SetDatabase(sqlProv database.SqlProvider) { // {{{
 | 
			
		||||
	service.Db = database.New(service.config.Database)
 | 
			
		||||
| 
						 | 
				
			
			@ -144,7 +174,7 @@ func (service *Service) Register(path string, requireSession, requireAuthenticat
 | 
			
		|||
				return
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			session, found = service.sessionRetrieve(headerSessionUUID)
 | 
			
		||||
			session, found = service.retrieveSession(headerSessionUUID)
 | 
			
		||||
			if !found {
 | 
			
		||||
				service.errorHandler(fmt.Errorf("Session '%s' not found", headerSessionUUID), w)
 | 
			
		||||
				return
 | 
			
		||||
| 
						 | 
				
			
			@ -194,6 +224,23 @@ func (service *Service) CreateUser(username, password, name string) (err error)
 | 
			
		|||
	err = service.Db.CreateUser(username, password, name)
 | 
			
		||||
	return
 | 
			
		||||
} // }}}
 | 
			
		||||
func (service *Service) CreateUserPrompt() { // {{{
 | 
			
		||||
	var err error
 | 
			
		||||
	var username, name, password string
 | 
			
		||||
 | 
			
		||||
	fmt.Printf("Username: ")
 | 
			
		||||
	fmt.Scanln(&username)
 | 
			
		||||
	fmt.Printf("Name: ")
 | 
			
		||||
	fmt.Scanln(&name)
 | 
			
		||||
	fmt.Printf("Password: ")
 | 
			
		||||
	fmt.Scanln(&password)
 | 
			
		||||
 | 
			
		||||
	err = service.CreateUser(username, password, name)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		service.logger.Error("application", "error", err)
 | 
			
		||||
		os.Exit(1)
 | 
			
		||||
	}
 | 
			
		||||
} // }}}
 | 
			
		||||
func (service *Service) Start() (err error) { // {{{
 | 
			
		||||
	if service.Db != nil {
 | 
			
		||||
		err = service.InitDatabaseConnection()
 | 
			
		||||
| 
						 | 
				
			
			@ -208,6 +255,81 @@ func (service *Service) Start() (err error) { // {{{
 | 
			
		|||
	return
 | 
			
		||||
} // }}}
 | 
			
		||||
 | 
			
		||||
func (service *Service) StaticHandler(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{
 | 
			
		||||
	var err error
 | 
			
		||||
 | 
			
		||||
	data := struct{ VERSION string }{service.Version}
 | 
			
		||||
 | 
			
		||||
	// URLs with pattern /(css|images)/v1.0.0/foobar are stripped of the version.
 | 
			
		||||
	// To get rid of problems with cached content in browser on a new version release,
 | 
			
		||||
	// while also not disabling cache altogether.
 | 
			
		||||
	if r.URL.Path == "/favicon.ico" {
 | 
			
		||||
		service.staticEmbeddedFileserver.ServeHTTP(w, r)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	rxp := regexp.MustCompile("^/(css|images|js|fonts)/v[0-9]+/(.*)$")
 | 
			
		||||
	if comp := rxp.FindStringSubmatch(r.URL.Path); comp != nil {
 | 
			
		||||
		r.URL.Path = fmt.Sprintf("/%s/%s", comp[1], comp[2])
 | 
			
		||||
		p := fmt.Sprintf(service.staticDirectory+"/%s/%s", comp[1], comp[2])
 | 
			
		||||
		_, err = os.Stat(p)
 | 
			
		||||
		if err == nil {
 | 
			
		||||
			service.staticLocalFileserver.ServeHTTP(w, r)
 | 
			
		||||
		} else {
 | 
			
		||||
			service.staticEmbeddedFileserver.ServeHTTP(w, r)
 | 
			
		||||
		}
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Everything else is run through the template system.
 | 
			
		||||
	// For now to get VERSION into files to fix caching.
 | 
			
		||||
	//log.Printf("template: %s", r.URL.Path)
 | 
			
		||||
	tmpl, err := service.newTemplate(r.URL.Path)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		if os.IsNotExist(err) {
 | 
			
		||||
			w.WriteHeader(404)
 | 
			
		||||
		}
 | 
			
		||||
		w.Write([]byte(err.Error()))
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err = tmpl.Execute(w, data); err != nil {
 | 
			
		||||
		w.Write([]byte(err.Error()))
 | 
			
		||||
	}
 | 
			
		||||
} // }}}
 | 
			
		||||
func (service *Service) WebsocketHandler(w http.ResponseWriter, r *http.Request, sess *session.T) { // {{{
 | 
			
		||||
	var err error
 | 
			
		||||
 | 
			
		||||
	_, err = service.WsConnectionManager.NewConnection(w, r)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		service.logger.Error("websocket", "error", err)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
} // }}}
 | 
			
		||||
 | 
			
		||||
func (service *Service) newTemplate(requestPath string) (tmpl *template.Template, err error) { // {{{
 | 
			
		||||
	// Append index.html if needed for further reading of the file
 | 
			
		||||
	p := requestPath
 | 
			
		||||
	if p[len(p)-1] == '/' {
 | 
			
		||||
		p += "index.html"
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if p[0:1] == "/" {
 | 
			
		||||
		p = p[1:]
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Try local disk files for faster testing
 | 
			
		||||
	if service.useStaticDirectory {
 | 
			
		||||
		_, err = os.Stat(service.staticDirectory + "/" + p)
 | 
			
		||||
		if err == nil {
 | 
			
		||||
			tmpl, err = template.ParseFiles(service.staticDirectory + "/" + p)
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	tmpl, err = template.ParseFS(service.staticSubFs, p)
 | 
			
		||||
	return
 | 
			
		||||
} // }}}
 | 
			
		||||
func sessionUUID(r *http.Request) (string, error) { // {{{
 | 
			
		||||
	headers := r.Header["X-Session-Id"]
 | 
			
		||||
	if len(headers) > 0 {
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										18
									
								
								session.go
									
										
									
									
									
								
							
							
						
						
									
										18
									
								
								session.go
									
										
									
									
									
								
							| 
						 | 
				
			
			@ -21,6 +21,7 @@ type AuthenticationRequest struct {
 | 
			
		|||
 | 
			
		||||
type AuthenticationResponse struct {
 | 
			
		||||
	Authenticated bool
 | 
			
		||||
	UserID int
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type AuthenticationHandler func(AuthenticationRequest, bool) (AuthenticationResponse, error)
 | 
			
		||||
| 
						 | 
				
			
			@ -44,7 +45,7 @@ func (service *Service) sessionNew(w http.ResponseWriter, r *http.Request, sess
 | 
			
		|||
			break
 | 
			
		||||
 | 
			
		||||
		} else {
 | 
			
		||||
			if _, found = service.sessionRetrieve(session.UUID); found {
 | 
			
		||||
			if _, found = service.retrieveSession(session.UUID); found {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -112,13 +113,26 @@ func (service *Service) sessionAuthenticate(w http.ResponseWriter, r *http.Reque
 | 
			
		|||
		service.errorHandler(err, w)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	authResponse.UserID = sess.UserID
 | 
			
		||||
 | 
			
		||||
	sess.Authenticated = authResponse.Authenticated
 | 
			
		||||
 | 
			
		||||
	authResp, _ := json.Marshal(authResponse)
 | 
			
		||||
	w.Write(authResp)
 | 
			
		||||
} // }}}
 | 
			
		||||
func (service *Service) sessionRetrieve(uuid string) (session *session.T, found bool) { // {{{
 | 
			
		||||
func (service *Service) sessionRetrieve(w http.ResponseWriter, r *http.Request, sess *session.T) {// {{{
 | 
			
		||||
	response := struct {
 | 
			
		||||
		OK bool
 | 
			
		||||
		Session *session.T
 | 
			
		||||
	}{
 | 
			
		||||
		true,
 | 
			
		||||
		sess,
 | 
			
		||||
	}
 | 
			
		||||
	out, _ := json.Marshal(response)
 | 
			
		||||
	w.Write(out)
 | 
			
		||||
}// }}}
 | 
			
		||||
 | 
			
		||||
func (service *Service) retrieveSession(uuid string) (session *session.T, found bool) { // {{{
 | 
			
		||||
	var err error
 | 
			
		||||
 | 
			
		||||
	if service.Db == nil {
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										131
									
								
								ws_conn_manager/pkg.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										131
									
								
								ws_conn_manager/pkg.go
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,131 @@
 | 
			
		|||
package ws_conn_manager
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	// External
 | 
			
		||||
	"github.com/google/uuid"
 | 
			
		||||
	"github.com/gorilla/websocket"
 | 
			
		||||
 | 
			
		||||
	// Standard
 | 
			
		||||
	"log/slog"
 | 
			
		||||
	"net/http"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type WsConnection struct {
 | 
			
		||||
	ConnectionManager *ConnectionManager
 | 
			
		||||
	UUID              string
 | 
			
		||||
	Conn              *websocket.Conn
 | 
			
		||||
	Pruned            bool
 | 
			
		||||
}
 | 
			
		||||
type ConnectionManager struct {
 | 
			
		||||
	connections    map[string]*WsConnection
 | 
			
		||||
	broadcastQueue chan interface{}
 | 
			
		||||
	sendQueue      chan SendRequest
 | 
			
		||||
	logger         *slog.Logger
 | 
			
		||||
}
 | 
			
		||||
type SendRequest struct {
 | 
			
		||||
	WsConn *WsConnection
 | 
			
		||||
	Msg    interface{}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func validateOrigin(r *http.Request) bool { // {{{
 | 
			
		||||
	/*
 | 
			
		||||
		host := r.Header.Get("X-Forwarded-Host")
 | 
			
		||||
		if host == "" {
 | 
			
		||||
			components := strings.Split(r.Host, ":")
 | 
			
		||||
			host = components[0]
 | 
			
		||||
		}
 | 
			
		||||
	*/
 | 
			
		||||
	return true
 | 
			
		||||
} // }}}
 | 
			
		||||
 | 
			
		||||
var (
 | 
			
		||||
	upgrader = websocket.Upgrader{
 | 
			
		||||
		ReadBufferSize:  1024,
 | 
			
		||||
		WriteBufferSize: 1024,
 | 
			
		||||
 | 
			
		||||
		// CheckOrigin is to match DOMAIN constant.
 | 
			
		||||
		// Use X-Forwarded-Server if behind proxy.
 | 
			
		||||
		CheckOrigin: validateOrigin,
 | 
			
		||||
	}
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func NewConnectionManager(logger *slog.Logger) (cm ConnectionManager) { // {{{
 | 
			
		||||
	cm.connections = make(map[string]*WsConnection)
 | 
			
		||||
	cm.sendQueue = make(chan SendRequest, 65536)
 | 
			
		||||
	cm.broadcastQueue = make(chan interface{}, 65536)
 | 
			
		||||
	cm.logger = logger
 | 
			
		||||
	return
 | 
			
		||||
} // }}}
 | 
			
		||||
 | 
			
		||||
// NewConnection creates a new connection, which is assigned a UUIDv4 for
 | 
			
		||||
// identification. This is then put into the connection collection.
 | 
			
		||||
func (cm *ConnectionManager) NewConnection(w http.ResponseWriter, r *http.Request) (*WsConnection, error) { // {{{
 | 
			
		||||
	var err error
 | 
			
		||||
	wsConn := WsConnection{
 | 
			
		||||
		UUID:              uuid.NewString(),
 | 
			
		||||
		ConnectionManager: cm,
 | 
			
		||||
	}
 | 
			
		||||
	wsConn.Conn, err = upgrader.Upgrade(w, r, nil)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Keep track of all connections.
 | 
			
		||||
	cm.connections[wsConn.UUID] = &wsConn
 | 
			
		||||
 | 
			
		||||
	// Successfully upgraded to a websocket connection.
 | 
			
		||||
	cm.logger.Info("websocket", "uuid", wsConn.UUID, "remote_addr", r.RemoteAddr)
 | 
			
		||||
 | 
			
		||||
	go cm.ReadLoop(&wsConn)
 | 
			
		||||
 | 
			
		||||
	return &wsConn, nil
 | 
			
		||||
} // }}}
 | 
			
		||||
 | 
			
		||||
// Prune closes an deletes connections. If this happened to be non-fatal, the
 | 
			
		||||
// user will just have to reconnect.
 | 
			
		||||
func (cm *ConnectionManager) Prune(wsConn *WsConnection, err error) { // {{{
 | 
			
		||||
	cm.logger.Info("websocket", "op", "prune", "uuid", wsConn.UUID)
 | 
			
		||||
	wsConn.Conn.Close()
 | 
			
		||||
	wsConn.Pruned = true
 | 
			
		||||
	delete(cm.connections, wsConn.UUID)
 | 
			
		||||
} // }}}
 | 
			
		||||
func (cm *ConnectionManager) ReadLoop(wsConn *WsConnection) { // {{{
 | 
			
		||||
	var data []byte
 | 
			
		||||
	var ok bool
 | 
			
		||||
 | 
			
		||||
	for {
 | 
			
		||||
		if data, ok = cm.Read(wsConn); !ok {
 | 
			
		||||
			break
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		cm.logger.Debug("websocket", "op", "read", "data", data)
 | 
			
		||||
		//cm.Send(wsConn, response)
 | 
			
		||||
	}
 | 
			
		||||
} // }}}
 | 
			
		||||
func (cm *ConnectionManager) Read(wsConn *WsConnection) ([]byte, bool) { // {{{
 | 
			
		||||
	var err error
 | 
			
		||||
	var requestData []byte
 | 
			
		||||
	_, requestData, err = wsConn.Conn.ReadMessage()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		cm.Prune(wsConn, err)
 | 
			
		||||
		return nil, false
 | 
			
		||||
	}
 | 
			
		||||
	return requestData, true
 | 
			
		||||
} // }}}
 | 
			
		||||
func (cm *ConnectionManager) Send(wsConn *WsConnection, msg interface{}) { // {{{
 | 
			
		||||
	wsConn.Conn.WriteJSON(msg)
 | 
			
		||||
} // }}}
 | 
			
		||||
func (cm *ConnectionManager) Broadcast(msg interface{}) { // {{{
 | 
			
		||||
	cm.broadcastQueue <- msg
 | 
			
		||||
} // }}}
 | 
			
		||||
 | 
			
		||||
func (cm *ConnectionManager) BroadcastLoop() { // {{{
 | 
			
		||||
	for {
 | 
			
		||||
		msg := <-cm.broadcastQueue
 | 
			
		||||
		for _, wsConn := range cm.connections {
 | 
			
		||||
			cm.Send(wsConn, msg)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
} // }}}
 | 
			
		||||
 | 
			
		||||
// vim: foldmethod=marker
 | 
			
		||||
		Loading…
	
	Add table
		
		Reference in a new issue