Compare commits
4 Commits
main
...
dgentry/sn
Author | SHA1 | Date |
---|---|---|
![]() |
0c4fd8ef3a | |
![]() |
5b4a35e7f1 | |
![]() |
a518e818cd | |
![]() |
93042b4407 |
|
@ -11,16 +11,22 @@ import (
|
|||
"flag"
|
||||
"log"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/dns/dnsmessage"
|
||||
"inet.af/tcpproxy"
|
||||
"tailscale.com/client/tailscale"
|
||||
"tailscale.com/net/netutil"
|
||||
"tailscale.com/tsnet"
|
||||
)
|
||||
|
||||
var ports = flag.String("ports", "443", "comma-separated list of ports to proxy")
|
||||
var (
|
||||
ports = flag.String("ports", "443", "comma-separated list of ports to proxy")
|
||||
dnsserv = flag.Bool("dns", true, "run a small DNS server to reply to any query with its own address")
|
||||
tsMBox = dnsmessage.MustNewName("support.tailscale.com.")
|
||||
)
|
||||
|
||||
func main() {
|
||||
flag.Parse()
|
||||
|
@ -45,6 +51,9 @@ func main() {
|
|||
log.Printf("Serving on port %v ...", portStr)
|
||||
go s.serve(ln)
|
||||
}
|
||||
if *dnsserv {
|
||||
go s.serveDns()
|
||||
}
|
||||
select {}
|
||||
}
|
||||
|
||||
|
@ -88,3 +97,116 @@ func (s *server) serveConn(c net.Conn) {
|
|||
})
|
||||
p.Start()
|
||||
}
|
||||
|
||||
// getAddresses returns the tsnet IP addresses of this process
|
||||
func (s *server) getAddresses() (ip4, ip6 netip.Addr) {
|
||||
for _, ip := range s.ts.TailscaleIPs() {
|
||||
if ip.Is6() {
|
||||
ip6 = ip
|
||||
}
|
||||
if ip.Is4() {
|
||||
ip4 = ip
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (s *server) serveDns() {
|
||||
buf := make([]byte, 1024)
|
||||
pconn, err := s.ts.ListenPacket("udp", ":53")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
for {
|
||||
_, addr, err := pconn.ReadFrom(buf)
|
||||
if err != nil {
|
||||
log.Printf("pconn.ReadFrom failed: %v\n ", err)
|
||||
continue
|
||||
}
|
||||
|
||||
var msg dnsmessage.Message
|
||||
err = msg.Unpack(buf)
|
||||
if err != nil {
|
||||
log.Printf("dnsmessage.Message unpack failed: %v\n ", err)
|
||||
continue
|
||||
}
|
||||
|
||||
buf, err := s.dnsResponse(&msg)
|
||||
if err != nil {
|
||||
log.Printf("s.dnsResponse failed: %v\n", err)
|
||||
continue
|
||||
}
|
||||
|
||||
_, err = pconn.WriteTo(buf, addr)
|
||||
if err != nil {
|
||||
log.Printf("pconn.WriteTo failed: %v\n", err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *server) dnsResponse(req *dnsmessage.Message) (buf []byte, err error) {
|
||||
resp := dnsmessage.NewBuilder(buf,
|
||||
dnsmessage.Header{
|
||||
ID: req.Header.ID,
|
||||
Response: true,
|
||||
Authoritative: true,
|
||||
})
|
||||
resp.EnableCompression()
|
||||
|
||||
if len(req.Questions) == 0 {
|
||||
buf, _ = resp.Finish()
|
||||
return
|
||||
}
|
||||
|
||||
q := req.Questions[0]
|
||||
err = resp.StartQuestions()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
resp.Question(q)
|
||||
|
||||
ip4, ip6 := s.getAddresses()
|
||||
err = resp.StartAnswers()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
switch q.Type {
|
||||
case dnsmessage.TypeAAAA:
|
||||
err = resp.AAAAResource(
|
||||
dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120},
|
||||
dnsmessage.AAAAResource{AAAA: ip6.As16()},
|
||||
)
|
||||
|
||||
case dnsmessage.TypeA:
|
||||
err = resp.AResource(
|
||||
dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120},
|
||||
dnsmessage.AResource{A: ip4.As4()},
|
||||
)
|
||||
case dnsmessage.TypeSOA:
|
||||
err = resp.SOAResource(
|
||||
dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120},
|
||||
dnsmessage.SOAResource{NS: q.Name, MBox: tsMBox, Serial: 2023030600,
|
||||
Refresh: 120, Retry: 120, Expire: 120, MinTTL: 60},
|
||||
)
|
||||
case dnsmessage.TypeNS:
|
||||
err = resp.NSResource(
|
||||
dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120},
|
||||
dnsmessage.NSResource{NS: tsMBox},
|
||||
)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
buf, err = resp.Finish()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
|
|
@ -40,6 +40,7 @@ type Dialer struct {
|
|||
// NetstackDialTCP dials the provided IPPort using netstack.
|
||||
// If nil, it's not used.
|
||||
NetstackDialTCP func(context.Context, netip.AddrPort) (net.Conn, error)
|
||||
NetstackDialUDP func(context.Context, netip.AddrPort) (net.Conn, error)
|
||||
|
||||
peerClientOnce sync.Once
|
||||
peerClient *http.Client
|
||||
|
@ -306,10 +307,19 @@ func (d *Dialer) UserDial(ctx context.Context, network, addr string) (net.Conn,
|
|||
return nil, err
|
||||
}
|
||||
if d.UseNetstackForIP != nil && d.UseNetstackForIP(ipp.Addr()) {
|
||||
if d.NetstackDialTCP == nil {
|
||||
return nil, errors.New("Dialer not initialized correctly")
|
||||
switch network {
|
||||
case "udp", "udp4", "udp6":
|
||||
if d.NetstackDialUDP == nil {
|
||||
return nil, errors.New("Dialer not initialized correctly")
|
||||
}
|
||||
return d.NetstackDialUDP(ctx, ipp)
|
||||
|
||||
default:
|
||||
if d.NetstackDialTCP == nil {
|
||||
return nil, errors.New("Dialer not initialized correctly")
|
||||
}
|
||||
return d.NetstackDialTCP(ctx, ipp)
|
||||
}
|
||||
return d.NetstackDialTCP(ctx, ipp)
|
||||
}
|
||||
// TODO(bradfitz): netns, etc
|
||||
var stdDialer net.Dialer
|
||||
|
|
|
@ -307,7 +307,7 @@ func (s *Server) Close() error {
|
|||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
for _, ln := range s.listeners {
|
||||
ln.Close()
|
||||
ln.closeUnlocked()
|
||||
}
|
||||
s.listeners = nil
|
||||
|
||||
|
@ -322,6 +322,17 @@ func (s *Server) doInit() {
|
|||
}
|
||||
}
|
||||
|
||||
func (s *Server) TailscaleIPs() []netip.Addr {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
st, err := s.localClient.Status(ctx)
|
||||
if err != nil {
|
||||
return []netip.Addr{}
|
||||
}
|
||||
return st.TailscaleIPs
|
||||
}
|
||||
|
||||
func (s *Server) getAuthKey() string {
|
||||
if v := s.AuthKey; v != "" {
|
||||
return v
|
||||
|
@ -440,6 +451,7 @@ func (s *Server) start() (reterr error) {
|
|||
}
|
||||
ns.ProcessLocalIPs = true
|
||||
ns.ForwardTCPIn = s.forwardTCP
|
||||
ns.ForwardUDPIn = s.forwardUDP
|
||||
s.netstack = ns
|
||||
s.dialer.UseNetstackForIP = func(ip netip.Addr) bool {
|
||||
_, ok := eng.PeerForIP(ip)
|
||||
|
@ -448,6 +460,9 @@ func (s *Server) start() (reterr error) {
|
|||
s.dialer.NetstackDialTCP = func(ctx context.Context, dst netip.AddrPort) (net.Conn, error) {
|
||||
return ns.DialContextTCP(ctx, dst)
|
||||
}
|
||||
s.dialer.NetstackDialUDP = func(ctx context.Context, dst netip.AddrPort) (net.Conn, error) {
|
||||
return ns.DialContextUDP(ctx, dst)
|
||||
}
|
||||
|
||||
if s.Store == nil {
|
||||
stateFile := filepath.Join(s.rootPath, "tailscaled.state")
|
||||
|
@ -579,6 +594,24 @@ func (s *Server) forwardTCP(c net.Conn, port uint16) {
|
|||
}
|
||||
}
|
||||
|
||||
func (s *Server) forwardUDP(p net.PacketConn, port uint16) {
|
||||
s.mu.Lock()
|
||||
ln, ok := s.listeners[listenKey{"udp", "", port}]
|
||||
s.mu.Unlock()
|
||||
if !ok {
|
||||
p.Close()
|
||||
return
|
||||
}
|
||||
|
||||
t := time.NewTimer(time.Second)
|
||||
defer t.Stop()
|
||||
select {
|
||||
case ln.pkt <- p:
|
||||
case <-t.C:
|
||||
p.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// getTSNetDir usually just returns filepath.Join(confDir, "tsnet-"+prog)
|
||||
// with no error.
|
||||
//
|
||||
|
@ -640,9 +673,12 @@ func (s *Server) APIClient() (*tailscale.Client, error) {
|
|||
|
||||
// Listen announces only on the Tailscale network.
|
||||
// It will start the server if it has not been started yet.
|
||||
func (s *Server) Listen(network, addr string) (net.Listener, error) {
|
||||
func (s *Server) listen(network, addr string) (*listener, error) {
|
||||
isPacket := false
|
||||
switch network {
|
||||
case "", "tcp", "tcp4", "tcp6":
|
||||
case "udp", "udp4", "udp6":
|
||||
isPacket = true
|
||||
default:
|
||||
return nil, errors.New("unsupported network type")
|
||||
}
|
||||
|
@ -660,11 +696,13 @@ func (s *Server) Listen(network, addr string) (net.Listener, error) {
|
|||
|
||||
key := listenKey{network, host, uint16(port)}
|
||||
ln := &listener{
|
||||
s: s,
|
||||
key: key,
|
||||
addr: addr,
|
||||
s: s,
|
||||
key: key,
|
||||
addr: addr,
|
||||
isPacket: isPacket,
|
||||
|
||||
conn: make(chan net.Conn),
|
||||
pkt: make(chan net.PacketConn, 1),
|
||||
}
|
||||
s.mu.Lock()
|
||||
if _, ok := s.listeners[key]; ok {
|
||||
|
@ -676,6 +714,19 @@ func (s *Server) Listen(network, addr string) (net.Listener, error) {
|
|||
return ln, nil
|
||||
}
|
||||
|
||||
func (s *Server) Listen(network, addr string) (net.Listener, error) {
|
||||
return s.listen(network, addr)
|
||||
}
|
||||
|
||||
func (s *Server) ListenPacket(network, addr string) (net.PacketConn, error) {
|
||||
ln, err := s.listen(network, addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ln.GetPacketConn()
|
||||
}
|
||||
|
||||
type listenKey struct {
|
||||
network string
|
||||
host string
|
||||
|
@ -683,13 +734,18 @@ type listenKey struct {
|
|||
}
|
||||
|
||||
type listener struct {
|
||||
s *Server
|
||||
key listenKey
|
||||
addr string
|
||||
conn chan net.Conn
|
||||
s *Server
|
||||
key listenKey
|
||||
addr string
|
||||
isPacket bool
|
||||
conn chan net.Conn
|
||||
pkt chan net.PacketConn
|
||||
}
|
||||
|
||||
func (ln *listener) Accept() (net.Conn, error) {
|
||||
if ln.isPacket {
|
||||
return nil, fmt.Errorf("tsnet: listener is for packets (UDP, not TCP)")
|
||||
}
|
||||
c, ok := <-ln.conn
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("tsnet: %w", net.ErrClosed)
|
||||
|
@ -697,13 +753,29 @@ func (ln *listener) Accept() (net.Conn, error) {
|
|||
return c, nil
|
||||
}
|
||||
|
||||
func (ln *listener) GetPacketConn() (net.PacketConn, error) {
|
||||
if !ln.isPacket {
|
||||
return nil, fmt.Errorf("tsnet: listener is for connections (TCP, not UDP)")
|
||||
}
|
||||
|
||||
p, ok := <-ln.pkt
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("tsnet: %w", net.ErrClosed)
|
||||
}
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func (ln *listener) Addr() net.Addr { return addr{ln} }
|
||||
func (ln *listener) Close() error {
|
||||
ln.s.mu.Lock()
|
||||
defer ln.s.mu.Unlock()
|
||||
return ln.closeUnlocked()
|
||||
}
|
||||
func (ln *listener) closeUnlocked() error {
|
||||
if v, ok := ln.s.listeners[ln.key]; ok && v == ln {
|
||||
delete(ln.s.listeners, ln.key)
|
||||
close(ln.conn)
|
||||
close(ln.pkt)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -13,6 +13,7 @@ import (
|
|||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
@ -87,31 +88,31 @@ func startControl(t *testing.T) (controlURL string) {
|
|||
return controlURL
|
||||
}
|
||||
|
||||
func TestConn(t *testing.T) {
|
||||
func setupTwoNodes(t *testing.T) (s1, s2 *Server, ctx context.Context) {
|
||||
controlURL := startControl(t)
|
||||
|
||||
tmp := t.TempDir()
|
||||
tmps1 := filepath.Join(tmp, "s1")
|
||||
os.MkdirAll(tmps1, 0755)
|
||||
s1 := &Server{
|
||||
s1 = &Server{
|
||||
Dir: tmps1,
|
||||
ControlURL: controlURL,
|
||||
Hostname: "s1",
|
||||
Store: new(mem.Store),
|
||||
Ephemeral: true,
|
||||
}
|
||||
defer s1.Close()
|
||||
t.Cleanup(func() { s1.Close() })
|
||||
|
||||
tmps2 := filepath.Join(tmp, "s1")
|
||||
os.MkdirAll(tmps2, 0755)
|
||||
s2 := &Server{
|
||||
s2 = &Server{
|
||||
Dir: tmps2,
|
||||
ControlURL: controlURL,
|
||||
Hostname: "s2",
|
||||
Store: new(mem.Store),
|
||||
Ephemeral: true,
|
||||
}
|
||||
defer s2.Close()
|
||||
t.Cleanup(func() { s2.Close() })
|
||||
|
||||
if !*verboseNodes {
|
||||
s1.Logf = logger.Discard
|
||||
|
@ -119,7 +120,7 @@ func TestConn(t *testing.T) {
|
|||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
t.Cleanup(cancel)
|
||||
|
||||
s1status, err := s1.Up(ctx)
|
||||
if err != nil {
|
||||
|
@ -142,6 +143,12 @@ func TestConn(t *testing.T) {
|
|||
}
|
||||
t.Logf("ping success: %#+v", res)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func TestConn(t *testing.T) {
|
||||
s1, s2, ctx := setupTwoNodes(t)
|
||||
|
||||
// pass some data through TCP.
|
||||
ln, err := s1.Listen("tcp", ":8081")
|
||||
if err != nil {
|
||||
|
@ -149,6 +156,7 @@ func TestConn(t *testing.T) {
|
|||
}
|
||||
defer ln.Close()
|
||||
|
||||
s1ip := s1.TailscaleIPs()[0]
|
||||
w, err := s2.Dial(ctx, "tcp", fmt.Sprintf("%s:8081", s1ip))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
@ -174,6 +182,42 @@ func TestConn(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestPackets(t *testing.T) {
|
||||
s1, s2, ctx := setupTwoNodes(t)
|
||||
|
||||
want := "PACKET!!"
|
||||
received := make(chan []byte)
|
||||
go func() {
|
||||
p, err := s1.ListenPacket("udp", ":5000")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer p.Close()
|
||||
|
||||
buf := make([]byte, len(want))
|
||||
if _, _, err := p.ReadFrom(buf); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
received <- buf
|
||||
}()
|
||||
|
||||
s1ip := s1.TailscaleIPs()[0]
|
||||
w, err := s2.Dial(ctx, "udp", fmt.Sprintf("%s:5000", s1ip))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if _, err := io.WriteString(w, want); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
got := <-received
|
||||
t.Logf("got: %q", got)
|
||||
if string(got) != want {
|
||||
t.Errorf("got %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoopbackLocalAPI(t *testing.T) {
|
||||
controlURL := startControl(t)
|
||||
|
||||
|
@ -258,3 +302,32 @@ func TestLoopbackLocalAPI(t *testing.T) {
|
|||
t.Errorf("GET /status returned %d, want 200", res.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTailscaleIPs(t *testing.T) {
|
||||
controlURL := startControl(t)
|
||||
|
||||
tmp := t.TempDir()
|
||||
tmps1 := filepath.Join(tmp, "s1")
|
||||
os.MkdirAll(tmps1, 0755)
|
||||
s1 := &Server{
|
||||
Dir: tmps1,
|
||||
ControlURL: controlURL,
|
||||
Hostname: "s1",
|
||||
Store: new(mem.Store),
|
||||
Ephemeral: true,
|
||||
}
|
||||
defer s1.Close()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
s1status, err := s1.Up(ctx)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
ips := s1.TailscaleIPs()
|
||||
if !reflect.DeepEqual(ips, s1status.TailscaleIPs) {
|
||||
t.Errorf("s1.TailscaleIPs returned a different result than S1.Up, %v != %v", ips, s1status.TailscaleIPs)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -84,6 +84,10 @@ type Impl struct {
|
|||
// port other than accepting it and closing it.
|
||||
ForwardTCPIn func(c net.Conn, port uint16)
|
||||
|
||||
// ForwardUDPIn, if non-nil, handles forwarding inbound UDP
|
||||
// packets.
|
||||
ForwardUDPIn func(c net.PacketConn, port uint16)
|
||||
|
||||
// ProcessLocalIPs is whether netstack should handle incoming
|
||||
// traffic directed at the Node.Addresses (local IPs).
|
||||
// It can only be set before calling Start.
|
||||
|
@ -1021,6 +1025,12 @@ func (ns *Impl) acceptUDP(r *udp.ForwarderRequest) {
|
|||
}
|
||||
|
||||
c := gonet.NewUDPConn(ns.ipstack, &wq, ep)
|
||||
|
||||
if ns.ForwardUDPIn != nil {
|
||||
ns.ForwardUDPIn(c, r.ID().LocalPort)
|
||||
return
|
||||
}
|
||||
|
||||
go ns.forwardUDP(c, &wq, srcAddr, dstAddr)
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue