router/dns: set all domains on Windows (#672)

Signed-off-by: Dmytro Shynkevych <dmytro@tailscale.com>
reviewable/pr696/r1
Dmytro Shynkevych 2020-08-19 14:16:57 -04:00 committed by GitHub
parent 287522730d
commit a583e498b0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 93 additions and 21 deletions

View File

@ -5,6 +5,7 @@
package dns package dns
import ( import (
"errors"
"fmt" "fmt"
"strings" "strings"
@ -13,6 +14,12 @@ import (
"tailscale.com/types/logger" "tailscale.com/types/logger"
) )
const (
ipv4RegBase = `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters`
ipv6RegBase = `SYSTEM\CurrentControlSet\Services\Tcpip6\Parameters`
tsRegBase = `SOFTWARE\Tailscale IPN`
)
type windowsManager struct { type windowsManager struct {
logf logger.Logf logf logger.Logf
guid string guid string
@ -25,29 +32,88 @@ func newManager(mconfig ManagerConfig) managerImpl {
} }
} }
func setRegistry(path, nameservers, domains string) error { func setRegistryString(path, name, value string) error {
key, err := registry.OpenKey(registry.LOCAL_MACHINE, path, registry.SET_VALUE)
if err != nil {
return fmt.Errorf("opening %s: %w", path, err)
}
defer key.Close()
err = key.SetStringValue(name, value)
if err != nil {
return fmt.Errorf("setting %s[%s]: %w", path, name, err)
}
return nil
}
func getRegistryString(path, name string) (string, error) {
key, err := registry.OpenKey(registry.LOCAL_MACHINE, path, registry.READ)
if err != nil {
return "", fmt.Errorf("opening %s: %w", path, err)
}
defer key.Close()
value, _, err := key.GetStringValue(name)
if err != nil {
return "", fmt.Errorf("getting %s[%s]: %w", path, name, err)
}
return value, nil
}
func (m windowsManager) setNameservers(basePath string, nameservers []string) error {
path := fmt.Sprintf(`%s\Interfaces\%s`, basePath, m.guid)
value := strings.Join(nameservers, ",")
return setRegistryString(path, "NameServer", value)
}
func (m windowsManager) setDomains(path string, oldDomains, newDomains []string) error {
// We reimplement setRegistryString to ensure that we hold the key for the whole operation.
key, err := registry.OpenKey(registry.LOCAL_MACHINE, path, registry.READ|registry.SET_VALUE) key, err := registry.OpenKey(registry.LOCAL_MACHINE, path, registry.READ|registry.SET_VALUE)
if err != nil { if err != nil {
return fmt.Errorf("opening %s: %w", path, err) return fmt.Errorf("opening %s: %w", path, err)
} }
defer key.Close() defer key.Close()
err = key.SetStringValue("NameServer", nameservers) searchList, _, err := key.GetStringValue("SearchList")
if err != nil { if err != nil && err != registry.ErrNotExist {
return fmt.Errorf("setting %s/NameServer: %w", path, err) return fmt.Errorf("getting %s[SearchList]: %w", path, err)
}
currentDomains := strings.Split(searchList, ",")
var domainsToSet []string
for _, domain := range currentDomains {
inOld, inNew := false, false
// The number of domains should be small,
// so this is probaly faster than constructing a map.
for _, oldDomain := range oldDomains {
if domain == oldDomain {
inOld = true
}
}
for _, newDomain := range newDomains {
if domain == newDomain {
inNew = true
}
} }
err = key.SetStringValue("Domain", domains) if !inNew && !inOld {
if err != nil { domainsToSet = append(domainsToSet, domain)
return fmt.Errorf("setting %s/Domain: %w", path, err)
} }
}
domainsToSet = append(domainsToSet, newDomains...)
searchList = strings.Join(domainsToSet, ",")
if err := key.SetStringValue("SearchList", searchList); err != nil {
return fmt.Errorf("setting %s[SearchList]: %w", path, err)
}
return nil return nil
} }
func (m windowsManager) Up(config Config) error { func (m windowsManager) Up(config Config) error {
var ipsv4 []string var ipsv4 []string
var ipsv6 []string var ipsv6 []string
for _, ip := range config.Nameservers { for _, ip := range config.Nameservers {
if ip.Is4() { if ip.Is4() {
ipsv4 = append(ipsv4, ip.String()) ipsv4 = append(ipsv4, ip.String())
@ -55,23 +121,29 @@ func (m windowsManager) Up(config Config) error {
ipsv6 = append(ipsv6, ip.String()) ipsv6 = append(ipsv6, ip.String())
} }
} }
nsv4 := strings.Join(ipsv4, ",")
nsv6 := strings.Join(ipsv6, ",")
var domains string lastSearchList, err := getRegistryString(tsRegBase, "SearchList")
if len(config.Domains) > 0 { if err != nil && !errors.Is(err, registry.ErrNotExist) {
if len(config.Domains) > 1 {
m.logf("only a single search domain is supported")
}
domains = config.Domains[0]
}
v4Path := `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces\` + m.guid
if err := setRegistry(v4Path, nsv4, domains); err != nil {
return err return err
} }
v6Path := `SYSTEM\CurrentControlSet\Services\Tcpip6\Parameters\Interfaces\` + m.guid lastDomains := strings.Split(lastSearchList, ",")
if err := setRegistry(v6Path, nsv6, domains); err != nil {
if err := m.setNameservers(ipv4RegBase, ipsv4); err != nil {
return err
}
if err := m.setDomains(ipv4RegBase, lastDomains, config.Domains); err != nil {
return err
}
if err := m.setNameservers(ipv6RegBase, ipsv6); err != nil {
return err
}
if err := m.setDomains(ipv6RegBase, lastDomains, config.Domains); err != nil {
return err
}
newSearchList := strings.Join(config.Domains, ",")
if err := setRegistryString(tsRegBase, "SearchList", newSearchList); err != nil {
return err return err
} }