diff --git a/net/dns/manager.go b/net/dns/manager.go index 82dd5d47b..ac65e4bd8 100644 --- a/net/dns/manager.go +++ b/net/dns/manager.go @@ -370,8 +370,8 @@ type dnsTCPSession struct { conn net.Conn srcAddr netaddr.IPPort - readClosing chan struct{} - responses chan []byte // DNS replies pending writing + readClosing chan struct{} + responses chan []byte // DNS replies pending writing ctx context.Context closeCtx context.CancelFunc @@ -457,11 +457,11 @@ func (s *dnsTCPSession) handleReads() { // servicing DNS requests sent down it. func (m *Manager) HandleTCPConn(conn net.Conn, srcAddr netaddr.IPPort) { s := dnsTCPSession{ - m: m, - conn: conn, - srcAddr: srcAddr, - responses: make(chan []byte), - readClosing: make(chan struct{}), + m: m, + conn: conn, + srcAddr: srcAddr, + responses: make(chan []byte), + readClosing: make(chan struct{}), } s.ctx, s.closeCtx = context.WithCancel(m.ctx) go s.handleReads() diff --git a/net/dns/manager_windows.go b/net/dns/manager_windows.go index 951e06444..f740a11e0 100644 --- a/net/dns/manager_windows.go +++ b/net/dns/manager_windows.go @@ -297,7 +297,11 @@ func (m windowsManager) SupportsSplitDNS() bool { } func (m windowsManager) Close() error { - return m.SetDNS(OSConfig{}) + err := m.SetDNS(OSConfig{}) + if m.nrptDB != nil { + m.nrptDB.Close() + } + return err } // disableDynamicUpdates sets the appropriate registry values to prevent the diff --git a/net/dns/manager_windows_test.go b/net/dns/manager_windows_test.go index ac9370587..e2dfdf3e8 100644 --- a/net/dns/manager_windows_test.go +++ b/net/dns/manager_windows_test.go @@ -5,6 +5,7 @@ package dns import ( + "context" "fmt" "math/rand" "strings" @@ -20,11 +21,6 @@ import ( const testGPRuleID = "{7B1B6151-84E6-41A3-8967-62F7F7B45687}" -var ( - procRegisterGPNotification = libUserenv.NewProc("RegisterGPNotification") - procUnregisterGPNotification = libUserenv.NewProc("UnregisterGPNotification") -) - func TestManagerWindowsLocal(t *testing.T) { if !isWindows10OrBetter() || !winutil.IsCurrentProcessElevated() { t.Skipf("test requires running as elevated user on Windows 10+") @@ -53,6 +49,121 @@ func TestManagerWindowsGP(t *testing.T) { runTest(t, false) } +func TestManagerWindowsGPMove(t *testing.T) { + if !isWindows10OrBetter() || !winutil.IsCurrentProcessElevated() { + t.Skipf("test requires running as elevated user on Windows 10+") + } + + checkGPNotificationsWork(t) + + logf := func(format string, args ...any) { + t.Logf(format, args...) + } + + fakeInterface, err := windows.GenerateGUID() + if err != nil { + t.Fatalf("windows.GenerateGUID: %v\n", err) + } + + delIfKey, err := createFakeInterfaceKey(t, fakeInterface) + if err != nil { + t.Fatalf("createFakeInterfaceKey: %v\n", err) + } + defer delIfKey() + + cfg, err := NewOSConfigurator(logf, fakeInterface.String()) + if err != nil { + t.Fatalf("NewOSConfigurator: %v\n", err) + } + mgr := cfg.(windowsManager) + defer mgr.Close() + + usingGP := mgr.nrptDB.writeAsGP + if usingGP { + t.Fatalf("usingGP %v, want %v\n", usingGP, false) + } + + regWatcher, err := newRegKeyWatcher() + if err != nil { + t.Fatalf("newRegKeyWatcher error %v\n", err) + } + + // Upon initialization of cfg, we should not have any NRPT rules + ensureNoRules(t) + + resolvers := []netaddr.IP{netaddr.MustParseIP("1.1.1.1")} + domains := genRandomSubdomains(t, 1) + + // 1. Populate local NRPT + err = mgr.setSplitDNS(resolvers, domains) + if err != nil { + t.Fatalf("setSplitDNS: %v\n", err) + } + + t.Logf("Validating that local NRPT is populated...\n") + validateRegistry(t, nrptBaseLocal, domains) + ensureNoRulesInSubkey(t, nrptBaseGP) + + // 2. Create fake GP key and refresh + t.Logf("Creating fake group policy key and refreshing...\n") + err = createFakeGPKey() + if err != nil { + t.Fatalf("createFakeGPKey: %v\n", err) + } + + err = regWatcher.watch() + if err != nil { + t.Fatalf("regWatcher.watch: %v\n", err) + } + + err = testDoRefresh() + if err != nil { + t.Fatalf("testDoRefresh: %v\n", err) + } + + err = regWatcher.wait() + if err != nil { + t.Fatalf("regWatcher.wait: %v\n", err) + } + + // 3. Check that local NRPT is empty and GP is populated + t.Logf("Validating that group policy NRPT is populated...\n") + validateRegistry(t, nrptBaseGP, domains) + ensureNoRulesInSubkey(t, nrptBaseLocal) + + // 4. Delete fake GP key and refresh + t.Logf("Deleting fake group policy key and refreshing...\n") + deleteFakeGPKey(t) + + err = regWatcher.watch() + if err != nil { + t.Fatalf("regWatcher.watch: %v\n", err) + } + + err = testDoRefresh() + if err != nil { + t.Fatalf("testDoRefresh: %v\n", err) + } + + err = regWatcher.wait() + if err != nil { + t.Fatalf("regWatcher.wait: %v\n", err) + } + + // 5. Check that local NRPT is populated and GP is empty + t.Logf("Validating that local NRPT is populated...\n") + validateRegistry(t, nrptBaseLocal, domains) + ensureNoRulesInSubkey(t, nrptBaseGP) + + // 6. Cleanup + t.Logf("Cleaning up...\n") + err = mgr.setSplitDNS(nil, domains) + if err != nil { + t.Fatalf("setSplitDNS: %v\n", err) + } + ensureNoRules(t) +} + func checkGPNotificationsWork(t *testing.T) { // Test to ensure that RegisterGPNotification work on this machine, // otherwise this test will fail. @@ -83,11 +194,18 @@ func runTest(t *testing.T, isLocal bool) { t.Fatalf("windows.GenerateGUID: %v\n", err) } + delIfKey, err := createFakeInterfaceKey(t, fakeInterface) + if err != nil { + t.Fatalf("createFakeInterfaceKey: %v\n", err) + } + defer delIfKey() + cfg, err := NewOSConfigurator(logf, fakeInterface.String()) if err != nil { t.Fatalf("NewOSConfigurator: %v\n", err) } mgr := cfg.(windowsManager) + defer mgr.Close() usingGP := mgr.nrptDB.writeAsGP if isLocal == usingGP { @@ -99,25 +217,7 @@ func runTest(t *testing.T, isLocal bool) { resolvers := []netaddr.IP{netaddr.MustParseIP("1.1.1.1")} - domains := make([]dnsname.FQDN, 0, 2*nrptMaxDomainsPerRule+1) - - r := rand.New(rand.NewSource(time.Now().UnixNano())) - const charset = "abcdefghijklmnopqrstuvwxyz" - - // Just generate a bunch of random subdomains - for len(domains) < cap(domains) { - l := r.Intn(19) + 1 - b := make([]byte, l) - for i, _ := range b { - b[i] = charset[r.Intn(len(charset))] - } - d := string(b) + ".example.com" - fqdn, err := dnsname.ToFQDN(d) - if err != nil { - t.Fatalf("dnsname.ToFQDN: %v\n", err) - } - domains = append(domains, fqdn) - } + domains := genRandomSubdomains(t, 2*nrptMaxDomainsPerRule+1) cases := []int{ 1, @@ -238,6 +338,32 @@ func deleteFakeGPKey(t *testing.T) { } } +func createFakeInterfaceKey(t *testing.T, guid windows.GUID) (func(), error) { + basePaths := []string{ipv4RegBase, ipv6RegBase} + keyPaths := make([]string, 0, len(basePaths)) + + for _, basePath := range basePaths { + keyPath := fmt.Sprintf(`%s\Interfaces\%s`, basePath, guid) + key, _, err := registry.CreateKey(registry.LOCAL_MACHINE, keyPath, registry.SET_VALUE) + if err != nil { + return nil, err + } + key.Close() + + keyPaths = append(keyPaths, keyPath) + } + + result := func() { + for _, keyPath := range keyPaths { + if err := registry.DeleteKey(registry.LOCAL_MACHINE, keyPath); err != nil { + t.Fatalf("deleting fake interface key \"%s\": %v\n", keyPath, err) + } + } + } + + return result, nil +} + func ensureNoRules(t *testing.T) { ruleIDs := winutil.GetRegStrings(nrptRuleIDValueName, nil) if ruleIDs != nil { @@ -263,11 +389,29 @@ func ensureNoRulesInSubkey(t *testing.T, base string) { key, err := registry.OpenKey(registry.LOCAL_MACHINE, keyName, registry.READ) if err == nil { key.Close() - } - if err != registry.ErrNotExist { + } else if err != registry.ErrNotExist { t.Fatalf("%s: %q, want %q\n", keyName, err, registry.ErrNotExist) } } + + if base == nrptBaseGP { + // When dealing with the group policy subkey, we want the base key to + // also be absent. + key, err := registry.OpenKey(registry.LOCAL_MACHINE, base, registry.READ) + if err == nil { + key.Close() + + isEmpty, err := isPolicyConfigSubkeyEmpty() + if err != nil { + t.Fatalf("isPolicyConfigSubkeyEmpty: %v", err) + } + if isEmpty { + t.Errorf("Unexpectedly found group policy key\n") + } + } else if err != registry.ErrNotExist { + t.Errorf("Group policy key error: %q, want %q\n", err, registry.ErrNotExist) + } + } } func ensureNoSingleRule(t *testing.T, base string) { @@ -332,6 +476,40 @@ func getSavedDomainsForRule(base, ruleID string) ([]string, error) { return result, err } +func genRandomSubdomains(t *testing.T, n int) []dnsname.FQDN { + domains := make([]dnsname.FQDN, 0, n) + + seed := time.Now().UnixNano() + t.Logf("genRandomSubdomains(%d) seed: %v\n", n, seed) + + r := rand.New(rand.NewSource(seed)) + const charset = "abcdefghijklmnopqrstuvwxyz" + + for len(domains) < cap(domains) { + l := r.Intn(19) + 1 + b := make([]byte, l) + for i, _ := range b { + b[i] = charset[r.Intn(len(charset))] + } + d := string(b) + ".example.com" + fqdn, err := dnsname.ToFQDN(d) + if err != nil { + t.Fatalf("dnsname.ToFQDN: %v\n", err) + } + domains = append(domains, fqdn) + } + + return domains +} + +func testDoRefresh() (err error) { + r, _, e := procRefreshPolicyEx.Call(uintptr(1), uintptr(_RP_FORCE)) + if r == 0 { + err = e + } + return err +} + // gpNotificationTracker registers with the Windows policy engine and receives // notifications when policy refreshes occur. type gpNotificationTracker struct { @@ -384,3 +562,103 @@ func (trk *gpNotificationTracker) Close() error { trk.event = 0 return nil } + +type regKeyWatcher struct { + keyLocal registry.Key + keyGP registry.Key + evtLocal windows.Handle + evtGP windows.Handle +} + +func newRegKeyWatcher() (*regKeyWatcher, error) { + var err error + + keyLocal, _, err := registry.CreateKey(registry.LOCAL_MACHINE, nrptBaseLocal, registry.READ) + if err != nil { + return nil, err + } + defer func() { + if err != nil { + keyLocal.Close() + } + }() + + // Monitor dnsBaseGP instead of nrptBaseGP, since the latter will be + // repeatedly created and destroyed throughout the course of the test. + keyGP, _, err := registry.CreateKey(registry.LOCAL_MACHINE, dnsBaseGP, registry.READ) + if err != nil { + return nil, err + } + defer func() { + if err != nil { + keyGP.Close() + } + }() + + evtLocal, err := windows.CreateEvent(nil, 0, 0, nil) + if err != nil { + return nil, err + } + defer func() { + if err != nil { + windows.CloseHandle(evtLocal) + } + }() + + evtGP, err := windows.CreateEvent(nil, 0, 0, nil) + if err != nil { + return nil, err + } + + result := ®KeyWatcher{ + keyLocal: keyLocal, + keyGP: keyGP, + evtLocal: evtLocal, + evtGP: evtGP, + } + + return result, nil +} + +func (rw *regKeyWatcher) watch() error { + // We can make these waits thread-agnostic because the tests that use this code must already run on Windows 10+ + err := windows.RegNotifyChangeKeyValue(windows.Handle(rw.keyLocal), true, + windows.REG_NOTIFY_CHANGE_NAME|windows.REG_NOTIFY_THREAD_AGNOSTIC, rw.evtLocal, true) + if err != nil { + return err + } + + return windows.RegNotifyChangeKeyValue(windows.Handle(rw.keyGP), true, + windows.REG_NOTIFY_CHANGE_NAME|windows.REG_NOTIFY_THREAD_AGNOSTIC, rw.evtGP, true) +} + +func (rw *regKeyWatcher) wait() error { + handles := []windows.Handle{ + rw.evtLocal, + rw.evtGP, + } + + waitCode, err := windows.WaitForMultipleObjects( + handles, + true, // Wait for both events to signal before resuming. + 10000, // 10 seconds (as milliseconds) + ) + + const WAIT_TIMEOUT = 0x102 + switch waitCode { + case WAIT_TIMEOUT: + return context.DeadlineExceeded + case windows.WAIT_FAILED: + return err + default: + return nil + } +} + +func (rw *regKeyWatcher) Close() error { + rw.keyLocal.Close() + rw.keyGP.Close() + windows.CloseHandle(rw.evtLocal) + windows.CloseHandle(rw.evtGP) + return nil +} diff --git a/net/dns/nrpt_windows.go b/net/dns/nrpt_windows.go index ceef38107..4d2b641ff 100644 --- a/net/dns/nrpt_windows.go +++ b/net/dns/nrpt_windows.go @@ -7,6 +7,8 @@ package dns import ( "fmt" "strings" + "sync" + "sync/atomic" "golang.org/x/sys/windows" "golang.org/x/sys/windows/registry" @@ -33,11 +35,25 @@ const ( // This is the name of the registry value we use to save Rule IDs under // the Tailscale registry key. nrptRuleIDValueName = `NRPTRuleIDs` + + // This is the name of the registry value the NRPT uses for storing a rule's version number. + nrptRuleVersionName = `Version` + + // This is the name of the registry value the NRPT uses for storing a rule's list of domains. + nrptRuleDomsName = `Name` + + // This is the name of the registry value the NRPT uses for storing a rule's list of DNS servers. + nrptRuleServersName = `GenericDNSServers` + + // This is the name of the registry value the NRPT uses for storing a rule's flags. + nrptRuleFlagsName = `ConfigOptions` ) var ( - libUserenv = windows.NewLazySystemDLL("userenv.dll") - procRefreshPolicyEx = libUserenv.NewProc("RefreshPolicyEx") + libUserenv = windows.NewLazySystemDLL("userenv.dll") + procRefreshPolicyEx = libUserenv.NewProc("RefreshPolicyEx") + procRegisterGPNotification = libUserenv.NewProc("RegisterGPNotification") + procUnregisterGPNotification = libUserenv.NewProc("UnregisterGPNotification") ) const _RP_FORCE = 1 // Flag for RefreshPolicyEx @@ -45,17 +61,20 @@ const _RP_FORCE = 1 // Flag for RefreshPolicyEx // nrptRuleDatabase ensapsulates access to the Windows Name Resolution Policy // Table (NRPT). type nrptRuleDatabase struct { - logf logger.Logf - ruleIDs []string - writeAsGP bool - isGPDirty bool + logf logger.Logf + watcher *gpNotificationWatcher + isGPRefreshPending atomic.Value // of bool + mu sync.Mutex // protects the fields below + ruleIDs []string + isGPDirty bool + writeAsGP bool } func newNRPTRuleDatabase(logf logger.Logf) *nrptRuleDatabase { ret := &nrptRuleDatabase{logf: logf} ret.loadRuleSubkeyNames() - ret.initWriteAsGP() - logf("nrptRuleDatabase using group policy: %v\n", ret.writeAsGP) + ret.detectWriteAsGP() + ret.watchForGPChanges() // Best-effort: if our NRPT rule exists, try to delete it. Unlike // per-interface configuration, NRPT rules survive the unclean // termination of the Tailscale process, and depending on the @@ -75,14 +94,28 @@ func (db *nrptRuleDatabase) loadRuleSubkeyNames() { db.ruleIDs = result } -// initWriteAsGP determines which registry path should be used for writing +// detectWriteAsGP determines which registry path should be used for writing // NRPT rules. If there are rules in the GP path that don't belong to us, then -// we should use the GP path. -func (db *nrptRuleDatabase) initWriteAsGP() { +// we should use the GP path. When detectWriteAsGP determines that the desired +// path has changed, it moves the NRPT policies as appropriate. +func (db *nrptRuleDatabase) detectWriteAsGP() { + db.mu.Lock() + defer db.mu.Unlock() + + writeAsGP := false var err error + defer func() { if err != nil { - db.writeAsGP = false + return + } + prev := db.writeAsGP + db.writeAsGP = writeAsGP + db.logf("nrptRuleDatabase using group policy: %v, was %v\n", writeAsGP, prev) + // When db.watcher == nil, prev != writeAsGP because we're initializing, not + // because anything has changed. We do not invoke db.movePolicies in that case. + if db.watcher != nil && prev != writeAsGP { + db.movePolicies(writeAsGP) } }() @@ -101,14 +134,13 @@ func (db *nrptRuleDatabase) initWriteAsGP() { // If the dnsKey contains any values, then we need to use the GP key. if ki.ValueCount > 0 { - db.writeAsGP = true + writeAsGP = true return } if ki.SubKeyCount == 0 { // If dnsKey contains no values and no subkeys, then we definitely don't // need to use the GP key. - db.writeAsGP = false return } @@ -139,11 +171,14 @@ func (db *nrptRuleDatabase) initWriteAsGP() { // Any leftover rules do not belong to us. When group policy is being used // by something else, we must also use the GP path. - db.writeAsGP = len(gpSubkeyMap) > 0 + writeAsGP = len(gpSubkeyMap) > 0 } // DelAllRuleKeys removes any and all NRPT rules that are owned by Tailscale. func (db *nrptRuleDatabase) DelAllRuleKeys() error { + db.mu.Lock() + defer db.mu.Unlock() + if err := db.delRuleKeys(db.ruleIDs); err != nil { return err } @@ -212,6 +247,9 @@ func isPolicyConfigSubkeyEmpty() (bool, error) { } func (db *nrptRuleDatabase) WriteSplitDNSConfig(servers []string, domains []dnsname.FQDN) error { + db.mu.Lock() + defer db.mu.Unlock() + // NRPT has an undocumented restriction that each rule may only be associated // with a maximum of 50 domains. If we are setting rules for more domains // than that, we need to split domains into chunks and write out a rule per chunk. @@ -224,6 +262,7 @@ func (db *nrptRuleDatabase) WriteSplitDNSConfig(servers []string, domains []dnsn } db.loadRuleSubkeyNames() + for len(db.ruleIDs) < domainRulesLen { guid, err := windows.GenerateGUID() if err != nil { @@ -280,9 +319,22 @@ func (db *nrptRuleDatabase) WriteSplitDNSConfig(servers []string, domains []dnsn // Refresh notifies the Windows group policy engine when policies have changed. func (db *nrptRuleDatabase) Refresh() { + db.mu.Lock() + defer db.mu.Unlock() + + db.refreshLocked() +} + +func (db *nrptRuleDatabase) refreshLocked() { if !db.isGPDirty { return } + + // Record that we are about to initiate a refresh. + // (*nrptRuleDatabase).watchForGPChanges() checks this value to avoid false + // positives. + db.isGPRefreshPending.Store(true) + ok, _, err := procRefreshPolicyEx.Call( uintptr(1), // Win32 TRUE: Refresh computer policy, not user policy. uintptr(_RP_FORCE), @@ -291,6 +343,7 @@ func (db *nrptRuleDatabase) Refresh() { db.logf("RefreshPolicyEx failed: %v", err) return } + db.isGPDirty = false } @@ -310,22 +363,256 @@ func (db *nrptRuleDatabase) writeNRPTRule(ruleID string, servers, doms []string) return fmt.Errorf("opening %s: %w", keyStr, err) } defer key.Close() - if err := key.SetDWordValue("Version", 1); err != nil { - return err - } - if err := key.SetStringsValue("Name", doms); err != nil { - return err - } - if err := key.SetStringValue("GenericDNSServers", strings.Join(servers, "; ")); err != nil { - return err - } - if err := key.SetDWordValue("ConfigOptions", nrptOverrideDNS); err != nil { + + if err := writeNRPTValues(key, strings.Join(servers, "; "), doms); err != nil { return err } - if db.writeAsGP { - db.isGPDirty = true - } + db.isGPDirty = db.writeAsGP + + return nil +} + +func readNRPTValues(key registry.Key) (servers string, doms []string, err error) { + doms, _, err = key.GetStringsValue(nrptRuleDomsName) + if err != nil { + return servers, doms, err + } + + servers, _, err = key.GetStringValue(nrptRuleServersName) + return servers, doms, err +} + +func writeNRPTValues(key registry.Key, servers string, doms []string) error { + if err := key.SetDWordValue(nrptRuleVersionName, 1); err != nil { + return err + } + + if err := key.SetStringsValue(nrptRuleDomsName, doms); err != nil { + return err + } + + if err := key.SetStringValue(nrptRuleServersName, servers); err != nil { + return err + } + + return key.SetDWordValue(nrptRuleFlagsName, nrptOverrideDNS) +} + +func (db *nrptRuleDatabase) watchForGPChanges() { + db.isGPRefreshPending.Store(false) + + watchHandler := func() { + // Do not invoke detectWriteAsGP when we ourselves were responsible for + // initiating the group policy refresh. + if db.isGPRefreshPending.CompareAndSwap(true, false) { + return + } + db.logf("Computer group policies refreshed, reconfiguring NRPT rule database.") + db.detectWriteAsGP() + } + + watcher, err := newGPNotificationWatcher(watchHandler) + if err != nil { + return + } + + db.watcher = watcher +} + +// movePolicies moves each NRPT rule depending on the value of writeAsGP. +// When writeAsGP is true, each NRPT rule is moved from the local NRPT table +// to the group policy NRPT table. When writeAsGP is false, the move is +// executed in the opposite direction. db.mu should already be locked. +func (db *nrptRuleDatabase) movePolicies(writeAsGP bool) { + // Since we're moving either in or out of the group policy NRPT table, we need + // to refresh once this movePolicies is done. + defer db.refreshLocked() + + var fromBase string + var toBase string + if writeAsGP { + fromBase = nrptBaseLocal + toBase = nrptBaseGP + } else { + fromBase = nrptBaseGP + toBase = nrptBaseLocal + } + fromBase += `\` + toBase += `\` + + for _, id := range db.ruleIDs { + fromStr := fromBase + id + toStr := toBase + id + + if err := executeMove(fromStr, toStr); err != nil { + db.logf("movePolicies: executeMove(\"%s\", \"%s\") failed with error %v", fromStr, toStr, err) + return + } + + db.isGPDirty = true + } + + if writeAsGP { + return + } + + // Now that we have moved our rules out of the group policy subkey, it should + // now be empty. Let's verify that. + isEmpty, err := isPolicyConfigSubkeyEmpty() + if err != nil { + db.logf("movePolicies: isPolicyConfigSubkeyEmpty error %v", err) + return + } + if !isEmpty { + db.logf("movePolicies: policy config subkey should be empty, but isn't!") + return + } + + // Delete the subkey itself. Group policy will continue to override local + // settings unless we do so. + if err := registry.DeleteKey(registry.LOCAL_MACHINE, nrptBaseGP); err != nil { + db.logf("movePolicies DeleteKey error %v", err) + } + + db.isGPDirty = true +} + +func executeMove(subKeyFrom, subKeyTo string) error { + err := func() error { + // Move the NRPT registry values from subKeyFrom to subKeyTo. + fromKey, err := registry.OpenKey(registry.LOCAL_MACHINE, subKeyFrom, registry.QUERY_VALUE) + if err != nil { + return err + } + defer fromKey.Close() + + toKey, _, err := registry.CreateKey(registry.LOCAL_MACHINE, subKeyTo, registry.WRITE) + if err != nil { + return err + } + defer toKey.Close() + + servers, doms, err := readNRPTValues(fromKey) + if err != nil { + return err + } + + return writeNRPTValues(toKey, servers, doms) + }() + if err != nil { + return err + } + + // This is a move operation, so we must delete subKeyFrom. + return registry.DeleteKey(registry.LOCAL_MACHINE, subKeyFrom) +} + +func (db *nrptRuleDatabase) Close() error { + if db.watcher == nil { + return nil + } + err := db.watcher.Close() + db.watcher = nil + return err +} + +type gpNotificationWatcher struct { + gpWaitEvents [2]windows.Handle + handler func() + done chan struct{} +} + +// newGPNotificationWatcher creates an instance of gpNotificationWatcher that +// invokes handler every time Windows notifies it of a group policy change. +func newGPNotificationWatcher(handler func()) (*gpNotificationWatcher, error) { + var err error + + // evtDone is signaled by (*gpNotificationWatcher).Close() to indicate that + // the doWatch goroutine should exit. + evtDone, err := windows.CreateEvent(nil, 0, 0, nil) + if err != nil { + return nil, err + } + defer func() { + if err != nil { + windows.CloseHandle(evtDone) + } + }() + + // evtChanged is registered with the Windows policy engine to become + // signalled any time group policy has been refreshed. + evtChanged, err := windows.CreateEvent(nil, 0, 0, nil) + if err != nil { + return nil, err + } + defer func() { + if err != nil { + windows.CloseHandle(evtChanged) + } + }() + + // Tell Windows to signal evtChanged whenever group policies are refreshed. + ok, _, e := procRegisterGPNotification.Call( + uintptr(evtChanged), + uintptr(1), // Win32 TRUE: We want to monitor computer policy changes, not user policy changes. + ) + if ok == 0 { + err = e + return nil, err + } + + result := &gpNotificationWatcher{ + // Ordering of the event handles in gpWaitEvents is important: + // When calling windows.WaitForMultipleObjects and multiple objects are + // signalled simultaneously, it always returns the wait code for the + // lowest-indexed handle in its input array. evtDone is higher priority for + // us than evtChanged, so the former must be placed into the array ahead of + // the latter. + gpWaitEvents: [2]windows.Handle{ + evtDone, + evtChanged, + }, + handler: handler, + done: make(chan struct{}), + } + + go result.doWatch() + + return result, nil +} + +func (w *gpNotificationWatcher) doWatch() { + // The wait code corresponding to the event that is signalled when a group + // policy change occurs. + const expectedWaitCode = windows.WAIT_OBJECT_0 + 1 + for { + if waitCode, _ := windows.WaitForMultipleObjects(w.gpWaitEvents[:], false, windows.INFINITE); waitCode != expectedWaitCode { + break + } + w.handler() + } + close(w.done) +} + +func (w *gpNotificationWatcher) Close() error { + // Notify doWatch that we're done and it should exit. + if err := windows.SetEvent(w.gpWaitEvents[0]); err != nil { + return err + } + + procUnregisterGPNotification.Call(uintptr(w.gpWaitEvents[1])) + + // Wait for doWatch to complete. + <-w.done + + // Now we may safely clean up all the things. + for i, evt := range w.gpWaitEvents { + windows.CloseHandle(evt) + w.gpWaitEvents[i] = 0 + } + + w.handler = nil return nil }