From d31a4d92e659db8a9afbea7747a69826d552f0d0 Mon Sep 17 00:00:00 2001 From: Kevin Allen Date: Wed, 26 Apr 2023 15:03:48 -0400 Subject: [PATCH] Add custom TLS options Signed-off-by: Kevin Allen kallen@bostondynamics.com --- net/tlsdial/tlsdial.go | 53 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/net/tlsdial/tlsdial.go b/net/tlsdial/tlsdial.go index d571d38a6..33f8255ec 100644 --- a/net/tlsdial/tlsdial.go +++ b/net/tlsdial/tlsdial.go @@ -41,6 +41,17 @@ var debug = envknob.RegisterBool("TS_DEBUG_TLS_DIAL") // Headscale, etc. var tlsdialWarningPrinted sync.Map // map[string]bool +var ( + // rootCAOverride creates environment variable config TS_TLS_DIAL_ROOT_CA which + // will override the certificate authority used to verify the server instead + // of the system default + rootCAOverride = envknob.RegisterString("TS_TLS_DIAL_ROOT_CA") + // serverHostOverride creates environment variable TS_TLS_DIAL_CONNECT_TO which + // will override the server name the certificate is validated against AND the SNI + // name presented to the server, which may affect virtual hosts + serverHostOverride = envknob.RegisterString("TS_TLS_DIAL_CONNECT_TO") +) + // Config returns a tls.Config for connecting to a server. // If base is non-nil, it's cloned as the base config before // being configured and returned. @@ -52,6 +63,9 @@ func Config(host string, base *tls.Config) *tls.Config { conf = base.Clone() } conf.ServerName = host + if len(serverHostOverride()) != 0 { + conf.ServerName = serverHostOverride() + } if n := sslKeyLogFile; n != "" { f, err := os.OpenFile(n, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0600) @@ -93,6 +107,21 @@ func Config(host string, base *tls.Config) *tls.Config { for _, cert := range cs.PeerCertificates[1:] { opts.Intermediates.AddCert(cert) } + + // Check against user overriden root CA if provided + if overrideRoots() != nil { + opts.Roots = overrideRoots() + _, err := cs.PeerCertificates[0].Verify(opts) + if debug() { + log.Printf("tlsdial(override %q): %v", host, err) + } + if err == nil { + atomic.AddInt32(&counterFallbackOK, 1) + return nil + } + return err + } + _, errSys := cs.PeerCertificates[0].Verify(opts) if debug() { log.Printf("tlsdial(sys %q): %v", host, errSys) @@ -272,3 +301,27 @@ func bakedInRoots() *x509.CertPool { }) return bakedInRootsOnce.p } + +var overrideRootsOnce struct { + sync.Once + p *x509.CertPool +} + +func overrideRoots() *x509.CertPool { + if len(rootCAOverride()) == 0 { + return nil + } + overrideRootsOnce.Do(func() { + pem, err := os.ReadFile(rootCAOverride()) + if err != nil { + panic(fmt.Sprintf("Error loading custom root CA %s: %v", rootCAOverride(), err)) + } + + p := x509.NewCertPool() + if !p.AppendCertsFromPEM(pem) { + panic(fmt.Sprintf("Invalid PEM in custom root CA %s", rootCAOverride())) + } + overrideRootsOnce.p = p + }) + return overrideRootsOnce.p +}