diff --git a/net/dns/resolver/doh_test.go b/net/dns/resolver/doh_test.go index e283f32e0..621ccd1fd 100644 --- a/net/dns/resolver/doh_test.go +++ b/net/dns/resolver/doh_test.go @@ -81,3 +81,16 @@ func TestDoH(t *testing.T) { }) } } + +func TestDoHV6Fallback(t *testing.T) { + for ip, base := range knownDoH { + if ip.Is4() { + ip6, ok := dohV6(base) + if !ok { + t.Errorf("no v6 DoH known for %v", ip) + } else if !ip6.Is6() { + t.Errorf("dohV6(%q) returned non-v6 address %v", base, ip6) + } + } + } +} diff --git a/net/dns/resolver/forwarder.go b/net/dns/resolver/forwarder.go index c836ff672..557f4522f 100644 --- a/net/dns/resolver/forwarder.go +++ b/net/dns/resolver/forwarder.go @@ -243,7 +243,16 @@ func (f *forwarder) getDoHClient(ip netaddr.IP) (urlBase string, c *http.Client, if !strings.HasPrefix(netw, "tcp") { return nil, fmt.Errorf("unexpected network %q", netw) } - return nsDialer.DialContext(ctx, "tcp", net.JoinHostPort(ip.String(), "443")) + c, err := nsDialer.DialContext(ctx, "tcp", net.JoinHostPort(ip.String(), "443")) + // If v4 failed, try an equivalent v6 also in the time remaining. + if err != nil && ctx.Err() == nil { + if ip6, ok := dohV6(urlBase); ok && ip.Is4() { + if c6, err := nsDialer.DialContext(ctx, "tcp", net.JoinHostPort(ip6.String(), "443")); err == nil { + return c6, nil + } + } + } + return c, err }, }, } @@ -509,7 +518,22 @@ func (p *closePool) Close() error { var knownDoH = map[netaddr.IP]string{} -func addDoH(ip, base string) { knownDoH[netaddr.MustParseIP(ip)] = base } +var dohIPsOfBase = map[string][]netaddr.IP{} + +func addDoH(ipStr, base string) { + ip := netaddr.MustParseIP(ipStr) + knownDoH[ip] = base + dohIPsOfBase[base] = append(dohIPsOfBase[base], ip) +} + +func dohV6(base string) (ip netaddr.IP, ok bool) { + for _, ip := range dohIPsOfBase[base] { + if ip.Is6() { + return ip, true + } + } + return ip, false +} func init() { // Cloudflare