From b360ab1c6bb449492c4e336d303e4aad481772fd Mon Sep 17 00:00:00 2001 From: David Crawshaw Date: Sun, 12 Mar 2023 22:14:38 -0700 Subject: [PATCH] tsnet: avoid deadlock on close tsnet.Server.Close was calling listener.Close with the server mutex held, but the listener close method tries to grab that mutex, resulting in a deadlock. Signed-off-by: David Crawshaw --- tsnet/tsnet.go | 22 ++++++++++++++++------ tsnet/tsnet_test.go | 22 ++++++++++++++++++++++ 2 files changed, 38 insertions(+), 6 deletions(-) diff --git a/tsnet/tsnet.go b/tsnet/tsnet.go index a4a4bda04..f1ad0180e 100644 --- a/tsnet/tsnet.go +++ b/tsnet/tsnet.go @@ -349,12 +349,17 @@ func (s *Server) Close() error { s.loopbackListener.Close() } + var lns []*listener + s.mu.Lock() - defer s.mu.Unlock() for _, ln := range s.listeners { + lns = append(lns, ln) + } + s.mu.Unlock() + + for _, ln := range lns { ln.Close() } - s.listeners = nil wg.Wait() return nil @@ -997,10 +1002,11 @@ type listenKey struct { } type listener struct { - s *Server - keys []listenKey - addr string - conn chan net.Conn + s *Server + keys []listenKey + addr string + conn chan net.Conn + closed bool // guarded by s.mu } func (ln *listener) Accept() (net.Conn, error) { @@ -1015,12 +1021,16 @@ func (ln *listener) Addr() net.Addr { return addr{ln} } func (ln *listener) Close() error { ln.s.mu.Lock() defer ln.s.mu.Unlock() + if ln.closed { + return fmt.Errorf("tsnet: %w", net.ErrClosed) + } for _, key := range ln.keys { if v, ok := ln.s.listeners[key]; ok && v == ln { delete(ln.s.listeners, key) } } close(ln.conn) + ln.closed = true return nil } diff --git a/tsnet/tsnet_test.go b/tsnet/tsnet_test.go index ab55b7b60..0ab75406b 100644 --- a/tsnet/tsnet_test.go +++ b/tsnet/tsnet_test.go @@ -9,6 +9,7 @@ import ( "flag" "fmt" "io" + "net" "net/http" "net/http/httptest" "net/netip" @@ -344,3 +345,24 @@ func TestTailscaleIPs(t *testing.T) { sIp4, upIp4, sIp6, upIp6) } } + +func TestListenerCleanup(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + controlURL := startControl(t) + s1, _ := startServer(t, ctx, controlURL, "s1") + + ln, err := s1.Listen("tcp", ":8081") + if err != nil { + t.Fatal(err) + } + + if err := s1.Close(); err != nil { + t.Fatal(err) + } + + if err := ln.Close(); !errors.Is(err, net.ErrClosed) { + t.Fatalf("second ln.Close error: %v, want net.ErrClosed", err) + } +}