Compare commits
6 Commits
main
...
crawshaw/d
Author | SHA1 | Date |
---|---|---|
![]() |
efbb4f2b66 | |
![]() |
9cd899f83e | |
![]() |
656b6f3fd3 | |
![]() |
ef5b09563f | |
![]() |
4ad608ac92 | |
![]() |
7221b8eff5 |
|
@ -119,6 +119,7 @@ func main() {
|
|||
letsEncrypt := tsweb.IsProd443(*addr)
|
||||
|
||||
s := derp.NewServer(key.Private(cfg.PrivateKey), log.Printf)
|
||||
s.WriteTimeout = 2 * time.Second
|
||||
if *mbps != 0 {
|
||||
s.BytesPerSecond = (*mbps << 20) / 8
|
||||
}
|
||||
|
|
|
@ -11,7 +11,6 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
|
@ -26,7 +25,7 @@ type Client struct {
|
|||
publicKey key.Public // of privateKey
|
||||
protoVersion int // min of server+client
|
||||
logf logger.Logf
|
||||
nc net.Conn
|
||||
nc Conn
|
||||
br *bufio.Reader
|
||||
|
||||
wmu sync.Mutex // hold while writing to bw
|
||||
|
@ -34,7 +33,7 @@ type Client struct {
|
|||
readErr error // sticky read error
|
||||
}
|
||||
|
||||
func NewClient(privateKey key.Private, nc net.Conn, brw *bufio.ReadWriter, logf logger.Logf) (*Client, error) {
|
||||
func NewClient(privateKey key.Private, nc Conn, brw *bufio.ReadWriter, logf logger.Logf) (*Client, error) {
|
||||
c := &Client{
|
||||
privateKey: privateKey,
|
||||
publicKey: privateKey.Public(),
|
||||
|
@ -138,7 +137,7 @@ func (c *Client) Send(dstKey key.Public, pkt []byte) error { return c.send(dstKe
|
|||
func (c *Client) send(dstKey key.Public, pkt []byte) (ret error) {
|
||||
defer func() {
|
||||
if ret != nil {
|
||||
ret = fmt.Errorf("derp.Send: %v", ret)
|
||||
ret = fmt.Errorf("derp.Send: %w", ret)
|
||||
}
|
||||
}()
|
||||
|
||||
|
@ -215,7 +214,7 @@ func (c *Client) Recv(b []byte) (m ReceivedMessage, err error) {
|
|||
}
|
||||
defer func() {
|
||||
if err != nil {
|
||||
err = fmt.Errorf("derp.Recv: %v", err)
|
||||
err = fmt.Errorf("derp.Recv: %w", err)
|
||||
c.readErr = err
|
||||
}
|
||||
}()
|
||||
|
|
|
@ -18,7 +18,6 @@ import (
|
|||
"io"
|
||||
"io/ioutil"
|
||||
"math/big"
|
||||
"net"
|
||||
"os"
|
||||
"strconv"
|
||||
"sync"
|
||||
|
@ -39,6 +38,10 @@ type Server struct {
|
|||
// second to cap per-client reads at.
|
||||
BytesPerSecond int
|
||||
|
||||
// WriteTimeout, if non-zero, specifies how long to wait
|
||||
// before failing when writing to a client.
|
||||
WriteTimeout time.Duration
|
||||
|
||||
privateKey key.Private
|
||||
publicKey key.Public
|
||||
logf logger.Logf
|
||||
|
@ -57,11 +60,23 @@ type Server struct {
|
|||
|
||||
mu sync.Mutex
|
||||
closed bool
|
||||
netConns map[net.Conn]chan struct{} // chan is closed when conn closes
|
||||
netConns map[Conn]chan struct{} // chan is closed when conn closes
|
||||
clients map[key.Public]*sclient
|
||||
clientsEver map[key.Public]bool // never deleted from, for stats; fine for now
|
||||
}
|
||||
|
||||
// Conn is the subset of the underlying net.Conn the DERP Server needs.
|
||||
// It is a defined type so that non-net connections can be used.
|
||||
type Conn interface {
|
||||
io.Closer
|
||||
|
||||
// The *Deadline methods follow the semantics of net.Conn.
|
||||
|
||||
SetDeadline(time.Time) error
|
||||
SetReadDeadline(time.Time) error
|
||||
SetWriteDeadline(time.Time) error
|
||||
}
|
||||
|
||||
// NewServer returns a new DERP server. It doesn't listen on its own.
|
||||
// Connections are given to it via Server.Accept.
|
||||
func NewServer(privateKey key.Private, logf logger.Logf) *Server {
|
||||
|
@ -71,7 +86,7 @@ func NewServer(privateKey key.Private, logf logger.Logf) *Server {
|
|||
logf: logf,
|
||||
clients: make(map[key.Public]*sclient),
|
||||
clientsEver: make(map[key.Public]bool),
|
||||
netConns: make(map[net.Conn]chan struct{}),
|
||||
netConns: make(map[Conn]chan struct{}),
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
@ -115,7 +130,7 @@ func (s *Server) isClosed() bool {
|
|||
// on its own.
|
||||
//
|
||||
// Accept closes nc.
|
||||
func (s *Server) Accept(nc net.Conn, brw *bufio.ReadWriter) {
|
||||
func (s *Server) Accept(nc Conn, brw *bufio.ReadWriter, remoteAddr string) {
|
||||
closed := make(chan struct{})
|
||||
|
||||
s.accepts.Add(1)
|
||||
|
@ -132,8 +147,8 @@ func (s *Server) Accept(nc net.Conn, brw *bufio.ReadWriter) {
|
|||
s.mu.Unlock()
|
||||
}()
|
||||
|
||||
if err := s.accept(nc, brw); err != nil && !s.isClosed() {
|
||||
s.logf("derp: %s: %v", nc.RemoteAddr(), err)
|
||||
if err := s.accept(nc, brw, remoteAddr); err != nil && !s.isClosed() {
|
||||
s.logf("derp: %s: %v", remoteAddr, err)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -147,8 +162,8 @@ func (s *Server) registerClient(c *sclient) {
|
|||
c.logf("adding connection")
|
||||
} else {
|
||||
s.clientsReplaced.Add(1)
|
||||
old.nc.Close()
|
||||
c.logf("adding connection, replacing %s", old.nc.RemoteAddr())
|
||||
c.logf("adding connection, replacing %s", old.remoteAddr)
|
||||
go old.nc.Close()
|
||||
}
|
||||
s.clients[c.key] = c
|
||||
s.clientsEver[c.key] = true
|
||||
|
@ -171,7 +186,7 @@ func (s *Server) unregisterClient(c *sclient) {
|
|||
}
|
||||
}
|
||||
|
||||
func (s *Server) accept(nc net.Conn, brw *bufio.ReadWriter) error {
|
||||
func (s *Server) accept(nc Conn, brw *bufio.ReadWriter, remoteAddr string) error {
|
||||
br, bw := brw.Reader, brw.Writer
|
||||
nc.SetDeadline(time.Now().Add(10 * time.Second))
|
||||
if err := s.sendServerKey(bw); err != nil {
|
||||
|
@ -203,7 +218,8 @@ func (s *Server) accept(nc net.Conn, brw *bufio.ReadWriter) error {
|
|||
br: br,
|
||||
bw: bw,
|
||||
limiter: limiter,
|
||||
logf: logger.WithPrefix(s.logf, fmt.Sprintf("derp client %v/%x: ", nc.RemoteAddr(), clientKey)),
|
||||
logf: logger.WithPrefix(s.logf, fmt.Sprintf("derp client %v/%x: ", remoteAddr, clientKey)),
|
||||
remoteAddr: remoteAddr,
|
||||
connectedAt: time.Now(),
|
||||
}
|
||||
if clientInfo != nil {
|
||||
|
@ -294,6 +310,9 @@ func (c *sclient) handleFrameSendPacket(ctx context.Context, ft frameType, fl ui
|
|||
}
|
||||
|
||||
dst.mu.Lock()
|
||||
if s.WriteTimeout != 0 {
|
||||
dst.nc.SetWriteDeadline(time.Now().Add(s.WriteTimeout))
|
||||
}
|
||||
err = s.sendPacket(dst.bw, &dst.info, c.key, contents)
|
||||
dst.mu.Unlock()
|
||||
|
||||
|
@ -308,7 +327,9 @@ func (c *sclient) handleFrameSendPacket(ctx context.Context, ft frameType, fl ui
|
|||
}
|
||||
s.mu.Unlock()
|
||||
}
|
||||
return err
|
||||
|
||||
// Do not treat a send error as an error with this transmitting client.
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) sendClientKeepAlives(ctx context.Context, c *sclient) {
|
||||
|
@ -450,11 +471,12 @@ func (s *Server) recvPacket(ctx context.Context, br *bufio.Reader, frameLen uint
|
|||
// (The "s" prefix is to more explicitly distinguish it from Client in derp_client.go)
|
||||
type sclient struct {
|
||||
s *Server
|
||||
nc net.Conn
|
||||
nc Conn
|
||||
key key.Public
|
||||
info clientInfo
|
||||
logf logger.Logf
|
||||
limiter *rate.Limiter
|
||||
remoteAddr string // usually ip:port from net.Conn.RemoteAddr().String()
|
||||
connectedAt time.Time
|
||||
|
||||
keepAliveTimer *time.Timer
|
||||
|
@ -512,6 +534,9 @@ func (c *sclient) keepAliveLoop(ctx context.Context) error {
|
|||
c.keepAliveTimer.Reset(keepAlive + jitter)
|
||||
case <-c.keepAliveTimer.C:
|
||||
c.mu.Lock()
|
||||
if c.s.WriteTimeout != 0 {
|
||||
c.nc.SetWriteDeadline(time.Now().Add(c.s.WriteTimeout))
|
||||
}
|
||||
err := writeFrame(c.bw, frameKeepAlive, nil)
|
||||
if err == nil {
|
||||
err = c.bw.Flush()
|
||||
|
|
|
@ -6,15 +6,22 @@ package derp
|
|||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
crand "crypto/rand"
|
||||
"errors"
|
||||
"expvar"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"tailscale.com/net/nettest"
|
||||
"tailscale.com/types/key"
|
||||
)
|
||||
|
||||
func newPrivateKey(t *testing.T) (k key.Private) {
|
||||
t.Helper()
|
||||
if _, err := crand.Read(k[:]); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -42,7 +49,7 @@ func TestSendRecv(t *testing.T) {
|
|||
defer ln.Close()
|
||||
|
||||
var clients []*Client
|
||||
var connsOut []net.Conn
|
||||
var connsOut []Conn
|
||||
var recvChs []chan []byte
|
||||
errCh := make(chan error, 3)
|
||||
|
||||
|
@ -60,7 +67,8 @@ func TestSendRecv(t *testing.T) {
|
|||
t.Fatal(err)
|
||||
}
|
||||
defer cin.Close()
|
||||
go s.Accept(cin, bufio.NewReadWriter(bufio.NewReader(cin), bufio.NewWriter(cin)))
|
||||
brwServer := bufio.NewReadWriter(bufio.NewReader(cin), bufio.NewWriter(cin))
|
||||
go s.Accept(cin, brwServer, fmt.Sprintf("test-client-%d", i))
|
||||
|
||||
key := clientPrivateKeys[i]
|
||||
brw := bufio.NewReadWriter(bufio.NewReader(cout), bufio.NewWriter(cout))
|
||||
|
@ -170,3 +178,168 @@ func TestSendRecv(t *testing.T) {
|
|||
t.Logf("passed")
|
||||
s.Close()
|
||||
}
|
||||
|
||||
func TestSendFreeze(t *testing.T) {
|
||||
serverPrivateKey := newPrivateKey(t)
|
||||
s := NewServer(serverPrivateKey, t.Logf)
|
||||
defer s.Close()
|
||||
s.WriteTimeout = 100 * time.Millisecond
|
||||
|
||||
// We send two streams of messages:
|
||||
//
|
||||
// alice --> bob
|
||||
// alice --> cathy
|
||||
//
|
||||
// Then cathy stops processing messsages.
|
||||
// That should not interfere with alice talking to bob.
|
||||
|
||||
newClient := func(name string, k key.Private) (c *Client, clientConn nettest.Conn) {
|
||||
t.Helper()
|
||||
c1, c2 := nettest.NewConn(name, 1024)
|
||||
go s.Accept(c1, bufio.NewReadWriter(bufio.NewReader(c1), bufio.NewWriter(c1)), name)
|
||||
|
||||
brw := bufio.NewReadWriter(bufio.NewReader(c2), bufio.NewWriter(c2))
|
||||
c, err := NewClient(k, c2, brw, t.Logf)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return c, c2
|
||||
}
|
||||
|
||||
aliceKey := newPrivateKey(t)
|
||||
aliceClient, aliceConn := newClient("alice", aliceKey)
|
||||
|
||||
bobKey := newPrivateKey(t)
|
||||
bobClient, bobConn := newClient("bob", bobKey)
|
||||
|
||||
cathyKey := newPrivateKey(t)
|
||||
cathyClient, cathyConn := newClient("cathy", cathyKey)
|
||||
|
||||
var aliceCount, bobCount, cathyCount expvar.Int
|
||||
|
||||
errCh := make(chan error, 4)
|
||||
recvAndCount := func(count *expvar.Int, name string, client *Client) {
|
||||
for {
|
||||
b := make([]byte, 1<<9)
|
||||
m, err := client.Recv(b)
|
||||
if err != nil {
|
||||
errCh <- fmt.Errorf("%s: %w", name, err)
|
||||
return
|
||||
}
|
||||
switch m := m.(type) {
|
||||
default:
|
||||
errCh <- fmt.Errorf("%s: unexpected message type %T", name, m)
|
||||
return
|
||||
case ReceivedPacket:
|
||||
if m.Source.IsZero() {
|
||||
errCh <- fmt.Errorf("%s: zero Source address in ReceivedPacket", name)
|
||||
return
|
||||
}
|
||||
count.Add(1)
|
||||
}
|
||||
}
|
||||
}
|
||||
go recvAndCount(&aliceCount, "alice", aliceClient)
|
||||
go recvAndCount(&bobCount, "bob", bobClient)
|
||||
go recvAndCount(&cathyCount, "cathy", cathyClient)
|
||||
|
||||
var cancel func()
|
||||
go func() {
|
||||
t := time.NewTicker(2 * time.Millisecond)
|
||||
defer t.Stop()
|
||||
var ctx context.Context
|
||||
ctx, cancel = context.WithCancel(context.Background())
|
||||
for {
|
||||
select {
|
||||
case <-t.C:
|
||||
case <-ctx.Done():
|
||||
errCh <- nil
|
||||
return
|
||||
}
|
||||
|
||||
msg1 := []byte("hello alice->bob\n")
|
||||
if err := aliceClient.Send(bobKey.Public(), msg1); err != nil {
|
||||
errCh <- fmt.Errorf("alice send to bob: %w", err)
|
||||
return
|
||||
}
|
||||
msg2 := []byte("hello alice->cathy\n")
|
||||
|
||||
// TODO: an error is expected here.
|
||||
// We ignore it, maybe we should log it somehow?
|
||||
aliceClient.Send(cathyKey.Public(), msg2)
|
||||
}
|
||||
}()
|
||||
|
||||
var countSnapshot [3]int64
|
||||
loadCounts := func() (adiff, bdiff, cdiff int64) {
|
||||
t.Helper()
|
||||
|
||||
atotal := aliceCount.Value()
|
||||
btotal := bobCount.Value()
|
||||
ctotal := cathyCount.Value()
|
||||
|
||||
adiff = atotal - countSnapshot[0]
|
||||
bdiff = btotal - countSnapshot[1]
|
||||
cdiff = ctotal - countSnapshot[2]
|
||||
|
||||
countSnapshot[0] = atotal
|
||||
countSnapshot[1] = btotal
|
||||
countSnapshot[2] = ctotal
|
||||
|
||||
t.Logf("count diffs: alice=%d, bob=%d, cathy=%d", adiff, bdiff, cdiff)
|
||||
return adiff, bdiff, cdiff
|
||||
}
|
||||
|
||||
t.Run("initial send", func(t *testing.T) {
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
a, b, c := loadCounts()
|
||||
if a != 0 {
|
||||
t.Errorf("alice diff=%d, want 0", a)
|
||||
}
|
||||
if b == 0 {
|
||||
t.Errorf("no bob diff, want positive value")
|
||||
}
|
||||
if c == 0 {
|
||||
t.Errorf("no cathy diff, want positive value")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("block cathy", func(t *testing.T) {
|
||||
// Block cathy. Now the cathyConn buffer will fill up quickly,
|
||||
// and the derp server will back up.
|
||||
cathyConn.SetReadBlock(true)
|
||||
time.Sleep(2 * s.WriteTimeout)
|
||||
|
||||
a, b, _ := loadCounts()
|
||||
if a != 0 {
|
||||
t.Errorf("alice diff=%d, want 0", a)
|
||||
}
|
||||
if b == 0 {
|
||||
t.Errorf("no bob diff, want positive value")
|
||||
}
|
||||
|
||||
// Now wait a little longer, and ensure packets still flow to bob
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
if _, b, _ := loadCounts(); b == 0 {
|
||||
t.Errorf("connection alice->bob frozen by alice->cathy")
|
||||
}
|
||||
})
|
||||
|
||||
// Cleanup, make sure we process all errors.
|
||||
t.Logf("TEST COMPLETE, cancelling sender")
|
||||
cancel()
|
||||
t.Logf("closing connections")
|
||||
aliceConn.Close()
|
||||
bobConn.Close()
|
||||
cathyConn.Close()
|
||||
|
||||
for i := 0; i < cap(errCh); i++ {
|
||||
err := <-errCh
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
continue
|
||||
}
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -32,6 +32,6 @@ func Handler(s *derp.Server) http.Handler {
|
|||
http.Error(w, "HTTP does not support general TCP support", 500)
|
||||
return
|
||||
}
|
||||
s.Accept(netConn, conn)
|
||||
s.Accept(netConn, conn, netConn.RemoteAddr().String())
|
||||
})
|
||||
}
|
||||
|
|
|
@ -0,0 +1,86 @@
|
|||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package nettest
|
||||
|
||||
import (
|
||||
"io"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Conn is a bi-directional in-memory stream that looks like a TCP net.Conn.
|
||||
type Conn interface {
|
||||
io.Reader
|
||||
io.Writer
|
||||
io.Closer
|
||||
|
||||
// The *Deadline methods follow the semantics of net.Conn.
|
||||
|
||||
SetDeadline(t time.Time) error
|
||||
SetReadDeadline(t time.Time) error
|
||||
SetWriteDeadline(t time.Time) error
|
||||
|
||||
// SetReadBlock blocks or unblocks the Read method of this Conn.
|
||||
// It reports an error if the existing value matches the new value,
|
||||
// or if the Conn has been Closed.
|
||||
SetReadBlock(bool) error
|
||||
|
||||
// SetWriteBlock blocks or unblocks the Write method of this Conn.
|
||||
// It reports an error if the existing value matches the new value,
|
||||
// or if the Conn has been Closed.
|
||||
SetWriteBlock(bool) error
|
||||
}
|
||||
|
||||
// NewConn creates a pair of Conns that are wired together by pipes.
|
||||
func NewConn(name string, maxBuf int) (Conn, Conn) {
|
||||
r := NewPipe(name+"|0", maxBuf)
|
||||
w := NewPipe(name+"|1", maxBuf)
|
||||
|
||||
return &connHalf{r: r, w: w}, &connHalf{r: w, w: r}
|
||||
}
|
||||
|
||||
type connHalf struct {
|
||||
r, w *Pipe
|
||||
}
|
||||
|
||||
func (c *connHalf) Read(b []byte) (n int, err error) {
|
||||
return c.r.Read(b)
|
||||
}
|
||||
func (c *connHalf) Write(b []byte) (n int, err error) {
|
||||
return c.w.Write(b)
|
||||
}
|
||||
func (c *connHalf) Close() error {
|
||||
err1 := c.r.Close()
|
||||
err2 := c.w.Close()
|
||||
if err1 != nil {
|
||||
return err1
|
||||
}
|
||||
return err2
|
||||
}
|
||||
func (c *connHalf) SetDeadline(t time.Time) error {
|
||||
err1 := c.SetReadDeadline(t)
|
||||
err2 := c.SetWriteDeadline(t)
|
||||
if err1 != nil {
|
||||
return err1
|
||||
}
|
||||
return err2
|
||||
}
|
||||
func (c *connHalf) SetReadDeadline(t time.Time) error {
|
||||
return c.r.SetReadDeadline(t)
|
||||
}
|
||||
func (c *connHalf) SetWriteDeadline(t time.Time) error {
|
||||
return c.w.SetWriteDeadline(t)
|
||||
}
|
||||
func (c *connHalf) SetReadBlock(b bool) error {
|
||||
if b {
|
||||
return c.r.Block()
|
||||
}
|
||||
return c.r.Unblock()
|
||||
}
|
||||
func (c *connHalf) SetWriteBlock(b bool) error {
|
||||
if b {
|
||||
return c.w.Block()
|
||||
}
|
||||
return c.w.Unblock()
|
||||
}
|
|
@ -0,0 +1,261 @@
|
|||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package nettest
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const debugPipe = false
|
||||
|
||||
// Pipe implements an in-memory FIFO with timeouts.
|
||||
type Pipe struct {
|
||||
name string
|
||||
maxBuf int
|
||||
rCh chan struct{}
|
||||
wCh chan struct{}
|
||||
|
||||
mu sync.Mutex
|
||||
closed bool
|
||||
blocked bool
|
||||
buf []byte
|
||||
readTimeout time.Time
|
||||
writeTimeout time.Time
|
||||
cancelReadTimer func()
|
||||
cancelWriteTimer func()
|
||||
}
|
||||
|
||||
// NewPipe creates a Pipe with a buffer size fixed at maxBuf.
|
||||
func NewPipe(name string, maxBuf int) *Pipe {
|
||||
return &Pipe{
|
||||
name: name,
|
||||
maxBuf: maxBuf,
|
||||
rCh: make(chan struct{}, 1),
|
||||
wCh: make(chan struct{}, 1),
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
ErrTimeout = errors.New("timeout")
|
||||
ErrReadTimeout = fmt.Errorf("read %w", ErrTimeout)
|
||||
ErrWriteTimeout = fmt.Errorf("write %w", ErrTimeout)
|
||||
)
|
||||
|
||||
// Read implements io.Reader.
|
||||
func (p *Pipe) Read(b []byte) (n int, err error) {
|
||||
if debugPipe {
|
||||
orig := b
|
||||
defer func() {
|
||||
log.Printf("Pipe(%q).Read( %q) n=%d, err=%v", p.name, string(orig[:n]), n, err)
|
||||
}()
|
||||
}
|
||||
for {
|
||||
p.mu.Lock()
|
||||
closed := p.closed
|
||||
timedout := !p.readTimeout.IsZero() && time.Now().After(p.readTimeout)
|
||||
blocked := p.blocked
|
||||
if !closed && !timedout && len(p.buf) > 0 {
|
||||
n2 := copy(b, p.buf)
|
||||
p.buf = p.buf[n2:]
|
||||
b = b[n2:]
|
||||
n += n2
|
||||
}
|
||||
p.mu.Unlock()
|
||||
|
||||
if closed {
|
||||
return 0, fmt.Errorf("nettest.Pipe(%q): closed: %w", p.name, io.EOF)
|
||||
}
|
||||
if timedout {
|
||||
return 0, fmt.Errorf("nettest.Pipe(%q): %w", p.name, ErrReadTimeout)
|
||||
}
|
||||
if blocked {
|
||||
<-p.rCh
|
||||
continue
|
||||
}
|
||||
if n > 0 {
|
||||
p.signalWrite()
|
||||
return n, nil
|
||||
}
|
||||
<-p.rCh
|
||||
}
|
||||
}
|
||||
|
||||
// Write implements io.Writer.
|
||||
func (p *Pipe) Write(b []byte) (n int, err error) {
|
||||
if debugPipe {
|
||||
orig := b
|
||||
defer func() {
|
||||
log.Printf("Pipe(%q).Write(%q) n=%d, err=%v", p.name, string(orig), n, err)
|
||||
}()
|
||||
}
|
||||
for {
|
||||
p.mu.Lock()
|
||||
closed := p.closed
|
||||
timedout := !p.writeTimeout.IsZero() && time.Now().After(p.writeTimeout)
|
||||
blocked := p.blocked
|
||||
if !closed && !timedout {
|
||||
n2 := len(b)
|
||||
if limit := p.maxBuf - len(p.buf); limit < n2 {
|
||||
n2 = limit
|
||||
}
|
||||
p.buf = append(p.buf, b[:n2]...)
|
||||
b = b[n2:]
|
||||
n += n2
|
||||
}
|
||||
p.mu.Unlock()
|
||||
|
||||
if closed {
|
||||
return n, fmt.Errorf("nettest.Pipe(%q): closed: %w", p.name, io.EOF)
|
||||
}
|
||||
if timedout {
|
||||
return n, fmt.Errorf("nettest.Pipe(%q): %w", p.name, ErrWriteTimeout)
|
||||
}
|
||||
if blocked {
|
||||
<-p.wCh
|
||||
continue
|
||||
}
|
||||
if n > 0 {
|
||||
p.signalRead()
|
||||
}
|
||||
if len(b) == 0 {
|
||||
return n, nil
|
||||
}
|
||||
<-p.wCh
|
||||
}
|
||||
}
|
||||
|
||||
// Close implements io.Closer.
|
||||
func (p *Pipe) Close() error {
|
||||
p.mu.Lock()
|
||||
closed := p.closed
|
||||
p.closed = true
|
||||
if p.cancelWriteTimer != nil {
|
||||
p.cancelWriteTimer()
|
||||
p.cancelWriteTimer = nil
|
||||
}
|
||||
if p.cancelReadTimer != nil {
|
||||
p.cancelReadTimer()
|
||||
p.cancelReadTimer = nil
|
||||
}
|
||||
p.mu.Unlock()
|
||||
|
||||
if closed {
|
||||
return fmt.Errorf("nettest.Pipe(%q).Close: already closed", p.name)
|
||||
}
|
||||
|
||||
p.signalRead()
|
||||
p.signalWrite()
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetReadDeadline sets the deadline for future Read calls.
|
||||
func (p *Pipe) SetReadDeadline(t time.Time) error {
|
||||
p.mu.Lock()
|
||||
p.readTimeout = t
|
||||
if p.cancelReadTimer != nil {
|
||||
p.cancelReadTimer()
|
||||
p.cancelReadTimer = nil
|
||||
}
|
||||
if d := time.Until(t); !t.IsZero() && d > 0 {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
p.cancelReadTimer = cancel
|
||||
go func() {
|
||||
t := time.NewTimer(d)
|
||||
defer t.Stop()
|
||||
select {
|
||||
case <-t.C:
|
||||
p.signalRead()
|
||||
case <-ctx.Done():
|
||||
}
|
||||
}()
|
||||
}
|
||||
p.mu.Unlock()
|
||||
|
||||
p.signalRead()
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetWriteDeadline sets the deadline for future Write calls.
|
||||
func (p *Pipe) SetWriteDeadline(t time.Time) error {
|
||||
p.mu.Lock()
|
||||
p.writeTimeout = t
|
||||
if p.cancelWriteTimer != nil {
|
||||
p.cancelWriteTimer()
|
||||
p.cancelWriteTimer = nil
|
||||
}
|
||||
if d := time.Until(t); !t.IsZero() && d > 0 {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
p.cancelWriteTimer = cancel
|
||||
go func() {
|
||||
t := time.NewTimer(d)
|
||||
defer t.Stop()
|
||||
select {
|
||||
case <-t.C:
|
||||
p.signalWrite()
|
||||
case <-ctx.Done():
|
||||
}
|
||||
}()
|
||||
}
|
||||
p.mu.Unlock()
|
||||
|
||||
p.signalWrite()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Pipe) Block() error {
|
||||
p.mu.Lock()
|
||||
closed := p.closed
|
||||
blocked := p.blocked
|
||||
p.blocked = true
|
||||
p.mu.Unlock()
|
||||
|
||||
if closed {
|
||||
return fmt.Errorf("nettest.Pipe(%q).Block: closed", p.name)
|
||||
}
|
||||
if blocked {
|
||||
return fmt.Errorf("nettest.Pipe(%q).Block: already blocked", p.name)
|
||||
}
|
||||
p.signalRead()
|
||||
p.signalWrite()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Pipe) Unblock() error {
|
||||
p.mu.Lock()
|
||||
closed := p.closed
|
||||
blocked := p.blocked
|
||||
p.blocked = false
|
||||
p.mu.Unlock()
|
||||
|
||||
if closed {
|
||||
return fmt.Errorf("nettest.Pipe(%q).Block: closed", p.name)
|
||||
}
|
||||
if !blocked {
|
||||
return fmt.Errorf("nettest.Pipe(%q).Block: already unblocked", p.name)
|
||||
}
|
||||
p.signalRead()
|
||||
p.signalWrite()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Pipe) signalRead() {
|
||||
select {
|
||||
case p.rCh <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Pipe) signalWrite() {
|
||||
select {
|
||||
case p.wCh <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
|
@ -0,0 +1,116 @@
|
|||
package nettest
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestPipeHello(t *testing.T) {
|
||||
p := NewPipe("p1", 1<<16)
|
||||
msg := "Hello, World!"
|
||||
if n, err := p.Write([]byte(msg)); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if n != len(msg) {
|
||||
t.Errorf("p.Write(%q) n=%d, want %d", msg, n, len(msg))
|
||||
}
|
||||
b := make([]byte, len(msg))
|
||||
if n, err := p.Read(b); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if n != len(b) {
|
||||
t.Errorf("p.Read(%q) n=%d, want %d", string(b[:n]), n, len(b))
|
||||
}
|
||||
if got := string(b); got != msg {
|
||||
t.Errorf("p.Read: %q, want %q", got, msg)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPipeTimeout(t *testing.T) {
|
||||
t.Run("write", func(t *testing.T) {
|
||||
p := NewPipe("p1", 1<<16)
|
||||
p.SetWriteDeadline(time.Now().Add(-1 * time.Second))
|
||||
n, err := p.Write([]byte{'h'})
|
||||
if err == nil || !errors.Is(err, ErrWriteTimeout) || !errors.Is(err, ErrTimeout) {
|
||||
t.Errorf("missing write timeout got err: %v", err)
|
||||
}
|
||||
if n != 0 {
|
||||
t.Errorf("n=%d on timeout", n)
|
||||
}
|
||||
})
|
||||
t.Run("read", func(t *testing.T) {
|
||||
p := NewPipe("p1", 1<<16)
|
||||
p.Write([]byte{'h'})
|
||||
|
||||
p.SetReadDeadline(time.Now().Add(-1 * time.Second))
|
||||
b := make([]byte, 1)
|
||||
n, err := p.Read(b)
|
||||
if err == nil || !errors.Is(err, ErrReadTimeout) || !errors.Is(err, ErrTimeout) {
|
||||
t.Errorf("missing read timeout got err: %v", err)
|
||||
}
|
||||
if n != 0 {
|
||||
t.Errorf("n=%d on timeout", n)
|
||||
}
|
||||
})
|
||||
t.Run("block-write", func(t *testing.T) {
|
||||
p := NewPipe("p1", 1<<16)
|
||||
p.SetWriteDeadline(time.Now().Add(10 * time.Millisecond))
|
||||
if _, err := p.Write([]byte{'h'}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := p.Block(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := p.Write([]byte{'h'}); err == nil || !errors.Is(err, ErrWriteTimeout) {
|
||||
t.Fatalf("want write timeout got: %v", err)
|
||||
}
|
||||
})
|
||||
t.Run("block-read", func(t *testing.T) {
|
||||
p := NewPipe("p1", 1<<16)
|
||||
p.Write([]byte{'h', 'i'})
|
||||
p.SetReadDeadline(time.Now().Add(10 * time.Millisecond))
|
||||
b := make([]byte, 1)
|
||||
if _, err := p.Read(b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := p.Block(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := p.Read(b); err == nil || !errors.Is(err, ErrReadTimeout) {
|
||||
t.Fatalf("want read timeout got: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
func TestLimit(t *testing.T) {
|
||||
p := NewPipe("p1", 1)
|
||||
errCh := make(chan error)
|
||||
go func() {
|
||||
n, err := p.Write([]byte{'a', 'b', 'c'})
|
||||
if err != nil {
|
||||
errCh <- err
|
||||
} else if n != 3 {
|
||||
errCh <- fmt.Errorf("p.Write n=%d, want 3", n)
|
||||
} else {
|
||||
errCh <- nil
|
||||
}
|
||||
}()
|
||||
b := make([]byte, 3)
|
||||
|
||||
if n, err := p.Read(b); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if n != 1 {
|
||||
t.Errorf("Read(%q): n=%d want 1", string(b), n)
|
||||
}
|
||||
if n, err := p.Read(b); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if n != 1 {
|
||||
t.Errorf("Read(%q): n=%d want 1", string(b), n)
|
||||
}
|
||||
if n, err := p.Read(b); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if n != 1 {
|
||||
t.Errorf("Read(%q): n=%d want 1", string(b), n)
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue