Compare commits

...

1 Commits

Author SHA1 Message Date
Andrew Dunham ff989a9a3a control/controlclient: cache control key
This can allow us to continue to communicate with control when moving to
a network that does TLS MiTM, which would otherwise prevent us from
being able to fetch the Noise key and establish a connection.

Updates #3198 (sorta)

Change-Id: I52caf5079de744874a2bdd0c9ffb9e8f087ff8e0
Signed-off-by: Andrew Dunham <andrew@du.nham.ca>
2023-03-07 19:31:23 -05:00
8 changed files with 460 additions and 34 deletions

View File

@ -286,6 +286,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de
tailscale.com/types/structs from tailscale.com/control/controlclient+
tailscale.com/types/tkatype from tailscale.com/tka+
tailscale.com/types/views from tailscale.com/ipn/ipnlocal+
tailscale.com/util/cache from tailscale.com/control/controlclient+
tailscale.com/util/clientmetric from tailscale.com/control/controlclient+
tailscale.com/util/cloudenv from tailscale.com/net/dns/resolver+
LW tailscale.com/util/cmpver from tailscale.com/net/dns+

View File

@ -49,6 +49,7 @@ import (
"tailscale.com/types/opt"
"tailscale.com/types/persist"
"tailscale.com/types/tkatype"
"tailscale.com/util/cache"
"tailscale.com/util/clientmetric"
"tailscale.com/util/multierr"
"tailscale.com/util/singleflight"
@ -80,6 +81,9 @@ type Direct struct {
dialPlan ControlDialPlanner // can be nil
controlKeyMu sync.Mutex // guards controlKeyCache
controlKeyCache cache.Cache[string, *tailcfg.OverTLSPublicKeyResponse]
mu sync.Mutex // mutex guards the following fields
serverKey key.MachinePublic // original ("legacy") nacl crypto_box-based public key
serverNoiseKey key.MachinePublic
@ -143,6 +147,10 @@ type Options struct {
// If we receive a new DialPlan from the server, this value will be
// updated.
DialPlan ControlDialPlanner
// ControlKeyCache caches Noise keys returned from control; if nil, no
// cache will be used.
ControlKeyCache cache.Cache[string, *tailcfg.OverTLSPublicKeyResponse]
}
// ControlDialPlanner is the interface optionally supplied when creating a
@ -227,6 +235,11 @@ func NewDirect(opts Options) (*Direct, error) {
httpc = &http.Client{Transport: tr}
}
ckcache := opts.ControlKeyCache
if ckcache == nil {
ckcache = cache.None[string, *tailcfg.OverTLSPublicKeyResponse]{}
}
c := &Direct{
httpc: httpc,
getMachinePrivKey: opts.GetMachinePrivateKey,
@ -249,6 +262,7 @@ func NewDirect(opts Options) (*Direct, error) {
c2nHandler: opts.C2NHandler,
dialer: opts.Dialer,
dialPlan: opts.DialPlan,
controlKeyCache: ckcache,
}
if opts.Hostinfo == nil {
c.SetHostinfo(hostinfo.New())
@ -455,11 +469,12 @@ func (c *Direct) doLogin(ctx context.Context, opt loginOpt) (mustRegen bool, new
c.logf("doLogin(regen=%v, hasUrl=%v)", regen, opt.URL != "")
if serverKey.IsZero() {
keys, err := loadServerPubKeys(ctx, c.httpc, c.serverURL)
keys, cached, err := c.loadServerPubKeys(ctx)
if err != nil {
c.logf("error fetching control server key (serverURL=%q): %v", c.serverURL, err)
return regen, opt.URL, nil, err
}
c.logf("control server key from %s: ts2021=%s, legacy=%v", c.serverURL, keys.PublicKey.ShortString(), keys.LegacyPublicKey.ShortString())
c.logf("control server key from %s: cached=%v ts2021=%s, legacy=%v", c.serverURL, cached, keys.PublicKey.ShortString(), keys.LegacyPublicKey.ShortString())
c.mu.Lock()
c.serverKey = keys.LegacyPublicKey
@ -1225,39 +1240,51 @@ func encode(v any, serverKey, serverNoiseKey key.MachinePublic, mkey key.Machine
return mkey.SealTo(serverKey, b), nil
}
func loadServerPubKeys(ctx context.Context, httpc *http.Client, serverURL string) (*tailcfg.OverTLSPublicKeyResponse, error) {
keyURL := fmt.Sprintf("%v/key?v=%d", serverURL, tailcfg.CurrentCapabilityVersion)
req, err := http.NewRequestWithContext(ctx, "GET", keyURL, nil)
if err != nil {
return nil, fmt.Errorf("create control key request: %v", err)
}
res, err := httpc.Do(req)
if err != nil {
return nil, fmt.Errorf("fetch control key: %v", err)
}
defer res.Body.Close()
b, err := io.ReadAll(io.LimitReader(res.Body, 64<<10))
if err != nil {
return nil, fmt.Errorf("fetch control key response: %v", err)
}
if res.StatusCode != 200 {
return nil, fmt.Errorf("fetch control key: %d", res.StatusCode)
}
var out tailcfg.OverTLSPublicKeyResponse
jsonErr := json.Unmarshal(b, &out)
if jsonErr == nil {
return &out, nil
}
func (c *Direct) loadServerPubKeys(ctx context.Context) (ret *tailcfg.OverTLSPublicKeyResponse, cached bool, err error) {
c.controlKeyMu.Lock()
defer c.controlKeyMu.Unlock()
cached = true
// Some old control servers might not be updated to send the new format.
// Accept the old pre-JSON format too.
out = tailcfg.OverTLSPublicKeyResponse{}
k, err := key.ParseMachinePublicUntyped(mem.B(b))
if err != nil {
return nil, multierr.New(jsonErr, err)
}
out.LegacyPublicKey = k
return &out, nil
keyURL := fmt.Sprintf("%v/key?v=%d", c.serverURL, tailcfg.CurrentCapabilityVersion)
ret, err = c.controlKeyCache.Get(keyURL, func() (*tailcfg.OverTLSPublicKeyResponse, time.Time, error) {
cached = false
req, err := http.NewRequestWithContext(ctx, "GET", keyURL, nil)
if err != nil {
return nil, time.Time{}, fmt.Errorf("create control key request: %v", err)
}
res, err := c.httpc.Do(req)
if err != nil {
return nil, time.Time{}, fmt.Errorf("fetch control key: %v", err)
}
defer res.Body.Close()
b, err := io.ReadAll(io.LimitReader(res.Body, 64<<10))
if err != nil {
return nil, time.Time{}, fmt.Errorf("fetch control key response: %v", err)
}
if res.StatusCode != 200 {
return nil, time.Time{}, fmt.Errorf("fetch control key: %d", res.StatusCode)
}
// Cache keys for one minute at most.
expiry := c.timeNow().Add(1 * time.Minute)
var out tailcfg.OverTLSPublicKeyResponse
jsonErr := json.Unmarshal(b, &out)
if jsonErr == nil {
return &out, expiry, nil
}
// Some old control servers might not be updated to send the new format.
// Accept the old pre-JSON format too.
out = tailcfg.OverTLSPublicKeyResponse{}
k, err := key.ParseMachinePublicUntyped(mem.B(b))
if err != nil {
return nil, time.Time{}, multierr.New(jsonErr, err)
}
out.LegacyPublicKey = k
return &out, expiry, nil
})
return
}
// DevKnob contains temporary internal-only debug knobs.

View File

@ -65,6 +65,7 @@ import (
"tailscale.com/types/preftype"
"tailscale.com/types/ptr"
"tailscale.com/types/views"
"tailscale.com/util/cache"
"tailscale.com/util/deephash"
"tailscale.com/util/dnsname"
"tailscale.com/util/mak"
@ -211,6 +212,7 @@ type LocalBackend struct {
directFileRoot string
directFileDoFinalRename bool // false on macOS, true on several NAS platforms
componentLogUntil map[string]componentLogState
controlKeyCache cache.Cache[string, *tailcfg.OverTLSPublicKeyResponse]
// ServeConfig fields. (also guarded by mu)
lastServeConfJSON mem.RO // last JSON that was parsed into serveConfig
@ -286,6 +288,16 @@ func NewLocalBackend(logf logger.Logf, logid string, store ipn.StateStore, diale
em: newExpiryManager(logf),
gotPortPollRes: make(chan struct{}),
loginFlags: loginFlags,
// TODO(andrew): can we cache this on-disk? If so, we'll need
// to Forget() the value if the tailcfg.CapabilityVersion
// changes between executions, and ensure we're handling
// profile switches/shutdowns.
controlKeyCache: &cache.Memory[string, *tailcfg.OverTLSPublicKeyResponse]{
// If we can't reach the control server, allow returning
// an expired value from the cache.
ServeExpired: true,
},
}
// Default filter blocks everything and logs nothing, until Start() is called.
@ -1382,6 +1394,11 @@ func (b *LocalBackend) Start(opts ipn.Options) error {
C2NHandler: http.HandlerFunc(b.handleC2N),
DialPlan: &b.dialPlan, // pointer because it can't be copied
// Cache control key in-memory; this helps when moving to a
// network that does TLS MiTM, and allows us to continue using
// the previously-cached control key.
ControlKeyCache: b.controlKeyCache,
// Don't warn about broken Linux IP forwarding when
// netstack is being used.
SkipIPForwardingCheck: isNetstack,

155
util/cache/cache_test.go vendored 100644
View File

@ -0,0 +1,155 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package cache
import (
"errors"
"path/filepath"
"testing"
"time"
)
var startTime = time.Date(2023, time.March, 1, 0, 0, 0, 0, time.UTC)
func TestMemoryCache(t *testing.T) {
testTime := startTime
timeNow := func() time.Time { return testTime }
c := &Memory[string, int]{
timeNow: timeNow,
}
t.Run("NoServeExpired", func(t *testing.T) {
testCacheImpl(t, c, &testTime, false)
})
t.Run("ServeExpired", func(t *testing.T) {
c.Forget()
c.ServeExpired = true
testTime = startTime
testCacheImpl(t, c, &testTime, true)
})
}
func TestDiskCache(t *testing.T) {
testTime := startTime
timeNow := func() time.Time { return testTime }
dc, err := NewDisk[string, int](filepath.Join(t.TempDir(), "cache.json"))
if err != nil {
t.Fatal(err)
}
dc.timeNow = timeNow
t.Run("NoServeExpired", func(t *testing.T) {
testCacheImpl(t, dc, &testTime, false)
})
t.Run("ServeExpired", func(t *testing.T) {
dc.Forget()
dc.ServeExpired = true
testTime = startTime
testCacheImpl(t, dc, &testTime, true)
})
}
func testCacheImpl(t *testing.T, c Cache[string, int], testTime *time.Time, serveExpired bool) {
var fillTime time.Time
t.Run("InitialFill", func(t *testing.T) {
fillTime = testTime.Add(time.Hour)
val, err := c.Get("key", func() (int, time.Time, error) {
return 123, fillTime, nil
})
if err != nil {
t.Fatal(err)
}
if val != 123 {
t.Fatalf("got val=%d; want 123", val)
}
})
// Fetching again won't call our fill function
t.Run("SecondFetch", func(t *testing.T) {
*testTime = fillTime.Add(-1 * time.Second)
called := false
val, err := c.Get("key", func() (int, time.Time, error) {
called = true
return -1, fillTime, nil
})
if called {
t.Fatal("wanted no call to fill function")
}
if err != nil {
t.Fatal(err)
}
if val != 123 {
t.Fatalf("got val=%d; want 123", val)
}
})
// Fetching after the expiry time will re-fill
t.Run("ReFill", func(t *testing.T) {
*testTime = fillTime.Add(1)
fillTime = fillTime.Add(time.Hour)
val, err := c.Get("key", func() (int, time.Time, error) {
return 999, fillTime, nil
})
if err != nil {
t.Fatal(err)
}
if val != 999 {
t.Fatalf("got val=%d; want 999", val)
}
})
// An error on fetch will serve the expired value.
t.Run("FetchError", func(t *testing.T) {
if !serveExpired {
t.Skipf("not testing ServeExpired")
}
*testTime = fillTime.Add(time.Hour + 1)
val, err := c.Get("key", func() (int, time.Time, error) {
return 0, time.Time{}, errors.New("some error")
})
if err != nil {
t.Fatal(err)
}
if val != 999 {
t.Fatalf("got val=%d; want 999", val)
}
})
// Fetching a different key re-fills
t.Run("DifferentKey", func(t *testing.T) {
*testTime = fillTime.Add(time.Hour + 1)
var calls int
val, err := c.Get("key1", func() (int, time.Time, error) {
calls++
return 123, fillTime, nil
})
if err != nil {
t.Fatal(err)
}
if val != 123 {
t.Fatalf("got val=%d; want 123", val)
}
if calls != 1 {
t.Errorf("got %d, want 1 call", calls)
}
val, err = c.Get("key2", func() (int, time.Time, error) {
calls++
return 456, fillTime, nil
})
if err != nil {
t.Fatal(err)
}
if val != 456 {
t.Fatalf("got val=%d; want 456", val)
}
if calls != 2 {
t.Errorf("got %d, want 2 call", calls)
}
})
}

120
util/cache/disk.go vendored 100644
View File

@ -0,0 +1,120 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package cache
import (
"encoding/json"
"os"
"time"
)
// Disk is a cache that stores data in a file on-disk. It also supports
// returning a previously-expired value if refreshing the value in the cache
// fails.
type Disk[K comparable, V any] struct {
key K
val V
goodUntil time.Time
path string
timeNow func() time.Time // for tests
// ServeExpired indicates that if an error occurs when filling the
// cache, an expired value can be returned instead of an error.
ServeExpired bool
}
type diskValue[K comparable, V any] struct {
Key K
Value V
Until time.Time // Always UTC
}
func NewDisk[K comparable, V any](path string) (*Disk[K, V], error) {
f, err := os.Open(path)
if err != nil {
if !os.IsNotExist(err) {
return nil, err
}
// Ignore "does not exist" errors
return &Disk[K, V]{path: path}, nil
}
defer f.Close()
var dv diskValue[K, V]
if err := json.NewDecoder(f).Decode(&dv); err != nil {
// Ignore errors; we'll overwrite when filling.
return &Disk[K, V]{path: path}, nil
}
return &Disk[K, V]{
key: dv.Key,
val: dv.Value,
goodUntil: dv.Until,
path: path,
}, nil
}
// Get will return the cached value, if any, or fill the cache by calling f and
// return the corresponding value. When the cache is filled, the value will be
// written to the configured path on-disk, along with the expiry time. Writing
// to the path on-disk is non-fatal.
//
// If f returns an error and c.ServeExpired is true, then a previous expired
// value can be returned with no error.
func (d *Disk[K, V]) Get(key K, f FillFunc[V]) (V, error) {
var now time.Time
if d.timeNow != nil {
now = d.timeNow()
} else {
now = time.Now()
}
if d.key == key && now.Before(d.goodUntil) {
return d.val, nil
}
// Re-fill cached entry
val, until, err := f()
if err == nil {
d.key = key
d.val = val
d.goodUntil = until
d.write()
return val, nil
}
// Never serve an expired entry for the wrong key.
if d.key == key && d.ServeExpired && !d.goodUntil.IsZero() {
return d.val, nil
}
var zero V
return zero, err
}
func (d *Disk[K, V]) write() {
// Try writing to the file on-disk, but ignore errors.
b, err := json.Marshal(diskValue[K, V]{
Key: d.key,
Value: d.val,
Until: d.goodUntil.UTC(),
})
if err == nil {
os.WriteFile(d.path, b, 0600)
}
}
// Forget implements Cache.
func (d *Disk[K, V]) Forget() {
d.goodUntil = time.Time{}
var zeroKey K
d.key = zeroKey
var zeroVal V
d.val = zeroVal
d.write()
}

26
util/cache/interface.go vendored 100644
View File

@ -0,0 +1,26 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
// Package cache contains an interface for a cache around a typed value, and
// various cache implementations that implement that interface.
package cache
import "time"
// Cache is the interface for the cache types in this package.
type Cache[K comparable, V any] interface {
// Get should return a previously-cached value or call the provided
// FillFunc to obtain a new one. The provided key can be used either to
// allow multiple cached values, or to drop the cache if the key
// changes; either is valid.
Get(K, FillFunc[V]) (V, error)
// Forget should empty the cache such that the next call to Get should
// call the provided FillFunc.
Forget()
}
// FillFunc is the signature of a function for filling a cache. It should
// return the value to be cached, the time that the cached value is valid
// until, or an error
type FillFunc[T any] func() (T, time.Time, error)

64
util/cache/memory.go vendored 100644
View File

@ -0,0 +1,64 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package cache
import "time"
// Memory is a simple in-memory cache that stores a value until a defined time
// before it is re-fetched. It also supports returning a previously-expired
// value if refreshing the value in the cache fails.
type Memory[K comparable, V any] struct {
key K
val V
goodUntil time.Time
timeNow func() time.Time // for tests
// ServeExpired indicates that if an error occurs when filling the
// cache, an expired value can be returned instead of an error.
ServeExpired bool
}
// Get will return the cached value, if any, or fill the cache by calling f and
// return the corresponding value. If f returns an error and c.ServeExpired is
// true, then a previous expired value can be returned with no error.
func (c *Memory[K, V]) Get(key K, f FillFunc[V]) (V, error) {
var now time.Time
if c.timeNow != nil {
now = c.timeNow()
} else {
now = time.Now()
}
if c.key == key && now.Before(c.goodUntil) {
return c.val, nil
}
// Re-fill cached entry
val, until, err := f()
if err == nil {
c.key = key
c.val = val
c.goodUntil = until
return val, nil
}
// Never serve an expired entry for the wrong key.
if c.key == key && c.ServeExpired && !c.goodUntil.IsZero() {
return c.val, nil
}
var zero V
return zero, err
}
// Forget implements Cache.
func (c *Memory[K, V]) Forget() {
c.goodUntil = time.Time{}
var zeroKey K
c.key = zeroKey
var zeroVal V
c.val = zeroVal
}

16
util/cache/none.go vendored 100644
View File

@ -0,0 +1,16 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package cache
// None provides no caching and always calls the provided FillFunc.
type None[K comparable, V any] struct{}
// Get always calls the provided FillFunc and returns what it does.
func (c None[K, V]) Get(_ K, f FillFunc[V]) (V, error) {
v, _, e := f()
return v, e
}
// Forget implements Cache.
func (c None[K, V]) Forget() {}