parent
8864112a0c
commit
d31a4d92e6
|
@ -41,6 +41,17 @@ var debug = envknob.RegisterBool("TS_DEBUG_TLS_DIAL")
|
|||
// Headscale, etc.
|
||||
var tlsdialWarningPrinted sync.Map // map[string]bool
|
||||
|
||||
var (
|
||||
// rootCAOverride creates environment variable config TS_TLS_DIAL_ROOT_CA which
|
||||
// will override the certificate authority used to verify the server instead
|
||||
// of the system default
|
||||
rootCAOverride = envknob.RegisterString("TS_TLS_DIAL_ROOT_CA")
|
||||
// serverHostOverride creates environment variable TS_TLS_DIAL_CONNECT_TO which
|
||||
// will override the server name the certificate is validated against AND the SNI
|
||||
// name presented to the server, which may affect virtual hosts
|
||||
serverHostOverride = envknob.RegisterString("TS_TLS_DIAL_CONNECT_TO")
|
||||
)
|
||||
|
||||
// Config returns a tls.Config for connecting to a server.
|
||||
// If base is non-nil, it's cloned as the base config before
|
||||
// being configured and returned.
|
||||
|
@ -52,6 +63,9 @@ func Config(host string, base *tls.Config) *tls.Config {
|
|||
conf = base.Clone()
|
||||
}
|
||||
conf.ServerName = host
|
||||
if len(serverHostOverride()) != 0 {
|
||||
conf.ServerName = serverHostOverride()
|
||||
}
|
||||
|
||||
if n := sslKeyLogFile; n != "" {
|
||||
f, err := os.OpenFile(n, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0600)
|
||||
|
@ -93,6 +107,21 @@ func Config(host string, base *tls.Config) *tls.Config {
|
|||
for _, cert := range cs.PeerCertificates[1:] {
|
||||
opts.Intermediates.AddCert(cert)
|
||||
}
|
||||
|
||||
// Check against user overriden root CA if provided
|
||||
if overrideRoots() != nil {
|
||||
opts.Roots = overrideRoots()
|
||||
_, err := cs.PeerCertificates[0].Verify(opts)
|
||||
if debug() {
|
||||
log.Printf("tlsdial(override %q): %v", host, err)
|
||||
}
|
||||
if err == nil {
|
||||
atomic.AddInt32(&counterFallbackOK, 1)
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
_, errSys := cs.PeerCertificates[0].Verify(opts)
|
||||
if debug() {
|
||||
log.Printf("tlsdial(sys %q): %v", host, errSys)
|
||||
|
@ -272,3 +301,27 @@ func bakedInRoots() *x509.CertPool {
|
|||
})
|
||||
return bakedInRootsOnce.p
|
||||
}
|
||||
|
||||
var overrideRootsOnce struct {
|
||||
sync.Once
|
||||
p *x509.CertPool
|
||||
}
|
||||
|
||||
func overrideRoots() *x509.CertPool {
|
||||
if len(rootCAOverride()) == 0 {
|
||||
return nil
|
||||
}
|
||||
overrideRootsOnce.Do(func() {
|
||||
pem, err := os.ReadFile(rootCAOverride())
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Error loading custom root CA %s: %v", rootCAOverride(), err))
|
||||
}
|
||||
|
||||
p := x509.NewCertPool()
|
||||
if !p.AppendCertsFromPEM(pem) {
|
||||
panic(fmt.Sprintf("Invalid PEM in custom root CA %s", rootCAOverride()))
|
||||
}
|
||||
overrideRootsOnce.p = p
|
||||
})
|
||||
return overrideRootsOnce.p
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue