From a583e498b0615bb02fc14ded7c200d6a3cc0da45 Mon Sep 17 00:00:00 2001 From: Dmytro Shynkevych Date: Wed, 19 Aug 2020 14:16:57 -0400 Subject: [PATCH] router/dns: set all domains on Windows (#672) Signed-off-by: Dmytro Shynkevych --- wgengine/router/dns/manager_windows.go | 114 ++++++++++++++++++++----- 1 file changed, 93 insertions(+), 21 deletions(-) diff --git a/wgengine/router/dns/manager_windows.go b/wgengine/router/dns/manager_windows.go index 4196d1f70..a768b4312 100644 --- a/wgengine/router/dns/manager_windows.go +++ b/wgengine/router/dns/manager_windows.go @@ -5,6 +5,7 @@ package dns import ( + "errors" "fmt" "strings" @@ -13,6 +14,12 @@ import ( "tailscale.com/types/logger" ) +const ( + ipv4RegBase = `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters` + ipv6RegBase = `SYSTEM\CurrentControlSet\Services\Tcpip6\Parameters` + tsRegBase = `SOFTWARE\Tailscale IPN` +) + type windowsManager struct { logf logger.Logf guid string @@ -25,29 +32,88 @@ func newManager(mconfig ManagerConfig) managerImpl { } } -func setRegistry(path, nameservers, domains string) error { +func setRegistryString(path, name, value string) error { + key, err := registry.OpenKey(registry.LOCAL_MACHINE, path, registry.SET_VALUE) + if err != nil { + return fmt.Errorf("opening %s: %w", path, err) + } + defer key.Close() + + err = key.SetStringValue(name, value) + if err != nil { + return fmt.Errorf("setting %s[%s]: %w", path, name, err) + } + return nil +} + +func getRegistryString(path, name string) (string, error) { + key, err := registry.OpenKey(registry.LOCAL_MACHINE, path, registry.READ) + if err != nil { + return "", fmt.Errorf("opening %s: %w", path, err) + } + defer key.Close() + + value, _, err := key.GetStringValue(name) + if err != nil { + return "", fmt.Errorf("getting %s[%s]: %w", path, name, err) + } + return value, nil +} + +func (m windowsManager) setNameservers(basePath string, nameservers []string) error { + path := fmt.Sprintf(`%s\Interfaces\%s`, basePath, m.guid) + value := strings.Join(nameservers, ",") + return setRegistryString(path, "NameServer", value) +} + +func (m windowsManager) setDomains(path string, oldDomains, newDomains []string) error { + // We reimplement setRegistryString to ensure that we hold the key for the whole operation. key, err := registry.OpenKey(registry.LOCAL_MACHINE, path, registry.READ|registry.SET_VALUE) if err != nil { return fmt.Errorf("opening %s: %w", path, err) } defer key.Close() - err = key.SetStringValue("NameServer", nameservers) - if err != nil { - return fmt.Errorf("setting %s/NameServer: %w", path, err) + searchList, _, err := key.GetStringValue("SearchList") + if err != nil && err != registry.ErrNotExist { + return fmt.Errorf("getting %s[SearchList]: %w", path, err) } + currentDomains := strings.Split(searchList, ",") - err = key.SetStringValue("Domain", domains) - if err != nil { - return fmt.Errorf("setting %s/Domain: %w", path, err) + var domainsToSet []string + for _, domain := range currentDomains { + inOld, inNew := false, false + + // The number of domains should be small, + // so this is probaly faster than constructing a map. + for _, oldDomain := range oldDomains { + if domain == oldDomain { + inOld = true + } + } + for _, newDomain := range newDomains { + if domain == newDomain { + inNew = true + } + } + + if !inNew && !inOld { + domainsToSet = append(domainsToSet, domain) + } } + domainsToSet = append(domainsToSet, newDomains...) + searchList = strings.Join(domainsToSet, ",") + if err := key.SetStringValue("SearchList", searchList); err != nil { + return fmt.Errorf("setting %s[SearchList]: %w", path, err) + } return nil } func (m windowsManager) Up(config Config) error { var ipsv4 []string var ipsv6 []string + for _, ip := range config.Nameservers { if ip.Is4() { ipsv4 = append(ipsv4, ip.String()) @@ -55,23 +121,29 @@ func (m windowsManager) Up(config Config) error { ipsv6 = append(ipsv6, ip.String()) } } - nsv4 := strings.Join(ipsv4, ",") - nsv6 := strings.Join(ipsv6, ",") - var domains string - if len(config.Domains) > 0 { - if len(config.Domains) > 1 { - m.logf("only a single search domain is supported") - } - domains = config.Domains[0] - } - - v4Path := `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces\` + m.guid - if err := setRegistry(v4Path, nsv4, domains); err != nil { + lastSearchList, err := getRegistryString(tsRegBase, "SearchList") + if err != nil && !errors.Is(err, registry.ErrNotExist) { return err } - v6Path := `SYSTEM\CurrentControlSet\Services\Tcpip6\Parameters\Interfaces\` + m.guid - if err := setRegistry(v6Path, nsv6, domains); err != nil { + lastDomains := strings.Split(lastSearchList, ",") + + if err := m.setNameservers(ipv4RegBase, ipsv4); err != nil { + return err + } + if err := m.setDomains(ipv4RegBase, lastDomains, config.Domains); err != nil { + return err + } + + if err := m.setNameservers(ipv6RegBase, ipsv6); err != nil { + return err + } + if err := m.setDomains(ipv6RegBase, lastDomains, config.Domains); err != nil { + return err + } + + newSearchList := strings.Join(config.Domains, ",") + if err := setRegistryString(tsRegBase, "SearchList", newSearchList); err != nil { return err }