diff --git a/cmd/tailscaled/tailscaled.go b/cmd/tailscaled/tailscaled.go index e0aaa72e8..c2dffaefc 100644 --- a/cmd/tailscaled/tailscaled.go +++ b/cmd/tailscaled/tailscaled.go @@ -317,7 +317,7 @@ func run() error { panic("internal error: exit node resolver not wired up") } - ns, err := newNetstack(logf, e) + ns, err := newNetstack(logf, dialer, e) if err != nil { return fmt.Errorf("newNetstack: %w", err) } @@ -525,12 +525,12 @@ func runDebugServer(mux *http.ServeMux, addr string) { } } -func newNetstack(logf logger.Logf, e wgengine.Engine) (*netstack.Impl, error) { +func newNetstack(logf logger.Logf, dialer *tsdial.Dialer, e wgengine.Engine) (*netstack.Impl, error) { tunDev, magicConn, ok := e.(wgengine.InternalsGetter).GetInternals() if !ok { return nil, fmt.Errorf("%T is not a wgengine.InternalsGetter", e) } - return netstack.Create(logf, tunDev, e, magicConn) + return netstack.Create(logf, tunDev, e, magicConn, dialer) } // mustStartProxyListeners creates listeners for local SOCKS and HTTP diff --git a/cmd/tailscaled/tailscaled_windows.go b/cmd/tailscaled/tailscaled_windows.go index 73f918c64..0645796c1 100644 --- a/cmd/tailscaled/tailscaled_windows.go +++ b/cmd/tailscaled/tailscaled_windows.go @@ -32,6 +32,7 @@ import ( "tailscale.com/ipn/ipnserver" "tailscale.com/logpolicy" "tailscale.com/net/dns" + "tailscale.com/net/tsdial" "tailscale.com/net/tstun" "tailscale.com/safesocket" "tailscale.com/types/logger" @@ -182,6 +183,7 @@ func startIPNServer(ctx context.Context, logid string) error { if err != nil { return err } + dialer := new(tsdial.Dialer) getEngineRaw := func() (wgengine.Engine, error) { dev, devName, err := tstun.New(logf, "Tailscale") @@ -208,13 +210,14 @@ func startIPNServer(ctx context.Context, logid string) error { DNS: d, ListenPort: 41641, LinkMonitor: linkMon, + Dialer: dialer, }) if err != nil { r.Close() dev.Close() return nil, fmt.Errorf("engine: %w", err) } - ns, err := newNetstack(logf, eng) + ns, err := newNetstack(logf, dialer, eng) if err != nil { return nil, fmt.Errorf("newNetstack: %w", err) } @@ -294,7 +297,7 @@ func startIPNServer(ctx context.Context, logid string) error { return fmt.Errorf("safesocket.Listen: %v", err) } - err = ipnserver.Run(ctx, logf, ln, store, linkMon, logid, getEngine, ipnServerOpts()) + err = ipnserver.Run(ctx, logf, ln, store, linkMon, dialer, logid, getEngine, ipnServerOpts()) if err != nil { logf("ipnserver.Run: %v", err) } diff --git a/ipn/ipnlocal/local.go b/ipn/ipnlocal/local.go index 19ce17821..c21da1912 100644 --- a/ipn/ipnlocal/local.go +++ b/ipn/ipnlocal/local.go @@ -165,9 +165,6 @@ func NewLocalBackend(logf logger.Logf, logid string, store ipn.StateStore, diale if dialer == nil { dialer = new(tsdial.Dialer) } - e.AddNetworkMapCallback(func(nm *netmap.NetworkMap) { - dialer.SetDNSMap(tsdial.DNSMapFromNetworkMap(nm)) - }) osshare.SetFileSharingEnabled(false, logf) @@ -2679,6 +2676,7 @@ func hasCapability(nm *netmap.NetworkMap, cap string) bool { } func (b *LocalBackend) setNetMapLocked(nm *netmap.NetworkMap) { + b.dialer.SetNetMap(nm) var login string if nm != nil { login = nm.UserProfiles[nm.User].LoginName diff --git a/ipn/ipnserver/server.go b/ipn/ipnserver/server.go index 6c5f5c173..ae4088dc7 100644 --- a/ipn/ipnserver/server.go +++ b/ipn/ipnserver/server.go @@ -43,7 +43,6 @@ import ( "tailscale.com/safesocket" "tailscale.com/smallzstd" "tailscale.com/types/logger" - "tailscale.com/types/netmap" "tailscale.com/util/groupmember" "tailscale.com/util/pidowner" "tailscale.com/util/systemd" @@ -654,7 +653,7 @@ func StateStore(path string, logf logger.Logf) (ipn.StateStore, error) { // The getEngine func is called repeatedly, once per connection, until it returns an engine successfully. // // Deprecated: use New and Server.Run instead. -func Run(ctx context.Context, logf logger.Logf, ln net.Listener, store ipn.StateStore, linkMon *monitor.Mon, logid string, getEngine func() (wgengine.Engine, error), opts Options) error { +func Run(ctx context.Context, logf logger.Logf, ln net.Listener, store ipn.StateStore, linkMon *monitor.Mon, dialer *tsdial.Dialer, logid string, getEngine func() (wgengine.Engine, error), opts Options) error { getEngine = getEngineUntilItWorksWrapper(getEngine) runDone := make(chan struct{}) defer close(runDone) @@ -738,11 +737,6 @@ func Run(ctx context.Context, logf logger.Logf, ln net.Listener, store ipn.State } } - dialer := new(tsdial.Dialer) - eng.AddNetworkMapCallback(func(nm *netmap.NetworkMap) { - dialer.SetDNSMap(tsdial.DNSMapFromNetworkMap(nm)) - }) - server, err := New(logf, logid, store, eng, dialer, serverModeUser, opts) if err != nil { return err diff --git a/ipn/ipnserver/server_test.go b/ipn/ipnserver/server_test.go index 790d720f0..aa66c7040 100644 --- a/ipn/ipnserver/server_test.go +++ b/ipn/ipnserver/server_test.go @@ -13,6 +13,7 @@ import ( "tailscale.com/ipn" "tailscale.com/ipn/ipnserver" + "tailscale.com/net/tsdial" "tailscale.com/safesocket" "tailscale.com/wgengine" ) @@ -72,6 +73,6 @@ func TestRunMultipleAccepts(t *testing.T) { } defer ln.Close() - err = ipnserver.Run(ctx, logTriggerTestf, ln, store, nil /* mon */, "dummy_logid", ipnserver.FixedEngine(eng), opts) + err = ipnserver.Run(ctx, logTriggerTestf, ln, store, nil /* mon */, new(tsdial.Dialer), "dummy_logid", ipnserver.FixedEngine(eng), opts) t.Logf("ipnserver.Run = %v", err) } diff --git a/net/tsdial/dnsmap.go b/net/tsdial/dnsmap.go index 8541d795b..553c2e765 100644 --- a/net/tsdial/dnsmap.go +++ b/net/tsdial/dnsmap.go @@ -6,6 +6,7 @@ package tsdial import ( "context" + "errors" "fmt" "net" "strconv" @@ -16,15 +17,18 @@ import ( "tailscale.com/util/dnsname" ) -// DNSMap maps MagicDNS names (both base + FQDN) to their first IP. +// dnsMap maps MagicDNS names (both base + FQDN) to their first IP. // It must not be mutated once created. // // Example keys are "foo.domain.tld.beta.tailscale.net" and "foo", // both without trailing dots. -type DNSMap map[string]netaddr.IP +type dnsMap map[string]netaddr.IP -func DNSMapFromNetworkMap(nm *netmap.NetworkMap) DNSMap { - ret := make(DNSMap) +func dnsMapFromNetworkMap(nm *netmap.NetworkMap) dnsMap { + if nm == nil { + return nil + } + ret := make(dnsMap) suffix := nm.MagicDNSSuffix() have4 := false if nm.Name != "" && len(nm.Addresses) > 0 { @@ -68,26 +72,35 @@ func DNSMapFromNetworkMap(nm *netmap.NetworkMap) DNSMap { return ret } +// errUnresolved is a sentinel error returned by dnsMap.resolveMemory. +var errUnresolved = errors.New("address well formed but not resolved") + +func splitHostPort(addr string) (host string, port uint16, err error) { + host, portStr, err := net.SplitHostPort(addr) + if err != nil { + return "", 0, err + } + port16, err := strconv.ParseUint(portStr, 10, 16) + if err != nil { + return "", 0, fmt.Errorf("invalid port in address %q", addr) + } + return host, uint16(port16), nil +} + // Resolve resolves addr into an IP:port using first the MagicDNS contents // of m, else using the system resolver. -func (m DNSMap) Resolve(ctx context.Context, addr string) (netaddr.IPPort, error) { - ipp, pippErr := netaddr.ParseIPPort(addr) - if pippErr == nil { - return ipp, nil - } - host, port, err := net.SplitHostPort(addr) +// +// The error is [exactly] errUnresolved if the addr is a name that isn't known +// in the map. +func (m dnsMap) resolveMemory(ctx context.Context, network, addr string) (_ netaddr.IPPort, err error) { + host, port, err := splitHostPort(addr) if err != nil { - // addr is malformed. + // addr malformed or invalid port. return netaddr.IPPort{}, err } - if _, err := netaddr.ParseIP(host); err == nil { - // The host part of addr was an IP, so the netaddr.ParseIPPort above should've - // passed. Must've been a bad port number. Return the original error. - return netaddr.IPPort{}, pippErr - } - port16, err := strconv.ParseUint(port, 10, 16) - if err != nil { - return netaddr.IPPort{}, fmt.Errorf("invalid port in address %q", addr) + if ip, err := netaddr.ParseIP(host); err == nil { + // addr was literal ip:port. + return netaddr.IPPortFrom(ip, port), nil } // Host is not an IP, so assume it's a DNS name. @@ -95,20 +108,8 @@ func (m DNSMap) Resolve(ctx context.Context, addr string) (netaddr.IPPort, error // Try MagicDNS first, otherwise a real DNS lookup. ip := m[host] if !ip.IsZero() { - return netaddr.IPPortFrom(ip, uint16(port16)), nil + return netaddr.IPPortFrom(ip, port), nil } - // TODO(bradfitz): wire up net/dnscache too. - - // No MagicDNS name so try real DNS. - var r net.Resolver - ips, err := r.LookupIP(ctx, "ip", host) - if err != nil { - return netaddr.IPPort{}, err - } - if len(ips) == 0 { - return netaddr.IPPort{}, fmt.Errorf("DNS lookup returned no results for %q", host) - } - ip, _ = netaddr.FromStdIP(ips[0]) - return netaddr.IPPortFrom(ip, uint16(port16)), nil + return netaddr.IPPort{}, errUnresolved } diff --git a/net/tsdial/dnsmap_test.go b/net/tsdial/dnsmap_test.go index e3cfc5e76..58520f801 100644 --- a/net/tsdial/dnsmap_test.go +++ b/net/tsdial/dnsmap_test.go @@ -19,7 +19,7 @@ func TestDNSMapFromNetworkMap(t *testing.T) { tests := []struct { name string nm *netmap.NetworkMap - want DNSMap + want dnsMap }{ { name: "self", @@ -30,7 +30,7 @@ func TestDNSMapFromNetworkMap(t *testing.T) { pfx("100::123/128"), }, }, - want: DNSMap{ + want: dnsMap{ "foo": ip("100.102.103.104"), "foo.tailnet": ip("100.102.103.104"), }, @@ -59,7 +59,7 @@ func TestDNSMapFromNetworkMap(t *testing.T) { }, }, }, - want: DNSMap{ + want: dnsMap{ "foo": ip("100.102.103.104"), "foo.tailnet": ip("100.102.103.104"), "a": ip("100.0.0.201"), @@ -91,7 +91,7 @@ func TestDNSMapFromNetworkMap(t *testing.T) { }, }, }, - want: DNSMap{ + want: dnsMap{ "foo": ip("100::123"), "foo.tailnet": ip("100::123"), "a": ip("100::201"), @@ -103,7 +103,7 @@ func TestDNSMapFromNetworkMap(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := DNSMapFromNetworkMap(tt.nm) + got := dnsMapFromNetworkMap(tt.nm) if !reflect.DeepEqual(got, tt.want) { t.Errorf("mismatch:\n got %v\nwant %v\n", got, tt.want) } diff --git a/net/tsdial/tsdial.go b/net/tsdial/tsdial.go index 3c32b07ba..3a27b2bed 100644 --- a/net/tsdial/tsdial.go +++ b/net/tsdial/tsdial.go @@ -18,6 +18,7 @@ import ( "inet.af/netaddr" "tailscale.com/net/netknob" + "tailscale.com/types/netmap" "tailscale.com/wgengine/monitor" ) @@ -43,7 +44,7 @@ type Dialer struct { peerDialer *net.Dialer mu sync.Mutex - dns DNSMap + dns dnsMap tunName string // tun device name linkMon *monitor.Mon } @@ -102,26 +103,55 @@ func (d *Dialer) PeerDialControlFunc() func(network, address string, c syscall.R return peerDialControlFunc(d) } -// SetDNSMap sets the current map of DNS names learned from the netmap. -func (d *Dialer) SetDNSMap(m DNSMap) { - // TODO(bradfitz): update this to be aware of DNSConfig - // ExtraNames and CertDomains. +// SetNetMap sets the current network map and notably, the DNS names +// in its DNS configuration. +func (d *Dialer) SetNetMap(nm *netmap.NetworkMap) { + m := dnsMapFromNetworkMap(nm) + d.mu.Lock() defer d.mu.Unlock() d.dns = m } -func (d *Dialer) resolve(ctx context.Context, addr string) (netaddr.IPPort, error) { +func (d *Dialer) Resolve(ctx context.Context, network, addr string) (netaddr.IPPort, error) { d.mu.Lock() dns := d.dns d.mu.Unlock() - return dns.Resolve(ctx, addr) + + // MagicDNS or otherwise baked in to the NetworkMap? Try that first. + ipp, err := dns.resolveMemory(ctx, network, addr) + + if err != errUnresolved { + return ipp, err + } + + // Otherwise, hit the network. + + // TODO(bradfitz): use ExitDNS (Issue 3475) + // TODO(bradfitz): wire up net/dnscache too. + + host, port, err := splitHostPort(addr) + if err != nil { + // addr is malformed. + return netaddr.IPPort{}, err + } + + var r net.Resolver + ips, err := r.LookupIP(ctx, network, host) + if err != nil { + return netaddr.IPPort{}, err + } + if len(ips) == 0 { + return netaddr.IPPort{}, fmt.Errorf("DNS lookup returned no results for %q", host) + } + ip, _ := netaddr.FromStdIP(ips[0]) + return netaddr.IPPortFrom(ip, port), nil } // UserDial connects to the provided network address as if a user were initiating the dial. // (e.g. from a SOCKS or HTTP outbound proxy) func (d *Dialer) UserDial(ctx context.Context, network, addr string) (net.Conn, error) { - ipp, err := d.resolve(ctx, addr) + ipp, err := d.Resolve(ctx, network, addr) if err != nil { return nil, err } diff --git a/tsnet/tsnet.go b/tsnet/tsnet.go index dedbc26a6..fd59b51da 100644 --- a/tsnet/tsnet.go +++ b/tsnet/tsnet.go @@ -132,7 +132,7 @@ func (s *Server) start() error { return fmt.Errorf("%T is not a wgengine.InternalsGetter", eng) } - ns, err := netstack.Create(logf, tunDev, eng, magicConn) + ns, err := netstack.Create(logf, tunDev, eng, magicConn, dialer) if err != nil { return fmt.Errorf("netstack.Create: %w", err) } diff --git a/wgengine/netstack/netstack.go b/wgengine/netstack/netstack.go index 86f61de65..f30c06ab1 100644 --- a/wgengine/netstack/netstack.go +++ b/wgengine/netstack/netstack.go @@ -71,6 +71,7 @@ type Impl struct { e wgengine.Engine mc *magicsock.Conn logf logger.Logf + dialer *tsdial.Dialer // atomicIsLocalIPFunc holds a func that reports whether an IP // is a local (non-subnet) Tailscale IP address of this @@ -78,8 +79,7 @@ type Impl struct { // updates. atomicIsLocalIPFunc atomic.Value // of func(netaddr.IP) bool - mu sync.Mutex - dns tsdial.DNSMap + mu sync.Mutex // connsOpenBySubnetIP keeps track of number of connections open // for each subnet IP temporarily registered on netstack for active // TCP connections, so they can be unregistered when connections are @@ -91,7 +91,7 @@ const nicID = 1 const mtu = 1500 // Create creates and populates a new Impl. -func Create(logf logger.Logf, tundev *tstun.Wrapper, e wgengine.Engine, mc *magicsock.Conn) (*Impl, error) { +func Create(logf logger.Logf, tundev *tstun.Wrapper, e wgengine.Engine, mc *magicsock.Conn, dialer *tsdial.Dialer) (*Impl, error) { if mc == nil { return nil, errors.New("nil magicsock.Conn") } @@ -104,6 +104,9 @@ func Create(logf logger.Logf, tundev *tstun.Wrapper, e wgengine.Engine, mc *magi if e == nil { return nil, errors.New("nil Engine") } + if dialer == nil { + return nil, errors.New("nil Dialer") + } ipstack := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol4, icmp.NewProtocol6}, @@ -139,6 +142,7 @@ func Create(logf logger.Logf, tundev *tstun.Wrapper, e wgengine.Engine, mc *magi tundev: tundev, e: e, mc: mc, + dialer: dialer, connsOpenBySubnetIP: make(map[netaddr.IP]int), } ns.atomicIsLocalIPFunc.Store(tsaddr.NewContainsIPFunc(nil)) @@ -179,12 +183,6 @@ func (ns *Impl) Start() error { return nil } -func (ns *Impl) updateDNS(nm *netmap.NetworkMap) { - ns.mu.Lock() - defer ns.mu.Unlock() - ns.dns = tsdial.DNSMapFromNetworkMap(nm) -} - func (ns *Impl) addSubnetAddress(ip netaddr.IP) { ns.mu.Lock() ns.connsOpenBySubnetIP[ip]++ @@ -230,7 +228,6 @@ func ipPrefixToAddressWithPrefix(ipp netaddr.IPPrefix) tcpip.AddressWithPrefix { func (ns *Impl) updateIPs(nm *netmap.NetworkMap) { ns.atomicIsLocalIPFunc.Store(tsaddr.NewContainsIPFunc(nm.Addresses)) - ns.updateDNS(nm) oldIPs := make(map[tcpip.AddressWithPrefix]bool) for _, protocolAddr := range ns.ipstack.AllAddresses()[nicID] { @@ -299,11 +296,7 @@ func (ns *Impl) updateIPs(nm *netmap.NetworkMap) { } func (ns *Impl) DialContextTCP(ctx context.Context, addr string) (*gonet.TCPConn, error) { - ns.mu.Lock() - dnsMap := ns.dns - ns.mu.Unlock() - - remoteIPPort, err := dnsMap.Resolve(ctx, addr) + remoteIPPort, err := ns.dialer.Resolve(ctx, "tcp", addr) if err != nil { return nil, err } @@ -323,11 +316,7 @@ func (ns *Impl) DialContextTCP(ctx context.Context, addr string) (*gonet.TCPConn } func (ns *Impl) DialContextUDP(ctx context.Context, addr string) (*gonet.UDPConn, error) { - ns.mu.Lock() - dnsMap := ns.dns - ns.mu.Unlock() - - remoteIPPort, err := dnsMap.Resolve(ctx, addr) + remoteIPPort, err := ns.dialer.Resolve(ctx, "udp", addr) if err != nil { return nil, err }