Compare commits
1 Commits
main
...
crawshaw/l
Author | SHA1 | Date |
---|---|---|
![]() |
b360ab1c6b |
|
@ -349,12 +349,17 @@ func (s *Server) Close() error {
|
||||||
s.loopbackListener.Close()
|
s.loopbackListener.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var lns []*listener
|
||||||
|
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
|
||||||
for _, ln := range s.listeners {
|
for _, ln := range s.listeners {
|
||||||
|
lns = append(lns, ln)
|
||||||
|
}
|
||||||
|
s.mu.Unlock()
|
||||||
|
|
||||||
|
for _, ln := range lns {
|
||||||
ln.Close()
|
ln.Close()
|
||||||
}
|
}
|
||||||
s.listeners = nil
|
|
||||||
|
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
return nil
|
return nil
|
||||||
|
@ -997,10 +1002,11 @@ type listenKey struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type listener struct {
|
type listener struct {
|
||||||
s *Server
|
s *Server
|
||||||
keys []listenKey
|
keys []listenKey
|
||||||
addr string
|
addr string
|
||||||
conn chan net.Conn
|
conn chan net.Conn
|
||||||
|
closed bool // guarded by s.mu
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ln *listener) Accept() (net.Conn, error) {
|
func (ln *listener) Accept() (net.Conn, error) {
|
||||||
|
@ -1015,12 +1021,16 @@ func (ln *listener) Addr() net.Addr { return addr{ln} }
|
||||||
func (ln *listener) Close() error {
|
func (ln *listener) Close() error {
|
||||||
ln.s.mu.Lock()
|
ln.s.mu.Lock()
|
||||||
defer ln.s.mu.Unlock()
|
defer ln.s.mu.Unlock()
|
||||||
|
if ln.closed {
|
||||||
|
return fmt.Errorf("tsnet: %w", net.ErrClosed)
|
||||||
|
}
|
||||||
for _, key := range ln.keys {
|
for _, key := range ln.keys {
|
||||||
if v, ok := ln.s.listeners[key]; ok && v == ln {
|
if v, ok := ln.s.listeners[key]; ok && v == ln {
|
||||||
delete(ln.s.listeners, key)
|
delete(ln.s.listeners, key)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
close(ln.conn)
|
close(ln.conn)
|
||||||
|
ln.closed = true
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -9,6 +9,7 @@ import (
|
||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
@ -344,3 +345,24 @@ func TestTailscaleIPs(t *testing.T) {
|
||||||
sIp4, upIp4, sIp6, upIp6)
|
sIp4, upIp4, sIp6, upIp6)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestListenerCleanup(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
controlURL := startControl(t)
|
||||||
|
s1, _ := startServer(t, ctx, controlURL, "s1")
|
||||||
|
|
||||||
|
ln, err := s1.Listen("tcp", ":8081")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s1.Close(); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := ln.Close(); !errors.Is(err, net.ErrClosed) {
|
||||||
|
t.Fatalf("second ln.Close error: %v, want net.ErrClosed", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue