net/tsnet: add UDP support.

Add tsnet server.ListenPacket(), which returns a
net.PacketConn.

Support udp in the tsnet Dial().

Additionally:
1. Add a tsnet TailscaleIPs() function to return the IP
   addresses, both IPv4 and IPv6.
2. Fix a deadlock in tsnet server.Close() where it
   acquired s.mu.Lock() and then called listener.Close()
   which also tried to acquire s.mu.Lock(). This caused a
   hang in the new TestPackets test case.

Updates https://github.com/tailscale/tailscale/issues/5871

Signed-off-by: Denton Gentry <dgentry@tailscale.com>
dgentry/sniproxy-dns
Denton Gentry 2023-03-04 19:15:29 -08:00
parent 93042b4407
commit a518e818cd
No known key found for this signature in database
3 changed files with 173 additions and 18 deletions

View File

@ -40,6 +40,7 @@ type Dialer struct {
// NetstackDialTCP dials the provided IPPort using netstack. // NetstackDialTCP dials the provided IPPort using netstack.
// If nil, it's not used. // If nil, it's not used.
NetstackDialTCP func(context.Context, netip.AddrPort) (net.Conn, error) NetstackDialTCP func(context.Context, netip.AddrPort) (net.Conn, error)
NetstackDialUDP func(context.Context, netip.AddrPort) (net.Conn, error)
peerClientOnce sync.Once peerClientOnce sync.Once
peerClient *http.Client peerClient *http.Client
@ -306,10 +307,19 @@ func (d *Dialer) UserDial(ctx context.Context, network, addr string) (net.Conn,
return nil, err return nil, err
} }
if d.UseNetstackForIP != nil && d.UseNetstackForIP(ipp.Addr()) { if d.UseNetstackForIP != nil && d.UseNetstackForIP(ipp.Addr()) {
if d.NetstackDialTCP == nil { switch network {
return nil, errors.New("Dialer not initialized correctly") case "udp", "udp4", "udp6":
if d.NetstackDialUDP == nil {
return nil, errors.New("Dialer not initialized correctly")
}
return d.NetstackDialUDP(ctx, ipp)
default:
if d.NetstackDialTCP == nil {
return nil, errors.New("Dialer not initialized correctly")
}
return d.NetstackDialTCP(ctx, ipp)
} }
return d.NetstackDialTCP(ctx, ipp)
} }
// TODO(bradfitz): netns, etc // TODO(bradfitz): netns, etc
var stdDialer net.Dialer var stdDialer net.Dialer

View File

@ -307,7 +307,7 @@ func (s *Server) Close() error {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
for _, ln := range s.listeners { for _, ln := range s.listeners {
ln.Close() ln.closeUnlocked()
} }
s.listeners = nil s.listeners = nil
@ -322,6 +322,17 @@ func (s *Server) doInit() {
} }
} }
func (s *Server) TailscaleIPs() []netip.Addr {
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
st, err := s.localClient.Status(ctx)
if err != nil {
return []netip.Addr{}
}
return st.TailscaleIPs
}
func (s *Server) getAuthKey() string { func (s *Server) getAuthKey() string {
if v := s.AuthKey; v != "" { if v := s.AuthKey; v != "" {
return v return v
@ -440,6 +451,7 @@ func (s *Server) start() (reterr error) {
} }
ns.ProcessLocalIPs = true ns.ProcessLocalIPs = true
ns.ForwardTCPIn = s.forwardTCP ns.ForwardTCPIn = s.forwardTCP
ns.ForwardUDPIn = s.forwardUDP
s.netstack = ns s.netstack = ns
s.dialer.UseNetstackForIP = func(ip netip.Addr) bool { s.dialer.UseNetstackForIP = func(ip netip.Addr) bool {
_, ok := eng.PeerForIP(ip) _, ok := eng.PeerForIP(ip)
@ -448,6 +460,9 @@ func (s *Server) start() (reterr error) {
s.dialer.NetstackDialTCP = func(ctx context.Context, dst netip.AddrPort) (net.Conn, error) { s.dialer.NetstackDialTCP = func(ctx context.Context, dst netip.AddrPort) (net.Conn, error) {
return ns.DialContextTCP(ctx, dst) return ns.DialContextTCP(ctx, dst)
} }
s.dialer.NetstackDialUDP = func(ctx context.Context, dst netip.AddrPort) (net.Conn, error) {
return ns.DialContextUDP(ctx, dst)
}
if s.Store == nil { if s.Store == nil {
stateFile := filepath.Join(s.rootPath, "tailscaled.state") stateFile := filepath.Join(s.rootPath, "tailscaled.state")
@ -579,6 +594,24 @@ func (s *Server) forwardTCP(c net.Conn, port uint16) {
} }
} }
func (s *Server) forwardUDP(p net.PacketConn, port uint16) {
s.mu.Lock()
ln, ok := s.listeners[listenKey{"udp", "", port}]
s.mu.Unlock()
if !ok {
p.Close()
return
}
t := time.NewTimer(time.Second)
defer t.Stop()
select {
case ln.pkt <- p:
case <-t.C:
p.Close()
}
}
// getTSNetDir usually just returns filepath.Join(confDir, "tsnet-"+prog) // getTSNetDir usually just returns filepath.Join(confDir, "tsnet-"+prog)
// with no error. // with no error.
// //
@ -640,9 +673,12 @@ func (s *Server) APIClient() (*tailscale.Client, error) {
// Listen announces only on the Tailscale network. // Listen announces only on the Tailscale network.
// It will start the server if it has not been started yet. // It will start the server if it has not been started yet.
func (s *Server) Listen(network, addr string) (net.Listener, error) { func (s *Server) listen(network, addr string) (*listener, error) {
isPacket := false
switch network { switch network {
case "", "tcp", "tcp4", "tcp6": case "", "tcp", "tcp4", "tcp6":
case "udp", "udp4", "udp6":
isPacket = true
default: default:
return nil, errors.New("unsupported network type") return nil, errors.New("unsupported network type")
} }
@ -660,11 +696,13 @@ func (s *Server) Listen(network, addr string) (net.Listener, error) {
key := listenKey{network, host, uint16(port)} key := listenKey{network, host, uint16(port)}
ln := &listener{ ln := &listener{
s: s, s: s,
key: key, key: key,
addr: addr, addr: addr,
isPacket: isPacket,
conn: make(chan net.Conn), conn: make(chan net.Conn),
pkt: make(chan net.PacketConn, 1),
} }
s.mu.Lock() s.mu.Lock()
if _, ok := s.listeners[key]; ok { if _, ok := s.listeners[key]; ok {
@ -676,6 +714,19 @@ func (s *Server) Listen(network, addr string) (net.Listener, error) {
return ln, nil return ln, nil
} }
func (s *Server) Listen(network, addr string) (net.Listener, error) {
return s.listen(network, addr)
}
func (s *Server) ListenPacket(network, addr string) (net.PacketConn, error) {
ln, err := s.listen(network, addr)
if err != nil {
return nil, err
}
return ln.GetPacketConn()
}
type listenKey struct { type listenKey struct {
network string network string
host string host string
@ -683,13 +734,18 @@ type listenKey struct {
} }
type listener struct { type listener struct {
s *Server s *Server
key listenKey key listenKey
addr string addr string
conn chan net.Conn isPacket bool
conn chan net.Conn
pkt chan net.PacketConn
} }
func (ln *listener) Accept() (net.Conn, error) { func (ln *listener) Accept() (net.Conn, error) {
if ln.isPacket {
return nil, fmt.Errorf("tsnet: listener is for packets (UDP, not TCP)")
}
c, ok := <-ln.conn c, ok := <-ln.conn
if !ok { if !ok {
return nil, fmt.Errorf("tsnet: %w", net.ErrClosed) return nil, fmt.Errorf("tsnet: %w", net.ErrClosed)
@ -697,13 +753,29 @@ func (ln *listener) Accept() (net.Conn, error) {
return c, nil return c, nil
} }
func (ln *listener) GetPacketConn() (net.PacketConn, error) {
if !ln.isPacket {
return nil, fmt.Errorf("tsnet: listener is for connections (TCP, not UDP)")
}
p, ok := <-ln.pkt
if !ok {
return nil, fmt.Errorf("tsnet: %w", net.ErrClosed)
}
return p, nil
}
func (ln *listener) Addr() net.Addr { return addr{ln} } func (ln *listener) Addr() net.Addr { return addr{ln} }
func (ln *listener) Close() error { func (ln *listener) Close() error {
ln.s.mu.Lock() ln.s.mu.Lock()
defer ln.s.mu.Unlock() defer ln.s.mu.Unlock()
return ln.closeUnlocked()
}
func (ln *listener) closeUnlocked() error {
if v, ok := ln.s.listeners[ln.key]; ok && v == ln { if v, ok := ln.s.listeners[ln.key]; ok && v == ln {
delete(ln.s.listeners, ln.key) delete(ln.s.listeners, ln.key)
close(ln.conn) close(ln.conn)
close(ln.pkt)
} }
return nil return nil
} }

View File

@ -13,6 +13,7 @@ import (
"net/http/httptest" "net/http/httptest"
"os" "os"
"path/filepath" "path/filepath"
"reflect"
"testing" "testing"
"time" "time"
@ -87,31 +88,31 @@ func startControl(t *testing.T) (controlURL string) {
return controlURL return controlURL
} }
func TestConn(t *testing.T) { func setupTwoNodes(t *testing.T) (s1, s2 *Server, ctx context.Context) {
controlURL := startControl(t) controlURL := startControl(t)
tmp := t.TempDir() tmp := t.TempDir()
tmps1 := filepath.Join(tmp, "s1") tmps1 := filepath.Join(tmp, "s1")
os.MkdirAll(tmps1, 0755) os.MkdirAll(tmps1, 0755)
s1 := &Server{ s1 = &Server{
Dir: tmps1, Dir: tmps1,
ControlURL: controlURL, ControlURL: controlURL,
Hostname: "s1", Hostname: "s1",
Store: new(mem.Store), Store: new(mem.Store),
Ephemeral: true, Ephemeral: true,
} }
defer s1.Close() t.Cleanup(func() { s1.Close() })
tmps2 := filepath.Join(tmp, "s1") tmps2 := filepath.Join(tmp, "s1")
os.MkdirAll(tmps2, 0755) os.MkdirAll(tmps2, 0755)
s2 := &Server{ s2 = &Server{
Dir: tmps2, Dir: tmps2,
ControlURL: controlURL, ControlURL: controlURL,
Hostname: "s2", Hostname: "s2",
Store: new(mem.Store), Store: new(mem.Store),
Ephemeral: true, Ephemeral: true,
} }
defer s2.Close() t.Cleanup(func() { s2.Close() })
if !*verboseNodes { if !*verboseNodes {
s1.Logf = logger.Discard s1.Logf = logger.Discard
@ -119,7 +120,7 @@ func TestConn(t *testing.T) {
} }
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel() t.Cleanup(cancel)
s1status, err := s1.Up(ctx) s1status, err := s1.Up(ctx)
if err != nil { if err != nil {
@ -142,6 +143,12 @@ func TestConn(t *testing.T) {
} }
t.Logf("ping success: %#+v", res) t.Logf("ping success: %#+v", res)
return
}
func TestConn(t *testing.T) {
s1, s2, ctx := setupTwoNodes(t)
// pass some data through TCP. // pass some data through TCP.
ln, err := s1.Listen("tcp", ":8081") ln, err := s1.Listen("tcp", ":8081")
if err != nil { if err != nil {
@ -149,6 +156,7 @@ func TestConn(t *testing.T) {
} }
defer ln.Close() defer ln.Close()
s1ip := s1.TailscaleIPs()[0]
w, err := s2.Dial(ctx, "tcp", fmt.Sprintf("%s:8081", s1ip)) w, err := s2.Dial(ctx, "tcp", fmt.Sprintf("%s:8081", s1ip))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -174,6 +182,42 @@ func TestConn(t *testing.T) {
} }
} }
func TestPackets(t *testing.T) {
s1, s2, ctx := setupTwoNodes(t)
want := "PACKET!!"
received := make(chan []byte)
go func() {
p, err := s1.ListenPacket("udp", ":5000")
if err != nil {
t.Fatal(err)
}
defer p.Close()
buf := make([]byte, len(want))
if _, _, err := p.ReadFrom(buf); err != nil {
t.Fatal(err)
}
received <- buf
}()
s1ip := s1.TailscaleIPs()[0]
w, err := s2.Dial(ctx, "udp", fmt.Sprintf("%s:5000", s1ip))
if err != nil {
t.Fatal(err)
}
if _, err := io.WriteString(w, want); err != nil {
t.Fatal(err)
}
got := <-received
t.Logf("got: %q", got)
if string(got) != want {
t.Errorf("got %q, want %q", got, want)
}
}
func TestLoopbackLocalAPI(t *testing.T) { func TestLoopbackLocalAPI(t *testing.T) {
controlURL := startControl(t) controlURL := startControl(t)
@ -258,3 +302,32 @@ func TestLoopbackLocalAPI(t *testing.T) {
t.Errorf("GET /status returned %d, want 200", res.StatusCode) t.Errorf("GET /status returned %d, want 200", res.StatusCode)
} }
} }
func TestTailscaleIPs(t *testing.T) {
controlURL := startControl(t)
tmp := t.TempDir()
tmps1 := filepath.Join(tmp, "s1")
os.MkdirAll(tmps1, 0755)
s1 := &Server{
Dir: tmps1,
ControlURL: controlURL,
Hostname: "s1",
Store: new(mem.Store),
Ephemeral: true,
}
defer s1.Close()
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
s1status, err := s1.Up(ctx)
if err != nil {
t.Fatal(err)
}
ips := s1.TailscaleIPs()
if !reflect.DeepEqual(ips, s1status.TailscaleIPs) {
t.Errorf("s1.TailscaleIPs returned a different result than S1.Up, %v != %v", ips, s1status.TailscaleIPs)
}
}