diff --git a/cmd/tailscaled/depaware.txt b/cmd/tailscaled/depaware.txt index e78270ac0..1d9e87003 100644 --- a/cmd/tailscaled/depaware.txt +++ b/cmd/tailscaled/depaware.txt @@ -286,6 +286,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de tailscale.com/types/structs from tailscale.com/control/controlclient+ tailscale.com/types/tkatype from tailscale.com/tka+ tailscale.com/types/views from tailscale.com/ipn/ipnlocal+ + tailscale.com/util/cache from tailscale.com/control/controlclient+ tailscale.com/util/clientmetric from tailscale.com/control/controlclient+ tailscale.com/util/cloudenv from tailscale.com/net/dns/resolver+ LW tailscale.com/util/cmpver from tailscale.com/net/dns+ diff --git a/control/controlclient/direct.go b/control/controlclient/direct.go index aa0e84293..40e557fe8 100644 --- a/control/controlclient/direct.go +++ b/control/controlclient/direct.go @@ -49,6 +49,7 @@ import ( "tailscale.com/types/opt" "tailscale.com/types/persist" "tailscale.com/types/tkatype" + "tailscale.com/util/cache" "tailscale.com/util/clientmetric" "tailscale.com/util/multierr" "tailscale.com/util/singleflight" @@ -80,6 +81,9 @@ type Direct struct { dialPlan ControlDialPlanner // can be nil + controlKeyMu sync.Mutex // guards controlKeyCache + controlKeyCache cache.Cache[string, *tailcfg.OverTLSPublicKeyResponse] + mu sync.Mutex // mutex guards the following fields serverKey key.MachinePublic // original ("legacy") nacl crypto_box-based public key serverNoiseKey key.MachinePublic @@ -143,6 +147,10 @@ type Options struct { // If we receive a new DialPlan from the server, this value will be // updated. DialPlan ControlDialPlanner + + // ControlKeyCache caches Noise keys returned from control; if nil, no + // cache will be used. + ControlKeyCache cache.Cache[string, *tailcfg.OverTLSPublicKeyResponse] } // ControlDialPlanner is the interface optionally supplied when creating a @@ -227,6 +235,11 @@ func NewDirect(opts Options) (*Direct, error) { httpc = &http.Client{Transport: tr} } + ckcache := opts.ControlKeyCache + if ckcache == nil { + ckcache = cache.None[string, *tailcfg.OverTLSPublicKeyResponse]{} + } + c := &Direct{ httpc: httpc, getMachinePrivKey: opts.GetMachinePrivateKey, @@ -249,6 +262,7 @@ func NewDirect(opts Options) (*Direct, error) { c2nHandler: opts.C2NHandler, dialer: opts.Dialer, dialPlan: opts.DialPlan, + controlKeyCache: ckcache, } if opts.Hostinfo == nil { c.SetHostinfo(hostinfo.New()) @@ -455,11 +469,12 @@ func (c *Direct) doLogin(ctx context.Context, opt loginOpt) (mustRegen bool, new c.logf("doLogin(regen=%v, hasUrl=%v)", regen, opt.URL != "") if serverKey.IsZero() { - keys, err := loadServerPubKeys(ctx, c.httpc, c.serverURL) + keys, cached, err := c.loadServerPubKeys(ctx) if err != nil { + c.logf("error fetching control server key (serverURL=%q): %v", c.serverURL, err) return regen, opt.URL, nil, err } - c.logf("control server key from %s: ts2021=%s, legacy=%v", c.serverURL, keys.PublicKey.ShortString(), keys.LegacyPublicKey.ShortString()) + c.logf("control server key from %s: cached=%v ts2021=%s, legacy=%v", c.serverURL, cached, keys.PublicKey.ShortString(), keys.LegacyPublicKey.ShortString()) c.mu.Lock() c.serverKey = keys.LegacyPublicKey @@ -1225,39 +1240,51 @@ func encode(v any, serverKey, serverNoiseKey key.MachinePublic, mkey key.Machine return mkey.SealTo(serverKey, b), nil } -func loadServerPubKeys(ctx context.Context, httpc *http.Client, serverURL string) (*tailcfg.OverTLSPublicKeyResponse, error) { - keyURL := fmt.Sprintf("%v/key?v=%d", serverURL, tailcfg.CurrentCapabilityVersion) - req, err := http.NewRequestWithContext(ctx, "GET", keyURL, nil) - if err != nil { - return nil, fmt.Errorf("create control key request: %v", err) - } - res, err := httpc.Do(req) - if err != nil { - return nil, fmt.Errorf("fetch control key: %v", err) - } - defer res.Body.Close() - b, err := io.ReadAll(io.LimitReader(res.Body, 64<<10)) - if err != nil { - return nil, fmt.Errorf("fetch control key response: %v", err) - } - if res.StatusCode != 200 { - return nil, fmt.Errorf("fetch control key: %d", res.StatusCode) - } - var out tailcfg.OverTLSPublicKeyResponse - jsonErr := json.Unmarshal(b, &out) - if jsonErr == nil { - return &out, nil - } +func (c *Direct) loadServerPubKeys(ctx context.Context) (ret *tailcfg.OverTLSPublicKeyResponse, cached bool, err error) { + c.controlKeyMu.Lock() + defer c.controlKeyMu.Unlock() + cached = true - // Some old control servers might not be updated to send the new format. - // Accept the old pre-JSON format too. - out = tailcfg.OverTLSPublicKeyResponse{} - k, err := key.ParseMachinePublicUntyped(mem.B(b)) - if err != nil { - return nil, multierr.New(jsonErr, err) - } - out.LegacyPublicKey = k - return &out, nil + keyURL := fmt.Sprintf("%v/key?v=%d", c.serverURL, tailcfg.CurrentCapabilityVersion) + ret, err = c.controlKeyCache.Get(keyURL, func() (*tailcfg.OverTLSPublicKeyResponse, time.Time, error) { + cached = false + req, err := http.NewRequestWithContext(ctx, "GET", keyURL, nil) + if err != nil { + return nil, time.Time{}, fmt.Errorf("create control key request: %v", err) + } + res, err := c.httpc.Do(req) + if err != nil { + return nil, time.Time{}, fmt.Errorf("fetch control key: %v", err) + } + defer res.Body.Close() + b, err := io.ReadAll(io.LimitReader(res.Body, 64<<10)) + if err != nil { + return nil, time.Time{}, fmt.Errorf("fetch control key response: %v", err) + } + if res.StatusCode != 200 { + return nil, time.Time{}, fmt.Errorf("fetch control key: %d", res.StatusCode) + } + + // Cache keys for one minute at most. + expiry := c.timeNow().Add(1 * time.Minute) + + var out tailcfg.OverTLSPublicKeyResponse + jsonErr := json.Unmarshal(b, &out) + if jsonErr == nil { + return &out, expiry, nil + } + + // Some old control servers might not be updated to send the new format. + // Accept the old pre-JSON format too. + out = tailcfg.OverTLSPublicKeyResponse{} + k, err := key.ParseMachinePublicUntyped(mem.B(b)) + if err != nil { + return nil, time.Time{}, multierr.New(jsonErr, err) + } + out.LegacyPublicKey = k + return &out, expiry, nil + }) + return } // DevKnob contains temporary internal-only debug knobs. diff --git a/ipn/ipnlocal/local.go b/ipn/ipnlocal/local.go index c887c3490..ec8195974 100644 --- a/ipn/ipnlocal/local.go +++ b/ipn/ipnlocal/local.go @@ -65,6 +65,7 @@ import ( "tailscale.com/types/preftype" "tailscale.com/types/ptr" "tailscale.com/types/views" + "tailscale.com/util/cache" "tailscale.com/util/deephash" "tailscale.com/util/dnsname" "tailscale.com/util/mak" @@ -211,6 +212,7 @@ type LocalBackend struct { directFileRoot string directFileDoFinalRename bool // false on macOS, true on several NAS platforms componentLogUntil map[string]componentLogState + controlKeyCache cache.Cache[string, *tailcfg.OverTLSPublicKeyResponse] // ServeConfig fields. (also guarded by mu) lastServeConfJSON mem.RO // last JSON that was parsed into serveConfig @@ -286,6 +288,16 @@ func NewLocalBackend(logf logger.Logf, logid string, store ipn.StateStore, diale em: newExpiryManager(logf), gotPortPollRes: make(chan struct{}), loginFlags: loginFlags, + + // TODO(andrew): can we cache this on-disk? If so, we'll need + // to Forget() the value if the tailcfg.CapabilityVersion + // changes between executions, and ensure we're handling + // profile switches/shutdowns. + controlKeyCache: &cache.Memory[string, *tailcfg.OverTLSPublicKeyResponse]{ + // If we can't reach the control server, allow returning + // an expired value from the cache. + ServeExpired: true, + }, } // Default filter blocks everything and logs nothing, until Start() is called. @@ -1382,6 +1394,11 @@ func (b *LocalBackend) Start(opts ipn.Options) error { C2NHandler: http.HandlerFunc(b.handleC2N), DialPlan: &b.dialPlan, // pointer because it can't be copied + // Cache control key in-memory; this helps when moving to a + // network that does TLS MiTM, and allows us to continue using + // the previously-cached control key. + ControlKeyCache: b.controlKeyCache, + // Don't warn about broken Linux IP forwarding when // netstack is being used. SkipIPForwardingCheck: isNetstack, diff --git a/util/cache/cache_test.go b/util/cache/cache_test.go new file mode 100644 index 000000000..31ed265de --- /dev/null +++ b/util/cache/cache_test.go @@ -0,0 +1,155 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package cache + +import ( + "errors" + "path/filepath" + "testing" + "time" +) + +var startTime = time.Date(2023, time.March, 1, 0, 0, 0, 0, time.UTC) + +func TestMemoryCache(t *testing.T) { + testTime := startTime + timeNow := func() time.Time { return testTime } + c := &Memory[string, int]{ + timeNow: timeNow, + } + + t.Run("NoServeExpired", func(t *testing.T) { + testCacheImpl(t, c, &testTime, false) + }) + + t.Run("ServeExpired", func(t *testing.T) { + c.Forget() + c.ServeExpired = true + testTime = startTime + testCacheImpl(t, c, &testTime, true) + }) +} + +func TestDiskCache(t *testing.T) { + testTime := startTime + timeNow := func() time.Time { return testTime } + dc, err := NewDisk[string, int](filepath.Join(t.TempDir(), "cache.json")) + if err != nil { + t.Fatal(err) + } + dc.timeNow = timeNow + + t.Run("NoServeExpired", func(t *testing.T) { + testCacheImpl(t, dc, &testTime, false) + }) + + t.Run("ServeExpired", func(t *testing.T) { + dc.Forget() + dc.ServeExpired = true + testTime = startTime + testCacheImpl(t, dc, &testTime, true) + }) +} + +func testCacheImpl(t *testing.T, c Cache[string, int], testTime *time.Time, serveExpired bool) { + var fillTime time.Time + t.Run("InitialFill", func(t *testing.T) { + fillTime = testTime.Add(time.Hour) + val, err := c.Get("key", func() (int, time.Time, error) { + return 123, fillTime, nil + }) + if err != nil { + t.Fatal(err) + } + if val != 123 { + t.Fatalf("got val=%d; want 123", val) + } + }) + + // Fetching again won't call our fill function + t.Run("SecondFetch", func(t *testing.T) { + *testTime = fillTime.Add(-1 * time.Second) + called := false + val, err := c.Get("key", func() (int, time.Time, error) { + called = true + return -1, fillTime, nil + }) + if called { + t.Fatal("wanted no call to fill function") + } + if err != nil { + t.Fatal(err) + } + if val != 123 { + t.Fatalf("got val=%d; want 123", val) + } + }) + + // Fetching after the expiry time will re-fill + t.Run("ReFill", func(t *testing.T) { + *testTime = fillTime.Add(1) + fillTime = fillTime.Add(time.Hour) + val, err := c.Get("key", func() (int, time.Time, error) { + return 999, fillTime, nil + }) + if err != nil { + t.Fatal(err) + } + if val != 999 { + t.Fatalf("got val=%d; want 999", val) + } + }) + + // An error on fetch will serve the expired value. + t.Run("FetchError", func(t *testing.T) { + if !serveExpired { + t.Skipf("not testing ServeExpired") + } + + *testTime = fillTime.Add(time.Hour + 1) + val, err := c.Get("key", func() (int, time.Time, error) { + return 0, time.Time{}, errors.New("some error") + }) + if err != nil { + t.Fatal(err) + } + if val != 999 { + t.Fatalf("got val=%d; want 999", val) + } + }) + + // Fetching a different key re-fills + t.Run("DifferentKey", func(t *testing.T) { + *testTime = fillTime.Add(time.Hour + 1) + + var calls int + val, err := c.Get("key1", func() (int, time.Time, error) { + calls++ + return 123, fillTime, nil + }) + if err != nil { + t.Fatal(err) + } + if val != 123 { + t.Fatalf("got val=%d; want 123", val) + } + if calls != 1 { + t.Errorf("got %d, want 1 call", calls) + } + + val, err = c.Get("key2", func() (int, time.Time, error) { + calls++ + return 456, fillTime, nil + }) + if err != nil { + t.Fatal(err) + } + if val != 456 { + t.Fatalf("got val=%d; want 456", val) + } + if calls != 2 { + t.Errorf("got %d, want 2 call", calls) + } + }) +} diff --git a/util/cache/disk.go b/util/cache/disk.go new file mode 100644 index 000000000..83beafa6e --- /dev/null +++ b/util/cache/disk.go @@ -0,0 +1,120 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package cache + +import ( + "encoding/json" + "os" + "time" +) + +// Disk is a cache that stores data in a file on-disk. It also supports +// returning a previously-expired value if refreshing the value in the cache +// fails. +type Disk[K comparable, V any] struct { + key K + val V + goodUntil time.Time + path string + timeNow func() time.Time // for tests + + // ServeExpired indicates that if an error occurs when filling the + // cache, an expired value can be returned instead of an error. + ServeExpired bool +} + +type diskValue[K comparable, V any] struct { + Key K + Value V + Until time.Time // Always UTC +} + +func NewDisk[K comparable, V any](path string) (*Disk[K, V], error) { + f, err := os.Open(path) + if err != nil { + if !os.IsNotExist(err) { + return nil, err + } + + // Ignore "does not exist" errors + return &Disk[K, V]{path: path}, nil + } + defer f.Close() + + var dv diskValue[K, V] + if err := json.NewDecoder(f).Decode(&dv); err != nil { + // Ignore errors; we'll overwrite when filling. + return &Disk[K, V]{path: path}, nil + } + + return &Disk[K, V]{ + key: dv.Key, + val: dv.Value, + goodUntil: dv.Until, + path: path, + }, nil +} + +// Get will return the cached value, if any, or fill the cache by calling f and +// return the corresponding value. When the cache is filled, the value will be +// written to the configured path on-disk, along with the expiry time. Writing +// to the path on-disk is non-fatal. +// +// If f returns an error and c.ServeExpired is true, then a previous expired +// value can be returned with no error. +func (d *Disk[K, V]) Get(key K, f FillFunc[V]) (V, error) { + var now time.Time + if d.timeNow != nil { + now = d.timeNow() + } else { + now = time.Now() + } + + if d.key == key && now.Before(d.goodUntil) { + return d.val, nil + } + + // Re-fill cached entry + val, until, err := f() + if err == nil { + d.key = key + d.val = val + d.goodUntil = until + d.write() + return val, nil + } + + // Never serve an expired entry for the wrong key. + if d.key == key && d.ServeExpired && !d.goodUntil.IsZero() { + return d.val, nil + } + + var zero V + return zero, err +} + +func (d *Disk[K, V]) write() { + // Try writing to the file on-disk, but ignore errors. + b, err := json.Marshal(diskValue[K, V]{ + Key: d.key, + Value: d.val, + Until: d.goodUntil.UTC(), + }) + if err == nil { + os.WriteFile(d.path, b, 0600) + } +} + +// Forget implements Cache. +func (d *Disk[K, V]) Forget() { + d.goodUntil = time.Time{} + + var zeroKey K + d.key = zeroKey + + var zeroVal V + d.val = zeroVal + + d.write() +} diff --git a/util/cache/interface.go b/util/cache/interface.go new file mode 100644 index 000000000..b35cd3160 --- /dev/null +++ b/util/cache/interface.go @@ -0,0 +1,26 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package cache contains an interface for a cache around a typed value, and +// various cache implementations that implement that interface. +package cache + +import "time" + +// Cache is the interface for the cache types in this package. +type Cache[K comparable, V any] interface { + // Get should return a previously-cached value or call the provided + // FillFunc to obtain a new one. The provided key can be used either to + // allow multiple cached values, or to drop the cache if the key + // changes; either is valid. + Get(K, FillFunc[V]) (V, error) + + // Forget should empty the cache such that the next call to Get should + // call the provided FillFunc. + Forget() +} + +// FillFunc is the signature of a function for filling a cache. It should +// return the value to be cached, the time that the cached value is valid +// until, or an error +type FillFunc[T any] func() (T, time.Time, error) diff --git a/util/cache/memory.go b/util/cache/memory.go new file mode 100644 index 000000000..5395f2b1a --- /dev/null +++ b/util/cache/memory.go @@ -0,0 +1,64 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package cache + +import "time" + +// Memory is a simple in-memory cache that stores a value until a defined time +// before it is re-fetched. It also supports returning a previously-expired +// value if refreshing the value in the cache fails. +type Memory[K comparable, V any] struct { + key K + val V + goodUntil time.Time + timeNow func() time.Time // for tests + + // ServeExpired indicates that if an error occurs when filling the + // cache, an expired value can be returned instead of an error. + ServeExpired bool +} + +// Get will return the cached value, if any, or fill the cache by calling f and +// return the corresponding value. If f returns an error and c.ServeExpired is +// true, then a previous expired value can be returned with no error. +func (c *Memory[K, V]) Get(key K, f FillFunc[V]) (V, error) { + var now time.Time + if c.timeNow != nil { + now = c.timeNow() + } else { + now = time.Now() + } + + if c.key == key && now.Before(c.goodUntil) { + return c.val, nil + } + + // Re-fill cached entry + val, until, err := f() + if err == nil { + c.key = key + c.val = val + c.goodUntil = until + return val, nil + } + + // Never serve an expired entry for the wrong key. + if c.key == key && c.ServeExpired && !c.goodUntil.IsZero() { + return c.val, nil + } + + var zero V + return zero, err +} + +// Forget implements Cache. +func (c *Memory[K, V]) Forget() { + c.goodUntil = time.Time{} + + var zeroKey K + c.key = zeroKey + + var zeroVal V + c.val = zeroVal +} diff --git a/util/cache/none.go b/util/cache/none.go new file mode 100644 index 000000000..1ec5265b1 --- /dev/null +++ b/util/cache/none.go @@ -0,0 +1,16 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package cache + +// None provides no caching and always calls the provided FillFunc. +type None[K comparable, V any] struct{} + +// Get always calls the provided FillFunc and returns what it does. +func (c None[K, V]) Get(_ K, f FillFunc[V]) (V, error) { + v, _, e := f() + return v, e +} + +// Forget implements Cache. +func (c None[K, V]) Forget() {}