diff --git a/wgengine/magicsock/magicsock_test.go b/wgengine/magicsock/magicsock_test.go index 4890ea70a..26c157d16 100644 --- a/wgengine/magicsock/magicsock_test.go +++ b/wgengine/magicsock/magicsock_test.go @@ -11,6 +11,7 @@ import ( "crypto/tls" "encoding/binary" "encoding/json" + "errors" "fmt" "io/ioutil" "net" @@ -927,30 +928,45 @@ func testTwoDevicePing(t *testing.T, d *devices) { // Retries take 5s each. Add 1s for some processing time. pingTimeout := 5*time.Second*time.Duration(allowedRetries) + time.Second + // sendWithTimeout sends msg using send, checking that it is received unchanged from in. + // It resends once per second until the send succeeds, or pingTimeout time has elapsed. + sendWithTimeout := func(msg []byte, in chan []byte, send func()) error { + start := time.Now() + for time.Since(start) < pingTimeout { + send() + select { + case recv := <-in: + if !bytes.Equal(msg, recv) { + return errors.New("ping did not transit correctly") + } + return nil + case <-time.After(time.Second): + // try again + } + } + return errors.New("ping timed out") + } + ping1 := func(t *testing.T) { msg2to1 := tuntest.Ping(net.ParseIP("1.0.0.1"), net.ParseIP("1.0.0.2")) - m2.tun.Outbound <- msg2to1 - t.Log("ping1 sent") - select { - case msgRecv := <-m1.tun.Inbound: - if !bytes.Equal(msg2to1, msgRecv) { - t.Error("ping did not transit correctly") - } - case <-time.After(pingTimeout): - t.Error("ping did not transit") + send := func() { + m2.tun.Outbound <- msg2to1 + t.Log("ping1 sent") + } + in := m1.tun.Inbound + if err := sendWithTimeout(msg2to1, in, send); err != nil { + t.Error(err) } } ping2 := func(t *testing.T) { msg1to2 := tuntest.Ping(net.ParseIP("1.0.0.2"), net.ParseIP("1.0.0.1")) - m1.tun.Outbound <- msg1to2 - t.Log("ping2 sent") - select { - case msgRecv := <-m2.tun.Inbound: - if !bytes.Equal(msg1to2, msgRecv) { - t.Error("return ping did not transit correctly") - } - case <-time.After(pingTimeout): - t.Error("return ping did not transit") + send := func() { + m1.tun.Outbound <- msg1to2 + t.Log("ping2 sent") + } + in := m2.tun.Inbound + if err := sendWithTimeout(msg1to2, in, send); err != nil { + t.Error(err) } } @@ -971,17 +987,15 @@ func testTwoDevicePing(t *testing.T, d *devices) { setT(t) defer setT(outerT) msg1to2 := tuntest.Ping(net.ParseIP("1.0.0.2"), net.ParseIP("1.0.0.1")) - if err := m1.tsTun.InjectOutbound(msg1to2); err != nil { - t.Fatal(err) - } - t.Log("SendPacket sent") - select { - case msgRecv := <-m2.tun.Inbound: - if !bytes.Equal(msg1to2, msgRecv) { - t.Error("return ping did not transit correctly") + send := func() { + if err := m1.tsTun.InjectOutbound(msg1to2); err != nil { + t.Fatal(err) } - case <-time.After(pingTimeout): - t.Error("return ping did not transit") + t.Log("SendPacket sent") + } + in := m2.tun.Inbound + if err := sendWithTimeout(msg1to2, in, send); err != nil { + t.Error(err) } })