diff --git a/ssh/tailssh/tailssh.go b/ssh/tailssh/tailssh.go index 31881cdb6..a9f62a88b 100644 --- a/ssh/tailssh/tailssh.go +++ b/ssh/tailssh/tailssh.go @@ -35,6 +35,7 @@ import ( "tailscale.com/ipn/ipnlocal" "tailscale.com/logtail/backoff" "tailscale.com/net/tsaddr" + "tailscale.com/net/tsdial" "tailscale.com/tailcfg" "tailscale.com/tempfork/gliderlabs/ssh" "tailscale.com/types/logger" @@ -62,6 +63,7 @@ type ipnLocalBackend interface { NetMap() *netmap.NetworkMap WhoIs(ipp netip.AddrPort) (n *tailcfg.Node, u tailcfg.UserProfile, ok bool) DoNoiseRequest(req *http.Request) (*http.Response, error) + Dialer() *tsdial.Dialer } type server struct { @@ -76,11 +78,33 @@ type server struct { // mu protects the following mu sync.Mutex + httpc *http.Client // for calling out to peers. activeConns map[*conn]bool // set; value is always true fetchPublicKeysCache map[string]pubKeyCacheEntry // by https URL shutdownCalled bool } +// sessionRecordingClient returns an http.Client that uses srv.lb.Dialer() to +// dial connections. This is used to make requests to the session recording +// server to upload session recordings. +func (srv *server) sessionRecordingClient() *http.Client { + srv.mu.Lock() + defer srv.mu.Unlock() + if srv.httpc != nil { + return srv.httpc + } + tr := http.DefaultTransport.(*http.Transport).Clone() + tr.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { + ctx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + return srv.lb.Dialer().UserDial(ctx, network, addr) + } + srv.httpc = &http.Client{ + Transport: tr, + } + return srv.httpc +} + func (srv *server) now() time.Time { if srv != nil && srv.timeNow != nil { return srv.timeNow() @@ -1404,9 +1428,8 @@ func (ss *sshSession) startNewRecording() (_ *recording, err error) { go func() { defer pw.Close() ss.logf("starting asciinema recording to %s", recorder) - - // We just use the default client here, which has a 30s dial timeout. - resp, err := http.DefaultClient.Do(req) + hc := ss.conn.srv.sessionRecordingClient() + resp, err := hc.Do(req) if err != nil { ss.cancelCtx(err) ss.logf("recording: error sending recording to %s: %v", recorder, err) diff --git a/ssh/tailssh/tailssh_test.go b/ssh/tailssh/tailssh_test.go index faa3af645..fc9260e77 100644 --- a/ssh/tailssh/tailssh_test.go +++ b/ssh/tailssh/tailssh_test.go @@ -237,6 +237,10 @@ var ( testSignerOnce sync.Once ) +func (ts *localState) Dialer() *tsdial.Dialer { + return nil +} + func (ts *localState) GetSSH_HostKeys() ([]gossh.Signer, error) { testSignerOnce.Do(func() { _, priv, err := ed25519.GenerateKey(rand.Reader)