derp/derphttp: pass `*tls.Config` to WebSocket connections

Signed-off-by: Kyle Carberry <kyle@carberry.com>
pull/7401/head
Kyle Carberry 2023-02-28 15:02:10 -06:00
parent e3211ff88b
commit 1bf230348c
2 changed files with 21 additions and 8 deletions

View File

@ -188,6 +188,9 @@ func (c *Client) tlsServerName(node *tailcfg.DERPNode) string {
if c.url != nil { if c.url != nil {
return c.url.Host return c.url.Host
} }
if node == nil {
return ""
}
return node.HostName return node.HostName
} }
@ -225,7 +228,7 @@ func (c *Client) preferIPv6() bool {
} }
// dialWebsocketFunc is non-nil (set by websocket.go's init) when compiled in. // dialWebsocketFunc is non-nil (set by websocket.go's init) when compiled in.
var dialWebsocketFunc func(ctx context.Context, urlStr string) (net.Conn, error) var dialWebsocketFunc func(ctx context.Context, urlStr string, tlsConfig *tls.Config) (net.Conn, error)
func useWebsockets() bool { func useWebsockets() bool {
if runtime.GOOS == "js" { if runtime.GOOS == "js" {
@ -292,13 +295,16 @@ func (c *Client) connect(ctx context.Context, caller string) (client *derp.Clien
switch { switch {
case useWebsockets(): case useWebsockets():
var urlStr string var urlStr string
var tlsConfig *tls.Config
if c.url != nil { if c.url != nil {
urlStr = c.url.String() urlStr = c.url.String()
tlsConfig = c.tlsConfig(nil)
} else { } else {
urlStr = c.urlString(reg.Nodes[0]) urlStr = c.urlString(reg.Nodes[0])
tlsConfig = c.tlsConfig(reg.Nodes[0])
} }
c.logf("%s: connecting websocket to %v", caller, urlStr) c.logf("%s: connecting websocket to %v", caller, urlStr)
conn, err := dialWebsocketFunc(ctx, urlStr) conn, err := dialWebsocketFunc(ctx, urlStr, tlsConfig)
if err != nil { if err != nil {
c.logf("%s: websocket to %v error: %v", caller, urlStr, err) c.logf("%s: websocket to %v error: %v", caller, urlStr, err)
return nil, 0, err return nil, 0, err
@ -363,7 +369,7 @@ func (c *Client) connect(ctx context.Context, caller string) (client *derp.Clien
var serverProtoVersion int var serverProtoVersion int
var tlsState *tls.ConnectionState var tlsState *tls.ConnectionState
if c.useHTTPS() { if c.useHTTPS() {
tlsConn := c.tlsClient(tcpConn, node) tlsConn := tls.Client(tcpConn, c.tlsConfig(node))
httpConn = tlsConn httpConn = tlsConn
// Force a handshake now (instead of waiting for it to // Force a handshake now (instead of waiting for it to
@ -522,7 +528,7 @@ func (c *Client) dialRegion(ctx context.Context, reg *tailcfg.DERPRegion) (net.C
return nil, nil, firstErr return nil, nil, firstErr
} }
func (c *Client) tlsClient(nc net.Conn, node *tailcfg.DERPNode) *tls.Conn { func (c *Client) tlsConfig(node *tailcfg.DERPNode) *tls.Config {
tlsConf := tlsdial.Config(c.tlsServerName(node), c.TLSConfig) tlsConf := tlsdial.Config(c.tlsServerName(node), c.TLSConfig)
if node != nil { if node != nil {
if node.InsecureForTests { if node.InsecureForTests {
@ -533,7 +539,7 @@ func (c *Client) tlsClient(nc net.Conn, node *tailcfg.DERPNode) *tls.Conn {
tlsdial.SetConfigExpectedCert(tlsConf, node.CertName) tlsdial.SetConfigExpectedCert(tlsConf, node.CertName)
} }
} }
return tls.Client(nc, tlsConf) return tlsConf
} }
// DialRegionTLS returns a TLS connection to a DERP node in the given region. // DialRegionTLS returns a TLS connection to a DERP node in the given region.
@ -549,7 +555,7 @@ func (c *Client) DialRegionTLS(ctx context.Context, reg *tailcfg.DERPRegion) (tl
done := make(chan bool) // unbuffered done := make(chan bool) // unbuffered
defer close(done) defer close(done)
tlsConn = c.tlsClient(tcpConn, node) tlsConn = tls.Client(tcpConn, c.tlsConfig(node))
go func() { go func() {
select { select {
case <-done: case <-done:

View File

@ -1,14 +1,16 @@
// Copyright (c) Tailscale Inc & AUTHORS // Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause // SPDX-License-Identifier: BSD-3-Clause
//go:build linux || js //go:build linux || windows || darwin || js
package derphttp package derphttp
import ( import (
"context" "context"
"crypto/tls"
"log" "log"
"net" "net"
"net/http"
"nhooyr.io/websocket" "nhooyr.io/websocket"
"tailscale.com/net/wsconn" "tailscale.com/net/wsconn"
@ -18,8 +20,13 @@ func init() {
dialWebsocketFunc = dialWebsocket dialWebsocketFunc = dialWebsocket
} }
func dialWebsocket(ctx context.Context, urlStr string) (net.Conn, error) { func dialWebsocket(ctx context.Context, urlStr string, tlsConfig *tls.Config) (net.Conn, error) {
c, res, err := websocket.Dial(ctx, urlStr, &websocket.DialOptions{ c, res, err := websocket.Dial(ctx, urlStr, &websocket.DialOptions{
HTTPClient: &http.Client{
Transport: &http.Transport{
TLSClientConfig: tlsConfig,
},
},
Subprotocols: []string{"derp"}, Subprotocols: []string{"derp"},
}) })
if err != nil { if err != nil {