Compare commits
6 Commits
main
...
crawshaw/d
Author | SHA1 | Date |
---|---|---|
![]() |
efbb4f2b66 | |
![]() |
9cd899f83e | |
![]() |
656b6f3fd3 | |
![]() |
ef5b09563f | |
![]() |
4ad608ac92 | |
![]() |
7221b8eff5 |
|
@ -119,6 +119,7 @@ func main() {
|
||||||
letsEncrypt := tsweb.IsProd443(*addr)
|
letsEncrypt := tsweb.IsProd443(*addr)
|
||||||
|
|
||||||
s := derp.NewServer(key.Private(cfg.PrivateKey), log.Printf)
|
s := derp.NewServer(key.Private(cfg.PrivateKey), log.Printf)
|
||||||
|
s.WriteTimeout = 2 * time.Second
|
||||||
if *mbps != 0 {
|
if *mbps != 0 {
|
||||||
s.BytesPerSecond = (*mbps << 20) / 8
|
s.BytesPerSecond = (*mbps << 20) / 8
|
||||||
}
|
}
|
||||||
|
|
|
@ -11,7 +11,6 @@ import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -26,7 +25,7 @@ type Client struct {
|
||||||
publicKey key.Public // of privateKey
|
publicKey key.Public // of privateKey
|
||||||
protoVersion int // min of server+client
|
protoVersion int // min of server+client
|
||||||
logf logger.Logf
|
logf logger.Logf
|
||||||
nc net.Conn
|
nc Conn
|
||||||
br *bufio.Reader
|
br *bufio.Reader
|
||||||
|
|
||||||
wmu sync.Mutex // hold while writing to bw
|
wmu sync.Mutex // hold while writing to bw
|
||||||
|
@ -34,7 +33,7 @@ type Client struct {
|
||||||
readErr error // sticky read error
|
readErr error // sticky read error
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewClient(privateKey key.Private, nc net.Conn, brw *bufio.ReadWriter, logf logger.Logf) (*Client, error) {
|
func NewClient(privateKey key.Private, nc Conn, brw *bufio.ReadWriter, logf logger.Logf) (*Client, error) {
|
||||||
c := &Client{
|
c := &Client{
|
||||||
privateKey: privateKey,
|
privateKey: privateKey,
|
||||||
publicKey: privateKey.Public(),
|
publicKey: privateKey.Public(),
|
||||||
|
@ -138,7 +137,7 @@ func (c *Client) Send(dstKey key.Public, pkt []byte) error { return c.send(dstKe
|
||||||
func (c *Client) send(dstKey key.Public, pkt []byte) (ret error) {
|
func (c *Client) send(dstKey key.Public, pkt []byte) (ret error) {
|
||||||
defer func() {
|
defer func() {
|
||||||
if ret != nil {
|
if ret != nil {
|
||||||
ret = fmt.Errorf("derp.Send: %v", ret)
|
ret = fmt.Errorf("derp.Send: %w", ret)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
@ -215,7 +214,7 @@ func (c *Client) Recv(b []byte) (m ReceivedMessage, err error) {
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err = fmt.Errorf("derp.Recv: %v", err)
|
err = fmt.Errorf("derp.Recv: %w", err)
|
||||||
c.readErr = err
|
c.readErr = err
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
|
@ -18,7 +18,6 @@ import (
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"math/big"
|
"math/big"
|
||||||
"net"
|
|
||||||
"os"
|
"os"
|
||||||
"strconv"
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
|
@ -39,6 +38,10 @@ type Server struct {
|
||||||
// second to cap per-client reads at.
|
// second to cap per-client reads at.
|
||||||
BytesPerSecond int
|
BytesPerSecond int
|
||||||
|
|
||||||
|
// WriteTimeout, if non-zero, specifies how long to wait
|
||||||
|
// before failing when writing to a client.
|
||||||
|
WriteTimeout time.Duration
|
||||||
|
|
||||||
privateKey key.Private
|
privateKey key.Private
|
||||||
publicKey key.Public
|
publicKey key.Public
|
||||||
logf logger.Logf
|
logf logger.Logf
|
||||||
|
@ -57,11 +60,23 @@ type Server struct {
|
||||||
|
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
closed bool
|
closed bool
|
||||||
netConns map[net.Conn]chan struct{} // chan is closed when conn closes
|
netConns map[Conn]chan struct{} // chan is closed when conn closes
|
||||||
clients map[key.Public]*sclient
|
clients map[key.Public]*sclient
|
||||||
clientsEver map[key.Public]bool // never deleted from, for stats; fine for now
|
clientsEver map[key.Public]bool // never deleted from, for stats; fine for now
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Conn is the subset of the underlying net.Conn the DERP Server needs.
|
||||||
|
// It is a defined type so that non-net connections can be used.
|
||||||
|
type Conn interface {
|
||||||
|
io.Closer
|
||||||
|
|
||||||
|
// The *Deadline methods follow the semantics of net.Conn.
|
||||||
|
|
||||||
|
SetDeadline(time.Time) error
|
||||||
|
SetReadDeadline(time.Time) error
|
||||||
|
SetWriteDeadline(time.Time) error
|
||||||
|
}
|
||||||
|
|
||||||
// NewServer returns a new DERP server. It doesn't listen on its own.
|
// NewServer returns a new DERP server. It doesn't listen on its own.
|
||||||
// Connections are given to it via Server.Accept.
|
// Connections are given to it via Server.Accept.
|
||||||
func NewServer(privateKey key.Private, logf logger.Logf) *Server {
|
func NewServer(privateKey key.Private, logf logger.Logf) *Server {
|
||||||
|
@ -71,7 +86,7 @@ func NewServer(privateKey key.Private, logf logger.Logf) *Server {
|
||||||
logf: logf,
|
logf: logf,
|
||||||
clients: make(map[key.Public]*sclient),
|
clients: make(map[key.Public]*sclient),
|
||||||
clientsEver: make(map[key.Public]bool),
|
clientsEver: make(map[key.Public]bool),
|
||||||
netConns: make(map[net.Conn]chan struct{}),
|
netConns: make(map[Conn]chan struct{}),
|
||||||
}
|
}
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
@ -115,7 +130,7 @@ func (s *Server) isClosed() bool {
|
||||||
// on its own.
|
// on its own.
|
||||||
//
|
//
|
||||||
// Accept closes nc.
|
// Accept closes nc.
|
||||||
func (s *Server) Accept(nc net.Conn, brw *bufio.ReadWriter) {
|
func (s *Server) Accept(nc Conn, brw *bufio.ReadWriter, remoteAddr string) {
|
||||||
closed := make(chan struct{})
|
closed := make(chan struct{})
|
||||||
|
|
||||||
s.accepts.Add(1)
|
s.accepts.Add(1)
|
||||||
|
@ -132,8 +147,8 @@ func (s *Server) Accept(nc net.Conn, brw *bufio.ReadWriter) {
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if err := s.accept(nc, brw); err != nil && !s.isClosed() {
|
if err := s.accept(nc, brw, remoteAddr); err != nil && !s.isClosed() {
|
||||||
s.logf("derp: %s: %v", nc.RemoteAddr(), err)
|
s.logf("derp: %s: %v", remoteAddr, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -147,8 +162,8 @@ func (s *Server) registerClient(c *sclient) {
|
||||||
c.logf("adding connection")
|
c.logf("adding connection")
|
||||||
} else {
|
} else {
|
||||||
s.clientsReplaced.Add(1)
|
s.clientsReplaced.Add(1)
|
||||||
old.nc.Close()
|
c.logf("adding connection, replacing %s", old.remoteAddr)
|
||||||
c.logf("adding connection, replacing %s", old.nc.RemoteAddr())
|
go old.nc.Close()
|
||||||
}
|
}
|
||||||
s.clients[c.key] = c
|
s.clients[c.key] = c
|
||||||
s.clientsEver[c.key] = true
|
s.clientsEver[c.key] = true
|
||||||
|
@ -171,7 +186,7 @@ func (s *Server) unregisterClient(c *sclient) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) accept(nc net.Conn, brw *bufio.ReadWriter) error {
|
func (s *Server) accept(nc Conn, brw *bufio.ReadWriter, remoteAddr string) error {
|
||||||
br, bw := brw.Reader, brw.Writer
|
br, bw := brw.Reader, brw.Writer
|
||||||
nc.SetDeadline(time.Now().Add(10 * time.Second))
|
nc.SetDeadline(time.Now().Add(10 * time.Second))
|
||||||
if err := s.sendServerKey(bw); err != nil {
|
if err := s.sendServerKey(bw); err != nil {
|
||||||
|
@ -203,7 +218,8 @@ func (s *Server) accept(nc net.Conn, brw *bufio.ReadWriter) error {
|
||||||
br: br,
|
br: br,
|
||||||
bw: bw,
|
bw: bw,
|
||||||
limiter: limiter,
|
limiter: limiter,
|
||||||
logf: logger.WithPrefix(s.logf, fmt.Sprintf("derp client %v/%x: ", nc.RemoteAddr(), clientKey)),
|
logf: logger.WithPrefix(s.logf, fmt.Sprintf("derp client %v/%x: ", remoteAddr, clientKey)),
|
||||||
|
remoteAddr: remoteAddr,
|
||||||
connectedAt: time.Now(),
|
connectedAt: time.Now(),
|
||||||
}
|
}
|
||||||
if clientInfo != nil {
|
if clientInfo != nil {
|
||||||
|
@ -294,6 +310,9 @@ func (c *sclient) handleFrameSendPacket(ctx context.Context, ft frameType, fl ui
|
||||||
}
|
}
|
||||||
|
|
||||||
dst.mu.Lock()
|
dst.mu.Lock()
|
||||||
|
if s.WriteTimeout != 0 {
|
||||||
|
dst.nc.SetWriteDeadline(time.Now().Add(s.WriteTimeout))
|
||||||
|
}
|
||||||
err = s.sendPacket(dst.bw, &dst.info, c.key, contents)
|
err = s.sendPacket(dst.bw, &dst.info, c.key, contents)
|
||||||
dst.mu.Unlock()
|
dst.mu.Unlock()
|
||||||
|
|
||||||
|
@ -308,7 +327,9 @@ func (c *sclient) handleFrameSendPacket(ctx context.Context, ft frameType, fl ui
|
||||||
}
|
}
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
}
|
}
|
||||||
return err
|
|
||||||
|
// Do not treat a send error as an error with this transmitting client.
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) sendClientKeepAlives(ctx context.Context, c *sclient) {
|
func (s *Server) sendClientKeepAlives(ctx context.Context, c *sclient) {
|
||||||
|
@ -450,11 +471,12 @@ func (s *Server) recvPacket(ctx context.Context, br *bufio.Reader, frameLen uint
|
||||||
// (The "s" prefix is to more explicitly distinguish it from Client in derp_client.go)
|
// (The "s" prefix is to more explicitly distinguish it from Client in derp_client.go)
|
||||||
type sclient struct {
|
type sclient struct {
|
||||||
s *Server
|
s *Server
|
||||||
nc net.Conn
|
nc Conn
|
||||||
key key.Public
|
key key.Public
|
||||||
info clientInfo
|
info clientInfo
|
||||||
logf logger.Logf
|
logf logger.Logf
|
||||||
limiter *rate.Limiter
|
limiter *rate.Limiter
|
||||||
|
remoteAddr string // usually ip:port from net.Conn.RemoteAddr().String()
|
||||||
connectedAt time.Time
|
connectedAt time.Time
|
||||||
|
|
||||||
keepAliveTimer *time.Timer
|
keepAliveTimer *time.Timer
|
||||||
|
@ -512,6 +534,9 @@ func (c *sclient) keepAliveLoop(ctx context.Context) error {
|
||||||
c.keepAliveTimer.Reset(keepAlive + jitter)
|
c.keepAliveTimer.Reset(keepAlive + jitter)
|
||||||
case <-c.keepAliveTimer.C:
|
case <-c.keepAliveTimer.C:
|
||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
|
if c.s.WriteTimeout != 0 {
|
||||||
|
c.nc.SetWriteDeadline(time.Now().Add(c.s.WriteTimeout))
|
||||||
|
}
|
||||||
err := writeFrame(c.bw, frameKeepAlive, nil)
|
err := writeFrame(c.bw, frameKeepAlive, nil)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
err = c.bw.Flush()
|
err = c.bw.Flush()
|
||||||
|
|
|
@ -6,15 +6,22 @@ package derp
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
|
"context"
|
||||||
crand "crypto/rand"
|
crand "crypto/rand"
|
||||||
|
"errors"
|
||||||
|
"expvar"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"tailscale.com/net/nettest"
|
||||||
"tailscale.com/types/key"
|
"tailscale.com/types/key"
|
||||||
)
|
)
|
||||||
|
|
||||||
func newPrivateKey(t *testing.T) (k key.Private) {
|
func newPrivateKey(t *testing.T) (k key.Private) {
|
||||||
|
t.Helper()
|
||||||
if _, err := crand.Read(k[:]); err != nil {
|
if _, err := crand.Read(k[:]); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -42,7 +49,7 @@ func TestSendRecv(t *testing.T) {
|
||||||
defer ln.Close()
|
defer ln.Close()
|
||||||
|
|
||||||
var clients []*Client
|
var clients []*Client
|
||||||
var connsOut []net.Conn
|
var connsOut []Conn
|
||||||
var recvChs []chan []byte
|
var recvChs []chan []byte
|
||||||
errCh := make(chan error, 3)
|
errCh := make(chan error, 3)
|
||||||
|
|
||||||
|
@ -60,7 +67,8 @@ func TestSendRecv(t *testing.T) {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
defer cin.Close()
|
defer cin.Close()
|
||||||
go s.Accept(cin, bufio.NewReadWriter(bufio.NewReader(cin), bufio.NewWriter(cin)))
|
brwServer := bufio.NewReadWriter(bufio.NewReader(cin), bufio.NewWriter(cin))
|
||||||
|
go s.Accept(cin, brwServer, fmt.Sprintf("test-client-%d", i))
|
||||||
|
|
||||||
key := clientPrivateKeys[i]
|
key := clientPrivateKeys[i]
|
||||||
brw := bufio.NewReadWriter(bufio.NewReader(cout), bufio.NewWriter(cout))
|
brw := bufio.NewReadWriter(bufio.NewReader(cout), bufio.NewWriter(cout))
|
||||||
|
@ -170,3 +178,168 @@ func TestSendRecv(t *testing.T) {
|
||||||
t.Logf("passed")
|
t.Logf("passed")
|
||||||
s.Close()
|
s.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSendFreeze(t *testing.T) {
|
||||||
|
serverPrivateKey := newPrivateKey(t)
|
||||||
|
s := NewServer(serverPrivateKey, t.Logf)
|
||||||
|
defer s.Close()
|
||||||
|
s.WriteTimeout = 100 * time.Millisecond
|
||||||
|
|
||||||
|
// We send two streams of messages:
|
||||||
|
//
|
||||||
|
// alice --> bob
|
||||||
|
// alice --> cathy
|
||||||
|
//
|
||||||
|
// Then cathy stops processing messsages.
|
||||||
|
// That should not interfere with alice talking to bob.
|
||||||
|
|
||||||
|
newClient := func(name string, k key.Private) (c *Client, clientConn nettest.Conn) {
|
||||||
|
t.Helper()
|
||||||
|
c1, c2 := nettest.NewConn(name, 1024)
|
||||||
|
go s.Accept(c1, bufio.NewReadWriter(bufio.NewReader(c1), bufio.NewWriter(c1)), name)
|
||||||
|
|
||||||
|
brw := bufio.NewReadWriter(bufio.NewReader(c2), bufio.NewWriter(c2))
|
||||||
|
c, err := NewClient(k, c2, brw, t.Logf)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
return c, c2
|
||||||
|
}
|
||||||
|
|
||||||
|
aliceKey := newPrivateKey(t)
|
||||||
|
aliceClient, aliceConn := newClient("alice", aliceKey)
|
||||||
|
|
||||||
|
bobKey := newPrivateKey(t)
|
||||||
|
bobClient, bobConn := newClient("bob", bobKey)
|
||||||
|
|
||||||
|
cathyKey := newPrivateKey(t)
|
||||||
|
cathyClient, cathyConn := newClient("cathy", cathyKey)
|
||||||
|
|
||||||
|
var aliceCount, bobCount, cathyCount expvar.Int
|
||||||
|
|
||||||
|
errCh := make(chan error, 4)
|
||||||
|
recvAndCount := func(count *expvar.Int, name string, client *Client) {
|
||||||
|
for {
|
||||||
|
b := make([]byte, 1<<9)
|
||||||
|
m, err := client.Recv(b)
|
||||||
|
if err != nil {
|
||||||
|
errCh <- fmt.Errorf("%s: %w", name, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
switch m := m.(type) {
|
||||||
|
default:
|
||||||
|
errCh <- fmt.Errorf("%s: unexpected message type %T", name, m)
|
||||||
|
return
|
||||||
|
case ReceivedPacket:
|
||||||
|
if m.Source.IsZero() {
|
||||||
|
errCh <- fmt.Errorf("%s: zero Source address in ReceivedPacket", name)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
count.Add(1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
go recvAndCount(&aliceCount, "alice", aliceClient)
|
||||||
|
go recvAndCount(&bobCount, "bob", bobClient)
|
||||||
|
go recvAndCount(&cathyCount, "cathy", cathyClient)
|
||||||
|
|
||||||
|
var cancel func()
|
||||||
|
go func() {
|
||||||
|
t := time.NewTicker(2 * time.Millisecond)
|
||||||
|
defer t.Stop()
|
||||||
|
var ctx context.Context
|
||||||
|
ctx, cancel = context.WithCancel(context.Background())
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-t.C:
|
||||||
|
case <-ctx.Done():
|
||||||
|
errCh <- nil
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
msg1 := []byte("hello alice->bob\n")
|
||||||
|
if err := aliceClient.Send(bobKey.Public(), msg1); err != nil {
|
||||||
|
errCh <- fmt.Errorf("alice send to bob: %w", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
msg2 := []byte("hello alice->cathy\n")
|
||||||
|
|
||||||
|
// TODO: an error is expected here.
|
||||||
|
// We ignore it, maybe we should log it somehow?
|
||||||
|
aliceClient.Send(cathyKey.Public(), msg2)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
var countSnapshot [3]int64
|
||||||
|
loadCounts := func() (adiff, bdiff, cdiff int64) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
atotal := aliceCount.Value()
|
||||||
|
btotal := bobCount.Value()
|
||||||
|
ctotal := cathyCount.Value()
|
||||||
|
|
||||||
|
adiff = atotal - countSnapshot[0]
|
||||||
|
bdiff = btotal - countSnapshot[1]
|
||||||
|
cdiff = ctotal - countSnapshot[2]
|
||||||
|
|
||||||
|
countSnapshot[0] = atotal
|
||||||
|
countSnapshot[1] = btotal
|
||||||
|
countSnapshot[2] = ctotal
|
||||||
|
|
||||||
|
t.Logf("count diffs: alice=%d, bob=%d, cathy=%d", adiff, bdiff, cdiff)
|
||||||
|
return adiff, bdiff, cdiff
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("initial send", func(t *testing.T) {
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
a, b, c := loadCounts()
|
||||||
|
if a != 0 {
|
||||||
|
t.Errorf("alice diff=%d, want 0", a)
|
||||||
|
}
|
||||||
|
if b == 0 {
|
||||||
|
t.Errorf("no bob diff, want positive value")
|
||||||
|
}
|
||||||
|
if c == 0 {
|
||||||
|
t.Errorf("no cathy diff, want positive value")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("block cathy", func(t *testing.T) {
|
||||||
|
// Block cathy. Now the cathyConn buffer will fill up quickly,
|
||||||
|
// and the derp server will back up.
|
||||||
|
cathyConn.SetReadBlock(true)
|
||||||
|
time.Sleep(2 * s.WriteTimeout)
|
||||||
|
|
||||||
|
a, b, _ := loadCounts()
|
||||||
|
if a != 0 {
|
||||||
|
t.Errorf("alice diff=%d, want 0", a)
|
||||||
|
}
|
||||||
|
if b == 0 {
|
||||||
|
t.Errorf("no bob diff, want positive value")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now wait a little longer, and ensure packets still flow to bob
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
if _, b, _ := loadCounts(); b == 0 {
|
||||||
|
t.Errorf("connection alice->bob frozen by alice->cathy")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Cleanup, make sure we process all errors.
|
||||||
|
t.Logf("TEST COMPLETE, cancelling sender")
|
||||||
|
cancel()
|
||||||
|
t.Logf("closing connections")
|
||||||
|
aliceConn.Close()
|
||||||
|
bobConn.Close()
|
||||||
|
cathyConn.Close()
|
||||||
|
|
||||||
|
for i := 0; i < cap(errCh); i++ {
|
||||||
|
err := <-errCh
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, io.EOF) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -32,6 +32,6 @@ func Handler(s *derp.Server) http.Handler {
|
||||||
http.Error(w, "HTTP does not support general TCP support", 500)
|
http.Error(w, "HTTP does not support general TCP support", 500)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
s.Accept(netConn, conn)
|
s.Accept(netConn, conn, netConn.RemoteAddr().String())
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,86 @@
|
||||||
|
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package nettest
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Conn is a bi-directional in-memory stream that looks like a TCP net.Conn.
|
||||||
|
type Conn interface {
|
||||||
|
io.Reader
|
||||||
|
io.Writer
|
||||||
|
io.Closer
|
||||||
|
|
||||||
|
// The *Deadline methods follow the semantics of net.Conn.
|
||||||
|
|
||||||
|
SetDeadline(t time.Time) error
|
||||||
|
SetReadDeadline(t time.Time) error
|
||||||
|
SetWriteDeadline(t time.Time) error
|
||||||
|
|
||||||
|
// SetReadBlock blocks or unblocks the Read method of this Conn.
|
||||||
|
// It reports an error if the existing value matches the new value,
|
||||||
|
// or if the Conn has been Closed.
|
||||||
|
SetReadBlock(bool) error
|
||||||
|
|
||||||
|
// SetWriteBlock blocks or unblocks the Write method of this Conn.
|
||||||
|
// It reports an error if the existing value matches the new value,
|
||||||
|
// or if the Conn has been Closed.
|
||||||
|
SetWriteBlock(bool) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewConn creates a pair of Conns that are wired together by pipes.
|
||||||
|
func NewConn(name string, maxBuf int) (Conn, Conn) {
|
||||||
|
r := NewPipe(name+"|0", maxBuf)
|
||||||
|
w := NewPipe(name+"|1", maxBuf)
|
||||||
|
|
||||||
|
return &connHalf{r: r, w: w}, &connHalf{r: w, w: r}
|
||||||
|
}
|
||||||
|
|
||||||
|
type connHalf struct {
|
||||||
|
r, w *Pipe
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *connHalf) Read(b []byte) (n int, err error) {
|
||||||
|
return c.r.Read(b)
|
||||||
|
}
|
||||||
|
func (c *connHalf) Write(b []byte) (n int, err error) {
|
||||||
|
return c.w.Write(b)
|
||||||
|
}
|
||||||
|
func (c *connHalf) Close() error {
|
||||||
|
err1 := c.r.Close()
|
||||||
|
err2 := c.w.Close()
|
||||||
|
if err1 != nil {
|
||||||
|
return err1
|
||||||
|
}
|
||||||
|
return err2
|
||||||
|
}
|
||||||
|
func (c *connHalf) SetDeadline(t time.Time) error {
|
||||||
|
err1 := c.SetReadDeadline(t)
|
||||||
|
err2 := c.SetWriteDeadline(t)
|
||||||
|
if err1 != nil {
|
||||||
|
return err1
|
||||||
|
}
|
||||||
|
return err2
|
||||||
|
}
|
||||||
|
func (c *connHalf) SetReadDeadline(t time.Time) error {
|
||||||
|
return c.r.SetReadDeadline(t)
|
||||||
|
}
|
||||||
|
func (c *connHalf) SetWriteDeadline(t time.Time) error {
|
||||||
|
return c.w.SetWriteDeadline(t)
|
||||||
|
}
|
||||||
|
func (c *connHalf) SetReadBlock(b bool) error {
|
||||||
|
if b {
|
||||||
|
return c.r.Block()
|
||||||
|
}
|
||||||
|
return c.r.Unblock()
|
||||||
|
}
|
||||||
|
func (c *connHalf) SetWriteBlock(b bool) error {
|
||||||
|
if b {
|
||||||
|
return c.w.Block()
|
||||||
|
}
|
||||||
|
return c.w.Unblock()
|
||||||
|
}
|
|
@ -0,0 +1,261 @@
|
||||||
|
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package nettest
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"log"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const debugPipe = false
|
||||||
|
|
||||||
|
// Pipe implements an in-memory FIFO with timeouts.
|
||||||
|
type Pipe struct {
|
||||||
|
name string
|
||||||
|
maxBuf int
|
||||||
|
rCh chan struct{}
|
||||||
|
wCh chan struct{}
|
||||||
|
|
||||||
|
mu sync.Mutex
|
||||||
|
closed bool
|
||||||
|
blocked bool
|
||||||
|
buf []byte
|
||||||
|
readTimeout time.Time
|
||||||
|
writeTimeout time.Time
|
||||||
|
cancelReadTimer func()
|
||||||
|
cancelWriteTimer func()
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewPipe creates a Pipe with a buffer size fixed at maxBuf.
|
||||||
|
func NewPipe(name string, maxBuf int) *Pipe {
|
||||||
|
return &Pipe{
|
||||||
|
name: name,
|
||||||
|
maxBuf: maxBuf,
|
||||||
|
rCh: make(chan struct{}, 1),
|
||||||
|
wCh: make(chan struct{}, 1),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrTimeout = errors.New("timeout")
|
||||||
|
ErrReadTimeout = fmt.Errorf("read %w", ErrTimeout)
|
||||||
|
ErrWriteTimeout = fmt.Errorf("write %w", ErrTimeout)
|
||||||
|
)
|
||||||
|
|
||||||
|
// Read implements io.Reader.
|
||||||
|
func (p *Pipe) Read(b []byte) (n int, err error) {
|
||||||
|
if debugPipe {
|
||||||
|
orig := b
|
||||||
|
defer func() {
|
||||||
|
log.Printf("Pipe(%q).Read( %q) n=%d, err=%v", p.name, string(orig[:n]), n, err)
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
for {
|
||||||
|
p.mu.Lock()
|
||||||
|
closed := p.closed
|
||||||
|
timedout := !p.readTimeout.IsZero() && time.Now().After(p.readTimeout)
|
||||||
|
blocked := p.blocked
|
||||||
|
if !closed && !timedout && len(p.buf) > 0 {
|
||||||
|
n2 := copy(b, p.buf)
|
||||||
|
p.buf = p.buf[n2:]
|
||||||
|
b = b[n2:]
|
||||||
|
n += n2
|
||||||
|
}
|
||||||
|
p.mu.Unlock()
|
||||||
|
|
||||||
|
if closed {
|
||||||
|
return 0, fmt.Errorf("nettest.Pipe(%q): closed: %w", p.name, io.EOF)
|
||||||
|
}
|
||||||
|
if timedout {
|
||||||
|
return 0, fmt.Errorf("nettest.Pipe(%q): %w", p.name, ErrReadTimeout)
|
||||||
|
}
|
||||||
|
if blocked {
|
||||||
|
<-p.rCh
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if n > 0 {
|
||||||
|
p.signalWrite()
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
<-p.rCh
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write implements io.Writer.
|
||||||
|
func (p *Pipe) Write(b []byte) (n int, err error) {
|
||||||
|
if debugPipe {
|
||||||
|
orig := b
|
||||||
|
defer func() {
|
||||||
|
log.Printf("Pipe(%q).Write(%q) n=%d, err=%v", p.name, string(orig), n, err)
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
for {
|
||||||
|
p.mu.Lock()
|
||||||
|
closed := p.closed
|
||||||
|
timedout := !p.writeTimeout.IsZero() && time.Now().After(p.writeTimeout)
|
||||||
|
blocked := p.blocked
|
||||||
|
if !closed && !timedout {
|
||||||
|
n2 := len(b)
|
||||||
|
if limit := p.maxBuf - len(p.buf); limit < n2 {
|
||||||
|
n2 = limit
|
||||||
|
}
|
||||||
|
p.buf = append(p.buf, b[:n2]...)
|
||||||
|
b = b[n2:]
|
||||||
|
n += n2
|
||||||
|
}
|
||||||
|
p.mu.Unlock()
|
||||||
|
|
||||||
|
if closed {
|
||||||
|
return n, fmt.Errorf("nettest.Pipe(%q): closed: %w", p.name, io.EOF)
|
||||||
|
}
|
||||||
|
if timedout {
|
||||||
|
return n, fmt.Errorf("nettest.Pipe(%q): %w", p.name, ErrWriteTimeout)
|
||||||
|
}
|
||||||
|
if blocked {
|
||||||
|
<-p.wCh
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if n > 0 {
|
||||||
|
p.signalRead()
|
||||||
|
}
|
||||||
|
if len(b) == 0 {
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
<-p.wCh
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close implements io.Closer.
|
||||||
|
func (p *Pipe) Close() error {
|
||||||
|
p.mu.Lock()
|
||||||
|
closed := p.closed
|
||||||
|
p.closed = true
|
||||||
|
if p.cancelWriteTimer != nil {
|
||||||
|
p.cancelWriteTimer()
|
||||||
|
p.cancelWriteTimer = nil
|
||||||
|
}
|
||||||
|
if p.cancelReadTimer != nil {
|
||||||
|
p.cancelReadTimer()
|
||||||
|
p.cancelReadTimer = nil
|
||||||
|
}
|
||||||
|
p.mu.Unlock()
|
||||||
|
|
||||||
|
if closed {
|
||||||
|
return fmt.Errorf("nettest.Pipe(%q).Close: already closed", p.name)
|
||||||
|
}
|
||||||
|
|
||||||
|
p.signalRead()
|
||||||
|
p.signalWrite()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetReadDeadline sets the deadline for future Read calls.
|
||||||
|
func (p *Pipe) SetReadDeadline(t time.Time) error {
|
||||||
|
p.mu.Lock()
|
||||||
|
p.readTimeout = t
|
||||||
|
if p.cancelReadTimer != nil {
|
||||||
|
p.cancelReadTimer()
|
||||||
|
p.cancelReadTimer = nil
|
||||||
|
}
|
||||||
|
if d := time.Until(t); !t.IsZero() && d > 0 {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
p.cancelReadTimer = cancel
|
||||||
|
go func() {
|
||||||
|
t := time.NewTimer(d)
|
||||||
|
defer t.Stop()
|
||||||
|
select {
|
||||||
|
case <-t.C:
|
||||||
|
p.signalRead()
|
||||||
|
case <-ctx.Done():
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
p.mu.Unlock()
|
||||||
|
|
||||||
|
p.signalRead()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetWriteDeadline sets the deadline for future Write calls.
|
||||||
|
func (p *Pipe) SetWriteDeadline(t time.Time) error {
|
||||||
|
p.mu.Lock()
|
||||||
|
p.writeTimeout = t
|
||||||
|
if p.cancelWriteTimer != nil {
|
||||||
|
p.cancelWriteTimer()
|
||||||
|
p.cancelWriteTimer = nil
|
||||||
|
}
|
||||||
|
if d := time.Until(t); !t.IsZero() && d > 0 {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
p.cancelWriteTimer = cancel
|
||||||
|
go func() {
|
||||||
|
t := time.NewTimer(d)
|
||||||
|
defer t.Stop()
|
||||||
|
select {
|
||||||
|
case <-t.C:
|
||||||
|
p.signalWrite()
|
||||||
|
case <-ctx.Done():
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
p.mu.Unlock()
|
||||||
|
|
||||||
|
p.signalWrite()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Pipe) Block() error {
|
||||||
|
p.mu.Lock()
|
||||||
|
closed := p.closed
|
||||||
|
blocked := p.blocked
|
||||||
|
p.blocked = true
|
||||||
|
p.mu.Unlock()
|
||||||
|
|
||||||
|
if closed {
|
||||||
|
return fmt.Errorf("nettest.Pipe(%q).Block: closed", p.name)
|
||||||
|
}
|
||||||
|
if blocked {
|
||||||
|
return fmt.Errorf("nettest.Pipe(%q).Block: already blocked", p.name)
|
||||||
|
}
|
||||||
|
p.signalRead()
|
||||||
|
p.signalWrite()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Pipe) Unblock() error {
|
||||||
|
p.mu.Lock()
|
||||||
|
closed := p.closed
|
||||||
|
blocked := p.blocked
|
||||||
|
p.blocked = false
|
||||||
|
p.mu.Unlock()
|
||||||
|
|
||||||
|
if closed {
|
||||||
|
return fmt.Errorf("nettest.Pipe(%q).Block: closed", p.name)
|
||||||
|
}
|
||||||
|
if !blocked {
|
||||||
|
return fmt.Errorf("nettest.Pipe(%q).Block: already unblocked", p.name)
|
||||||
|
}
|
||||||
|
p.signalRead()
|
||||||
|
p.signalWrite()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Pipe) signalRead() {
|
||||||
|
select {
|
||||||
|
case p.rCh <- struct{}{}:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Pipe) signalWrite() {
|
||||||
|
select {
|
||||||
|
case p.wCh <- struct{}{}:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,116 @@
|
||||||
|
package nettest
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPipeHello(t *testing.T) {
|
||||||
|
p := NewPipe("p1", 1<<16)
|
||||||
|
msg := "Hello, World!"
|
||||||
|
if n, err := p.Write([]byte(msg)); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
} else if n != len(msg) {
|
||||||
|
t.Errorf("p.Write(%q) n=%d, want %d", msg, n, len(msg))
|
||||||
|
}
|
||||||
|
b := make([]byte, len(msg))
|
||||||
|
if n, err := p.Read(b); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
} else if n != len(b) {
|
||||||
|
t.Errorf("p.Read(%q) n=%d, want %d", string(b[:n]), n, len(b))
|
||||||
|
}
|
||||||
|
if got := string(b); got != msg {
|
||||||
|
t.Errorf("p.Read: %q, want %q", got, msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPipeTimeout(t *testing.T) {
|
||||||
|
t.Run("write", func(t *testing.T) {
|
||||||
|
p := NewPipe("p1", 1<<16)
|
||||||
|
p.SetWriteDeadline(time.Now().Add(-1 * time.Second))
|
||||||
|
n, err := p.Write([]byte{'h'})
|
||||||
|
if err == nil || !errors.Is(err, ErrWriteTimeout) || !errors.Is(err, ErrTimeout) {
|
||||||
|
t.Errorf("missing write timeout got err: %v", err)
|
||||||
|
}
|
||||||
|
if n != 0 {
|
||||||
|
t.Errorf("n=%d on timeout", n)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
t.Run("read", func(t *testing.T) {
|
||||||
|
p := NewPipe("p1", 1<<16)
|
||||||
|
p.Write([]byte{'h'})
|
||||||
|
|
||||||
|
p.SetReadDeadline(time.Now().Add(-1 * time.Second))
|
||||||
|
b := make([]byte, 1)
|
||||||
|
n, err := p.Read(b)
|
||||||
|
if err == nil || !errors.Is(err, ErrReadTimeout) || !errors.Is(err, ErrTimeout) {
|
||||||
|
t.Errorf("missing read timeout got err: %v", err)
|
||||||
|
}
|
||||||
|
if n != 0 {
|
||||||
|
t.Errorf("n=%d on timeout", n)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
t.Run("block-write", func(t *testing.T) {
|
||||||
|
p := NewPipe("p1", 1<<16)
|
||||||
|
p.SetWriteDeadline(time.Now().Add(10 * time.Millisecond))
|
||||||
|
if _, err := p.Write([]byte{'h'}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := p.Block(); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if _, err := p.Write([]byte{'h'}); err == nil || !errors.Is(err, ErrWriteTimeout) {
|
||||||
|
t.Fatalf("want write timeout got: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
t.Run("block-read", func(t *testing.T) {
|
||||||
|
p := NewPipe("p1", 1<<16)
|
||||||
|
p.Write([]byte{'h', 'i'})
|
||||||
|
p.SetReadDeadline(time.Now().Add(10 * time.Millisecond))
|
||||||
|
b := make([]byte, 1)
|
||||||
|
if _, err := p.Read(b); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := p.Block(); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if _, err := p.Read(b); err == nil || !errors.Is(err, ErrReadTimeout) {
|
||||||
|
t.Fatalf("want read timeout got: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLimit(t *testing.T) {
|
||||||
|
p := NewPipe("p1", 1)
|
||||||
|
errCh := make(chan error)
|
||||||
|
go func() {
|
||||||
|
n, err := p.Write([]byte{'a', 'b', 'c'})
|
||||||
|
if err != nil {
|
||||||
|
errCh <- err
|
||||||
|
} else if n != 3 {
|
||||||
|
errCh <- fmt.Errorf("p.Write n=%d, want 3", n)
|
||||||
|
} else {
|
||||||
|
errCh <- nil
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
b := make([]byte, 3)
|
||||||
|
|
||||||
|
if n, err := p.Read(b); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
} else if n != 1 {
|
||||||
|
t.Errorf("Read(%q): n=%d want 1", string(b), n)
|
||||||
|
}
|
||||||
|
if n, err := p.Read(b); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
} else if n != 1 {
|
||||||
|
t.Errorf("Read(%q): n=%d want 1", string(b), n)
|
||||||
|
}
|
||||||
|
if n, err := p.Read(b); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
} else if n != 1 {
|
||||||
|
t.Errorf("Read(%q): n=%d want 1", string(b), n)
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue