From f205d232331579102ae22ff2eab727c6d7e42aba Mon Sep 17 00:00:00 2001 From: Tom DNetto Date: Wed, 21 Jun 2023 14:57:15 -0700 Subject: [PATCH] disco,types,wgengine: implement Knock,KnockReply disco messages EXTREME WIP, DO NOT SUBMIT Updates #1227 --- disco/disco.go | 59 +++++++++++ disco/disco_test.go | 15 +++ types/key/node.go | 20 ++++ wgengine/magicsock/magicsock.go | 141 +++++++++++++++++++++++++++ wgengine/magicsock/magicsock_test.go | 61 ++++++++++++ 5 files changed, 296 insertions(+) diff --git a/disco/disco.go b/disco/disco.go index 0e7c3f7e5..7c92507d6 100644 --- a/disco/disco.go +++ b/disco/disco.go @@ -27,6 +27,7 @@ import ( "net/netip" "go4.org/mem" + "golang.org/x/crypto/nacl/box" "tailscale.com/types/key" ) @@ -44,6 +45,8 @@ const ( TypePing = MessageType(0x01) TypePong = MessageType(0x02) TypeCallMeMaybe = MessageType(0x03) + TypeKnock = MessageType(0x04) + TypeKnockReply = MessageType(0x05) ) const v0 = byte(0) @@ -83,6 +86,10 @@ func Parse(p []byte) (Message, error) { return parsePong(ver, p) case TypeCallMeMaybe: return parseCallMeMaybe(ver, p) + case TypeKnock: + return parseKnock(ver, p) + case TypeKnockReply: + return parseKnockReply(ver, p) default: return nil, fmt.Errorf("unknown message type 0x%02x", byte(t)) } @@ -240,6 +247,54 @@ func parsePong(ver uint8, p []byte) (m *Pong, err error) { return m, nil } +type Knock struct { + // SealedNonce is the random client-generated per-knock nonce, + // which is NaCL-box sealed to the node key of the destination. + // The unencrypted nonce is 8 bytes. + SealedNonce [box.AnonymousOverhead + 8]byte +} + +func (m *Knock) AppendMarshal(b []byte) []byte { + dataLen := box.AnonymousOverhead + 8 + ret, d := appendMsgHeader(b, TypeKnock, v0, dataLen) + copy(d, m.SealedNonce[:]) + return ret +} + +func parseKnock(ver uint8, p []byte) (m *Knock, err error) { + if len(p) < (box.AnonymousOverhead + 8) { + return nil, errShort + } + m = new(Knock) + p = p[copy(m.SealedNonce[:], p):] + // Deliberately lax on longer-than-expected messages, for future + // compatibility. + return m, nil +} + +type KnockReply struct { + // Nonce is the nonce value from the Knock request. + Nonce [8]byte +} + +func (m *KnockReply) AppendMarshal(b []byte) []byte { + dataLen := 8 + ret, d := appendMsgHeader(b, TypeKnockReply, v0, dataLen) + copy(d, m.Nonce[:]) + return ret +} + +func parseKnockReply(ver uint8, p []byte) (m *KnockReply, err error) { + if len(p) < 8 { + return nil, errShort + } + m = new(KnockReply) + p = p[copy(m.Nonce[:], p):] + // Deliberately lax on longer-than-expected messages, for future + // compatibility. + return m, nil +} + // MessageSummary returns a short summary of m for logging purposes. func MessageSummary(m Message) string { switch m := m.(type) { @@ -249,6 +304,10 @@ func MessageSummary(m Message) string { return fmt.Sprintf("pong tx=%x", m.TxID[:6]) case *CallMeMaybe: return "call-me-maybe" + case *Knock: + return fmt.Sprintf("knock") + case *KnockReply: + return fmt.Sprintf("knock reply nonce=%x", m.Nonce[:]) default: return fmt.Sprintf("%#v", m) } diff --git a/disco/disco_test.go b/disco/disco_test.go index 67bd1561a..6d8a2be13 100644 --- a/disco/disco_test.go +++ b/disco/disco_test.go @@ -4,6 +4,7 @@ package disco import ( + "bytes" "fmt" "net/netip" "reflect" @@ -66,6 +67,20 @@ func TestMarshalAndParse(t *testing.T) { }, want: "03 00 00 00 00 00 00 00 00 00 00 00 ff ff 01 02 03 04 02 37 20 01 00 00 00 00 00 00 00 00 00 00 00 00 34 56 03 15", }, + { + name: "knock", + m: &Knock{ + SealedNonce: [16 + 32 + 8]byte(bytes.Repeat([]byte{1, 2}, 28)), + }, + want: "04 00 01 02 01 02 01 02 01 02 01 02 01 02 01 02 01 02 01 02 01 02 01 02 01 02 01 02 01 02 01 02 01 02 01 02 01 02 01 02 01 02 01 02 01 02 01 02 01 02 01 02 01 02 01 02 01 02", + }, + { + name: "knock_reply", + m: &KnockReply{ + Nonce: [8]byte{1, 2, 3, 4, 5, 6, 7, 8}, + }, + want: "05 00 01 02 03 04 05 06 07 08", + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/types/key/node.go b/types/key/node.go index a84057231..8d7b99d59 100644 --- a/types/key/node.go +++ b/types/key/node.go @@ -142,6 +142,26 @@ func (k NodePrivate) OpenFrom(p NodePublic, ciphertext []byte) (cleartext []byte return box.Open(nil, ciphertext[len(nonce):], nonce, &p.k, &k.k) } +// SealAnonymous seals the cleartext to the node key k. +func (k NodePublic) SealAnonymous(cleartext []byte) (ciphertext []byte, err error) { + if k.IsZero() { + panic("can't seal with zero keys") + } + return box.SealAnonymous(nil, cleartext, &k.k, nil) +} + +// OpenAnonymous opens the anonymous NaCl box ciphertext, which must be a value +// created by SealAnonymous, and returns the inner cleartext if ciphertext is +// a valid box to k. +func (k NodePrivate) OpenAnonymous(ciphertext []byte) (cleartext []byte, ok bool) { + if k.IsZero() { + panic("can't open with zero keys") + } + + p := k.Public() + return box.OpenAnonymous(nil, ciphertext, &p.k, &k.k) +} + func (k NodePrivate) UntypedHexString() string { return hex.EncodeToString(k.k[:]) } diff --git a/wgengine/magicsock/magicsock.go b/wgengine/magicsock/magicsock.go index dea8b2d97..ee7f755a6 100644 --- a/wgengine/magicsock/magicsock.go +++ b/wgengine/magicsock/magicsock.go @@ -2306,6 +2306,25 @@ func (c *Conn) handleDiscoMessage(msg []byte, src netip.AddrPort, derpNodeSrc ke ep.publicKey.ShortString(), derpStr(src.String()), len(dm.MyNumber)) go ep.handleCallMeMaybe(dm) + case *disco.Knock: + metricRecvDiscoKnock.Add(1) + if isDERP { + metricRecvDiscoKnockBadDisco.Add(1) + c.logf("[unexpected] Knock packets should only come via LAN") + return + } + c.handleKnockLocked(dm, src, di) + case *disco.KnockReply: + metricRecvDiscoKnockReply.Add(1) + if isDERP { + metricRecvDiscoKnockReplyBadDisco.Add(1) + c.logf("[unexpected] Knock reply packets should only come via LAN") + return + } + c.logf("magicsock: disco: got knock reply %v from %v", dm, src) + c.peerMap.forEachEndpointWithDiscoKey(sender, func(ep *endpoint) (keepGoing bool) { + return !ep.handleKnockReplyLocked(dm, src, di) + }) } return } @@ -2348,6 +2367,114 @@ func (c *Conn) unambiguousNodeKeyOfPingLocked(dm *disco.Ping, dk key.DiscoPublic return nk, false } +// handleKnockReplyLocked handles a DISCO Knock Reply message. If the nonce is +// correct, the callback for the pending knock is invoked. +// +// True is returned if this endpoint handled the nonce. +// +// di is the discoInfo of the source of the knock packet. +func (de *endpoint) handleKnockReplyLocked(dm *disco.KnockReply, src netip.AddrPort, di *discoInfo) bool { + de.mu.Lock() + defer de.mu.Unlock() + + if de.pendingKnock == nil || !bytes.Equal(dm.Nonce[:], de.pendingKnock.nonce[:]) { + return false + } + + // From this point on, nonce is correct + cb := de.pendingKnock.cb + de.pendingKnock = nil + go cb(nil) + return true +} + +// handleKnockLocked handles a DISCO Knock message. If the recieved packet +// is in order, a response is sent containing the unwrapped nonce. +// +// di is the discoInfo of the source of the knock packet. +func (c *Conn) handleKnockLocked(dm *disco.Knock, src netip.AddrPort, di *discoInfo) { + // TODO(tom): Filter to LAN-only sources + + nonceBytes, ok := c.privateKey.OpenAnonymous(dm.SealedNonce[:]) + if !ok { + metricRecvDiscoKnockBadSeal.Add(1) + c.logf("magicsock: disco: dropping bad knock from %v", src) + return + } + + var nonce [8]byte + copy(nonce[:], nonceBytes) + + c.peerMap.forEachEndpointWithDiscoKey(di.discoKey, func(ep *endpoint) (keepGoing bool) { + go c.sendDiscoMessage(src, ep.publicKey, di.discoKey, &disco.KnockReply{ + Nonce: nonce, + }, discoVerboseLog) + return true + }) +} + +// Knock handles a request to knock a specific peer. +func (c *Conn) Knock(addr netip.AddrPort, peer *tailcfg.Node, cb func(error)) { + if runtime.GOOS == "js" { + cb(errors.New("no direct over tsconnect")) + return + } + + c.mu.Lock() + defer c.mu.Unlock() + if c.privateKey.IsZero() { + cb(errNetworkDown) + return + } + + ep, ok := c.peerMap.endpointForNodeKey(peer.Key) + if !ok { + cb(errors.New("unknown peer")) + return + } + ep.knock(addr, cb) +} + +func (de *endpoint) knock(addr netip.AddrPort, cb func(error)) { + de.mu.Lock() + defer de.mu.Unlock() + + if de.expired { + cb(errExpired) + return + } + epDisco := de.disco.Load() + if epDisco == nil { + cb(errors.New("no disco key")) + return + } + + var nonce [8]byte + if _, err := crand.Read(nonce[:]); err != nil { + panic(err) // worth dying for + } + sealed, err := de.publicKey.SealAnonymous(nonce[:]) + if err != nil { + cb(err) + return + } + + if de.pendingKnock != nil { + de.pendingKnock.cb(errors.New("superceded")) + } + de.pendingKnock = &pendingKnock{addr, cb, nonce} + + go func() { + knock := disco.Knock{} + copy(knock.SealedNonce[:], sealed) + sent, _ := de.c.sendDiscoMessage(addr, de.publicKey, epDisco.key, &knock, discoVerboseLog) + if !sent { + panic("not sent") + } + }() + de.noteActiveLocked() +} + // di is the discoInfo of the source of the ping. // derpNodeSrc is non-zero if the ping arrived via DERP. func (c *Conn) handlePingLocked(dm *disco.Ping, src netip.AddrPort, di *discoInfo, derpNodeSrc key.NodePublic) { @@ -4141,6 +4268,8 @@ type endpoint struct { pendingCLIPings []pendingCLIPing // any outstanding "tailscale ping" commands running + pendingKnock *pendingKnock // any outstanding knock challenge, if any + // The following fields are related to the new "silent disco" // implementation that's a WIP as of 2022-10-20. // See #540 for background. @@ -4156,6 +4285,12 @@ type pendingCLIPing struct { cb func(*ipnstate.PingResult) } +type pendingKnock struct { + addr netip.AddrPort + cb func(error) + nonce [8]byte +} + const ( // sessionActiveTimeout is how long since the last activity we // try to keep an established endpoint peering alive. @@ -5269,6 +5404,7 @@ func (de *endpoint) stopAndReset() { de.heartBeatTimer = nil } de.pendingCLIPings = nil + de.pendingKnock = nil } // resetLocked clears all the endpoint's p2p state, reverting it to a @@ -5468,6 +5604,11 @@ var ( metricRecvDiscoCallMeMaybe = clientmetric.NewCounter("magicsock_disco_recv_callmemaybe") metricRecvDiscoCallMeMaybeBadNode = clientmetric.NewCounter("magicsock_disco_recv_callmemaybe_bad_node") metricRecvDiscoCallMeMaybeBadDisco = clientmetric.NewCounter("magicsock_disco_recv_callmemaybe_bad_disco") + metricRecvDiscoKnock = clientmetric.NewCounter("magicsock_disco_recv_knock") + metricRecvDiscoKnockBadDisco = clientmetric.NewCounter("magicsock_disco_recv_knock_bad_disco") + metricRecvDiscoKnockBadSeal = clientmetric.NewCounter("magicsock_disco_recv_knock_bad_seal") + metricRecvDiscoKnockReply = clientmetric.NewCounter("magicsock_disco_recv_knock_reply") + metricRecvDiscoKnockReplyBadDisco = clientmetric.NewCounter("magicsock_disco_recv_knock_reply_bad_disco") metricRecvDiscoDERPPeerNotHere = clientmetric.NewCounter("magicsock_disco_recv_derp_peer_not_here") metricRecvDiscoDERPPeerGoneUnknown = clientmetric.NewCounter("magicsock_disco_recv_derp_peer_gone_unknown") // metricDERPHomeChange is how many times our DERP home region DI has diff --git a/wgengine/magicsock/magicsock_test.go b/wgengine/magicsock/magicsock_test.go index 78e5bb232..0423447c8 100644 --- a/wgengine/magicsock/magicsock_test.go +++ b/wgengine/magicsock/magicsock_test.go @@ -2908,3 +2908,64 @@ func TestAddrForSendLockedForWireGuardOnly(t *testing.T) { } } } + +func TestDiscoKnock(t *testing.T) { + mstun := &natlab.Machine{Name: "stun"} + m1 := &natlab.Machine{Name: "m1"} + m2 := &natlab.Machine{Name: "m2"} + inet := natlab.NewInternet() + sif := mstun.Attach("eth0", inet) + m1if := m1.Attach("eth0", inet) + m2if := m2.Attach("eth0", inet) + + d := &devices{ + m1: m1, + m1IP: m1if.V4(), + m2: m2, + m2IP: m2if.V4(), + stun: mstun, + stunIP: sif.V4(), + } + + logf, closeLogf := logger.LogfCloser(t.Logf) + defer closeLogf() + + derpMap, cleanup := runDERPAndStun(t, logf, d.stun, d.stunIP) + defer cleanup() + + ms1 := newMagicStack(t, logger.WithPrefix(logf, "conn1: "), d.m1, derpMap) + defer ms1.Close() + ms2 := newMagicStack(t, logger.WithPrefix(logf, "conn2: "), d.m2, derpMap) + defer ms2.Close() + + cleanup = meshStacks(t.Logf, nil, ms1, ms2) + defer cleanup() + + // Wait for both peers to know about each other. + for { + if s1 := ms1.Status(); len(s1.Peer) != 1 { + time.Sleep(10 * time.Millisecond) + continue + } + if s2 := ms2.Status(); len(s2.Peer) != 1 { + time.Sleep(10 * time.Millisecond) + continue + } + break + } + + cbErr := make(chan error, 1) + ms1.conn.Knock(netip.AddrPortFrom(m2if.V4(), ms2.conn.pconn4.LocalAddr().AddrPort().Port()), &tailcfg.Node{Key: ms2.privateKey.Public()}, func(err error) { + cbErr <- err + }) + + select { + case err := <-cbErr: + if err != nil { + t.Errorf("Knock failed: %v", err) + } + + case <-time.After(2 * time.Second): + t.Error("timeout waiting for knock callback") + } +}