diff --git a/netcheck/netcheck.go b/netcheck/netcheck.go index 5338c2d8c..fa97e8995 100644 --- a/netcheck/netcheck.go +++ b/netcheck/netcheck.go @@ -13,6 +13,7 @@ import ( "io" "log" "net" + "sort" "sync" "time" @@ -74,6 +75,7 @@ type Client struct { mu sync.Mutex // guards following prev map[time.Time]*Report // some previous reports + last *Report // most recent report s4 *stunner.Stunner s6 *stunner.Stunner hairTX stun.TxID @@ -140,6 +142,71 @@ func (c *Client) ReceiveSTUNPacket(pkt []byte, src *net.UDPAddr) { } } +// pickSubset selects a subset of IPv4 and IPv6 STUN server addresses +// to hit based on history. +// +// maxTries is the max number of tries per server. +// +// The caller owns the returned values. +func (c *Client) pickSubset() (stuns4, stuns6 []string, maxTries map[string]int, err error) { + c.mu.Lock() + defer c.mu.Unlock() + + const defaultMaxTries = 2 + maxTries = map[string]int{} + + var prev4, prev6 []string // sorted fastest to slowest + if c.last != nil { + condAppend := func(dst []string, server string) []string { + if server != "" && c.last.DERPLatency[server] != 0 { + return append(dst, server) + } + return dst + } + c.DERP.ForeachServer(func(s *derpmap.Server) { + prev4 = condAppend(prev4, s.STUN4) + prev6 = condAppend(prev6, s.STUN6) + }) + sort.Slice(prev4, func(i, j int) bool { return c.last.DERPLatency[prev4[i]] < c.last.DERPLatency[prev4[j]] }) + sort.Slice(prev6, func(i, j int) bool { return c.last.DERPLatency[prev6[i]] < c.last.DERPLatency[prev6[j]] }) + } + + c.DERP.ForeachServer(func(s *derpmap.Server) { + if s.STUN4 == "" { + return + } + // STUN against all DERP's IPv4 endpoints, but + // if the previous report had results from + // more than 2 servers, only do 1 try against + // all but the first two. + stuns4 = append(stuns4, s.STUN4) + tries := defaultMaxTries + if len(prev4) > 2 && !stringsContains(prev4[:2], s.STUN4) { + tries = 1 + } + maxTries[s.STUN4] = tries + if s.STUN6 != "" && tries == defaultMaxTries { + // For IPv6, we mostly care whether the user has IPv6 at all, + // so we don't need to send to all servers. The IPv4 timing + // information is enough for now. (We don't yet support IPv6-only) + // So only add the two fastest ones, or all if this is a fresh one. + stuns6 = append(stuns6, s.STUN6) + maxTries[s.STUN6] = 1 + } + }) + + if len(stuns4) == 0 { + // TODO: make this work? if we ever need it + // to. Requirement for self-hosted Tailscale might be + // to run a DERP+STUN server co-resident with the + // Control server. + return nil, nil, nil, errors.New("netcheck: GetReport: no STUN servers, no Report") + } + sort.Strings(stuns4) + sort.Strings(stuns6) + return stuns4, stuns6, maxTries, nil +} + // GetReport gets a report. // // It may not be called concurrently with itself. @@ -173,24 +240,9 @@ func (c *Client) GetReport(ctx context.Context) (*Report, error) { c.gotHairSTUN = nil }() - stuns4 := c.DERP.STUN4() - stuns6 := c.DERP.STUN6() - if len(stuns4) == 0 { - // TODO: make this work? if we ever need it - // to. Requirement for self-hosted Tailscale might be - // to run a DERP+STUN server co-resident with the - // Control server. - return nil, errors.New("netcheck: GetReport: no STUN servers, no Report") - } - for _, s := range stuns4 { - if _, _, err := net.SplitHostPort(s); err != nil { - return nil, fmt.Errorf("netcheck: GetReport: bogus STUN4 server %q", s) - } - } - for _, s := range stuns6 { - if _, _, err := net.SplitHostPort(s); err != nil { - return nil, fmt.Errorf("netcheck: GetReport: bogus STUN6 server %q", s) - } + stuns4, stuns6, maxTries, err := c.pickSubset() + if err != nil { + return nil, err } closeOnCtx := func(c io.Closer) { @@ -330,6 +382,7 @@ func (c *Client) GetReport(ctx context.Context) (*Report, error) { Servers: stuns4, Logf: c.logf, DNSCache: dnscache.Get(), + MaxTries: maxTries, } c.mu.Lock() @@ -358,6 +411,7 @@ func (c *Client) GetReport(ctx context.Context) (*Report, error) { Logf: c.logf, OnlyIPv6: true, DNSCache: dnscache.Get(), + MaxTries: maxTries, } c.mu.Lock() @@ -469,6 +523,7 @@ func (c *Client) addReportHistoryAndSetPreferredDERP(r *Report) { } now := c.timeNow() c.prev[now] = r + c.last = r const maxAge = 5 * time.Minute @@ -498,3 +553,12 @@ func (c *Client) addReportHistoryAndSetPreferredDERP(r *Report) { } } } + +func stringsContains(ss []string, s string) bool { + for _, v := range ss { + if s == v { + return true + } + } + return false +} diff --git a/netcheck/netcheck_test.go b/netcheck/netcheck_test.go index 110ace79a..6e9894df9 100644 --- a/netcheck/netcheck_test.go +++ b/netcheck/netcheck_test.go @@ -211,3 +211,82 @@ func TestAddReportHistoryAndSetPreferredDERP(t *testing.T) { }) } } + +func TestPickSubset(t *testing.T) { + derps := derpmap.NewTestWorldWith( + &derpmap.Server{ + ID: 1, + STUN4: "d1:4", + STUN6: "d1:6", + }, + &derpmap.Server{ + ID: 2, + STUN4: "d2:4", + STUN6: "d2:6", + }, + &derpmap.Server{ + ID: 3, + STUN4: "d3:4", + STUN6: "d3:6", + }, + ) + tests := []struct { + name string + last *Report + want4 []string + want6 []string + wantTries map[string]int + }{ + { + name: "fresh", + last: nil, + want4: []string{"d1:4", "d2:4", "d3:4"}, + want6: []string{"d1:6", "d2:6", "d3:6"}, + wantTries: map[string]int{ + "d1:4": 2, + "d2:4": 2, + "d3:4": 2, + "d1:6": 1, + "d2:6": 1, + "d3:6": 1, + }, + }, + { + name: "1_and_3_closest", + last: &Report{ + DERPLatency: map[string]time.Duration{ + "d1:4": 15 * time.Millisecond, + "d2:4": 300 * time.Millisecond, + "d3:4": 25 * time.Millisecond, + }, + }, + want4: []string{"d1:4", "d2:4", "d3:4"}, + want6: []string{"d1:6", "d3:6"}, + wantTries: map[string]int{ + "d1:4": 2, + "d3:4": 2, + "d2:4": 1, + "d1:6": 1, + "d3:6": 1, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &Client{DERP: derps, last: tt.last} + got4, got6, gotTries, err := c.pickSubset() + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(got4, tt.want4) { + t.Errorf("stuns4 = %q; want %q", got4, tt.want4) + } + if !reflect.DeepEqual(got6, tt.want6) { + t.Errorf("stuns6 = %q; want %q", got6, tt.want6) + } + if !reflect.DeepEqual(gotTries, tt.wantTries) { + t.Errorf("tries = %v; want %v", gotTries, tt.wantTries) + } + }) + } +}