Compare commits

...

1 Commits

Author SHA1 Message Date
Anton Tolchanov 654247715c prober: allow probes to export gauge metrics
This allows each probe to easily export gauge metrics with consistent
naming scheme without integrating with expvar directly.

Note that it's a backwards incompatible change to ProbeFunc. Given
minimal usage of this module I don't think this is problematic, but I
can add a separate function here if you would prefer that.

Signed-off-by: Anton Tolchanov <anton@tailscale.com>
2022-10-17 11:12:29 +01:00
6 changed files with 92 additions and 50 deletions

View File

@ -20,15 +20,15 @@ const maxHTTPBody = 4 << 20 // MiB
// response, and verifies that want is present in the response // response, and verifies that want is present in the response
// body. // body.
func HTTP(url, wantText string) ProbeFunc { func HTTP(url, wantText string) ProbeFunc {
return func(ctx context.Context) error { return func(ctx context.Context) (*ProbeResponse, error) {
return probeHTTP(ctx, url, []byte(wantText)) return probeHTTP(ctx, url, []byte(wantText))
} }
} }
func probeHTTP(ctx context.Context, url string, want []byte) error { func probeHTTP(ctx context.Context, url string, want []byte) (*ProbeResponse, error) {
req, err := http.NewRequestWithContext(ctx, "GET", url, nil) req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil { if err != nil {
return fmt.Errorf("constructing request: %w", err) return nil, fmt.Errorf("constructing request: %w", err)
} }
// Get a completely new transport each time, so we don't reuse a // Get a completely new transport each time, so we don't reuse a
@ -41,21 +41,21 @@ func probeHTTP(ctx context.Context, url string, want []byte) error {
resp, err := c.Do(req) resp, err := c.Do(req)
if err != nil { if err != nil {
return fmt.Errorf("fetching %q: %w", url, err) return nil, fmt.Errorf("fetching %q: %w", url, err)
} }
defer resp.Body.Close() defer resp.Body.Close()
if resp.StatusCode != 200 { if resp.StatusCode != 200 {
return fmt.Errorf("fetching %q: status code %d, want 200", url, resp.StatusCode) return nil, fmt.Errorf("fetching %q: status code %d, want 200", url, resp.StatusCode)
} }
bs, err := io.ReadAll(io.LimitReader(resp.Body, maxHTTPBody)) bs, err := io.ReadAll(io.LimitReader(resp.Body, maxHTTPBody))
if err != nil { if err != nil {
return fmt.Errorf("reading body of %q: %w", url, err) return nil, fmt.Errorf("reading body of %q: %w", url, err)
} }
if !bytes.Contains(bs, want) { if !bytes.Contains(bs, want) {
return fmt.Errorf("body of %q does not contain %q", url, want) return nil, fmt.Errorf("body of %q does not contain %q", url, want)
} }
return nil return nil, nil
} }

View File

@ -24,7 +24,7 @@ import (
// ProbeFunc is a function that probes something and reports whether // ProbeFunc is a function that probes something and reports whether
// the probe succeeded. The provided context's deadline must be obeyed // the probe succeeded. The provided context's deadline must be obeyed
// for correct probe scheduling. // for correct probe scheduling.
type ProbeFunc func(context.Context) error type ProbeFunc func(context.Context) (*ProbeResponse, error)
// a Prober manages a set of probes and keeps track of their results. // a Prober manages a set of probes and keeps track of their results.
type Prober struct { type Prober struct {
@ -97,6 +97,14 @@ func (p *Prober) activeProbes() int {
return len(p.probes) return len(p.probes)
} }
type ProbeResponse struct {
Gauges map[string]float64
}
func NewResponse() *ProbeResponse {
return &ProbeResponse{Gauges: make(map[string]float64)}
}
// Probe is a probe that healthchecks something and updates Prometheus // Probe is a probe that healthchecks something and updates Prometheus
// metrics with the results. // metrics with the results.
type Probe struct { type Probe struct {
@ -111,10 +119,12 @@ type Probe struct {
tick ticker tick ticker
labels map[string]string labels map[string]string
mu sync.Mutex mu sync.Mutex
start time.Time // last time doProbe started start time.Time // last time doProbe started
end time.Time // last time doProbe returned end time.Time // last time doProbe returned
result bool // whether the last doProbe call succeeded result bool // whether the last doProbe call succeeded
lastErr error
lastResponse *ProbeResponse
} }
// Close shuts down the Probe and unregisters it from its Prober. // Close shuts down the Probe and unregisters it from its Prober.
@ -158,15 +168,15 @@ func (p *Probe) run() {
// alert for debugging. // alert for debugging.
if r := recover(); r != nil { if r := recover(); r != nil {
log.Printf("probe %s panicked: %v", p.name, r) log.Printf("probe %s panicked: %v", p.name, r)
p.recordEnd(start, errors.New("panic")) p.recordEnd(start, errors.New("panic"), nil)
} }
}() }()
timeout := time.Duration(float64(p.interval) * 0.8) timeout := time.Duration(float64(p.interval) * 0.8)
ctx, cancel := context.WithTimeout(p.ctx, timeout) ctx, cancel := context.WithTimeout(p.ctx, timeout)
defer cancel() defer cancel()
err := p.doProbe(ctx) resp, err := p.doProbe(ctx)
p.recordEnd(start, err) p.recordEnd(start, err, resp)
if err != nil { if err != nil {
log.Printf("probe %s: %v", p.name, err) log.Printf("probe %s: %v", p.name, err)
} }
@ -180,12 +190,14 @@ func (p *Probe) recordStart() time.Time {
return st return st
} }
func (p *Probe) recordEnd(start time.Time, err error) { func (p *Probe) recordEnd(start time.Time, err error, resp *ProbeResponse) {
end := p.prober.now() end := p.prober.now()
p.mu.Lock() p.mu.Lock()
defer p.mu.Unlock() defer p.mu.Unlock()
p.end = end p.end = end
p.result = err == nil p.result = err == nil
p.lastErr = err
p.lastResponse = resp
} }
type varExporter struct { type varExporter struct {
@ -195,11 +207,13 @@ type varExporter struct {
// probeInfo is the state of a Probe. Used in expvar-format debug // probeInfo is the state of a Probe. Used in expvar-format debug
// data. // data.
type probeInfo struct { type probeInfo struct {
Labels map[string]string Labels map[string]string
Start time.Time Start time.Time
End time.Time End time.Time
Latency string // as a string because time.Duration doesn't encode readably to JSON Latency string // as a string because time.Duration doesn't encode readably to JSON
Result bool Result bool
Err string
Response *ProbeResponse
} }
// String implements expvar.Var, returning the prober's state as an // String implements expvar.Var, returning the prober's state as an
@ -217,14 +231,18 @@ func (v varExporter) String() string {
for _, probe := range probes { for _, probe := range probes {
probe.mu.Lock() probe.mu.Lock()
inf := probeInfo{ inf := probeInfo{
Labels: probe.labels, Labels: probe.labels,
Start: probe.start, Start: probe.start,
End: probe.end, End: probe.end,
Result: probe.result, Result: probe.result,
Response: probe.lastResponse,
} }
if probe.end.After(probe.start) { if probe.end.After(probe.start) {
inf.Latency = probe.end.Sub(probe.start).String() inf.Latency = probe.end.Sub(probe.start).String()
} }
if probe.lastErr != nil {
inf.Err = probe.lastErr.Error()
}
out[probe.name] = inf out[probe.name] = inf
probe.mu.Unlock() probe.mu.Unlock()
} }
@ -289,6 +307,11 @@ func (v varExporter) WritePrometheus(w io.Writer, prefix string) {
fmt.Fprintf(w, "%s_result{%s} 0\n", prefix, labels) fmt.Fprintf(w, "%s_result{%s} 0\n", prefix, labels)
} }
} }
if probe.lastResponse != nil {
for n, v := range probe.lastResponse.Gauges {
fmt.Fprintf(w, "%s_result_%s{%s} %f\n", prefix, n, labels, v)
}
}
probe.mu.Unlock() probe.mu.Unlock()
} }
} }

View File

@ -55,9 +55,9 @@ func TestProberTiming(t *testing.T) {
} }
} }
p.Run("test-probe", probeInterval, nil, func(context.Context) error { p.Run("test-probe", probeInterval, nil, func(context.Context) (*ProbeResponse, error) {
invoked <- struct{}{} invoked <- struct{}{}
return nil return nil, nil
}) })
waitActiveProbes(t, p, 1) waitActiveProbes(t, p, 1)
@ -87,11 +87,11 @@ func TestProberRun(t *testing.T) {
var probes []*Probe var probes []*Probe
for i := 0; i < startingProbes; i++ { for i := 0; i < startingProbes; i++ {
probes = append(probes, p.Run(fmt.Sprintf("probe%d", i), probeInterval, nil, func(context.Context) error { probes = append(probes, p.Run(fmt.Sprintf("probe%d", i), probeInterval, nil, func(context.Context) (*ProbeResponse, error) {
mu.Lock() mu.Lock()
defer mu.Unlock() defer mu.Unlock()
cnt++ cnt++
return nil return nil, nil
})) }))
} }
@ -132,12 +132,12 @@ func TestExpvar(t *testing.T) {
p := newForTest(clk.Now, clk.NewTicker) p := newForTest(clk.Now, clk.NewTicker)
var succeed atomic.Bool var succeed atomic.Bool
p.Run("probe", probeInterval, map[string]string{"label": "value"}, func(context.Context) error { p.Run("probe", probeInterval, map[string]string{"label": "value"}, func(context.Context) (*ProbeResponse, error) {
clk.Advance(aFewMillis) clk.Advance(aFewMillis)
if succeed.Load() { if succeed.Load() {
return nil return nil, nil
} }
return errors.New("failing, as instructed by test") return nil, errors.New("failing, as instructed by test")
}) })
waitActiveProbes(t, p, 1) waitActiveProbes(t, p, 1)
@ -170,6 +170,7 @@ func TestExpvar(t *testing.T) {
End: epoch.Add(aFewMillis), End: epoch.Add(aFewMillis),
Latency: aFewMillis.String(), Latency: aFewMillis.String(),
Result: false, Result: false,
Err: "failing, as instructed by test",
}) })
succeed.Store(true) succeed.Store(true)
@ -190,12 +191,12 @@ func TestPrometheus(t *testing.T) {
p := newForTest(clk.Now, clk.NewTicker) p := newForTest(clk.Now, clk.NewTicker)
var succeed atomic.Bool var succeed atomic.Bool
p.Run("testprobe", probeInterval, map[string]string{"label": "value"}, func(context.Context) error { p.Run("testprobe", probeInterval, map[string]string{"label": "value"}, func(context.Context) (*ProbeResponse, error) {
clk.Advance(aFewMillis) clk.Advance(aFewMillis)
if succeed.Load() { if succeed.Load() {
return nil return nil, nil
} }
return errors.New("failing, as instructed by test") return nil, errors.New("failing, as instructed by test")
}) })
waitActiveProbes(t, p, 1) waitActiveProbes(t, p, 1)

View File

@ -14,17 +14,17 @@ import (
// //
// The ProbeFunc reports whether it can successfully connect to addr. // The ProbeFunc reports whether it can successfully connect to addr.
func TCP(addr string) ProbeFunc { func TCP(addr string) ProbeFunc {
return func(ctx context.Context) error { return func(ctx context.Context) (*ProbeResponse, error) {
return probeTCP(ctx, addr) return probeTCP(ctx, addr)
} }
} }
func probeTCP(ctx context.Context, addr string) error { func probeTCP(ctx context.Context, addr string) (*ProbeResponse, error) {
var d net.Dialer var d net.Dialer
conn, err := d.DialContext(ctx, "tcp", addr) conn, err := d.DialContext(ctx, "tcp", addr)
if err != nil { if err != nil {
return fmt.Errorf("dialing %q: %v", addr, err) return nil, fmt.Errorf("dialing %q: %v", addr, err)
} }
conn.Close() conn.Close()
return nil return nil, nil
} }

View File

@ -21,6 +21,7 @@ import (
) )
const expiresSoon = 7 * 24 * time.Hour // 7 days from now const expiresSoon = 7 * 24 * time.Hour // 7 days from now
const earliestExpiration = "earliest_cert_expiration_secs"
// TLS returns a Probe that healthchecks a TLS endpoint. // TLS returns a Probe that healthchecks a TLS endpoint.
// //
@ -28,21 +29,21 @@ const expiresSoon = 7 * 24 * time.Hour // 7 days from now
// handshake, verifies that the hostname matches the presented certificate, // handshake, verifies that the hostname matches the presented certificate,
// checks certificate validity time and OCSP revocation status. // checks certificate validity time and OCSP revocation status.
func TLS(hostname string) ProbeFunc { func TLS(hostname string) ProbeFunc {
return func(ctx context.Context) error { return func(ctx context.Context) (*ProbeResponse, error) {
return probeTLS(ctx, hostname) return probeTLS(ctx, hostname)
} }
} }
func probeTLS(ctx context.Context, hostname string) error { func probeTLS(ctx context.Context, hostname string) (*ProbeResponse, error) {
host, _, err := net.SplitHostPort(hostname) host, _, err := net.SplitHostPort(hostname)
if err != nil { if err != nil {
return err return nil, err
} }
dialer := &tls.Dialer{Config: &tls.Config{ServerName: host}} dialer := &tls.Dialer{Config: &tls.Config{ServerName: host}}
conn, err := dialer.DialContext(ctx, "tcp", hostname) conn, err := dialer.DialContext(ctx, "tcp", hostname)
if err != nil { if err != nil {
return fmt.Errorf("connecting to %q: %w", hostname, err) return nil, fmt.Errorf("connecting to %q: %w", hostname, err)
} }
defer conn.Close() defer conn.Close()
@ -53,13 +54,15 @@ func probeTLS(ctx context.Context, hostname string) error {
// validateConnState verifies certificate validity time in all certificates // validateConnState verifies certificate validity time in all certificates
// returned by the TLS server and checks OCSP revocation status for the // returned by the TLS server and checks OCSP revocation status for the
// leaf cert. // leaf cert.
func validateConnState(ctx context.Context, cs *tls.ConnectionState) (returnerr error) { func validateConnState(ctx context.Context, cs *tls.ConnectionState) (resp *ProbeResponse, returnerr error) {
var errs []error var errs []error
defer func() { defer func() {
returnerr = multierr.New(errs...) returnerr = multierr.New(errs...)
}() }()
latestAllowedExpiration := time.Now().Add(expiresSoon) latestAllowedExpiration := time.Now().Add(expiresSoon)
resp = NewResponse()
var leafCert *x509.Certificate var leafCert *x509.Certificate
var issuerCert *x509.Certificate var issuerCert *x509.Certificate
var leafAuthorityKeyID string var leafAuthorityKeyID string
@ -68,6 +71,7 @@ func validateConnState(ctx context.Context, cs *tls.ConnectionState) (returnerr
if i == 0 { if i == 0 {
leafCert = cert leafCert = cert
leafAuthorityKeyID = string(cert.AuthorityKeyId) leafAuthorityKeyID = string(cert.AuthorityKeyId)
resp.Gauges[earliestExpiration] = float64(cert.NotAfter.Unix())
} }
if i > 0 { if i > 0 {
if leafAuthorityKeyID == string(cert.SubjectKeyId) { if leafAuthorityKeyID == string(cert.SubjectKeyId) {
@ -90,6 +94,10 @@ func validateConnState(ctx context.Context, cs *tls.ConnectionState) (returnerr
left := cert.NotAfter.Sub(time.Now()) left := cert.NotAfter.Sub(time.Now())
errs = append(errs, fmt.Errorf("one of the certs expires in %v: %v", left, cert.Subject)) errs = append(errs, fmt.Errorf("one of the certs expires in %v: %v", left, cert.Subject))
} }
if float64(cert.NotAfter.Unix()) < resp.Gauges[earliestExpiration] {
resp.Gauges[earliestExpiration] = float64(cert.NotAfter.Unix())
}
} }
if len(leafCert.OCSPServer) == 0 { if len(leafCert.OCSPServer) == 0 {

View File

@ -48,7 +48,7 @@ var issuerCertTpl = x509.Certificate{
Version: 3, Version: 3,
IPAddresses: []net.IP{net.IPv4(127, 0, 0, 1), net.IPv6loopback}, IPAddresses: []net.IP{net.IPv4(127, 0, 0, 1), net.IPv6loopback},
NotBefore: time.Now().Add(-5 * time.Minute), NotBefore: time.Now().Add(-5 * time.Minute),
NotAfter: time.Now().Add(60 * 24 * time.Hour), NotAfter: time.Now().Add(55 * 24 * time.Hour),
SubjectKeyId: []byte{1, 2, 3, 4, 5}, SubjectKeyId: []byte{1, 2, 3, 4, 5},
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
KeyUsage: x509.KeyUsageDigitalSignature, KeyUsage: x509.KeyUsageDigitalSignature,
@ -86,7 +86,7 @@ func TestTLSConnection(t *testing.T) {
srv.StartTLS() srv.StartTLS()
defer srv.Close() defer srv.Close()
err = probeTLS(context.Background(), srv.Listener.Addr().String()) _, err = probeTLS(context.Background(), srv.Listener.Addr().String())
// The specific error message here is platform-specific ("certificate is not trusted" // The specific error message here is platform-specific ("certificate is not trusted"
// on macOS and "certificate signed by unknown authority" on Linux), so only check // on macOS and "certificate signed by unknown authority" on Linux), so only check
// that it contains the word 'certificate'. // that it contains the word 'certificate'.
@ -126,11 +126,21 @@ func TestCertExpiration(t *testing.T) {
}, },
} { } {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
cs := &tls.ConnectionState{PeerCertificates: []*x509.Certificate{tt.cert()}} leaf := tt.cert()
err := validateConnState(context.Background(), cs) cs := &tls.ConnectionState{PeerCertificates: []*x509.Certificate{leaf, &issuerCertTpl}}
resp, err := validateConnState(context.Background(), cs)
if err == nil || !strings.Contains(err.Error(), tt.wantErr) { if err == nil || !strings.Contains(err.Error(), tt.wantErr) {
t.Errorf("unexpected error %q; want %q", err, tt.wantErr) t.Errorf("unexpected error %q; want %q", err, tt.wantErr)
} }
wantExpiration := issuerCertTpl.NotAfter.Unix()
if leaf.NotAfter.Unix() < wantExpiration {
wantExpiration = leaf.NotAfter.Unix()
}
if int64(resp.Gauges[earliestExpiration]) != wantExpiration {
t.Errorf("unexpected cert expiration metric: %f; want %d", resp.Gauges[earliestExpiration], wantExpiration)
}
}) })
} }
} }
@ -222,7 +232,7 @@ func TestOCSP(t *testing.T) {
handler.template.SerialNumber = big.NewInt(1337) handler.template.SerialNumber = big.NewInt(1337)
} }
cs := &tls.ConnectionState{PeerCertificates: []*x509.Certificate{parsed, issuerCert}} cs := &tls.ConnectionState{PeerCertificates: []*x509.Certificate{parsed, issuerCert}}
err := validateConnState(context.Background(), cs) _, err := validateConnState(context.Background(), cs)
if err == nil && tt.wantErr == "" { if err == nil && tt.wantErr == "" {
return return