diff --git a/wgengine/magicsock/magicsock.go b/wgengine/magicsock/magicsock.go index f315e88b6..b11f30b3b 100644 --- a/wgengine/magicsock/magicsock.go +++ b/wgengine/magicsock/magicsock.go @@ -2765,6 +2765,7 @@ func (c *Conn) SetNetworkMap(nm *netmap.NetworkMap) { sentPing: map[stun.TxID]sentPing{}, endpointState: map[netip.AddrPort]*endpointState{}, heartbeatDisabled: heartbeatDisabled, + isWireguardOnly: n.IsWireGuardOnly, } if len(n.Addresses) > 0 { ep.nodeAddr = n.Addresses[0].Addr() @@ -4084,9 +4085,14 @@ type endpointDisco struct { short string // ShortString of discoKey. } -// endpoint is a wireguard/conn.Endpoint that picks the best -// available path to communicate with a peer, based on network -// conditions and what the peer supports. +// endpoint is a wireguard/conn.Endpoint. In wireguard-go and kernel WireGuard +// there is only one endpoint for a peer, but in Tailscale we distribute a +// number of possible endpoints for a peer which would include the all the +// likely addresses at which a peer may be reachable. This endpoint type holds +// the information required that when WiregGuard-Go wants to send to a +// particular peer (essentally represented by this endpoint type), the send +// function can use the currnetly best known Tailscale endpoint to send packets +// to the peer. type endpoint struct { // atomically accessed; declared first for alignment reasons lastRecv mono.Time @@ -4126,7 +4132,8 @@ type endpoint struct { heartbeatDisabled bool pathFinderRunning bool - expired bool // whether the node has expired + expired bool // whether the node has expired + isWireguardOnly bool // whether the node is a pure wireguard node } type pendingCLIPing struct { @@ -4238,6 +4245,15 @@ func (st *endpointState) shouldDeleteLocked() bool { } } +// latencyLocked returns the most recent latency measurement, if any. +// endpoint.mu must be held. +func (st *endpointState) latencyLocked() (lat time.Duration, ok bool) { + if len(st.recentPongs) == 0 { + return 0, false + } + return st.recentPongs[st.recentPong].latency, true +} + func (de *endpoint) deleteEndpointLocked(why string, ep netip.AddrPort) { de.debugUpdates.Add(EndpointChange{ When: time.Now(), @@ -4326,12 +4342,64 @@ func (de *endpoint) DstToBytes() []byte { return packIPPort(de.fakeWGAddr) } // de.mu must be held. func (de *endpoint) addrForSendLocked(now mono.Time) (udpAddr, derpAddr netip.AddrPort) { udpAddr = de.bestAddr.AddrPort - if !udpAddr.IsValid() || now.After(de.trustBestAddrUntil) { - // We had a bestAddr but it expired so send both to it - // and DERP. - derpAddr = de.derpAddr + + if udpAddr.IsValid() && !now.After(de.trustBestAddrUntil) { + return udpAddr, netip.AddrPort{} } - return + + if de.isWireguardOnly { + // Attempt to find the endpoint with the lowest recorded latency, + // but if all endpoint latencies are zero, then pick one at random + // instead. + + // lowestLatency is a high duration initially, so we + // can be sure we're going to have a duration lower than this + // for the first latency retrieved. + lowestLatency := time.Hour + for ipp, state := range de.endpointState { + if latency, ok := state.latencyLocked(); ok { + if latency < lowestLatency || latency == lowestLatency && ipp.Addr().Is6() { + // If we have the same latency, we should priortize IPv6. + // TODO(catzkorn): We should consider a buffer time where + // we accept a small increase in latency in return for + // IPv6 priority. + lowestLatency = latency + udpAddr = ipp + } + } + } + + if udpAddr.IsValid() { + // QUESTION: Are we setting the long timed trustBestAddrUntil here or somewhere else? + // Think its here, just wanna check + de.trustBestAddrUntil = now.Add(1 * time.Hour) + return udpAddr, netip.AddrPort{} + } + + // If we have not yet retrieved latency information for an address, + // we should choose the first one we expect to work. + for ipp := range de.endpointState { + if ipp.Addr().Is4() && de.c.noV4.Load() { + continue + } + if ipp.Addr().Is6() && de.c.noV6.Load() { + continue + } + if ipp.Addr().Is6() { + // If we can use IPv6, we should, so break early. + udpAddr = ipp + break + } + udpAddr = ipp + } + + de.bestAddr.AddrPort = udpAddr + return udpAddr, netip.AddrPort{} + } + + // We had a bestAddr but it expired so send both to it + // and DERP. + return udpAddr, de.derpAddr } // heartbeat is called every heartbeatInterval to keep the best UDP path alive, diff --git a/wgengine/magicsock/magicsock_test.go b/wgengine/magicsock/magicsock_test.go index c050cfc20..eccdf358a 100644 --- a/wgengine/magicsock/magicsock_test.go +++ b/wgengine/magicsock/magicsock_test.go @@ -23,6 +23,7 @@ import ( "strconv" "strings" "sync" + "sync/atomic" "testing" "time" "unsafe" @@ -47,6 +48,7 @@ import ( "tailscale.com/tailcfg" "tailscale.com/tstest" "tailscale.com/tstest/natlab" + "tailscale.com/tstime/mono" "tailscale.com/types/key" "tailscale.com/types/logger" "tailscale.com/types/netlogtype" @@ -2393,3 +2395,240 @@ func TestEndpointTracker(t *testing.T) { } } } + +func TestAddrForSendLockedForWireGuardOnly(t *testing.T) { + t.Run("choose lowest latency for useable IPv4 and IPv6", func(t *testing.T) { + testTime := mono.Now() + + endpointDetails := []struct { + addrPort netip.AddrPort + latency time.Duration + }{ + { + addrPort: netip.MustParseAddrPort("1.1.1.1:111"), + latency: 100 * time.Millisecond, + }, + { + addrPort: netip.MustParseAddrPort("[2345:0425:2CA1:0000:0000:0567:5673:23b5]:222"), + latency: 10 * time.Millisecond, + }, + } + want := endpointDetails[1].addrPort + + endpoint := &endpoint{ + isWireguardOnly: true, + endpointState: map[netip.AddrPort]*endpointState{}, + c: &Conn{ + noV4: atomic.Bool{}, + noV6: atomic.Bool{}, + }, + } + + for _, epd := range endpointDetails { + endpoint.endpointState[epd.addrPort] = &endpointState{} + } + + udpAddr, _ := endpoint.addrForSendLocked(testTime) + if udpAddr != want { + t.Errorf("returned udpAddr chosen without latency info is not expected: got %v, want %v", udpAddr, want) + } + + for _, epd := range endpointDetails { + state, ok := endpoint.endpointState[epd.addrPort] + if !ok { + t.Errorf("addr does not exist in endpoint state map") + } + + latency, ok := state.latencyLocked() + if ok { + t.Errorf("latency was set for %v: %v", epd.addrPort, latency) + } + state.recentPongs = append(state.recentPongs, pongReply{ + latency: epd.latency, + }) + state.recentPong = 0 + } + + udpAddr, _ = endpoint.addrForSendLocked(testTime.Add(2 * time.Minute)) + if udpAddr != want { + t.Errorf("udpAddr returned is not expected: got %v, want %v", udpAddr, want) + } + }) + + t.Run("choose IPv4 when IPv6 is not useable", func(t *testing.T) { + testTime := mono.Now() + endpointDetails := []struct { + addrPort netip.AddrPort + latency time.Duration + }{ + { + addrPort: netip.MustParseAddrPort("1.1.1.1:111"), + latency: 100 * time.Millisecond, + }, + { + addrPort: netip.MustParseAddrPort("[1::1]:567"), + }, + } + want := endpointDetails[0].addrPort + + endpoint := &endpoint{ + isWireguardOnly: true, + endpointState: map[netip.AddrPort]*endpointState{}, + c: &Conn{ + noV4: atomic.Bool{}, + noV6: atomic.Bool{}, + }, + } + endpoint.c.noV6.Store(true) + + for _, epd := range endpointDetails { + endpoint.endpointState[epd.addrPort] = &endpointState{} + } + + udpAddr, _ := endpoint.addrForSendLocked(testTime) + if udpAddr != want { + t.Errorf("returned udpAddr chosen without latency info is not expected: got %v, want %v", udpAddr, want) + } + + for _, epd := range endpointDetails { + state, ok := endpoint.endpointState[epd.addrPort] + if !ok { + t.Errorf("addr does not exist in endpoint state map") + } + + latency, ok := state.latencyLocked() + if ok { + t.Errorf("latency was set for %v: %v", epd.addrPort, latency) + } + if epd.latency != 0 { + state.recentPongs = append(state.recentPongs, pongReply{ + latency: epd.latency, + }) + state.recentPong = 0 + } + } + + udpAddr, _ = endpoint.addrForSendLocked(testTime.Add(2 * time.Minute)) + if udpAddr != want { + t.Errorf("udpAddr returned is not expected: got %v, want %v", udpAddr, want) + } + }) + + t.Run("choose IPv6 when IPv4 is not useable", func(t *testing.T) { + testTime := mono.Now() + endpointDetails := []struct { + addrPort netip.AddrPort + latency time.Duration + }{ + { + addrPort: netip.MustParseAddrPort("1.1.1.1:111"), + }, + { + addrPort: netip.MustParseAddrPort("[1::1]:567"), + latency: 100 * time.Millisecond, + }, + } + want := endpointDetails[1].addrPort + + endpoint := &endpoint{ + isWireguardOnly: true, + endpointState: map[netip.AddrPort]*endpointState{}, + c: &Conn{ + noV4: atomic.Bool{}, + noV6: atomic.Bool{}, + }, + } + endpoint.c.noV4.Store(true) + + for _, epd := range endpointDetails { + endpoint.endpointState[epd.addrPort] = &endpointState{} + } + + udpAddr, _ := endpoint.addrForSendLocked(testTime) + if udpAddr != want { + t.Errorf("returned udpAddr chosen without latency info is not expected: got %v, want %v", udpAddr, want) + } + + for _, epd := range endpointDetails { + state, ok := endpoint.endpointState[epd.addrPort] + if !ok { + t.Errorf("addr does not exist in endpoint state map") + } + + latency, ok := state.latencyLocked() + if ok { + t.Errorf("latency was set for %v: %v", epd.addrPort, latency) + } + if epd.latency != 0 { + state.recentPongs = append(state.recentPongs, pongReply{ + latency: epd.latency, + }) + state.recentPong = 0 + } + } + + udpAddr, _ = endpoint.addrForSendLocked(testTime.Add(2 * time.Minute)) + if udpAddr != want { + t.Errorf("udpAddr returned is not expected: got %v, want %v", udpAddr, want) + } + }) + + t.Run("choose IPv6 address when latency is the same for v4 and v6", func(t *testing.T) { + testTime := mono.Now() + endpointDetails := []struct { + addrPort netip.AddrPort + latency time.Duration + }{ + { + addrPort: netip.MustParseAddrPort("1.1.1.1:111"), + latency: 100 * time.Millisecond, + }, + { + addrPort: netip.MustParseAddrPort("[1::1]:567"), + latency: 100 * time.Millisecond, + }, + } + want := endpointDetails[1].addrPort + + endpoint := &endpoint{ + isWireguardOnly: true, + endpointState: map[netip.AddrPort]*endpointState{}, + c: &Conn{ + noV4: atomic.Bool{}, + noV6: atomic.Bool{}, + }, + } + + for _, epd := range endpointDetails { + endpoint.endpointState[epd.addrPort] = &endpointState{} + } + + udpAddr, _ := endpoint.addrForSendLocked(testTime) + if udpAddr != want { + t.Errorf("returned udpAddr chosen without latency info is not expected: got %v, want %v", udpAddr, want) + } + + for _, epd := range endpointDetails { + state, ok := endpoint.endpointState[epd.addrPort] + if !ok { + t.Errorf("addr does not exist in endpoint state map") + } + + latency, ok := state.latencyLocked() + if ok { + t.Errorf("latency was set for %v: %v", epd.addrPort, latency) + } + if epd.latency != 0 { + state.recentPongs = append(state.recentPongs, pongReply{ + latency: epd.latency, + }) + state.recentPong = 0 + } + } + + udpAddr, _ = endpoint.addrForSendLocked(testTime.Add(2 * time.Minute)) + if udpAddr != want { + t.Errorf("udpAddr returned is not expected: got %v, want %v", udpAddr, want) + } + }) +}