diff --git a/cmd/derper/depaware.txt b/cmd/derper/depaware.txt index 84ddb6a5f..30c3ba77d 100644 --- a/cmd/derper/depaware.txt +++ b/cmd/derper/depaware.txt @@ -76,6 +76,7 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa tailscale.com/util/dnsname from tailscale.com/hostinfo+ W tailscale.com/util/endian from tailscale.com/net/netns tailscale.com/util/lineread from tailscale.com/hostinfo+ + tailscale.com/util/mak from tailscale.com/syncs tailscale.com/util/singleflight from tailscale.com/net/dnscache L tailscale.com/util/strs from tailscale.com/hostinfo W 💣 tailscale.com/util/winutil from tailscale.com/hostinfo+ diff --git a/cmd/tailscale/depaware.txt b/cmd/tailscale/depaware.txt index 0bc8504b9..f6b9778be 100644 --- a/cmd/tailscale/depaware.txt +++ b/cmd/tailscale/depaware.txt @@ -100,7 +100,7 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep W tailscale.com/util/endian from tailscale.com/net/netns tailscale.com/util/groupmember from tailscale.com/cmd/tailscale/cli tailscale.com/util/lineread from tailscale.com/net/interfaces+ - tailscale.com/util/mak from tailscale.com/net/netcheck + tailscale.com/util/mak from tailscale.com/net/netcheck+ tailscale.com/util/multierr from tailscale.com/control/controlhttp tailscale.com/util/singleflight from tailscale.com/net/dnscache L tailscale.com/util/strs from tailscale.com/hostinfo diff --git a/syncs/syncs.go b/syncs/syncs.go index 1728a5df4..fd565fbbe 100644 --- a/syncs/syncs.go +++ b/syncs/syncs.go @@ -7,7 +7,10 @@ package syncs import ( "context" + "sync" "sync/atomic" + + "tailscale.com/util/mak" ) // ClosedChan returns a channel that's already closed. @@ -152,3 +155,66 @@ func (s Semaphore) TryAcquire() bool { func (s Semaphore) Release() { <-s.c } + +// Map is a Go map protected by a [sync.RWMutex]. +// It is preferred over [sync.Map] for maps with entries that change +// at a relatively high frequency. +// This must not be shallow copied. +type Map[K comparable, V any] struct { + mu sync.RWMutex + m map[K]V +} + +func (m *Map[K, V]) Load(key K) (value V, ok bool) { + m.mu.RLock() + defer m.mu.RUnlock() + value, ok = m.m[key] + return value, ok +} + +func (m *Map[K, V]) Store(key K, value V) { + m.mu.Lock() + defer m.mu.Unlock() + mak.Set(&m.m, key, value) +} + +func (m *Map[K, V]) LoadOrStore(key K, value V) (actual V, loaded bool) { + if actual, loaded = m.Load(key); loaded { + return actual, loaded + } + + m.mu.Lock() + defer m.mu.Unlock() + actual, loaded = m.m[key] + if !loaded { + actual = value + mak.Set(&m.m, key, value) + } + return actual, loaded +} + +func (m *Map[K, V]) LoadAndDelete(key K) (value V, loaded bool) { + m.mu.Lock() + defer m.mu.Unlock() + value, loaded = m.m[key] + if loaded { + delete(m.m, key) + } + return value, loaded +} + +func (m *Map[K, V]) Delete(key K) { + m.mu.Lock() + defer m.mu.Unlock() + delete(m.m, key) +} + +func (m *Map[K, V]) Range(f func(key K, value V) bool) { + m.mu.RLock() + defer m.mu.RUnlock() + for k, v := range m.m { + if !f(k, v) { + return + } + } +} diff --git a/syncs/syncs_test.go b/syncs/syncs_test.go index a6768e90b..632cae64f 100644 --- a/syncs/syncs_test.go +++ b/syncs/syncs_test.go @@ -6,7 +6,10 @@ package syncs import ( "context" + "sync" "testing" + + "github.com/google/go-cmp/cmp" ) func TestWaitGroupChan(t *testing.T) { @@ -73,3 +76,66 @@ func TestSemaphore(t *testing.T) { s.Release() s.Release() } + +func TestMap(t *testing.T) { + var m Map[string, int] + if v, ok := m.Load("noexist"); v != 0 || ok { + t.Errorf(`Load("noexist") = (%v, %v), want (0, false)`, v, ok) + } + m.Store("one", 1) + if v, ok := m.LoadOrStore("one", -1); v != 1 || !ok { + t.Errorf(`LoadOrStore("one", 1) = (%v, %v), want (1, true)`, v, ok) + } + if v, ok := m.Load("one"); v != 1 || !ok { + t.Errorf(`Load("one") = (%v, %v), want (1, true)`, v, ok) + } + if v, ok := m.LoadOrStore("two", 2); v != 2 || ok { + t.Errorf(`LoadOrStore("two", 2) = (%v, %v), want (2, false)`, v, ok) + } + got := map[string]int{} + want := map[string]int{"one": 1, "two": 2} + m.Range(func(k string, v int) bool { + got[k] = v + return true + }) + if d := cmp.Diff(got, want); d != "" { + t.Errorf("Range mismatch (-got +want):\n%s", d) + } + if v, ok := m.LoadAndDelete("two"); v != 2 || !ok { + t.Errorf(`LoadAndDelete("two) = (%v, %v), want (2, true)`, v, ok) + } + if v, ok := m.LoadAndDelete("two"); v != 0 || ok { + t.Errorf(`LoadAndDelete("two) = (%v, %v), want (0, false)`, v, ok) + } + m.Delete("one") + m.Delete("noexist") + got = map[string]int{} + want = map[string]int{} + m.Range(func(k string, v int) bool { + got[k] = v + return true + }) + if d := cmp.Diff(got, want); d != "" { + t.Errorf("Range mismatch (-got +want):\n%s", d) + } + + t.Run("LoadOrStore", func(t *testing.T) { + var m Map[string, string] + var wg sync.WaitGroup + wg.Add(2) + var ok1, ok2 bool + go func() { + defer wg.Done() + _, ok1 = m.LoadOrStore("", "") + }() + go func() { + defer wg.Done() + _, ok2 = m.LoadOrStore("", "") + }() + wg.Wait() + + if ok1 == ok2 { + t.Errorf("exactly one LoadOrStore should load") + } + }) +}