From cd6ccb5104c42ba08233327bf42824fa20af02b3 Mon Sep 17 00:00:00 2001 From: Val Date: Tue, 6 Jun 2023 12:38:52 +0200 Subject: [PATCH] magicsock: add multiple sizes of ping/pong packets to probe peer mtu Proof of concept, doesn't actually use the results of the probe. Signed-off-by: Val --- wgengine/magicsock/magicsock.go | 130 ++++++++++++++++++--------- wgengine/magicsock/magicsock_test.go | 18 ++-- wgengine/netstack/netstack.go | 1 + 3 files changed, 103 insertions(+), 46 deletions(-) diff --git a/wgengine/magicsock/magicsock.go b/wgengine/magicsock/magicsock.go index dea8b2d97..5874b9f35 100644 --- a/wgengine/magicsock/magicsock.go +++ b/wgengine/magicsock/magicsock.go @@ -2055,13 +2055,39 @@ const ( // speeds. var debugIPv4DiscoPingPenalty = envknob.RegisterDuration("TS_DISCO_PONG_IPV4_DELAY") +// Peer MTU notes: +// +// The general concept is to send ~5 different sizes of ping/pong +// packets between clients. The largest one that comes back defines +// our MTU for communicating with this peer via this endpoint. This is +// accomplished by teaching betterAddr() about pong packet length and +// having it prefer the larger mtu for the same addr/port +// endpoint. This prototype adds a msgLen parameter to the relevant +// disco routines because it was easy. +// +// TODO: Actually use the MTU when sending packets to our peer via +// this endpoint. +// +// TODO: Do the math to convert MTU <-> the msgLen parameter, if +// necessary. I think it is necessary? + +// usefulMtus are the set of likely on-the-wire MTUs (including all the +// layers of protocal headers above link layer) +var usefulMtus = [...]int{ + 576, // Smallest MTU for IPv4, probably useless? + 1124, // An observed max mtu in the wild, maybe 1100 instead? + 1280, // Smallest MTU for IPv6, current default + 1500, // Most common real world MTU + 9000, // Most jumbo frames are this size or slightly larger +} + // sendDiscoMessage sends discovery message m to dstDisco at dst. // // If dst is a DERP IP:port, then dstKey must be non-zero. // // The dstKey should only be non-zero if the dstDisco key // unambiguously maps to exactly one peer. -func (c *Conn) sendDiscoMessage(dst netip.AddrPort, dstKey key.NodePublic, dstDisco key.DiscoPublic, m disco.Message, logLevel discoLogLevel) (sent bool, err error) { +func (c *Conn) sendDiscoMessage(dst netip.AddrPort, dstKey key.NodePublic, dstDisco key.DiscoPublic, m disco.Message, logLevel discoLogLevel, msgLen int) (sent bool, err error) { isDERP := dst.Addr() == derpMagicIPAddr if _, isPong := m.(*disco.Pong); isPong && !isDERP && dst.Addr().Is4() { time.Sleep(debugIPv4DiscoPingPenalty()) @@ -2076,7 +2102,11 @@ func (c *Conn) sendDiscoMessage(dst netip.AddrPort, dstKey key.NodePublic, dstDi if _, err := crand.Read(nonce[:]); err != nil { panic(err) // worth dying for } - pkt := make([]byte, 0, 512) // TODO: size it correctly? pool? if it matters. + // This is the previous default size + if msgLen == 0 { + msgLen = 512 + } + pkt := make([]byte, 0, msgLen) pkt = append(pkt, disco.Magic...) pkt = c.discoPublic.AppendTo(pkt) di := c.discoInfoLocked(dstDisco) @@ -2097,7 +2127,7 @@ func (c *Conn) sendDiscoMessage(dst netip.AddrPort, dstKey key.NodePublic, dstDi if !dstKey.IsZero() { node = dstKey.ShortString() } - c.dlogf("[v1] magicsock: disco: %v->%v (%v, %v) sent %v", c.discoShort, dstDisco.ShortString(), node, derpStr(dst.String()), disco.MessageSummary(m)) + c.dlogf("[v1] magicsock: disco: %v->%v (%v, %v) sent %v len %v", c.discoShort, dstDisco.ShortString(), node, derpStr(dst.String()), disco.MessageSummary(m), msgLen) } if isDERP { metricSentDiscoDERP.Add(1) @@ -2174,7 +2204,8 @@ const ( // over UDP. func (c *Conn) handleDiscoMessage(msg []byte, src netip.AddrPort, derpNodeSrc key.NodePublic, via discoRXPath) (isDiscoMsg bool) { const headerLen = len(disco.Magic) + key.DiscoPublicRawLen - if len(msg) < headerLen || string(msg[:len(disco.Magic)]) != disco.Magic { + msgLen := len(msg) + if msgLen < headerLen || string(msg[:len(disco.Magic)]) != disco.Magic { return false } @@ -2193,7 +2224,7 @@ func (c *Conn) handleDiscoMessage(msg []byte, src netip.AddrPort, derpNodeSrc ke return } if debugDisco() { - c.logf("magicsock: disco: got disco-looking frame from %v via %s", sender.ShortString(), via) + c.logf("magicsock: disco: got disco-looking frame from %v via %s len %v", sender.ShortString(), via, msgLen) } if c.privateKey.IsZero() { // Ignore disco messages when we're stopped. @@ -2266,14 +2297,14 @@ func (c *Conn) handleDiscoMessage(msg []byte, src netip.AddrPort, derpNodeSrc ke switch dm := dm.(type) { case *disco.Ping: metricRecvDiscoPing.Add(1) - c.handlePingLocked(dm, src, di, derpNodeSrc) + c.handlePingLocked(dm, src, di, derpNodeSrc, msgLen) case *disco.Pong: metricRecvDiscoPong.Add(1) // There might be multiple nodes for the sender's DiscoKey. // Ask each to handle it, stopping once one reports that // the Pong's TxID was theirs. c.peerMap.forEachEndpointWithDiscoKey(sender, func(ep *endpoint) (keepGoing bool) { - if ep.handlePongConnLocked(dm, di, src) { + if ep.handlePongConnLocked(dm, di, src, msgLen) { return false } return true @@ -2350,7 +2381,7 @@ func (c *Conn) unambiguousNodeKeyOfPingLocked(dm *disco.Ping, dk key.DiscoPublic // di is the discoInfo of the source of the ping. // derpNodeSrc is non-zero if the ping arrived via DERP. -func (c *Conn) handlePingLocked(dm *disco.Ping, src netip.AddrPort, di *discoInfo, derpNodeSrc key.NodePublic) { +func (c *Conn) handlePingLocked(dm *disco.Ping, src netip.AddrPort, di *discoInfo, derpNodeSrc key.NodePublic, msgLen int) { likelyHeartBeat := src == di.lastPingFrom && time.Since(di.lastPingTime) < 5*time.Second di.lastPingFrom = src di.lastPingTime = time.Now() @@ -2381,14 +2412,14 @@ func (c *Conn) handlePingLocked(dm *disco.Ping, src netip.AddrPort, di *discoInf var dup bool if isDerp { if ep, ok := c.peerMap.endpointForNodeKey(derpNodeSrc); ok { - if ep.addCandidateEndpoint(src, dm.TxID) { + if ep.addCandidateEndpoint(src, dm.TxID, msgLen) { return } numNodes = 1 } } else { c.peerMap.forEachEndpointWithDiscoKey(di.discoKey, func(ep *endpoint) (keepGoing bool) { - if ep.addCandidateEndpoint(src, dm.TxID) { + if ep.addCandidateEndpoint(src, dm.TxID, msgLen) { dup = true return false } @@ -2426,7 +2457,7 @@ func (c *Conn) handlePingLocked(dm *disco.Ping, src netip.AddrPort, di *discoInf go c.sendDiscoMessage(ipDst, dstKey, discoDest, &disco.Pong{ TxID: dm.TxID, Src: src, - }, discoVerboseLog) + }, discoVerboseLog, msgLen) } // enqueueCallMeMaybe schedules a send of disco.CallMeMaybe to de via derpAddr @@ -2468,12 +2499,12 @@ func (c *Conn) enqueueCallMeMaybe(derpAddr netip.AddrPort, de *endpoint) { for _, ep := range c.lastEndpoints { eps = append(eps, ep.Addr) } - go de.c.sendDiscoMessage(derpAddr, de.publicKey, epDisco.key, &disco.CallMeMaybe{MyNumber: eps}, discoLog) + go de.c.sendDiscoMessage(derpAddr, de.publicKey, epDisco.key, &disco.CallMeMaybe{MyNumber: eps}, discoLog, 0) if debugSendCallMeUnknownPeer() { // Send a callMeMaybe packet to a non-existent peer unknownKey := key.NewNode().Public() c.logf("magicsock: sending CallMeMaybe to unknown peer per TS_DEBUG_SEND_CALLME_UNKNOWN_PEER") - go de.c.sendDiscoMessage(derpAddr, unknownKey, epDisco.key, &disco.CallMeMaybe{MyNumber: eps}, discoLog) + go de.c.sendDiscoMessage(derpAddr, unknownKey, epDisco.key, &disco.CallMeMaybe{MyNumber: eps}, discoLog, 0) } } @@ -4132,7 +4163,7 @@ type endpoint struct { lastFullPing mono.Time // last time we pinged all disco endpoints derpAddr netip.AddrPort // fallback/bootstrap path, if non-zero (non-zero for well-behaved clients) - bestAddr addrLatency // best non-DERP path; zero if none + bestAddr addrQuality // best non-DERP path; zero if none bestAddrAt mono.Time // time best address re-confirmed trustBestAddrUntil mono.Time // time when bestAddr expires sentPing map[stun.TxID]sentPing @@ -4239,7 +4270,8 @@ type endpointState struct { recentPongs []pongReply // ring buffer up to pongHistoryCount entries recentPong uint16 // index into recentPongs of most recent; older before, wrapped - index int16 // index in nodecfg.Node.Endpoints; meaningless if lastGotPing non-zero + index int16 // index in nodecfg.Node.Endpoints; meaningless if lastGotPing non-zero + msgLen int // max message that got through, for choosing mtu } // indexSentinelDeleted is the temporary value that endpointState.index takes while @@ -4282,7 +4314,7 @@ func (de *endpoint) deleteEndpointLocked(why string, ep netip.AddrPort) { What: "deleteEndpointLocked-bestAddr-" + why, From: de.bestAddr, }) - de.bestAddr = addrLatency{} + de.bestAddr = addrQuality{} } } @@ -4294,6 +4326,7 @@ type pongReply struct { pongAt mono.Time // when we received the pong from netip.AddrPort // the pong's src (usually same as endpoint map key) pongSrc netip.AddrPort // what they reported they heard + msgLen int // length of the pong packet, for choosing MTU } type sentPing struct { @@ -4301,6 +4334,7 @@ type sentPing struct { at mono.Time timer *time.Timer // timeout timer purpose discoPingPurpose + msgLen int } // initFakeUDPAddr populates fakeWGAddr with a globally unique fake UDPAddr. @@ -4642,11 +4676,11 @@ func (de *endpoint) removeSentDiscoPingLocked(txid stun.TxID, sp sentPing) { // // The caller should use de.discoKey as the discoKey argument. // It is passed in so that sendDiscoPing doesn't need to lock de.mu. -func (de *endpoint) sendDiscoPing(ep netip.AddrPort, discoKey key.DiscoPublic, txid stun.TxID, logLevel discoLogLevel) { +func (de *endpoint) sendDiscoPing(ep netip.AddrPort, discoKey key.DiscoPublic, txid stun.TxID, logLevel discoLogLevel, msgLen int) { sent, _ := de.c.sendDiscoMessage(ep, de.publicKey, discoKey, &disco.Ping{ TxID: [12]byte(txid), NodeKey: de.c.publicKeyAtomic.Load(), - }, logLevel) + }, logLevel, 0) if !sent { de.forgetDiscoPing(txid) } @@ -4689,18 +4723,24 @@ func (de *endpoint) startDiscoPingLocked(ep netip.AddrPort, now mono.Time, purpo st.lastPing = now } - txid := stun.NewTxID() - de.sentPing[txid] = sentPing{ - to: ep, - at: now, - timer: time.AfterFunc(pingTimeoutDuration, func() { de.discoPingTimeout(txid) }), - purpose: purpose, + // Send a bouquet of pings in different sizes to probe peer mtu + for mtu := range usefulMtus { + + txid := stun.NewTxID() + de.sentPing[txid] = sentPing{ + to: ep, + at: now, + timer: time.AfterFunc(pingTimeoutDuration, func() { de.discoPingTimeout(txid) }), + purpose: purpose, + msgLen: mtu, + } + logLevel := discoLog + if purpose == pingHeartbeat { + logLevel = discoVerboseLog + } + // XXX do math to convert mtu to msgLen somewhere + go de.sendDiscoPing(ep, epDisco.key, txid, logLevel, mtu) } - logLevel := discoLog - if purpose == pingHeartbeat { - logLevel = discoVerboseLog - } - go de.sendDiscoPing(ep, epDisco.key, txid, logLevel) } func (de *endpoint) sendDiscoPingsLocked(now mono.Time, sendCallMeMaybe bool) { @@ -4807,6 +4847,7 @@ func (de *endpoint) sendWireGuardOnlyPing(ipp netip.AddrPort, now mono.Time) { if !ok { return } + // TODO: figure out how to ignore msgLen for non-mtu discovery packets state.addPongReplyLocked(pongReply{ latency: latency, pongAt: now, @@ -4927,7 +4968,7 @@ func (de *endpoint) updateFromNode(n *tailcfg.Node, heartbeatDisabled bool) { // // This is called once we've already verified that we got a valid // discovery message from de via ep. -func (de *endpoint) addCandidateEndpoint(ep netip.AddrPort, forRxPingTxID stun.TxID) (duplicatePing bool) { +func (de *endpoint) addCandidateEndpoint(ep netip.AddrPort, forRxPingTxID stun.TxID, msgLen int) (duplicatePing bool) { de.mu.Lock() defer de.mu.Unlock() @@ -4936,6 +4977,7 @@ func (de *endpoint) addCandidateEndpoint(ep netip.AddrPort, forRxPingTxID stun.T if !duplicatePing { st.lastGotPingTxID = forRxPingTxID } + // TODO: update the max MTU instead if st.lastGotPing.IsZero() { // Already-known endpoint from the network map. return duplicatePing @@ -4945,10 +4987,11 @@ func (de *endpoint) addCandidateEndpoint(ep netip.AddrPort, forRxPingTxID stun.T } // Newly discovered endpoint. Exciting! - de.c.dlogf("[v1] magicsock: disco: adding %v as candidate endpoint for %v (%s)", ep, de.discoShort(), de.publicKey.ShortString()) + de.c.dlogf("[v1] magicsock: disco: adding %v as candidate endpoint for %v (%s) len %v", ep, de.discoShort(), de.publicKey.ShortString(), msgLen) de.endpointState[ep] = &endpointState{ lastGotPing: time.Now(), lastGotPingTxID: forRxPingTxID, + msgLen: msgLen, // TODO calculate MTU } // If for some reason this gets very large, do some cleanup. @@ -4978,7 +5021,7 @@ func (de *endpoint) noteConnectivityChange() { // It should be called with the Conn.mu held. // // It reports whether m.TxID corresponds to a ping that this endpoint sent. -func (de *endpoint) handlePongConnLocked(m *disco.Pong, di *discoInfo, src netip.AddrPort) (knownTxID bool) { +func (de *endpoint) handlePongConnLocked(m *disco.Pong, di *discoInfo, src netip.AddrPort, msgLen int) (knownTxID bool) { de.mu.Lock() defer de.mu.Unlock() @@ -5009,15 +5052,16 @@ func (de *endpoint) handlePongConnLocked(m *disco.Pong, di *discoInfo, src netip pongAt: now, from: src, pongSrc: m.Src, + msgLen: msgLen, }) } if sp.purpose != pingHeartbeat { - de.c.dlogf("[v1] magicsock: disco: %v<-%v (%v, %v) got pong tx=%x latency=%v pong.src=%v%v", de.c.discoShort, de.discoShort(), de.publicKey.ShortString(), src, m.TxID[:6], latency.Round(time.Millisecond), m.Src, logger.ArgWriter(func(bw *bufio.Writer) { + de.c.dlogf("[v1] magicsock: disco: %v<-%v (%v, %v) got pong tx=%x latency=%v pong.src=%v%v len %v", de.c.discoShort, de.discoShort(), de.publicKey.ShortString(), src, m.TxID[:6], latency.Round(time.Millisecond), m.Src, logger.ArgWriter(func(bw *bufio.Writer) { if sp.to != src { fmt.Fprintf(bw, " ping.to=%v", sp.to) } - })) + }), msgLen) } for _, pp := range de.pendingCLIPings { @@ -5029,9 +5073,9 @@ func (de *endpoint) handlePongConnLocked(m *disco.Pong, di *discoInfo, src netip // Promote this pong response to our current best address if it's lower latency. // TODO(bradfitz): decide how latency vs. preference order affects decision if !isDerp { - thisPong := addrLatency{sp.to, latency} + thisPong := addrQuality{sp.to, latency, msgLen} if betterAddr(thisPong, de.bestAddr) { - de.c.logf("magicsock: disco: node %v %v now using %v", de.publicKey.ShortString(), de.discoShort(), sp.to) + de.c.logf("magicsock: disco: node %v %v now using %v len %v", de.publicKey.ShortString(), de.discoShort(), sp.to, msgLen) de.debugUpdates.Add(EndpointChange{ When: time.Now(), What: "handlePingLocked-bestAddr-update", @@ -5069,19 +5113,23 @@ func portableTrySetSocketBuffer(pconn nettype.PacketConn, logf logger.Logf) { } } -// addrLatency is an IPPort with an associated latency. -type addrLatency struct { +// addrQuality is an IPPort with an associated latency. +type addrQuality struct { netip.AddrPort latency time.Duration + mtu int } -func (a addrLatency) String() string { +func (a addrQuality) String() string { return a.AddrPort.String() + "@" + a.latency.String() } // betterAddr reports whether a is a better addr to use than b. -func betterAddr(a, b addrLatency) bool { +func betterAddr(a, b addrQuality) bool { if a.AddrPort == b.AddrPort { + if a.mtu > b.mtu { + return true + } return false } if !b.IsValid() { @@ -5277,7 +5325,7 @@ func (de *endpoint) stopAndReset() { func (de *endpoint) resetLocked() { de.lastSend = 0 de.lastFullPing = 0 - de.bestAddr = addrLatency{} + de.bestAddr = addrQuality{} de.bestAddrAt = 0 de.trustBestAddrUntil = 0 for _, es := range de.endpointState { diff --git a/wgengine/magicsock/magicsock_test.go b/wgengine/magicsock/magicsock_test.go index 78e5bb232..053fdce00 100644 --- a/wgengine/magicsock/magicsock_test.go +++ b/wgengine/magicsock/magicsock_test.go @@ -1627,10 +1627,13 @@ func TestEndpointSetsEqual(t *testing.T) { func TestBetterAddr(t *testing.T) { const ms = time.Millisecond - al := func(ipps string, d time.Duration) addrLatency { - return addrLatency{netip.MustParseAddrPort(ipps), d} + al := func(ipps string, d time.Duration) addrQuality { + return addrQuality{AddrPort: netip.MustParseAddrPort(ipps), latency: d, mtu: 0} } - zero := addrLatency{} + almtu := func(ipps string, d time.Duration, mtu int) addrQuality { + return addrQuality{AddrPort: netip.MustParseAddrPort(ipps), latency: d, mtu: mtu} + } + zero := addrQuality{} const ( publicV4 = "1.2.3.4:555" @@ -1641,7 +1644,7 @@ func TestBetterAddr(t *testing.T) { ) tests := []struct { - a, b addrLatency + a, b addrQuality want bool // whether a is better than b }{ {a: zero, b: zero, want: false}, @@ -1703,7 +1706,12 @@ func TestBetterAddr(t *testing.T) { b: al(publicV6, 100*ms), want: true, }, - + // If addresses are equal, prefer larger MTU + { + a: almtu(publicV4, 30*ms, 1500), + b: almtu(publicV4, 30*ms, 0), + want: true, + }, // Private IPs are preferred over public IPs even if the public // IP is IPv6. { diff --git a/wgengine/netstack/netstack.go b/wgengine/netstack/netstack.go index 54e991801..0a27f97f2 100644 --- a/wgengine/netstack/netstack.go +++ b/wgengine/netstack/netstack.go @@ -150,6 +150,7 @@ const nicID = 1 // maxUDPPacketSize is the maximum size of a UDP packet we copy in startPacketCopy // when relaying UDP packets. We don't use the 'mtu' const in anticipation of // one day making the MTU more dynamic. +// TODO: make this bigger const maxUDPPacketSize = 1500 // Create creates and populates a new Impl.