From fe392d0dff0089c70f6addf83c122100af7d24be Mon Sep 17 00:00:00 2001
From: qmuntal <quimmuntal@gmail.com>
Date: Thu, 1 Aug 2024 16:26:00 +0200
Subject: [PATCH] os/user: support calling Current on impersonated threads

The syscall.OpenCurrentProcessToken call in user.Current fails
when called from an impersonated thread, as the process token is
normally in that case.

This change ensures that the current thread is not impersonated
when calling OpenCurrentProcessToken, and then restores the
impersonation state, if any.

Fixes #68647

Change-Id: I3197535dd8355d21029a42f7aa3936d8fb021202
Reviewed-on: https://go-review.googlesource.com/c/go/+/602415
Reviewed-by: David Chase <drchase@google.com>
Reviewed-by: Michael Knyszek <mknyszek@google.com>
Reviewed-by: Alex Brainman <alex.brainman@gmail.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
---
 .../syscall/windows/security_windows.go       |  24 +++
 .../syscall/windows/syscall_windows.go        |   1 +
 .../syscall/windows/zsyscall_windows.go       |  36 +++++
 src/os/user/lookup_windows.go                 | 113 ++++++++++----
 src/os/user/user_windows_test.go              | 145 ++++++++++++++++++
 5 files changed, 291 insertions(+), 28 deletions(-)
 create mode 100644 src/os/user/user_windows_test.go

diff --git a/src/internal/syscall/windows/security_windows.go b/src/internal/syscall/windows/security_windows.go
index c8c5cbed747..95694c368a5 100644
--- a/src/internal/syscall/windows/security_windows.go
+++ b/src/internal/syscall/windows/security_windows.go
@@ -18,6 +18,8 @@ const (
 
 //sys	ImpersonateSelf(impersonationlevel uint32) (err error) = advapi32.ImpersonateSelf
 //sys	RevertToSelf() (err error) = advapi32.RevertToSelf
+//sys	ImpersonateLoggedOnUser(token syscall.Token) (err error) = advapi32.ImpersonateLoggedOnUser
+//sys	LogonUser(username *uint16, domain *uint16, password *uint16, logonType uint32, logonProvider uint32, token *syscall.Token) (err error) = advapi32.LogonUserW
 
 const (
 	TOKEN_ADJUST_PRIVILEGES = 0x0020
@@ -93,6 +95,26 @@ type LocalGroupUserInfo0 struct {
 	Name *uint16
 }
 
+const (
+	NERR_UserNotFound syscall.Errno = 2221
+	NERR_UserExists   syscall.Errno = 2224
+)
+
+const (
+	USER_PRIV_USER = 1
+)
+
+type UserInfo1 struct {
+	Name        *uint16
+	Password    *uint16
+	PasswordAge uint32
+	Priv        uint32
+	HomeDir     *uint16
+	Comment     *uint16
+	Flags       uint32
+	ScriptPath  *uint16
+}
+
 type UserInfo4 struct {
 	Name            *uint16
 	Password        *uint16
@@ -125,6 +147,8 @@ type UserInfo4 struct {
 	PasswordExpired uint32
 }
 
+//sys	NetUserAdd(serverName *uint16, level uint32, buf *byte, parmErr *uint32) (neterr error) = netapi32.NetUserAdd
+//sys	NetUserDel(serverName *uint16, userName *uint16) (neterr error) = netapi32.NetUserDel
 //sys	NetUserGetLocalGroups(serverName *uint16, userName *uint16, level uint32, flags uint32, buf **byte, prefMaxLen uint32, entriesRead *uint32, totalEntries *uint32) (neterr error) = netapi32.NetUserGetLocalGroups
 
 // GetSystemDirectory retrieves the path to current location of the system
diff --git a/src/internal/syscall/windows/syscall_windows.go b/src/internal/syscall/windows/syscall_windows.go
index cc26a50bb0a..944e4e24505 100644
--- a/src/internal/syscall/windows/syscall_windows.go
+++ b/src/internal/syscall/windows/syscall_windows.go
@@ -39,6 +39,7 @@ const (
 	ERROR_CALL_NOT_IMPLEMENTED   syscall.Errno = 120
 	ERROR_INVALID_NAME           syscall.Errno = 123
 	ERROR_LOCK_FAILED            syscall.Errno = 167
+	ERROR_NO_TOKEN               syscall.Errno = 1008
 	ERROR_NO_UNICODE_TRANSLATION syscall.Errno = 1113
 )
 
diff --git a/src/internal/syscall/windows/zsyscall_windows.go b/src/internal/syscall/windows/zsyscall_windows.go
index 414ad2647d1..7e4d91112b1 100644
--- a/src/internal/syscall/windows/zsyscall_windows.go
+++ b/src/internal/syscall/windows/zsyscall_windows.go
@@ -49,7 +49,9 @@ var (
 
 	procAdjustTokenPrivileges             = modadvapi32.NewProc("AdjustTokenPrivileges")
 	procDuplicateTokenEx                  = modadvapi32.NewProc("DuplicateTokenEx")
+	procImpersonateLoggedOnUser           = modadvapi32.NewProc("ImpersonateLoggedOnUser")
 	procImpersonateSelf                   = modadvapi32.NewProc("ImpersonateSelf")
+	procLogonUserW                        = modadvapi32.NewProc("LogonUserW")
 	procLookupPrivilegeValueW             = modadvapi32.NewProc("LookupPrivilegeValueW")
 	procOpenSCManagerW                    = modadvapi32.NewProc("OpenSCManagerW")
 	procOpenServiceW                      = modadvapi32.NewProc("OpenServiceW")
@@ -82,6 +84,8 @@ var (
 	procVirtualQuery                      = modkernel32.NewProc("VirtualQuery")
 	procNetShareAdd                       = modnetapi32.NewProc("NetShareAdd")
 	procNetShareDel                       = modnetapi32.NewProc("NetShareDel")
+	procNetUserAdd                        = modnetapi32.NewProc("NetUserAdd")
+	procNetUserDel                        = modnetapi32.NewProc("NetUserDel")
 	procNetUserGetLocalGroups             = modnetapi32.NewProc("NetUserGetLocalGroups")
 	procRtlGetVersion                     = modntdll.NewProc("RtlGetVersion")
 	procGetProcessMemoryInfo              = modpsapi.NewProc("GetProcessMemoryInfo")
@@ -113,6 +117,14 @@ func DuplicateTokenEx(hExistingToken syscall.Token, dwDesiredAccess uint32, lpTo
 	return
 }
 
+func ImpersonateLoggedOnUser(token syscall.Token) (err error) {
+	r1, _, e1 := syscall.Syscall(procImpersonateLoggedOnUser.Addr(), 1, uintptr(token), 0, 0)
+	if r1 == 0 {
+		err = errnoErr(e1)
+	}
+	return
+}
+
 func ImpersonateSelf(impersonationlevel uint32) (err error) {
 	r1, _, e1 := syscall.Syscall(procImpersonateSelf.Addr(), 1, uintptr(impersonationlevel), 0, 0)
 	if r1 == 0 {
@@ -121,6 +133,14 @@ func ImpersonateSelf(impersonationlevel uint32) (err error) {
 	return
 }
 
+func LogonUser(username *uint16, domain *uint16, password *uint16, logonType uint32, logonProvider uint32, token *syscall.Token) (err error) {
+	r1, _, e1 := syscall.Syscall6(procLogonUserW.Addr(), 6, uintptr(unsafe.Pointer(username)), uintptr(unsafe.Pointer(domain)), uintptr(unsafe.Pointer(password)), uintptr(logonType), uintptr(logonProvider), uintptr(unsafe.Pointer(token)))
+	if r1 == 0 {
+		err = errnoErr(e1)
+	}
+	return
+}
+
 func LookupPrivilegeValue(systemname *uint16, name *uint16, luid *LUID) (err error) {
 	r1, _, e1 := syscall.Syscall(procLookupPrivilegeValueW.Addr(), 3, uintptr(unsafe.Pointer(systemname)), uintptr(unsafe.Pointer(name)), uintptr(unsafe.Pointer(luid)))
 	if r1 == 0 {
@@ -385,6 +405,22 @@ func NetShareDel(serverName *uint16, netName *uint16, reserved uint32) (neterr e
 	return
 }
 
+func NetUserAdd(serverName *uint16, level uint32, buf *byte, parmErr *uint32) (neterr error) {
+	r0, _, _ := syscall.Syscall6(procNetUserAdd.Addr(), 4, uintptr(unsafe.Pointer(serverName)), uintptr(level), uintptr(unsafe.Pointer(buf)), uintptr(unsafe.Pointer(parmErr)), 0, 0)
+	if r0 != 0 {
+		neterr = syscall.Errno(r0)
+	}
+	return
+}
+
+func NetUserDel(serverName *uint16, userName *uint16) (neterr error) {
+	r0, _, _ := syscall.Syscall(procNetUserDel.Addr(), 2, uintptr(unsafe.Pointer(serverName)), uintptr(unsafe.Pointer(userName)), 0)
+	if r0 != 0 {
+		neterr = syscall.Errno(r0)
+	}
+	return
+}
+
 func NetUserGetLocalGroups(serverName *uint16, userName *uint16, level uint32, flags uint32, buf **byte, prefMaxLen uint32, entriesRead *uint32, totalEntries *uint32) (neterr error) {
 	r0, _, _ := syscall.Syscall9(procNetUserGetLocalGroups.Addr(), 8, uintptr(unsafe.Pointer(serverName)), uintptr(unsafe.Pointer(userName)), uintptr(level), uintptr(flags), uintptr(unsafe.Pointer(buf)), uintptr(prefMaxLen), uintptr(unsafe.Pointer(entriesRead)), uintptr(unsafe.Pointer(totalEntries)), 0)
 	if r0 != 0 {
diff --git a/src/os/user/lookup_windows.go b/src/os/user/lookup_windows.go
index a48fc89720b..f259269a539 100644
--- a/src/os/user/lookup_windows.go
+++ b/src/os/user/lookup_windows.go
@@ -5,9 +5,11 @@
 package user
 
 import (
+	"errors"
 	"fmt"
 	"internal/syscall/windows"
 	"internal/syscall/windows/registry"
+	"runtime"
 	"syscall"
 	"unsafe"
 )
@@ -200,36 +202,91 @@ var (
 )
 
 func current() (*User, error) {
-	t, e := syscall.OpenCurrentProcessToken()
-	if e != nil {
-		return nil, e
-	}
-	defer t.Close()
-	u, e := t.GetTokenUser()
-	if e != nil {
-		return nil, e
-	}
-	pg, e := t.GetTokenPrimaryGroup()
-	if e != nil {
-		return nil, e
-	}
-	uid, e := u.User.Sid.String()
-	if e != nil {
-		return nil, e
-	}
-	gid, e := pg.PrimaryGroup.String()
-	if e != nil {
-		return nil, e
-	}
-	dir, e := t.GetUserProfileDirectory()
-	if e != nil {
-		return nil, e
+	// Use runAsProcessOwner to ensure that we can access the process token
+	// when calling syscall.OpenCurrentProcessToken if the current thread
+	// is impersonating a different user. See https://go.dev/issue/68647.
+	var usr *User
+	err := runAsProcessOwner(func() error {
+		t, e := syscall.OpenCurrentProcessToken()
+		if e != nil {
+			return e
+		}
+		defer t.Close()
+		u, e := t.GetTokenUser()
+		if e != nil {
+			return e
+		}
+		pg, e := t.GetTokenPrimaryGroup()
+		if e != nil {
+			return e
+		}
+		uid, e := u.User.Sid.String()
+		if e != nil {
+			return e
+		}
+		gid, e := pg.PrimaryGroup.String()
+		if e != nil {
+			return e
+		}
+		dir, e := t.GetUserProfileDirectory()
+		if e != nil {
+			return e
+		}
+		username, domain, e := lookupUsernameAndDomain(u.User.Sid)
+		if e != nil {
+			return e
+		}
+		usr, e = newUser(uid, gid, dir, username, domain)
+		return e
+	})
+	return usr, err
+}
+
+// runAsProcessOwner runs f in the context of the current process owner,
+// that is, removing any impersonation that may be in effect before calling f,
+// and restoring the impersonation afterwards.
+func runAsProcessOwner(f func() error) error {
+	var impersonationRollbackErr error
+	runtime.LockOSThread()
+	defer func() {
+		// If impersonation failed, the thread is running with the wrong token,
+		// so it's better to terminate it.
+		// This is achieved by not calling runtime.UnlockOSThread.
+		if impersonationRollbackErr != nil {
+			println("os/user: failed to revert to previous token:", impersonationRollbackErr.Error())
+			runtime.Goexit()
+		} else {
+			runtime.UnlockOSThread()
+		}
+	}()
+	prevToken, isProcessToken, err := getCurrentToken()
+	if err != nil {
+		return fmt.Errorf("os/user: failed to get current token: %w", err)
 	}
-	username, domain, e := lookupUsernameAndDomain(u.User.Sid)
-	if e != nil {
-		return nil, e
+	defer prevToken.Close()
+	if !isProcessToken {
+		if err = windows.RevertToSelf(); err != nil {
+			return fmt.Errorf("os/user: failed to revert to self: %w", err)
+		}
+		defer func() {
+			impersonationRollbackErr = windows.ImpersonateLoggedOnUser(prevToken)
+		}()
 	}
-	return newUser(uid, gid, dir, username, domain)
+	return f()
+}
+
+// getCurrentToken returns the current thread token, or
+// the process token if the thread doesn't have a token.
+func getCurrentToken() (t syscall.Token, isProcessToken bool, err error) {
+	thread, _ := windows.GetCurrentThread()
+	// Need TOKEN_DUPLICATE and TOKEN_IMPERSONATE to use the token in ImpersonateLoggedOnUser.
+	err = windows.OpenThreadToken(thread, syscall.TOKEN_QUERY|syscall.TOKEN_DUPLICATE|syscall.TOKEN_IMPERSONATE, true, &t)
+	if errors.Is(err, windows.ERROR_NO_TOKEN) {
+		// Not impersonating, use the process token.
+		isProcessToken = true
+		t, err = syscall.OpenCurrentProcessToken()
+	}
+	return t, isProcessToken, err
 }
 
 // lookupUserPrimaryGroup obtains the primary group SID for a user using this method:
diff --git a/src/os/user/user_windows_test.go b/src/os/user/user_windows_test.go
new file mode 100644
index 00000000000..3364d7c9eac
--- /dev/null
+++ b/src/os/user/user_windows_test.go
@@ -0,0 +1,145 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package user
+
+import (
+	"crypto/rand"
+	"encoding/base64"
+	"errors"
+	"internal/syscall/windows"
+	"runtime"
+	"strconv"
+	"syscall"
+	"testing"
+	"unsafe"
+)
+
+// windowsTestAcount creates a test user and returns a token for that user.
+// If the user already exists, it will be deleted and recreated.
+// The caller is responsible for closing the token.
+func windowsTestAcount(t *testing.T) syscall.Token {
+	var password [33]byte
+	rand.Read(password[:])
+	// Add special chars to ensure it satisfies password requirements.
+	pwd := base64.StdEncoding.EncodeToString(password[:]) + "_-As@!%*(1)4#2"
+	name, err := syscall.UTF16PtrFromString("GoStdTestUser01")
+	if err != nil {
+		t.Fatal(err)
+	}
+	pwd16, err := syscall.UTF16PtrFromString(pwd)
+	if err != nil {
+		t.Fatal(err)
+	}
+	userInfo := windows.UserInfo1{
+		Name:     name,
+		Password: pwd16,
+		Priv:     windows.USER_PRIV_USER,
+	}
+	// Create user.
+	err = windows.NetUserAdd(nil, 1, (*byte)(unsafe.Pointer(&userInfo)), nil)
+	if errors.Is(err, syscall.ERROR_ACCESS_DENIED) {
+		t.Skip("skipping test; don't have permission to create user")
+	}
+	if errors.Is(err, windows.NERR_UserExists) {
+		// User already exists, delete and recreate.
+		if err = windows.NetUserDel(nil, name); err != nil {
+			t.Fatal(err)
+		}
+		if err = windows.NetUserAdd(nil, 1, (*byte)(unsafe.Pointer(&userInfo)), nil); err != nil {
+			t.Fatal(err)
+		}
+	} else if err != nil {
+		t.Fatal(err)
+	}
+	domain, err := syscall.UTF16PtrFromString(".")
+	if err != nil {
+		t.Fatal(err)
+	}
+	const LOGON32_PROVIDER_DEFAULT = 0
+	const LOGON32_LOGON_INTERACTIVE = 2
+	var token syscall.Token
+	if err = windows.LogonUser(name, domain, pwd16, LOGON32_LOGON_INTERACTIVE, LOGON32_PROVIDER_DEFAULT, &token); err != nil {
+		t.Fatal(err)
+	}
+	t.Cleanup(func() {
+		token.Close()
+		if err = windows.NetUserDel(nil, name); err != nil {
+			if !errors.Is(err, windows.NERR_UserNotFound) {
+				t.Fatal(err)
+			}
+		}
+	})
+	return token
+}
+
+func TestImpersonatedSelf(t *testing.T) {
+	runtime.LockOSThread()
+	defer runtime.UnlockOSThread()
+
+	want, err := current()
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	levels := []uint32{
+		windows.SecurityAnonymous,
+		windows.SecurityIdentification,
+		windows.SecurityImpersonation,
+		windows.SecurityDelegation,
+	}
+	for _, level := range levels {
+		t.Run(strconv.Itoa(int(level)), func(t *testing.T) {
+			if err = windows.ImpersonateSelf(level); err != nil {
+				t.Fatal(err)
+			}
+			defer windows.RevertToSelf()
+
+			got, err := current()
+			if level == windows.SecurityAnonymous {
+				// We can't get the process token when using an anonymous token,
+				// so we expect an error here.
+				if err == nil {
+					t.Fatal("expected error")
+				}
+				return
+			}
+			if err != nil {
+				t.Fatal(err)
+			}
+			compare(t, want, got)
+		})
+	}
+}
+
+func TestImpersonated(t *testing.T) {
+	runtime.LockOSThread()
+	defer runtime.UnlockOSThread()
+
+	want, err := current()
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	// Create a test user and log in as that user.
+	token := windowsTestAcount(t)
+
+	// Impersonate the test user.
+	if err = windows.ImpersonateLoggedOnUser(token); err != nil {
+		t.Fatal(err)
+	}
+	defer func() {
+		err = windows.RevertToSelf()
+		if err != nil {
+			// If we can't revert to self, we can't continue testing.
+			panic(err)
+		}
+	}()
+
+	got, err := current()
+	if err != nil {
+		t.Fatal(err)
+	}
+	compare(t, want, got)
+}
-- 
GitLab