diff --git a/ipn/localapi/cert.go b/ipn/localapi/cert.go index 05906331d..b08117562 100644 --- a/ipn/localapi/cert.go +++ b/ipn/localapi/cert.go @@ -31,6 +31,7 @@ import ( "path/filepath" "strconv" "strings" + "sync" "time" "golang.org/x/crypto/acme" @@ -39,6 +40,18 @@ import ( "tailscale.com/types/logger" ) +// Process-wide cache. (A new *Handler is created per connection, +// effectively per request) +var ( + // acmeMu guards all ACME operations, so concurrent requests + // for certs don't slam ACME. The first will go through and + // populate the on-disk cache and the rest should use that. + acmeMu sync.Mutex + + renewMu sync.Mutex // lock order: don't hold acmeMu and renewMu at the same time + lastRenewCheck = map[string]time.Time{} +) + func (h *Handler) certDir() (string, error) { base := paths.DefaultTailscaledStateFile() if base == "" { @@ -65,17 +78,13 @@ func (h *Handler) serveCert(w http.ResponseWriter, r *http.Request) { return } - suff := strings.TrimPrefix(r.URL.Path, "/localapi/v0/cert/") - if suff == r.URL.Path { + domain := strings.TrimPrefix(r.URL.Path, "/localapi/v0/cert/") + if domain == r.URL.Path { http.Error(w, "internal handler config wired wrong", 500) return } - domain := suff - - mu := &h.certMu - mu.Lock() - defer mu.Unlock() + now := time.Now() logf := logger.WithPrefix(h.logf, fmt.Sprintf("cert(%q): ", domain)) traceACME := func(v interface{}) { if !acmeDebug { @@ -85,22 +94,50 @@ func (h *Handler) serveCert(w http.ResponseWriter, r *http.Request) { log.Printf("acme %T: %s", v, j) } - pair, err := h.getCertPEM(r.Context(), logf, traceACME, dir, domain, time.Now()) + if pair, ok := h.getCertPEMCached(dir, domain, now); ok { + future := now.AddDate(0, 0, 14) + if h.shouldStartDomainRenewal(dir, domain, future) { + logf("starting async renewal") + // Start renewal in the background. + go h.getCertPEM(context.Background(), logf, traceACME, dir, domain, future) + } + serveKeyPair(w, r, pair) + return + } + + pair, err := h.getCertPEM(r.Context(), logf, traceACME, dir, domain, now) if err != nil { logf("getCertPEM: %v", err) http.Error(w, fmt.Sprint(err), 500) return } + serveKeyPair(w, r, pair) +} +func (h *Handler) shouldStartDomainRenewal(dir, domain string, future time.Time) bool { + renewMu.Lock() + defer renewMu.Unlock() + now := time.Now() + if last, ok := lastRenewCheck[domain]; ok && now.Sub(last) < time.Minute { + // We checked very recently. Don't bother reparsing & + // validating the x509 cert. + return false + } + lastRenewCheck[domain] = now + _, ok := h.getCertPEMCached(dir, domain, future) + return !ok +} + +func serveKeyPair(w http.ResponseWriter, r *http.Request, p *keyPair) { w.Header().Set("Content-Type", "text/plain") switch r.URL.Query().Get("type") { case "", "crt", "cert": - w.Write(pair.certPEM) + w.Write(p.certPEM) case "key": - w.Write(pair.keyPEM) + w.Write(p.keyPEM) case "pair": - w.Write(pair.keyPEM) - w.Write(pair.certPEM) + w.Write(p.keyPEM) + w.Write(p.certPEM) default: http.Error(w, `invalid type; want "cert" (default), "key", or "pair"`, 400) } @@ -112,16 +149,29 @@ type keyPair struct { cached bool } -func (h *Handler) getCertPEM(ctx context.Context, logf logger.Logf, traceACME func(interface{}), dir, domain string, now time.Time) (*keyPair, error) { - keyFile := filepath.Join(dir, domain+".key") - certFile := filepath.Join(dir, domain+".crt") +func keyFile(dir, domain string) string { return filepath.Join(dir, domain+".key") } +func certFile(dir, domain string) string { return filepath.Join(dir, domain+".crt") } - if keyPEM, err := os.ReadFile(keyFile); err == nil { - certPEM, _ := os.ReadFile(certFile) +// getCertPEMCached returns a non-nil keyPair and true if a cached +// keypair for domain exists on disk in dir that is valid at the +// provided now time. +func (h *Handler) getCertPEMCached(dir, domain string, now time.Time) (p *keyPair, ok bool) { + if keyPEM, err := os.ReadFile(keyFile(dir, domain)); err == nil { + certPEM, _ := os.ReadFile(certFile(dir, domain)) if validCertPEM(domain, keyPEM, certPEM, now) { - return &keyPair{certPEM: certPEM, keyPEM: keyPEM, cached: true}, nil + return &keyPair{certPEM: certPEM, keyPEM: keyPEM, cached: true}, true } } + return nil, false +} + +func (h *Handler) getCertPEM(ctx context.Context, logf logger.Logf, traceACME func(interface{}), dir, domain string, now time.Time) (*keyPair, error) { + acmeMu.Lock() + defer acmeMu.Unlock() + + if p, ok := h.getCertPEMCached(dir, domain, now); ok { + return p, nil + } key, err := acmeKey(dir) if err != nil { @@ -238,7 +288,7 @@ func (h *Handler) getCertPEM(ctx context.Context, logf logger.Logf, traceACME fu if err := encodeECDSAKey(&privPEM, certPrivKey); err != nil { return nil, err } - if err := ioutil.WriteFile(keyFile, privPEM.Bytes(), 0600); err != nil { + if err := ioutil.WriteFile(keyFile(dir, domain), privPEM.Bytes(), 0600); err != nil { return nil, err } @@ -259,7 +309,7 @@ func (h *Handler) getCertPEM(ctx context.Context, logf logger.Logf, traceACME fu return nil, err } } - if err := ioutil.WriteFile(certFile, certPEM.Bytes(), 0644); err != nil { + if err := ioutil.WriteFile(certFile(dir, domain), certPEM.Bytes(), 0644); err != nil { return nil, err } diff --git a/ipn/localapi/localapi.go b/ipn/localapi/localapi.go index 76ff67732..6a5116d90 100644 --- a/ipn/localapi/localapi.go +++ b/ipn/localapi/localapi.go @@ -57,12 +57,6 @@ type Handler struct { b *ipnlocal.LocalBackend logf logger.Logf backendLogID string - - // certMu guards all cert/ACME operations, so concurrent - // requests for certs don't slam ACME. The first will go - // through and populate the on-disk cache and the rest should - // use that. - certMu sync.Mutex } func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {