diff --git a/net/dns/resolver/forwarder.go b/net/dns/resolver/forwarder.go index 866b4f045..db533d195 100644 --- a/net/dns/resolver/forwarder.go +++ b/net/dns/resolver/forwarder.go @@ -88,6 +88,57 @@ func getTxID(packet []byte) txid { return (txid(hash) << 32) | txid(dnsid) } +// clampEDNSSize attempts to limit the maximum EDNS response size. This is not +// an exhaustive solution, instead only easy cases are currently handled in the +// interest of speed and reduced complexity. Only OPT records at the very end of +// the message with no option codes are addressed. +// TODO: handle more situations if we discover that they happen often +func clampEDNSSize(packet []byte, maxSize uint16) { + // optFixedBytes is the size of an OPT record with no option codes. + const optFixedBytes = 11 + const edns0Version = 0 + + if len(packet) < headerBytes+optFixedBytes { + return + } + + arCount := binary.BigEndian.Uint16(packet[10:12]) + if arCount == 0 { + // OPT shows up in an AR, so there must be no OPT + return + } + + opt := packet[len(packet)-optFixedBytes:] + + if opt[0] != 0 { + // OPT NAME must be 0 (root domain) + return + } + if dns.Type(binary.BigEndian.Uint16(opt[1:3])) != dns.TypeOPT { + // Not an OPT record + return + } + requestedSize := binary.BigEndian.Uint16(opt[3:5]) + // Ignore extended RCODE in opt[5] + if opt[6] != edns0Version { + // Be conservative and don't touch unknown versions. + return + } + // Ignore flags in opt[7:9] + if binary.BigEndian.Uint16(opt[10:12]) != 0 { + // RDLEN must be 0 (no variable length data). We're at the end of the + // packet so this should be 0 anyway).. + return + } + + if requestedSize <= maxSize { + return + } + + // Clamp the maximum size + binary.BigEndian.PutUint16(opt[3:5], maxSize) +} + type route struct { Suffix dnsname.FQDN Resolvers []netaddr.IPPort @@ -233,6 +284,8 @@ func (f *forwarder) send(ctx context.Context, txidOut txid, closeOnCtxDone *clos // best we can do. } + clampEDNSSize(out, maxResponseBytes) + return out, nil } @@ -257,6 +310,7 @@ func (f *forwarder) forward(query packet) error { } txid := getTxID(query.bs) + clampEDNSSize(query.bs, maxResponseBytes) resolvers := f.resolvers(domain) if len(resolvers) == 0 { diff --git a/net/dns/resolver/tsdns_server_test.go b/net/dns/resolver/tsdns_server_test.go index 71a31eba2..e0d8c40c5 100644 --- a/net/dns/resolver/tsdns_server_test.go +++ b/net/dns/resolver/tsdns_server_test.go @@ -68,7 +68,7 @@ func resolveToIP(ipv4, ipv6 netaddr.IP, ns string) dns.HandlerFunc { // resolveToTXT returns a handler function which responds to queries of type TXT // it receives with the strings in txts. -func resolveToTXT(txts []string) dns.HandlerFunc { +func resolveToTXT(txts []string, ednsMaxSize uint16) dns.HandlerFunc { return func(w dns.ResponseWriter, req *dns.Msg) { m := new(dns.Msg) m.SetReply(req) @@ -93,6 +93,27 @@ func resolveToTXT(txts []string) dns.HandlerFunc { } m.Answer = append(m.Answer, ans) + + queryInfo := &dns.TXT{ + Hdr: dns.RR_Header{ + Name: "query-info.test.", + Rrtype: dns.TypeTXT, + Class: dns.ClassINET, + }, + } + + if edns := req.IsEdns0(); edns == nil { + queryInfo.Txt = []string{"EDNS=false"} + } else { + queryInfo.Txt = []string{"EDNS=true", fmt.Sprintf("maxSize=%v", edns.UDPSize())} + } + + m.Extra = append(m.Extra, queryInfo) + + if ednsMaxSize > 0 { + m.SetEdns0(ednsMaxSize, false) + } + if err := w.WriteMsg(m); err != nil { panic(err) } diff --git a/net/dns/resolver/tsdns_test.go b/net/dns/resolver/tsdns_test.go index 857a8ba08..353327b35 100644 --- a/net/dns/resolver/tsdns_test.go +++ b/net/dns/resolver/tsdns_test.go @@ -12,6 +12,8 @@ import ( "math/rand" "net" "runtime" + "strconv" + "strings" "testing" dns "golang.org/x/net/dns/dnsmessage" @@ -32,7 +34,9 @@ var dnsCfg = Config{ LocalDomains: []dnsname.FQDN{"ipn.dev."}, } -func dnspacket(domain dnsname.FQDN, tp dns.Type) []byte { +const noEdns = 0 + +func dnspacket(domain dnsname.FQDN, tp dns.Type, ednsSize uint16) []byte { var dnsHeader dns.Header question := dns.Question{ Name: dns.MustNewName(domain.WithTrailingDot()), @@ -41,19 +45,44 @@ func dnspacket(domain dnsname.FQDN, tp dns.Type) []byte { } builder := dns.NewBuilder(nil, dnsHeader) - builder.StartQuestions() - builder.Question(question) + if err := builder.StartQuestions(); err != nil { + panic(err) + } + if err := builder.Question(question); err != nil { + panic(err) + } + + if ednsSize != noEdns { + if err := builder.StartAdditionals(); err != nil { + panic(err) + } + + ednsHeader := dns.ResourceHeader{ + Name: dns.MustNewName("."), + Type: dns.TypeOPT, + Class: dns.Class(ednsSize), + } + + if err := builder.OPTResource(ednsHeader, dns.OPTResource{}); err != nil { + panic(err) + } + } + payload, _ := builder.Finish() return payload } type dnsResponse struct { - ip netaddr.IP - txt []string - name dnsname.FQDN - rcode dns.RCode - truncated bool + ip netaddr.IP + txt []string + name dnsname.FQDN + rcode dns.RCode + truncated bool + requestEdns bool + requestEdnsSize uint16 + responseEdns bool + responseEdnsSize uint16 } func unpackResponse(payload []byte) (dnsResponse, error) { @@ -89,41 +118,98 @@ func unpackResponse(payload []byte) (dnsResponse, error) { return response, err } - ah, err := parser.AnswerHeader() + for { + ah, err := parser.AnswerHeader() + if err == dns.ErrSectionDone { + break + } + if err != nil { + return response, err + } + + switch ah.Type { + case dns.TypeA: + res, err := parser.AResource() + if err != nil { + return response, err + } + response.ip = netaddr.IPv4(res.A[0], res.A[1], res.A[2], res.A[3]) + case dns.TypeAAAA: + res, err := parser.AAAAResource() + if err != nil { + return response, err + } + response.ip = netaddr.IPv6Raw(res.AAAA) + case dns.TypeTXT: + res, err := parser.TXTResource() + if err != nil { + return response, err + } + response.txt = res.TXT + case dns.TypeNS: + res, err := parser.NSResource() + if err != nil { + return response, err + } + response.name, err = dnsname.ToFQDN(res.NS.String()) + if err != nil { + return response, err + } + default: + return response, errors.New("type not in {A, AAAA, NS}") + } + } + + err = parser.SkipAllAuthorities() if err != nil { return response, err } - switch ah.Type { - case dns.TypeA: - res, err := parser.AResource() + for { + ah, err := parser.AdditionalHeader() + if err == dns.ErrSectionDone { + break + } if err != nil { return response, err } - response.ip = netaddr.IPv4(res.A[0], res.A[1], res.A[2], res.A[3]) - case dns.TypeAAAA: - res, err := parser.AAAAResource() - if err != nil { - return response, err + + switch ah.Type { + case dns.TypeOPT: + _, err := parser.OPTResource() + if err != nil { + return response, err + } + response.responseEdns = true + response.responseEdnsSize = uint16(ah.Class) + case dns.TypeTXT: + res, err := parser.TXTResource() + if err != nil { + return response, err + } + switch ah.Name.String() { + case "query-info.test.": + for _, msg := range res.TXT { + s := strings.SplitN(msg, "=", 2) + if len(s) != 2 { + continue + } + switch s[0] { + case "EDNS": + response.requestEdns, err = strconv.ParseBool(s[1]) + if err != nil { + return response, err + } + case "maxSize": + sz, err := strconv.ParseUint(s[1], 10, 16) + if err != nil { + return response, err + } + response.requestEdnsSize = uint16(sz) + } + } + } } - response.ip = netaddr.IPv6Raw(res.AAAA) - case dns.TypeTXT: - res, err := parser.TXTResource() - if err != nil { - return response, err - } - response.txt = res.TXT - case dns.TypeNS: - res, err := parser.NSResource() - if err != nil { - return response, err - } - response.name, err = dnsname.ToFQDN(res.NS.String()) - if err != nil { - return response, err - } - default: - return response, errors.New("type not in {A, AAAA, NS}") } return response, nil @@ -340,7 +426,7 @@ func TestDelegate(t *testing.T) { // support these sizes of response without truncation because they are // moderately common. medTXT := generateTXT(1200, randSource) - largeTXT := generateTXT(4000, randSource) + largeTXT := generateTXT(3900, randSource) // xlargeTXT is slightly above the maximum response size that we support, // so there should be truncation. @@ -351,23 +437,20 @@ func TestDelegate(t *testing.T) { // intend to handle responses this large, so there should be truncation. hugeTXT := generateTXT(64000, randSource) - v4server := serveDNS(t, "127.0.0.1:0", - "test.site.", resolveToIP(testipv4, testipv6, "dns.test.site."), + records := []interface{}{ + "test.site.", + resolveToIP(testipv4, testipv6, "dns.test.site."), "nxdomain.site.", resolveToNXDOMAIN, - "small.txt.", resolveToTXT(smallTXT), - "med.txt.", resolveToTXT(medTXT), - "large.txt.", resolveToTXT(largeTXT), - "xlarge.txt.", resolveToTXT(xlargeTXT), - "huge.txt.", resolveToTXT(hugeTXT)) + "small.txt.", resolveToTXT(smallTXT, noEdns), + "smalledns.txt.", resolveToTXT(smallTXT, 512), + "med.txt.", resolveToTXT(medTXT, 1500), + "large.txt.", resolveToTXT(largeTXT, maxResponseBytes), + "xlarge.txt.", resolveToTXT(xlargeTXT, 8000), + "huge.txt.", resolveToTXT(hugeTXT, 65527), + } + v4server := serveDNS(t, "127.0.0.1:0", records...) defer v4server.Shutdown() - v6server := serveDNS(t, "[::1]:0", - "test.site.", resolveToIP(testipv4, testipv6, "dns.test.site."), - "nxdomain.site.", resolveToNXDOMAIN, - "small.txt.", resolveToTXT(smallTXT), - "med.txt.", resolveToTXT(medTXT), - "large.txt.", resolveToTXT(largeTXT), - "xlarge.txt.", resolveToTXT(xlargeTXT), - "huge.txt.", resolveToTXT(hugeTXT)) + v6server := serveDNS(t, "[::1]:0", records...) defer v6server.Shutdown() r := newResolver(t) @@ -389,48 +472,84 @@ func TestDelegate(t *testing.T) { }{ { "ipv4", - dnspacket("test.site.", dns.TypeA), + dnspacket("test.site.", dns.TypeA, noEdns), dnsResponse{ip: testipv4, rcode: dns.RCodeSuccess}, }, { "ipv6", - dnspacket("test.site.", dns.TypeAAAA), + dnspacket("test.site.", dns.TypeAAAA, noEdns), dnsResponse{ip: testipv6, rcode: dns.RCodeSuccess}, }, { "ns", - dnspacket("test.site.", dns.TypeNS), + dnspacket("test.site.", dns.TypeNS, noEdns), dnsResponse{name: "dns.test.site.", rcode: dns.RCodeSuccess}, }, { "nxdomain", - dnspacket("nxdomain.site.", dns.TypeA), + dnspacket("nxdomain.site.", dns.TypeA, noEdns), dnsResponse{rcode: dns.RCodeNameError}, }, { "smalltxt", - dnspacket("small.txt.", dns.TypeTXT), - dnsResponse{txt: smallTXT, rcode: dns.RCodeSuccess}, + dnspacket("small.txt.", dns.TypeTXT, 8000), + dnsResponse{txt: smallTXT, rcode: dns.RCodeSuccess, requestEdns: true, requestEdnsSize: maxResponseBytes}, + }, + { + "smalltxtedns", + dnspacket("smalledns.txt.", dns.TypeTXT, 512), + dnsResponse{ + txt: smallTXT, + rcode: dns.RCodeSuccess, + requestEdns: true, + requestEdnsSize: 512, + responseEdns: true, + responseEdnsSize: 512, + }, }, { "medtxt", - dnspacket("med.txt.", dns.TypeTXT), - dnsResponse{txt: medTXT, rcode: dns.RCodeSuccess}, + dnspacket("med.txt.", dns.TypeTXT, 2000), + dnsResponse{ + txt: medTXT, + rcode: dns.RCodeSuccess, + requestEdns: true, + requestEdnsSize: 2000, + responseEdns: true, + responseEdnsSize: 1500, + }, }, { "largetxt", - dnspacket("large.txt.", dns.TypeTXT), - dnsResponse{txt: largeTXT, rcode: dns.RCodeSuccess}, + dnspacket("large.txt.", dns.TypeTXT, maxResponseBytes), + dnsResponse{ + txt: largeTXT, + rcode: dns.RCodeSuccess, + requestEdns: true, + requestEdnsSize: maxResponseBytes, + responseEdns: true, + responseEdnsSize: maxResponseBytes, + }, }, { "xlargetxt", - dnspacket("xlarge.txt.", dns.TypeTXT), - dnsResponse{rcode: dns.RCodeSuccess, truncated: true}, + dnspacket("xlarge.txt.", dns.TypeTXT, 8000), + dnsResponse{ + rcode: dns.RCodeSuccess, + truncated: true, + // request/response EDNS fields will be unset because of + // they were truncated away + }, }, { "hugetxt", - dnspacket("huge.txt.", dns.TypeTXT), - dnsResponse{rcode: dns.RCodeSuccess, truncated: true}, + dnspacket("huge.txt.", dns.TypeTXT, 8000), + dnsResponse{ + rcode: dns.RCodeSuccess, + truncated: true, + // request/response EDNS fields will be unset because of + // they were truncated away + }, }, } @@ -467,6 +586,18 @@ func TestDelegate(t *testing.T) { } } } + if response.requestEdns != tt.response.requestEdns { + t.Errorf("requestEdns = %v; want %v", response.requestEdns, tt.response.requestEdns) + } + if response.requestEdnsSize != tt.response.requestEdnsSize { + t.Errorf("requestEdnsSize = %v; want %v", response.requestEdnsSize, tt.response.requestEdnsSize) + } + if response.responseEdns != tt.response.responseEdns { + t.Errorf("responseEdns = %v; want %v", response.requestEdns, tt.response.requestEdns) + } + if response.responseEdnsSize != tt.response.responseEdnsSize { + t.Errorf("responseEdnsSize = %v; want %v", response.responseEdnsSize, tt.response.responseEdnsSize) + } }) } } @@ -499,12 +630,12 @@ func TestDelegateSplitRoute(t *testing.T) { }{ { "general", - dnspacket("test.site.", dns.TypeA), + dnspacket("test.site.", dns.TypeA, noEdns), dnsResponse{ip: testipv4, rcode: dns.RCodeSuccess}, }, { "override", - dnspacket("test.other.", dns.TypeA), + dnspacket("test.other.", dns.TypeA, noEdns), dnsResponse{ip: test4, rcode: dns.RCodeSuccess}, }, } @@ -561,7 +692,7 @@ func TestDelegateCollision(t *testing.T) { // packets will have the same dns txid. for _, p := range packets { - payload := dnspacket(p.qname, p.qtype) + payload := dnspacket(p.qname, p.qtype, noEdns) err := r.EnqueueRequest(payload, p.addr) if err != nil { t.Error(err) @@ -764,15 +895,15 @@ func TestFull(t *testing.T) { request []byte response []byte }{ - {"all", dnspacket("test1.ipn.dev.", dns.TypeALL), allResponse}, - {"ipv4", dnspacket("test1.ipn.dev.", dns.TypeA), ipv4Response}, - {"ipv6", dnspacket("test2.ipn.dev.", dns.TypeAAAA), ipv6Response}, - {"no-ipv6", dnspacket("test1.ipn.dev.", dns.TypeAAAA), emptyResponse}, - {"upper", dnspacket("TEST1.IPN.DEV.", dns.TypeA), ipv4UppercaseResponse}, - {"ptr4", dnspacket("4.3.2.1.in-addr.arpa.", dns.TypePTR), ptrResponse}, + {"all", dnspacket("test1.ipn.dev.", dns.TypeALL, noEdns), allResponse}, + {"ipv4", dnspacket("test1.ipn.dev.", dns.TypeA, noEdns), ipv4Response}, + {"ipv6", dnspacket("test2.ipn.dev.", dns.TypeAAAA, noEdns), ipv6Response}, + {"no-ipv6", dnspacket("test1.ipn.dev.", dns.TypeAAAA, noEdns), emptyResponse}, + {"upper", dnspacket("TEST1.IPN.DEV.", dns.TypeA, noEdns), ipv4UppercaseResponse}, + {"ptr4", dnspacket("4.3.2.1.in-addr.arpa.", dns.TypePTR, noEdns), ptrResponse}, {"ptr6", dnspacket("f.0.e.0.d.0.c.0.b.0.a.0.9.0.8.0.7.0.6.0.5.0.4.0.3.0.2.0.1.0.0.0.ip6.arpa.", - dns.TypePTR), ptrResponse6}, - {"nxdomain", dnspacket("test3.ipn.dev.", dns.TypeA), nxdomainResponse}, + dns.TypePTR, noEdns), ptrResponse6}, + {"nxdomain", dnspacket("test3.ipn.dev.", dns.TypeA, noEdns), nxdomainResponse}, } for _, tt := range tests { @@ -801,9 +932,9 @@ func TestAllocs(t *testing.T) { want int }{ // Name lowercasing and response slice created by dns.NewBuilder. - {"forward", dnspacket("test1.ipn.dev.", dns.TypeA), 2}, + {"forward", dnspacket("test1.ipn.dev.", dns.TypeA, noEdns), 2}, // 3 extra allocs in rdnsNameToIPv4 and one in marshalPTRRecord (dns.NewName). - {"reverse", dnspacket("4.3.2.1.in-addr.arpa.", dns.TypePTR), 5}, + {"reverse", dnspacket("4.3.2.1.in-addr.arpa.", dns.TypePTR, noEdns), 5}, } for _, tt := range tests { @@ -857,9 +988,9 @@ func BenchmarkFull(b *testing.B) { name string request []byte }{ - {"forward", dnspacket("test1.ipn.dev.", dns.TypeA)}, - {"reverse", dnspacket("4.3.2.1.in-addr.arpa.", dns.TypePTR)}, - {"delegated", dnspacket("test.site.", dns.TypeA)}, + {"forward", dnspacket("test1.ipn.dev.", dns.TypeA, noEdns)}, + {"reverse", dnspacket("4.3.2.1.in-addr.arpa.", dns.TypePTR, noEdns)}, + {"delegated", dnspacket("test.site.", dns.TypeA, noEdns)}, } for _, tt := range tests {