From 6cc6c70d704fedb7ee5c99e2c8fd9c8946f90efe Mon Sep 17 00:00:00 2001 From: Anton Tolchanov Date: Tue, 22 Nov 2022 16:13:53 +0000 Subject: [PATCH] derp: prevent concurrent access to multiForwarder map Instead of iterating over the map to determine the preferred forwarder on every packet (which could happen concurrently with map mutations), store it separately in an atomic variable. Fixes #6445 Signed-off-by: Anton Tolchanov --- derp/derp_server.go | 94 ++++++++++++++++++++++++++++++--------------- derp/derp_test.go | 81 ++++++++++++++++++++++++++++++++++---- 2 files changed, 136 insertions(+), 39 deletions(-) diff --git a/derp/derp_server.go b/derp/derp_server.go index cf7f6fe7e..924192477 100644 --- a/derp/derp_server.go +++ b/derp/derp_server.go @@ -40,6 +40,7 @@ import ( "tailscale.com/disco" "tailscale.com/envknob" "tailscale.com/metrics" + "tailscale.com/syncs" "tailscale.com/types/key" "tailscale.com/types/logger" "tailscale.com/version" @@ -1560,22 +1561,20 @@ func (s *Server) AddPacketForwarder(dst key.NodePublic, fwd PacketForwarder) { // Duplicate registration of same forwarder. Ignore. return } - if m, ok := prev.(multiForwarder); ok { - if _, ok := m[fwd]; ok { + if m, ok := prev.(*multiForwarder); ok { + if _, ok := m.all[fwd]; ok { // Duplicate registration of same forwarder in set; ignore. return } - m[fwd] = m.maxVal() + 1 + m.add(fwd) return } if prev != nil { // Otherwise, the existing value is not a set, // not a dup, and not local-only (nil) so make - // it a set. - fwd = multiForwarder{ - prev: 1, // existed 1st, higher priority - fwd: 2, // the passed in fwd is in 2nd place - } + // it a set. `prev` existed first, so will have higher + // priority. + fwd = newMultiForwarder(prev, fwd) s.multiForwarderCreated.Add(1) } } @@ -1591,19 +1590,14 @@ func (s *Server) RemovePacketForwarder(dst key.NodePublic, fwd PacketForwarder) if !ok { return } - if m, ok := v.(multiForwarder); ok { - if len(m) < 2 { + if m, ok := v.(*multiForwarder); ok { + if len(m.all) < 2 { panic("unexpected") } - delete(m, fwd) - // If fwd was in m and we no longer need to be a - // multiForwarder, replace the entry with the - // remaining PacketForwarder. - if len(m) == 1 { - var remain PacketForwarder - for k := range m { - remain = k - } + if remain, isLast := m.deleteLocked(fwd); isLast { + // If fwd was in m and we no longer need to be a + // multiForwarder, replace the entry with the + // remaining PacketForwarder. s.clientsMesh[dst] = remain s.multiForwarderDeleted.Add(1) } @@ -1635,27 +1629,65 @@ func (s *Server) RemovePacketForwarder(dst key.NodePublic, fwd PacketForwarder) // client is. The map value is unique connection number; the lowest // one has been seen the longest. It's used to make sure we forward // packets consistently to the same node and don't pick randomly. -type multiForwarder map[PacketForwarder]uint8 +type multiForwarder struct { + fwd syncs.AtomicValue[PacketForwarder] // preferred forwarder. + all map[PacketForwarder]uint8 // all forwarders, protected by s.mu. +} -func (m multiForwarder) maxVal() (max uint8) { - for _, v := range m { +// newMultiForwarder creates a new multiForwarder. +// The first PacketForwarder passed to this function will be the preferred one. +func newMultiForwarder(fwds ...PacketForwarder) *multiForwarder { + f := &multiForwarder{all: make(map[PacketForwarder]uint8)} + f.fwd.Store(fwds[0]) + for idx, fwd := range fwds { + f.all[fwd] = uint8(idx) + } + return f +} + +// add adds a new forwarder to the map with a connection number that +// is higher than the existing ones. +func (f *multiForwarder) add(fwd PacketForwarder) { + var max uint8 + for _, v := range f.all { if v > max { max = v } } - return + f.all[fwd] = max + 1 } -func (m multiForwarder) ForwardPacket(src, dst key.NodePublic, payload []byte) error { - var fwd PacketForwarder - var lowest uint8 - for k, v := range m { - if fwd == nil || v < lowest { - fwd = k - lowest = v +// deleteLocked removes a packet forwarder from the map. It expects Server.mu to be held. +// If only one forwarder remains after the removal, it will be returned alongside a `true` boolean value. +func (f *multiForwarder) deleteLocked(fwd PacketForwarder) (_ PacketForwarder, isLast bool) { + delete(f.all, fwd) + + if fwd == f.fwd.Load() { + // The preferred forwarder has been removed, choose a new one + // based on the lowest index. + var lowestfwd PacketForwarder + var lowest uint8 + for k, v := range f.all { + if lowestfwd == nil || v < lowest { + lowestfwd = k + lowest = v + } + } + if lowestfwd != nil { + f.fwd.Store(lowestfwd) } } - return fwd.ForwardPacket(src, dst, payload) + + if len(f.all) == 1 { + for k := range f.all { + return k, true + } + } + return nil, false +} + +func (f *multiForwarder) ForwardPacket(src, dst key.NodePublic, payload []byte) error { + return f.fwd.Load().ForwardPacket(src, dst, payload) } func (s *Server) expVarFunc(f func() any) expvar.Func { diff --git a/derp/derp_test.go b/derp/derp_test.go index 6da11197a..2edcb057a 100644 --- a/derp/derp_test.go +++ b/derp/derp_test.go @@ -19,6 +19,7 @@ import ( "net" "os" "reflect" + "strconv" "sync" "testing" "time" @@ -723,20 +724,14 @@ func TestForwarderRegistration(t *testing.T) { s.AddPacketForwarder(u1, testFwd(100)) s.AddPacketForwarder(u1, testFwd(100)) // dup to trigger dup path want(map[key.NodePublic]PacketForwarder{ - u1: multiForwarder{ - testFwd(1): 1, - testFwd(100): 2, - }, + u1: newMultiForwarder(testFwd(1), testFwd(100)), }) wantCounter(&s.multiForwarderCreated, 1) // Removing a forwarder in a multi set that doesn't exist; does nothing. s.RemovePacketForwarder(u1, testFwd(55)) want(map[key.NodePublic]PacketForwarder{ - u1: multiForwarder{ - testFwd(1): 1, - testFwd(100): 2, - }, + u1: newMultiForwarder(testFwd(1), testFwd(100)), }) // Removing a forwarder in a multi set that does exist should collapse it away @@ -785,6 +780,76 @@ func TestForwarderRegistration(t *testing.T) { }) } +type channelFwd struct { + // id is to ensure that different instances that reference the + // same channel are not equal, as they are used as keys in the + // multiForwarder map. + id int + c chan []byte +} + +func (f channelFwd) ForwardPacket(_ key.NodePublic, _ key.NodePublic, packet []byte) error { + f.c <- packet + return nil +} + +func TestMultiForwarder(t *testing.T) { + received := 0 + var wg sync.WaitGroup + ch := make(chan []byte) + ctx, cancel := context.WithCancel(context.Background()) + + s := &Server{ + clients: make(map[key.NodePublic]clientSet), + clientsMesh: map[key.NodePublic]PacketForwarder{}, + } + u := pubAll(1) + s.AddPacketForwarder(u, channelFwd{1, ch}) + + wg.Add(2) + go func() { + defer wg.Done() + for { + select { + case <-ch: + received += 1 + case <-ctx.Done(): + return + } + } + }() + go func() { + defer wg.Done() + for { + s.AddPacketForwarder(u, channelFwd{2, ch}) + s.AddPacketForwarder(u, channelFwd{3, ch}) + s.RemovePacketForwarder(u, channelFwd{2, ch}) + s.RemovePacketForwarder(u, channelFwd{1, ch}) + s.AddPacketForwarder(u, channelFwd{1, ch}) + s.RemovePacketForwarder(u, channelFwd{3, ch}) + if ctx.Err() != nil { + return + } + } + }() + + // Number of messages is chosen arbitrarily, just for this loop to + // run long enough concurrently with {Add,Remove}PacketForwarder loop above. + numMsgs := 5000 + var fwd PacketForwarder + for i := 0; i < numMsgs; i++ { + s.mu.Lock() + fwd = s.clientsMesh[u] + s.mu.Unlock() + fwd.ForwardPacket(u, u, []byte(strconv.Itoa(i))) + } + + cancel() + wg.Wait() + if received != numMsgs { + t.Errorf("expected %d messages to be forwarded; got %d", numMsgs, received) + } +} func TestMetaCert(t *testing.T) { priv := key.NewNode() pub := priv.Public()