From 0b1b2e5ed56300a3f82335a8d575f1ad5bdd3253 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Mon, 8 Mar 2021 15:48:49 -0800 Subject: [PATCH] wgengine/magicsock: fix Conn.Rebind race that let ErrClosed errors be read There was a logical race where Conn.Rebind could acquire the RebindingUDPConn mutex, close the connection, fail to rebind, release the mutex, and then because the mutex was no longer held, ReceiveIPv4 wouldn't retry reads that failed with net.ErrClosed, letting that error back to wireguard-go, which would then stop running that receive IP goroutine. Instead, keep the RebindingUDPConn mutex held for the entirety of the replacement in all cases. Updates tailscale/corp#1289 Signed-off-by: Brad Fitzpatrick (cherry picked from commit 387e83c8fe7e3ad446ee0905828275cb42ae58d5) --- wgengine/magicsock/magicsock.go | 40 ++++++++++----- wgengine/magicsock/magicsock_test.go | 76 +++++++++++++++++++++++++--- 2 files changed, 97 insertions(+), 19 deletions(-) diff --git a/wgengine/magicsock/magicsock.go b/wgengine/magicsock/magicsock.go index 5628161e9..5fd305633 100644 --- a/wgengine/magicsock/magicsock.go +++ b/wgengine/magicsock/magicsock.go @@ -2616,28 +2616,38 @@ func (c *Conn) Rebind() { host = "127.0.0.1" } listenCtx := context.Background() // unused without DNS name to resolve + if c.port != 0 { c.pconn4.mu.Lock() + oldPort := c.pconn4.localAddrLocked().Port if err := c.pconn4.pconn.Close(); err != nil { c.logf("magicsock: link change close failed: %v", err) } - packetConn, err := c.listenPacket(listenCtx, "udp4", fmt.Sprintf("%s:%d", host, c.port)) - if err == nil { + packetConn, err := c.listenPacket(listenCtx, "udp4", net.JoinHostPort(host, fmt.Sprint(c.port))) + if err != nil { + c.logf("magicsock: link change unable to bind fixed port %d: %v, falling back to random port", c.port, err) + packetConn, err = c.listenPacket(listenCtx, "udp4", net.JoinHostPort(host, "0")) + if err != nil { + c.logf("magicsock: link change failed to bind random port: %v", err) + c.pconn4.mu.Unlock() + return + } + newPort := c.pconn4.localAddrLocked().Port + c.logf("magicsock: link change rebound port: from %v to %v (failed to get %v)", oldPort, newPort, c.port) + } else { c.logf("magicsock: link change rebound port: %d", c.port) - c.pconn4.pconn = packetConn.(*net.UDPConn) - c.pconn4.mu.Unlock() + } + c.pconn4.pconn = packetConn.(*net.UDPConn) + c.pconn4.mu.Unlock() + } else { + c.logf("magicsock: link change, binding new port") + packetConn, err := c.listenPacket(listenCtx, "udp4", host+":0") + if err != nil { + c.logf("magicsock: link change failed to bind new port: %v", err) return } - c.logf("magicsock: link change unable to bind fixed port %d: %v, falling back to random port", c.port, err) - c.pconn4.mu.Unlock() + c.pconn4.Reset(packetConn.(*net.UDPConn)) } - c.logf("magicsock: link change, binding new port") - packetConn, err := c.listenPacket(listenCtx, "udp4", host+":0") - if err != nil { - c.logf("magicsock: link change failed to bind new port: %v", err) - return - } - c.pconn4.Reset(packetConn.(*net.UDPConn)) c.mu.Lock() c.closeAllDerpLocked("rebind") @@ -2764,6 +2774,10 @@ func (c *RebindingUDPConn) ReadFrom(b []byte) (int, net.Addr, error) { func (c *RebindingUDPConn) LocalAddr() *net.UDPAddr { c.mu.Lock() defer c.mu.Unlock() + return c.localAddrLocked() +} + +func (c *RebindingUDPConn) localAddrLocked() *net.UDPAddr { return c.pconn.LocalAddr().(*net.UDPAddr) } diff --git a/wgengine/magicsock/magicsock_test.go b/wgengine/magicsock/magicsock_test.go index c764aeb89..cacd223cc 100644 --- a/wgengine/magicsock/magicsock_test.go +++ b/wgengine/magicsock/magicsock_test.go @@ -1602,6 +1602,15 @@ func BenchmarkReceiveFrom_Native(b *testing.B) { } } +func logBufWriter(buf *bytes.Buffer) logger.Logf { + return func(format string, a ...interface{}) { + fmt.Fprintf(buf, format, a...) + if !bytes.HasSuffix(buf.Bytes(), []byte("\n")) { + buf.WriteByte('\n') + } + } +} + // Test that a netmap update where node changes its node key but // doesn't change its disco key doesn't result in a broken state. // @@ -1610,12 +1619,7 @@ func TestSetNetworkMapChangingNodeKey(t *testing.T) { conn := newNonLegacyTestConn(t) t.Cleanup(func() { conn.Close() }) var logBuf bytes.Buffer - conn.logf = func(format string, a ...interface{}) { - fmt.Fprintf(&logBuf, format, a...) - if !bytes.HasSuffix(logBuf.Bytes(), []byte("\n")) { - logBuf.WriteByte('\n') - } - } + conn.logf = logBufWriter(&logBuf) conn.SetPrivateKey(wgkey.Private{0: 1}) @@ -1669,3 +1673,63 @@ func TestSetNetworkMapChangingNodeKey(t *testing.T) { t.Logf("log output: %s", log) } } + +func TestRebindStress(t *testing.T) { + conn := newNonLegacyTestConn(t) + + var logBuf bytes.Buffer + conn.logf = logBufWriter(&logBuf) + + closed := false + t.Cleanup(func() { + if !closed { + conn.Close() + } + }) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + errc := make(chan error, 1) + go func() { + buf := make([]byte, 1500) + for { + _, _, err := conn.ReceiveIPv4(buf) + if ctx.Err() != nil { + errc <- nil + return + } + if err != nil { + errc <- err + return + } + } + }() + + var wg sync.WaitGroup + wg.Add(2) + go func() { + defer wg.Done() + for i := 0; i < 2000; i++ { + conn.Rebind() + } + }() + go func() { + defer wg.Done() + for i := 0; i < 2000; i++ { + conn.Rebind() + } + }() + wg.Wait() + + cancel() + if err := conn.Close(); err != nil { + t.Fatal(err) + } + closed = true + + err := <-errc + if err != nil { + t.Fatalf("Got ReceiveIPv4 error: %v. Log:\n%s", err, logBuf.Bytes()) + } +}