net/dns/resolver: teach the forwarder to do per-domain routing.
Given a DNS route map, the forwarder selects the right set of upstreams for a given name. Signed-off-by: David Anderson <danderson@tailscale.com>pull/1644/head
parent
4ed111281b
commit
9f105d3968
|
@ -17,10 +17,12 @@ import (
|
|||
"sync"
|
||||
"time"
|
||||
|
||||
dns "golang.org/x/net/dns/dnsmessage"
|
||||
"inet.af/netaddr"
|
||||
"tailscale.com/logtail/backoff"
|
||||
"tailscale.com/net/netns"
|
||||
"tailscale.com/types/logger"
|
||||
"tailscale.com/util/dnsname"
|
||||
)
|
||||
|
||||
// headerBytes is the number of bytes in a DNS message header.
|
||||
|
@ -100,6 +102,11 @@ func getTxID(packet []byte) txid {
|
|||
return (txid(hash) << 32) | txid(dnsid)
|
||||
}
|
||||
|
||||
type route struct {
|
||||
suffix string
|
||||
resolvers []netaddr.IPPort
|
||||
}
|
||||
|
||||
// forwarder forwards DNS packets to a number of upstream nameservers.
|
||||
type forwarder struct {
|
||||
logf logger.Logf
|
||||
|
@ -116,10 +123,9 @@ type forwarder struct {
|
|||
conns []*fwdConn
|
||||
|
||||
mu sync.Mutex
|
||||
// upstreams are the nameserver addresses that should be used for forwarding.
|
||||
upstreams []net.Addr
|
||||
// txMap maps DNS txids to active forwarding records.
|
||||
txMap map[txid]forwardingRecord
|
||||
// routes are per-suffix resolvers to use.
|
||||
routes []route // most specific routes first
|
||||
txMap map[txid]forwardingRecord // txids to in-flight requests
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
@ -127,24 +133,22 @@ func init() {
|
|||
}
|
||||
|
||||
func newForwarder(logf logger.Logf, responses chan packet) *forwarder {
|
||||
return &forwarder{
|
||||
ret := &forwarder{
|
||||
logf: logger.WithPrefix(logf, "forward: "),
|
||||
responses: responses,
|
||||
closed: make(chan struct{}),
|
||||
conns: make([]*fwdConn, connCount),
|
||||
txMap: make(map[txid]forwardingRecord),
|
||||
}
|
||||
}
|
||||
|
||||
func (f *forwarder) Start() error {
|
||||
f.wg.Add(connCount + 1)
|
||||
for idx := range f.conns {
|
||||
f.conns[idx] = newFwdConn(f.logf, idx)
|
||||
go f.recv(f.conns[idx])
|
||||
ret.wg.Add(connCount + 1)
|
||||
for idx := range ret.conns {
|
||||
ret.conns[idx] = newFwdConn(ret.logf, idx)
|
||||
go ret.recv(ret.conns[idx])
|
||||
}
|
||||
go f.cleanMap()
|
||||
go ret.cleanMap()
|
||||
|
||||
return nil
|
||||
return ret
|
||||
}
|
||||
|
||||
func (f *forwarder) Close() {
|
||||
|
@ -171,14 +175,15 @@ func (f *forwarder) rebindFromNetworkChange() {
|
|||
}
|
||||
}
|
||||
|
||||
func (f *forwarder) setUpstreams(upstreams []net.Addr) {
|
||||
func (f *forwarder) setRoutes(routes []route) {
|
||||
fmt.Println(routes)
|
||||
f.mu.Lock()
|
||||
f.upstreams = upstreams
|
||||
f.routes = routes
|
||||
f.mu.Unlock()
|
||||
}
|
||||
|
||||
// send sends packet to dst. It is best effort.
|
||||
func (f *forwarder) send(packet []byte, dst net.Addr) {
|
||||
func (f *forwarder) send(packet []byte, dst netaddr.IPPort) {
|
||||
connIdx := rand.Intn(connCount)
|
||||
conn := f.conns[connIdx]
|
||||
conn.send(packet, dst)
|
||||
|
@ -256,24 +261,38 @@ func (f *forwarder) cleanMap() {
|
|||
|
||||
// forward forwards the query to all upstream nameservers and returns the first response.
|
||||
func (f *forwarder) forward(query packet) error {
|
||||
domain, err := nameFromQuery(query.bs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
txid := getTxID(query.bs)
|
||||
|
||||
f.mu.Lock()
|
||||
routes := f.routes
|
||||
f.mu.Unlock()
|
||||
|
||||
upstreams := f.upstreams
|
||||
if len(upstreams) == 0 {
|
||||
f.mu.Unlock()
|
||||
var resolvers []netaddr.IPPort
|
||||
for _, route := range routes {
|
||||
if route.suffix != "." && !dnsname.HasSuffix(domain, route.suffix) {
|
||||
continue
|
||||
}
|
||||
resolvers = route.resolvers
|
||||
break
|
||||
}
|
||||
if len(resolvers) == 0 {
|
||||
return errNoUpstreams
|
||||
}
|
||||
|
||||
f.mu.Lock()
|
||||
f.txMap[txid] = forwardingRecord{
|
||||
src: query.addr,
|
||||
createdAt: time.Now(),
|
||||
}
|
||||
|
||||
f.mu.Unlock()
|
||||
|
||||
for _, upstream := range upstreams {
|
||||
f.send(query.bs, upstream)
|
||||
for _, resolver := range resolvers {
|
||||
f.send(query.bs, resolver)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
@ -309,7 +328,7 @@ func newFwdConn(logf logger.Logf, idx int) *fwdConn {
|
|||
|
||||
// send sends packet to dst using c's connection.
|
||||
// It is best effort. It is UDP, after all. Failures are logged.
|
||||
func (c *fwdConn) send(packet []byte, dst net.Addr) {
|
||||
func (c *fwdConn) send(packet []byte, dst netaddr.IPPort) {
|
||||
var b *backoff.Backoff // lazily initialized, since it is not needed in the common case
|
||||
backOff := func(err error) {
|
||||
if b == nil {
|
||||
|
@ -335,8 +354,9 @@ func (c *fwdConn) send(packet []byte, dst net.Addr) {
|
|||
}
|
||||
c.mu.Unlock()
|
||||
|
||||
a := dst.UDPAddr()
|
||||
c.wg.Add(1)
|
||||
_, err := conn.WriteTo(packet, dst)
|
||||
_, err := conn.WriteTo(packet, a)
|
||||
c.wg.Done()
|
||||
if err == nil {
|
||||
// Success
|
||||
|
@ -469,3 +489,24 @@ func (c *fwdConn) close() {
|
|||
// Unblock any remaining readers.
|
||||
c.change.Broadcast()
|
||||
}
|
||||
|
||||
// nameFromQuery extracts the normalized query name from bs.
|
||||
func nameFromQuery(bs []byte) (string, error) {
|
||||
var parser dns.Parser
|
||||
|
||||
hdr, err := parser.Start(bs)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if hdr.Response {
|
||||
return "", errNotQuery
|
||||
}
|
||||
|
||||
q, err := parser.Question()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
n := q.Name.Data[:q.Name.Length]
|
||||
return rawNameToLower(n), nil
|
||||
}
|
||||
|
|
|
@ -10,7 +10,6 @@ import (
|
|||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
|
@ -68,11 +67,6 @@ type Config struct {
|
|||
LocalDomains []string
|
||||
}
|
||||
|
||||
type route struct {
|
||||
suffix string
|
||||
resolvers []netaddr.IPPort
|
||||
}
|
||||
|
||||
// Resolver is a DNS resolver for nodes on the Tailscale network,
|
||||
// associating them with domain names of the form <mynode>.<mydomain>.<root>.
|
||||
// If it is asked to resolve a domain that is not of that form,
|
||||
|
@ -100,7 +94,6 @@ type Resolver struct {
|
|||
localDomains []string
|
||||
hostToIP map[string][]netaddr.IP
|
||||
ipToHost map[netaddr.IP]string
|
||||
routes []route // most specific routes first
|
||||
}
|
||||
|
||||
// New returns a new resolver.
|
||||
|
@ -121,10 +114,6 @@ func New(logf logger.Logf, linkMon *monitor.Mon) (*Resolver, error) {
|
|||
r.unregLinkMon = r.linkMon.RegisterChangeCallback(r.onLinkMonitorChange)
|
||||
}
|
||||
|
||||
if err := r.forwarder.Start(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
r.wg.Add(1)
|
||||
go r.poll()
|
||||
|
||||
|
@ -138,7 +127,6 @@ func isFQDN(s string) bool {
|
|||
func (r *Resolver) SetConfig(cfg Config) error {
|
||||
routes := make([]route, 0, len(cfg.Routes))
|
||||
reverse := make(map[netaddr.IP]string, len(cfg.Hosts))
|
||||
var defaultUpstream []net.Addr
|
||||
|
||||
for host, ips := range cfg.Hosts {
|
||||
if !isFQDN(host) {
|
||||
|
@ -162,32 +150,19 @@ func (r *Resolver) SetConfig(cfg Config) error {
|
|||
suffix: suffix,
|
||||
resolvers: ips,
|
||||
})
|
||||
if suffix == "." {
|
||||
// TODO: this is a temporary hack to forward upstream
|
||||
// resolvers to the forwarder, which doesn't yet
|
||||
// understand per-domain resolvers. Effectively, SetConfig
|
||||
// currently ignores all routes except for ".", which it
|
||||
// sets as the only resolver.
|
||||
for _, ip := range ips {
|
||||
up := ip.UDPAddr()
|
||||
defaultUpstream = append(defaultUpstream, up)
|
||||
}
|
||||
}
|
||||
}
|
||||
// Sort from longest prefix to shortest.
|
||||
sort.Slice(routes, func(i, j int) bool {
|
||||
return strings.Count(routes[i].suffix, ".") > strings.Count(routes[j].suffix, ".")
|
||||
return dnsname.NumLabels(routes[i].suffix) > dnsname.NumLabels(routes[j].suffix)
|
||||
})
|
||||
|
||||
r.forwarder.setUpstreams(defaultUpstream)
|
||||
r.forwarder.setRoutes(routes)
|
||||
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.localDomains = cfg.LocalDomains
|
||||
r.hostToIP = cfg.Hosts
|
||||
r.ipToHost = reverse
|
||||
r.routes = routes
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -386,6 +361,8 @@ type response struct {
|
|||
}
|
||||
|
||||
// parseQuery parses the query in given packet into a response struct.
|
||||
// if the parse is successful, resp.Name contains the normalized name being queried.
|
||||
// TODO: stuffing the query name in resp.Name temporarily is a hack. Clean it up.
|
||||
func parseQuery(query []byte, resp *response) error {
|
||||
var parser dns.Parser
|
||||
var err error
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
package resolver
|
||||
|
||||
import (
|
||||
"log"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
|
@ -16,8 +16,6 @@ import (
|
|||
// that depends on github.com/miekg/dns
|
||||
// from the rest, which only depends on dnsmessage.
|
||||
|
||||
var dnsHandleFunc = dns.HandleFunc
|
||||
|
||||
// resolveToIP returns a handler function which responds
|
||||
// to queries of type A it receives with an A record containing ipv4,
|
||||
// to queries of type AAAA with an AAAA record containing ipv6,
|
||||
|
@ -68,28 +66,38 @@ func resolveToIP(ipv4, ipv6 netaddr.IP, ns string) dns.HandlerFunc {
|
|||
}
|
||||
}
|
||||
|
||||
func resolveToNXDOMAIN(w dns.ResponseWriter, req *dns.Msg) {
|
||||
var resolveToNXDOMAIN = dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) {
|
||||
m := new(dns.Msg)
|
||||
m.SetRcode(req, dns.RcodeNameError)
|
||||
w.WriteMsg(m)
|
||||
}
|
||||
|
||||
func serveDNS(tb testing.TB, addr string) (*dns.Server, chan error) {
|
||||
server := &dns.Server{Addr: addr, Net: "udp"}
|
||||
})
|
||||
|
||||
func serveDNS(tb testing.TB, addr string, records ...interface{}) *dns.Server {
|
||||
if len(records)%2 != 0 {
|
||||
panic("must have an even number of record values")
|
||||
}
|
||||
mux := dns.NewServeMux()
|
||||
for i := 0; i < len(records); i += 2 {
|
||||
name := records[i].(string)
|
||||
handler := records[i+1].(dns.Handler)
|
||||
mux.Handle(name, handler)
|
||||
}
|
||||
waitch := make(chan struct{})
|
||||
server.NotifyStartedFunc = func() { close(waitch) }
|
||||
server := &dns.Server{
|
||||
Addr: addr,
|
||||
Net: "udp",
|
||||
Handler: mux,
|
||||
NotifyStartedFunc: func() { close(waitch) },
|
||||
ReusePort: true,
|
||||
}
|
||||
|
||||
errch := make(chan error, 1)
|
||||
go func() {
|
||||
err := server.ListenAndServe()
|
||||
if err != nil {
|
||||
log.Printf("ListenAndServe(%q): %v", addr, err)
|
||||
panic(fmt.Sprintf("ListenAndServe(%q): %v", addr, err))
|
||||
}
|
||||
errch <- err
|
||||
close(errch)
|
||||
}()
|
||||
|
||||
<-waitch
|
||||
return server, errch
|
||||
return server
|
||||
}
|
||||
|
|
|
@ -15,13 +15,8 @@ import (
|
|||
"tailscale.com/tstest"
|
||||
)
|
||||
|
||||
var testipv4 = netaddr.IPv4(1, 2, 3, 4)
|
||||
var testipv6 = netaddr.IPv6Raw([16]byte{
|
||||
0x00, 0x01, 0x02, 0x03,
|
||||
0x04, 0x05, 0x06, 0x07,
|
||||
0x08, 0x09, 0x0a, 0x0b,
|
||||
0x0c, 0x0d, 0x0e, 0x0f,
|
||||
})
|
||||
var testipv4 = netaddr.MustParseIP("1.2.3.4")
|
||||
var testipv6 = netaddr.MustParseIP("0001:0203:0405:0607:0809:0a0b:0c0d:0e0f")
|
||||
|
||||
var dnsCfg = Config{
|
||||
Hosts: map[string][]netaddr.IP{
|
||||
|
@ -283,32 +278,14 @@ func TestDelegate(t *testing.T) {
|
|||
t.Skip("skipping test that requires localhost IPv6")
|
||||
}
|
||||
|
||||
dnsHandleFunc("test.site.", resolveToIP(testipv4, testipv6, "dns.test.site."))
|
||||
dnsHandleFunc("nxdomain.site.", resolveToNXDOMAIN)
|
||||
|
||||
v4server, v4errch := serveDNS(t, "127.0.0.1:0")
|
||||
v6server, v6errch := serveDNS(t, "[::1]:0")
|
||||
|
||||
defer func() {
|
||||
if err := <-v4errch; err != nil {
|
||||
t.Errorf("v4 server error: %v", err)
|
||||
}
|
||||
if err := <-v6errch; err != nil {
|
||||
t.Errorf("v6 server error: %v", err)
|
||||
}
|
||||
}()
|
||||
if v4server != nil {
|
||||
defer v4server.Shutdown()
|
||||
}
|
||||
if v6server != nil {
|
||||
defer v6server.Shutdown()
|
||||
}
|
||||
|
||||
if v4server == nil || v6server == nil {
|
||||
// There is an error in at least one of the channels
|
||||
// and we cannot proceed; return to see it.
|
||||
return
|
||||
}
|
||||
v4server := serveDNS(t, "127.0.0.1:0",
|
||||
"test.site.", resolveToIP(testipv4, testipv6, "dns.test.site."),
|
||||
"nxdomain.site.", resolveToNXDOMAIN)
|
||||
defer v4server.Shutdown()
|
||||
v6server := serveDNS(t, "[::1]:0",
|
||||
"test.site.", resolveToIP(testipv4, testipv6, "dns.test.site."),
|
||||
"nxdomain.site.", resolveToNXDOMAIN)
|
||||
defer v6server.Shutdown()
|
||||
|
||||
r, err := New(t.Logf, nil)
|
||||
if err != nil {
|
||||
|
@ -377,19 +354,75 @@ func TestDelegate(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestDelegateCollision(t *testing.T) {
|
||||
dnsHandleFunc("test.site.", resolveToIP(testipv4, testipv6, "dns.test.site."))
|
||||
func TestDelegateSplitRoute(t *testing.T) {
|
||||
test4 := netaddr.MustParseIP("2.3.4.5")
|
||||
test6 := netaddr.MustParseIP("ff::1")
|
||||
|
||||
server, errch := serveDNS(t, "127.0.0.1:0")
|
||||
defer func() {
|
||||
if err := <-errch; err != nil {
|
||||
t.Errorf("server error: %v", err)
|
||||
}
|
||||
}()
|
||||
server1 := serveDNS(t, "127.0.0.1:0",
|
||||
"test.site.", resolveToIP(testipv4, testipv6, "dns.test.site."))
|
||||
defer server1.Shutdown()
|
||||
server2 := serveDNS(t, "127.0.0.1:0",
|
||||
"test.other.", resolveToIP(test4, test6, "dns.other."))
|
||||
defer server2.Shutdown()
|
||||
|
||||
if server == nil {
|
||||
return
|
||||
r, err := New(t.Logf, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("start: %v", err)
|
||||
}
|
||||
defer r.Close()
|
||||
|
||||
cfg := dnsCfg
|
||||
cfg.Routes = map[string][]netaddr.IPPort{
|
||||
".": {netaddr.MustParseIPPort(server1.PacketConn.LocalAddr().String())},
|
||||
"other.": {netaddr.MustParseIPPort(server2.PacketConn.LocalAddr().String())},
|
||||
}
|
||||
r.SetConfig(cfg)
|
||||
|
||||
tests := []struct {
|
||||
title string
|
||||
query []byte
|
||||
response dnsResponse
|
||||
}{
|
||||
{
|
||||
"general",
|
||||
dnspacket("test.site.", dns.TypeA),
|
||||
dnsResponse{ip: testipv4, rcode: dns.RCodeSuccess},
|
||||
},
|
||||
{
|
||||
"override",
|
||||
dnspacket("test.other.", dns.TypeA),
|
||||
dnsResponse{ip: test4, rcode: dns.RCodeSuccess},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.title, func(t *testing.T) {
|
||||
payload, err := syncRespond(r, tt.query)
|
||||
if err != nil {
|
||||
t.Errorf("err = %v; want nil", err)
|
||||
return
|
||||
}
|
||||
response, err := unpackResponse(payload)
|
||||
if err != nil {
|
||||
t.Errorf("extract: err = %v; want nil (in %x)", err, payload)
|
||||
return
|
||||
}
|
||||
if response.rcode != tt.response.rcode {
|
||||
t.Errorf("rcode = %v; want %v", response.rcode, tt.response.rcode)
|
||||
}
|
||||
if response.ip != tt.response.ip {
|
||||
t.Errorf("ip = %v; want %v", response.ip, tt.response.ip)
|
||||
}
|
||||
if response.name != tt.response.name {
|
||||
t.Errorf("name = %v; want %v", response.name, tt.response.name)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDelegateCollision(t *testing.T) {
|
||||
server := serveDNS(t, "127.0.0.1:0",
|
||||
"test.site.", resolveToIP(testipv4, testipv6, "dns.test.site."))
|
||||
defer server.Shutdown()
|
||||
|
||||
r, err := New(t.Logf, nil)
|
||||
|
@ -628,8 +661,8 @@ func TestFull(t *testing.T) {
|
|||
{"ipv6", dnspacket("test2.ipn.dev.", dns.TypeAAAA), ipv6Response},
|
||||
{"no-ipv6", dnspacket("test1.ipn.dev.", dns.TypeAAAA), emptyResponse},
|
||||
{"upper", dnspacket("TEST1.IPN.DEV.", dns.TypeA), ipv4UppercaseResponse},
|
||||
{"ptr", dnspacket("4.3.2.1.in-addr.arpa.", dns.TypePTR), ptrResponse},
|
||||
{"ptr", dnspacket("f.0.e.0.d.0.c.0.b.0.a.0.9.0.8.0.7.0.6.0.5.0.4.0.3.0.2.0.1.0.0.0.ip6.arpa.",
|
||||
{"ptr4", dnspacket("4.3.2.1.in-addr.arpa.", dns.TypePTR), ptrResponse},
|
||||
{"ptr6", dnspacket("f.0.e.0.d.0.c.0.b.0.a.0.9.0.8.0.7.0.6.0.5.0.4.0.3.0.2.0.1.0.0.0.ip6.arpa.",
|
||||
dns.TypePTR), ptrResponse6},
|
||||
{"nxdomain", dnspacket("test3.ipn.dev.", dns.TypeA), nxdomainResponse},
|
||||
}
|
||||
|
@ -702,18 +735,8 @@ func TestTrimRDNSBonjourPrefix(t *testing.T) {
|
|||
}
|
||||
|
||||
func BenchmarkFull(b *testing.B) {
|
||||
dnsHandleFunc("test.site.", resolveToIP(testipv4, testipv6, "dns.test.site."))
|
||||
|
||||
server, errch := serveDNS(b, "127.0.0.1:0")
|
||||
defer func() {
|
||||
if err := <-errch; err != nil {
|
||||
b.Errorf("server error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if server == nil {
|
||||
return
|
||||
}
|
||||
server := serveDNS(b, "127.0.0.1:0",
|
||||
"test.site.", resolveToIP(testipv4, testipv6, "dns.test.site."))
|
||||
defer server.Shutdown()
|
||||
|
||||
r, err := New(b.Logf, nil)
|
||||
|
|
|
@ -124,3 +124,12 @@ func SanitizeHostname(hostname string) string {
|
|||
hostname = TrimCommonSuffixes(hostname)
|
||||
return SanitizeLabel(hostname)
|
||||
}
|
||||
|
||||
// NumLabels returns the number of DNS labels in hostname.
|
||||
// If hostname is empty or the top-level name ".", returns 0.
|
||||
func NumLabels(hostname string) int {
|
||||
if hostname == "" || hostname == "." {
|
||||
return 0
|
||||
}
|
||||
return strings.Count(hostname, ".")
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue