From 011115f9d693ca736c943f62380e1e57f6e3b7f2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Magnus=20=C3=85hall?= Date: Thu, 22 Feb 2024 12:18:02 +0100 Subject: [PATCH] Initial release --- .gitignore | 1 + Makefile | 2 + README.md | 8 ++ go.mod | 10 +++ go.sum | 7 ++ main.go | 138 ++++++++++++++++++++++++++++++++++ native.go | 216 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 7 files changed, 382 insertions(+) create mode 100644 .gitignore create mode 100644 Makefile create mode 100644 README.md create mode 100644 go.mod create mode 100644 go.sum create mode 100644 main.go create mode 100644 native.go diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..13275d8 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +wireguard-mfa.exe diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..c4b6971 --- /dev/null +++ b/Makefile @@ -0,0 +1,2 @@ +wireguard-mfa.exe: main.go native.go + GOOS=windows GOARCH=amd64 go build -ldflags -H=windowsgui diff --git a/README.md b/README.md new file mode 100644 index 0000000..dbd9fa2 --- /dev/null +++ b/README.md @@ -0,0 +1,8 @@ +Installation +============ +1. Copy `wireguard-mfa.exe` to `C:\Program Files\Gibon\` +2. Create the `C:\Program Files\Gibon\wg-spool\` directory and make sure users activating Wireguard tunnels have read- and write permissions to it. + +Log file +======== +If anything goes wrong, look in `C:\Program Files\Gibon\wg-spool\log.txt`. diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..7268da8 --- /dev/null +++ b/go.mod @@ -0,0 +1,10 @@ +module wireguard-mfa + +go 1.19 + +require ( + golang.org/x/sys v0.7.0 + golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 +) + +require golang.org/x/crypto v0.8.0 // indirect diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..f71e551 --- /dev/null +++ b/go.sum @@ -0,0 +1,7 @@ +github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= +golang.org/x/crypto v0.8.0 h1:pd9TJtTueMTVQXzk8E2XESSMQDj/U7OUu0PqJqPXQjQ= +golang.org/x/crypto v0.8.0/go.mod h1:mRqEX+O9/h5TFCrQhkgjo2yKi0yYA+9ecGkdQoHrywE= +golang.org/x/sys v0.7.0 h1:3jlCCIQZPdOYu1h8BkNvLz8Kgwtae2cagcG/VamtZRU= +golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 h1:CawjfCvYQH2OU3/TnxLx97WDSUDRABfT18pCOYwc2GE= +golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6/go.mod h1:3rxYc4HtVcSG9gVaTs2GEBdehh+sYPOwKtyUWEOTb80= diff --git a/main.go b/main.go new file mode 100644 index 0000000..40ae61c --- /dev/null +++ b/main.go @@ -0,0 +1,138 @@ +package main + +import ( + // External + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + + // Standard + "errors" + "flag" + "fmt" + "io/fs" + "net/url" + "os" + "os/exec" + "regexp" + "time" +) + +const VERSION = "v1" +const SPOOLDIR = "C:\\Program Files\\Gibon\\wg-spool" +const DOMAIN = "https://vpn.gibonuddevalla.se" +// const DOMAIN = "http://192.168.122.1:8000" + +var ( + publicKeyPath string + runInteractiveAuth bool + connected bool + printVersion bool +) + +func init() { + flag.BoolVar(&printVersion, "version", false, "print version and exit") + flag.BoolVar(&runInteractiveAuth, "interactive", false, "run interactive authentication") + flag.BoolVar(&connected, "connected", false, "postop, is connected") + flag.Parse() + + if printVersion { + fmt.Println(VERSION) + os.Exit(0) + } +} + +func logError(where, e string) { + fmt.Printf("ERROR: %s\n", e) + fname := fmt.Sprintf("%s\\log.txt", SPOOLDIR) + file, err := os.OpenFile(fname, os.O_CREATE|os.O_APPEND, 0644) + if err != nil { + fmt.Println(err) + os.Exit(1) + } + defer file.Close() + + now := time.Now().String() + file.Write([]byte(now)) + file.Write([]byte("\t")) + file.Write([]byte(where)) + file.Write([]byte("\t")) + file.Write([]byte(e)) + file.Write([]byte("\n")) +} + +func main() { + publicKeyPath = SPOOLDIR + "\\wireguard-mfa.url" + iface := os.Getenv("WIREGUARD_TUNNEL_NAME") + + // Run from PreUp, step 1. + // Starts an interactive process, dropping privileges, + // in order to have permission to open a browser in step 2. + // Doesn't have permission to get public key. + if !runInteractiveAuth && !connected { + ex, _ := os.Executable() + if err := StartProcessAsCurrentUser("", ex+" -interactive", ""); err != nil { + logError("start_interactive", err.Error()) + } + os.Exit(0) + } + + // Started from PreUp, step 2. + // Doesn't have permission to get public key. + if runInteractiveAuth { + var err error + start := time.Now() + for { + // Abort after 60 seconds to not have an eternal loop. + if time.Since(start).Seconds() > 60 { + os.Exit(1) + } + + _, err = os.ReadFile(publicKeyPath) + if err != nil { + if errors.Is(err, fs.ErrNotExist) { + time.Sleep(time.Millisecond * 200) + continue + } + + if !errors.Is(err, fs.ErrNotExist) { + logError("read_publickey_file", err.Error()) + os.Exit(1) + } + } + break + } + + cmd := exec.Command("rundll32.exe", "url.dll,OpenURL", publicKeyPath) + cmd.Output() + os.Remove(publicKeyPath) + } + + // Run from PostUp, step 3. + // Has only permission to get public key, can't run browser. + if connected { + getPublicKey(iface) + } +} + +func getPublicKey(iface string) { + rxpPrivateKey := regexp.MustCompile("PrivateKey\\s*=\\s*(.*)") + cmd := exec.Command("wg", "showconf", iface) + out, _ := cmd.Output() + privateKey := rxpPrivateKey.FindStringSubmatch(string(out)) + if len(privateKey) == 2 { + pubkey, err := wgtypes.ParseKey(privateKey[1]) + if err != nil { + logError("parse_privatekey", fmt.Sprintf("%s [%s]", err.Error(), privateKey[1])) + return + } + + urlData := fmt.Sprintf( + "[InternetShortcut]\r\nURL=%s/?key=%s\r\n", + DOMAIN, + url.QueryEscape(pubkey.PublicKey().String()), + ) + os.WriteFile(publicKeyPath, []byte(urlData), 0644) + logError("publickey", pubkey.PublicKey().String()) + } else { + logError("get_publickey", string(out)) + } +} diff --git a/native.go b/native.go new file mode 100644 index 0000000..1b3488c --- /dev/null +++ b/native.go @@ -0,0 +1,216 @@ +package main + +import ( + "fmt" + "unsafe" + + "golang.org/x/sys/windows" +) + +var ( + modwtsapi32 *windows.LazyDLL = windows.NewLazySystemDLL("wtsapi32.dll") + modkernel32 *windows.LazyDLL = windows.NewLazySystemDLL("kernel32.dll") + modadvapi32 *windows.LazyDLL = windows.NewLazySystemDLL("advapi32.dll") + moduserenv *windows.LazyDLL = windows.NewLazySystemDLL("userenv.dll") + + procWTSEnumerateSessionsW *windows.LazyProc = modwtsapi32.NewProc("WTSEnumerateSessionsW") + procWTSGetActiveConsoleSessionId *windows.LazyProc = modkernel32.NewProc("WTSGetActiveConsoleSessionId") + procWTSQueryUserToken *windows.LazyProc = modwtsapi32.NewProc("WTSQueryUserToken") + procDuplicateTokenEx *windows.LazyProc = modadvapi32.NewProc("DuplicateTokenEx") + procCreateEnvironmentBlock *windows.LazyProc = moduserenv.NewProc("CreateEnvironmentBlock") + procCreateProcessAsUser *windows.LazyProc = modadvapi32.NewProc("CreateProcessAsUserW") +) + +const ( + WTS_CURRENT_SERVER_HANDLE uintptr = 0 +) + +type WTS_CONNECTSTATE_CLASS int + +const ( + WTSActive WTS_CONNECTSTATE_CLASS = iota + WTSConnected + WTSConnectQuery + WTSShadow + WTSDisconnected + WTSIdle + WTSListen + WTSReset + WTSDown + WTSInit +) + +type SECURITY_IMPERSONATION_LEVEL int + +const ( + SecurityAnonymous SECURITY_IMPERSONATION_LEVEL = iota + SecurityIdentification + SecurityImpersonation + SecurityDelegation +) + +type TOKEN_TYPE int + +const ( + TokenPrimary TOKEN_TYPE = iota + 1 + TokenImpersonazion +) + +type SW int + +const ( + SW_HIDE SW = 0 + SW_SHOWNORMAL = 1 + SW_NORMAL = 1 + SW_SHOWMINIMIZED = 2 + SW_SHOWMAXIMIZED = 3 + SW_MAXIMIZE = 3 + SW_SHOWNOACTIVATE = 4 + SW_SHOW = 5 + SW_MINIMIZE = 6 + SW_SHOWMINNOACTIVE = 7 + SW_SHOWNA = 8 + SW_RESTORE = 9 + SW_SHOWDEFAULT = 10 + SW_MAX = 1 +) + +type WTS_SESSION_INFO struct { + SessionID windows.Handle + WinStationName *uint16 + State WTS_CONNECTSTATE_CLASS +} + +const ( + CREATE_UNICODE_ENVIRONMENT uint16 = 0x00000400 + CREATE_NO_WINDOW = 0x08000000 + CREATE_NEW_CONSOLE = 0x00000010 +) + +// GetCurrentUserSessionId will attempt to resolve +// the session ID of the user currently active on +// the system. +func GetCurrentUserSessionId() (windows.Handle, error) { + sessionList, err := WTSEnumerateSessions() + if err != nil { + return 0xFFFFFFFF, fmt.Errorf("get current user session token: %s", err) + } + + for i := range sessionList { + if sessionList[i].State == WTSActive { + return sessionList[i].SessionID, nil + } + } + + if sessionId, _, err := procWTSGetActiveConsoleSessionId.Call(); sessionId == 0xFFFFFFFF { + return 0xFFFFFFFF, fmt.Errorf("get current user session token: call native WTSGetActiveConsoleSessionId: %s", err) + } else { + return windows.Handle(sessionId), nil + } +} + +// WTSEnumerateSession will call the native +// version for Windows and parse the result +// to a Golang friendly version +func WTSEnumerateSessions() ([]*WTS_SESSION_INFO, error) { + var ( + sessionInformation windows.Handle = windows.Handle(0) + sessionCount int = 0 + sessionList []*WTS_SESSION_INFO = make([]*WTS_SESSION_INFO, 0) + ) + + if returnCode, _, err := procWTSEnumerateSessionsW.Call(WTS_CURRENT_SERVER_HANDLE, 0, 1, uintptr(unsafe.Pointer(&sessionInformation)), uintptr(unsafe.Pointer(&sessionCount))); returnCode == 0 { + return nil, fmt.Errorf("call native WTSEnumerateSessionsW: %s", err) + } + + structSize := unsafe.Sizeof(WTS_SESSION_INFO{}) + current := uintptr(sessionInformation) + for i := 0; i < sessionCount; i++ { + sessionList = append(sessionList, (*WTS_SESSION_INFO)(unsafe.Pointer(current))) + current += structSize + } + + return sessionList, nil +} + +// DuplicateUserTokenFromSessionID will attempt +// to duplicate the user token for the user logged +// into the provided session ID +func DuplicateUserTokenFromSessionID(sessionId windows.Handle) (windows.Token, error) { + var ( + impersonationToken windows.Handle = 0 + userToken windows.Token = 0 + ) + + if returnCode, _, err := procWTSQueryUserToken.Call(uintptr(sessionId), uintptr(unsafe.Pointer(&impersonationToken))); returnCode == 0 { + return 0xFFFFFFFF, fmt.Errorf("call native WTSQueryUserToken: %s", err) + } + + if returnCode, _, err := procDuplicateTokenEx.Call(uintptr(impersonationToken), 0, 0, uintptr(SecurityImpersonation), uintptr(TokenPrimary), uintptr(unsafe.Pointer(&userToken))); returnCode == 0 { + return 0xFFFFFFFF, fmt.Errorf("call native DuplicateTokenEx: %s", err) + } + + if err := windows.CloseHandle(impersonationToken); err != nil { + return 0xFFFFFFFF, fmt.Errorf("close windows handle used for token duplication: %s", err) + } + + return userToken, nil +} + +func StartProcessAsCurrentUser(appPath, cmdLine, workDir string) error { + var ( + sessionId windows.Handle + userToken windows.Token + envInfo windows.Handle + + startupInfo windows.StartupInfo + processInfo windows.ProcessInformation + + commandLine uintptr = 0 + workingDir uintptr = 0 + + err error + ) + + if sessionId, err = GetCurrentUserSessionId(); err != nil { + return err + } + + if userToken, err = DuplicateUserTokenFromSessionID(sessionId); err != nil { + return fmt.Errorf("get duplicate user token for current user session: %s", err) + } + + if returnCode, _, err := procCreateEnvironmentBlock.Call(uintptr(unsafe.Pointer(&envInfo)), uintptr(userToken), 0); returnCode == 0 { + return fmt.Errorf("create environment details for process: %s", err) + } + + creationFlags := CREATE_UNICODE_ENVIRONMENT | CREATE_NEW_CONSOLE + startupInfo.ShowWindow = SW_SHOW + startupInfo.Desktop = windows.StringToUTF16Ptr("winsta0\\default") + + if len(cmdLine) > 0 { + commandLine = uintptr(unsafe.Pointer(windows.StringToUTF16Ptr(cmdLine))) + } + if len(workDir) > 0 { + workingDir = uintptr(unsafe.Pointer(windows.StringToUTF16Ptr(workDir))) + } + + if returnCode, _, err := procCreateProcessAsUser.Call( + uintptr(userToken), // hToken + //uintptr(unsafe.Pointer(windows.StringToUTF16Ptr(appPath))), // lpApplicationName + uintptr(unsafe.Pointer(nil)), // lpApplicationName + commandLine, // lpCommandLine + 0, // lpProcessAttributes + 0, // lpThreadAttributes + 0, // lpInheritHandles + uintptr(creationFlags), // dwCreationFlags + uintptr(envInfo), // lpEnvironment + workingDir, // lpCurrentDirectory + uintptr(unsafe.Pointer(&startupInfo)), // lpStartupInfo + uintptr(unsafe.Pointer(&processInfo)), // lpProcessInformation + ); returnCode == 0 { + return fmt.Errorf("create process as user: %s", err) + } + + return nil +}