diff --git a/ipn/ipn_clone.go b/ipn/ipn_clone.go index 34b7bc5a7..8f9bc0567 100644 --- a/ipn/ipn_clone.go +++ b/ipn/ipn_clone.go @@ -24,10 +24,7 @@ func (src *Prefs) Clone() *Prefs { *dst = *src dst.AdvertiseTags = append(src.AdvertiseTags[:0:0], src.AdvertiseTags...) dst.AdvertiseRoutes = append(src.AdvertiseRoutes[:0:0], src.AdvertiseRoutes...) - if dst.Persist != nil { - dst.Persist = new(persist.Persist) - *dst.Persist = *src.Persist - } + dst.Persist = src.Persist.Clone() return dst } diff --git a/ipn/ipn_view.go b/ipn/ipn_view.go index 1cfa0eee9..2209cb0a1 100644 --- a/ipn/ipn_view.go +++ b/ipn/ipn_view.go @@ -87,13 +87,7 @@ func (v PrefsView) NoSNAT() bool { return v.ж.NoSNAT } func (v PrefsView) NetfilterMode() preftype.NetfilterMode { return v.ж.NetfilterMode } func (v PrefsView) OperatorUser() string { return v.ж.OperatorUser } func (v PrefsView) ProfileName() string { return v.ж.ProfileName } -func (v PrefsView) Persist() *persist.Persist { - if v.ж.Persist == nil { - return nil - } - x := *v.ж.Persist - return &x -} +func (v PrefsView) Persist() persist.PersistView { return v.ж.Persist.View() } // A compilation failure here means this code must be regenerated, with the command at the top of this file. var _PrefsViewNeedsRegeneration = Prefs(struct { diff --git a/ipn/ipnlocal/local.go b/ipn/ipnlocal/local.go index b35c1585a..0e6c5da05 100644 --- a/ipn/ipnlocal/local.go +++ b/ipn/ipnlocal/local.go @@ -517,7 +517,7 @@ func (b *LocalBackend) Shutdown() { } func stripKeysFromPrefs(p ipn.PrefsView) ipn.PrefsView { - if !p.Valid() || p.Persist() == nil { + if !p.Valid() || !p.Persist().Valid() { return p } @@ -816,7 +816,7 @@ func (b *LocalBackend) setClientStatus(st controlclient.Status) { b.mu.Lock() if st.LogoutFinished != nil { - if p := b.pm.CurrentPrefs(); p.Persist() == nil || p.Persist().LoginName == "" { + if p := b.pm.CurrentPrefs(); !p.Persist().Valid() || p.Persist().LoginName() == "" { b.mu.Unlock() return } @@ -1203,7 +1203,7 @@ func (b *LocalBackend) Start(opts ipn.Options) error { if opts.UpdatePrefs != nil { oldPrefs := b.pm.CurrentPrefs() newPrefs := opts.UpdatePrefs.Clone() - newPrefs.Persist = oldPrefs.Persist() + newPrefs.Persist = oldPrefs.Persist().AsStruct() pv := newPrefs.View() if err := b.pm.SetPrefs(pv); err != nil { b.logf("failed to save UpdatePrefs state: %v", err) @@ -1228,7 +1228,7 @@ func (b *LocalBackend) Start(opts ipn.Options) error { b.applyPrefsToHostinfoLocked(hostinfo, prefs) b.setNetMapLocked(nil) - persistv := prefs.Persist() + persistv := prefs.Persist().AsStruct() if persistv == nil { persistv = new(persist.Persist) } @@ -1947,8 +1947,8 @@ func (b *LocalBackend) initMachineKeyLocked() (err error) { } var legacyMachineKey key.MachinePrivate - if p := b.pm.CurrentPrefs().Persist(); p != nil { - legacyMachineKey = p.LegacyFrontendPrivateMachineKey + if p := b.pm.CurrentPrefs().Persist(); p.Valid() { + legacyMachineKey = p.LegacyFrontendPrivateMachineKey() } keyText, err := b.store.ReadState(ipn.MachineKeyStateKey) @@ -2481,7 +2481,7 @@ func (b *LocalBackend) setPrefsLockedOnEntry(caller string, newp *ipn.Prefs) ipn oldp := b.pm.CurrentPrefs() if oldp.Valid() { - newp.Persist = oldp.Persist().Clone() // caller isn't allowed to override this + newp.Persist = oldp.Persist().AsStruct() // caller isn't allowed to override this } // findExitNodeIDLocked returns whether it updated b.prefs, but // everything in this function treats b.prefs as completely new @@ -3338,7 +3338,7 @@ func (b *LocalBackend) hasNodeKey() bool { b.mu.Lock() defer b.mu.Unlock() p := b.pm.CurrentPrefs() - return p.Valid() && p.Persist() != nil && !p.Persist().PrivateNodeKey.IsZero() + return p.Valid() && p.Persist().Valid() && !p.Persist().PrivateNodeKey().IsZero() } // nextState returns the state the backend seems to be in, based on @@ -3926,8 +3926,8 @@ func (b *LocalBackend) SetDNS(ctx context.Context, name, value string) error { b.mu.Lock() cc := b.ccAuto - if prefs := b.pm.CurrentPrefs(); prefs.Valid() { - req.NodeKey = prefs.Persist().PrivateNodeKey.Public() + if prefs := b.pm.CurrentPrefs(); prefs.Valid() && prefs.Persist().Valid() { + req.NodeKey = prefs.Persist().PrivateNodeKey().Public() } b.mu.Unlock() if cc == nil { diff --git a/ipn/ipnlocal/network-lock.go b/ipn/ipnlocal/network-lock.go index 8346ed7cf..10c05b206 100644 --- a/ipn/ipnlocal/network-lock.go +++ b/ipn/ipnlocal/network-lock.go @@ -345,10 +345,10 @@ func (b *LocalBackend) NetworkLockStatus() *ipnstate.NetworkLockStatus { nodeKey *key.NodePublic nlPriv key.NLPrivate ) - if p := b.pm.CurrentPrefs(); p.Valid() && p.Persist() != nil && !p.Persist().PrivateNodeKey.IsZero() { + if p := b.pm.CurrentPrefs(); p.Valid() && p.Persist().Valid() && !p.Persist().PrivateNodeKey().IsZero() { nkp := p.Persist().PublicNodeKey() nodeKey = &nkp - nlPriv = p.Persist().NetworkLockKey + nlPriv = p.Persist().NetworkLockKey() } if nlPriv.IsZero() { @@ -411,9 +411,9 @@ func (b *LocalBackend) NetworkLockInit(keys []tka.Key, disablementValues [][]byt var ourNodeKey key.NodePublic var nlPriv key.NLPrivate b.mu.Lock() - if p := b.pm.CurrentPrefs(); p.Valid() && p.Persist() != nil && !p.Persist().PrivateNodeKey.IsZero() { + if p := b.pm.CurrentPrefs(); p.Valid() && p.Persist().Valid() && !p.Persist().PrivateNodeKey().IsZero() { ourNodeKey = p.Persist().PublicNodeKey() - nlPriv = p.Persist().NetworkLockKey + nlPriv = p.Persist().NetworkLockKey() } b.mu.Unlock() if ourNodeKey.IsZero() || nlPriv.IsZero() { @@ -503,8 +503,8 @@ func (b *LocalBackend) NetworkLockSign(nodeKey key.NodePublic, rotationPublic [] defer b.mu.Unlock() var nlPriv key.NLPrivate - if p := b.pm.CurrentPrefs(); p.Valid() && p.Persist() != nil { - nlPriv = p.Persist().NetworkLockKey + if p := b.pm.CurrentPrefs(); p.Valid() && p.Persist().Valid() { + nlPriv = p.Persist().NetworkLockKey() } if nlPriv.IsZero() { return key.NodePublic{}, tka.NodeKeySignature{}, errMissingNetmap @@ -557,7 +557,7 @@ func (b *LocalBackend) NetworkLockModify(addKeys, removeKeys []tka.Key) (err err defer b.mu.Unlock() var ourNodeKey key.NodePublic - if p := b.pm.CurrentPrefs(); p.Valid() && p.Persist() != nil && !p.Persist().PrivateNodeKey.IsZero() { + if p := b.pm.CurrentPrefs(); p.Valid() && p.Persist().Valid() && !p.Persist().PrivateNodeKey().IsZero() { ourNodeKey = p.Persist().PublicNodeKey() } if ourNodeKey.IsZero() { @@ -568,8 +568,8 @@ func (b *LocalBackend) NetworkLockModify(addKeys, removeKeys []tka.Key) (err err return err } var nlPriv key.NLPrivate - if p := b.pm.CurrentPrefs(); p.Valid() && p.Persist() != nil { - nlPriv = p.Persist().NetworkLockKey + if p := b.pm.CurrentPrefs(); p.Valid() && p.Persist().Valid() { + nlPriv = p.Persist().NetworkLockKey() } if nlPriv.IsZero() { return errMissingNetmap @@ -634,7 +634,7 @@ func (b *LocalBackend) NetworkLockDisable(secret []byte) error { ) b.mu.Lock() - if p := b.pm.CurrentPrefs(); p.Valid() && p.Persist() != nil && !p.Persist().PrivateNodeKey.IsZero() { + if p := b.pm.CurrentPrefs(); p.Valid() && p.Persist().Valid() && !p.Persist().PrivateNodeKey().IsZero() { ourNodeKey = p.Persist().PublicNodeKey() } if b.tka == nil { diff --git a/ipn/ipnlocal/profiles.go b/ipn/ipnlocal/profiles.go index cd6d64168..917ebf027 100644 --- a/ipn/ipnlocal/profiles.go +++ b/ipn/ipnlocal/profiles.go @@ -179,7 +179,7 @@ func init() { // provided prefs, which may be accessed via CurrentPrefs. func (pm *profileManager) SetPrefs(prefsIn ipn.PrefsView) error { prefs := prefsIn.AsStruct().View() - newPersist := prefs.Persist() + newPersist := prefs.Persist().AsStruct() if newPersist == nil || newPersist.LoginName == "" { return pm.setPrefsLocked(prefs) } diff --git a/ipn/ipnlocal/state_test.go b/ipn/ipnlocal/state_test.go index 3ff1fcf72..a66ec46f0 100644 --- a/ipn/ipnlocal/state_test.go +++ b/ipn/ipnlocal/state_test.go @@ -489,7 +489,7 @@ func TestStateMachine(t *testing.T) { c.Assert(nn[0].LoginFinished, qt.IsNotNil) c.Assert(nn[1].Prefs, qt.IsNotNil) c.Assert(nn[2].State, qt.IsNotNil) - c.Assert(nn[1].Prefs.Persist().LoginName, qt.Equals, "user1") + c.Assert(nn[1].Prefs.Persist().LoginName(), qt.Equals, "user1") c.Assert(ipn.NeedsMachineAuth, qt.Equals, *nn[2].State) c.Assert(ipn.NeedsMachineAuth, qt.Equals, b.State()) } @@ -711,7 +711,7 @@ func TestStateMachine(t *testing.T) { c.Assert(nn[1].Prefs.Persist(), qt.IsNotNil) c.Assert(nn[2].State, qt.IsNotNil) // Prefs after finishing the login, so LoginName updated. - c.Assert(nn[1].Prefs.Persist().LoginName, qt.Equals, "user2") + c.Assert(nn[1].Prefs.Persist().LoginName(), qt.Equals, "user2") c.Assert(nn[1].Prefs.LoggedOut(), qt.IsFalse) c.Assert(nn[1].Prefs.WantRunning(), qt.IsTrue) c.Assert(ipn.Starting, qt.Equals, *nn[2].State) @@ -852,7 +852,7 @@ func TestStateMachine(t *testing.T) { c.Assert(nn[1].Prefs, qt.IsNotNil) c.Assert(nn[2].State, qt.IsNotNil) // Prefs after finishing the login, so LoginName updated. - c.Assert(nn[1].Prefs.Persist().LoginName, qt.Equals, "user3") + c.Assert(nn[1].Prefs.Persist().LoginName(), qt.Equals, "user3") c.Assert(nn[1].Prefs.LoggedOut(), qt.IsFalse) c.Assert(nn[1].Prefs.WantRunning(), qt.IsTrue) c.Assert(ipn.Starting, qt.Equals, *nn[2].State) @@ -957,7 +957,7 @@ func TestEditPrefsHasNoKeys(t *testing.T) { LegacyFrontendPrivateMachineKey: key.NewMachine(), }, }).View()) - if b.pm.CurrentPrefs().Persist().PrivateNodeKey.IsZero() { + if p := b.pm.CurrentPrefs().Persist(); !p.Valid() || p.PrivateNodeKey().IsZero() { t.Fatalf("PrivateNodeKey not set") } p, err := b.EditPrefs(&ipn.MaskedPrefs{ @@ -973,20 +973,20 @@ func TestEditPrefsHasNoKeys(t *testing.T) { t.Errorf("Hostname = %q; want foo", p.Hostname()) } - if !p.Persist().PrivateNodeKey.IsZero() { - t.Errorf("PrivateNodeKey = %v; want zero", p.Persist().PrivateNodeKey) + if !p.Persist().PrivateNodeKey().IsZero() { + t.Errorf("PrivateNodeKey = %v; want zero", p.Persist().PrivateNodeKey()) } - if !p.Persist().OldPrivateNodeKey.IsZero() { - t.Errorf("OldPrivateNodeKey = %v; want zero", p.Persist().OldPrivateNodeKey) + if !p.Persist().OldPrivateNodeKey().IsZero() { + t.Errorf("OldPrivateNodeKey = %v; want zero", p.Persist().OldPrivateNodeKey()) } - if !p.Persist().LegacyFrontendPrivateMachineKey.IsZero() { - t.Errorf("LegacyFrontendPrivateMachineKey = %v; want zero", p.Persist().LegacyFrontendPrivateMachineKey) + if !p.Persist().LegacyFrontendPrivateMachineKey().IsZero() { + t.Errorf("LegacyFrontendPrivateMachineKey = %v; want zero", p.Persist().LegacyFrontendPrivateMachineKey()) } - if !p.Persist().NetworkLockKey.IsZero() { - t.Errorf("NetworkLockKey= %v; want zero", p.Persist().NetworkLockKey) + if !p.Persist().NetworkLockKey().IsZero() { + t.Errorf("NetworkLockKey= %v; want zero", p.Persist().NetworkLockKey()) } } diff --git a/types/persist/persist.go b/types/persist/persist.go index b128c5f70..b21a5e26d 100644 --- a/types/persist/persist.go +++ b/types/persist/persist.go @@ -7,6 +7,7 @@ package persist import ( "fmt" + "reflect" "tailscale.com/tailcfg" "tailscale.com/types/key" @@ -39,6 +40,12 @@ type Persist struct { UserProfile tailcfg.UserProfile NetworkLockKey key.NLPrivate NodeID tailcfg.StableNodeID + + // DisallowedTKAStateIDs stores the tka.State.StateID values which + // this node will not operate network lock on. This is used to + // prevent bootstrapping TKA onto a key authority which was forcibly + // disabled. + DisallowedTKAStateIDs []string `json:",omitempty"` } // PublicNodeKey returns the public key for the node key. @@ -55,6 +62,13 @@ func (p PersistView) Equals(p2 PersistView) bool { return p.ж.Equals(p2.ж) } +func nilIfEmpty[E any](s []E) []E { + if len(s) == 0 { + return nil + } + return s +} + func (p *Persist) Equals(p2 *Persist) bool { if p == nil && p2 == nil { return true @@ -70,7 +84,8 @@ func (p *Persist) Equals(p2 *Persist) bool { p.LoginName == p2.LoginName && p.UserProfile == p2.UserProfile && p.NetworkLockKey.Equal(p2.NetworkLockKey) && - p.NodeID == p2.NodeID + p.NodeID == p2.NodeID && + reflect.DeepEqual(nilIfEmpty(p.DisallowedTKAStateIDs), nilIfEmpty(p2.DisallowedTKAStateIDs)) } func (p *Persist) Pretty() string { diff --git a/types/persist/persist_clone.go b/types/persist/persist_clone.go index aeb40afe5..82db9c52b 100644 --- a/types/persist/persist_clone.go +++ b/types/persist/persist_clone.go @@ -20,6 +20,7 @@ func (src *Persist) Clone() *Persist { } dst := new(Persist) *dst = *src + dst.DisallowedTKAStateIDs = append(src.DisallowedTKAStateIDs[:0:0], src.DisallowedTKAStateIDs...) return dst } @@ -34,4 +35,5 @@ var _PersistCloneNeedsRegeneration = Persist(struct { UserProfile tailcfg.UserProfile NetworkLockKey key.NLPrivate NodeID tailcfg.StableNodeID + DisallowedTKAStateIDs []string }{}) diff --git a/types/persist/persist_test.go b/types/persist/persist_test.go index 7651fe02a..f4be69304 100644 --- a/types/persist/persist_test.go +++ b/types/persist/persist_test.go @@ -22,7 +22,7 @@ func fieldsOf(t reflect.Type) (fields []string) { } func TestPersistEqual(t *testing.T) { - persistHandles := []string{"LegacyFrontendPrivateMachineKey", "PrivateNodeKey", "OldPrivateNodeKey", "Provider", "LoginName", "UserProfile", "NetworkLockKey", "NodeID"} + persistHandles := []string{"LegacyFrontendPrivateMachineKey", "PrivateNodeKey", "OldPrivateNodeKey", "Provider", "LoginName", "UserProfile", "NetworkLockKey", "NodeID", "DisallowedTKAStateIDs"} if have := fieldsOf(reflect.TypeOf(Persist{})); !reflect.DeepEqual(have, persistHandles) { t.Errorf("Persist.Equal check might be out of sync\nfields: %q\nhandled: %q\n", have, persistHandles) @@ -133,6 +133,21 @@ func TestPersistEqual(t *testing.T) { &Persist{NodeID: "abc"}, false, }, + { + &Persist{DisallowedTKAStateIDs: nil}, + &Persist{DisallowedTKAStateIDs: []string{"0:0"}}, + false, + }, + { + &Persist{DisallowedTKAStateIDs: []string{"0:1"}}, + &Persist{DisallowedTKAStateIDs: []string{"0:1"}}, + true, + }, + { + &Persist{DisallowedTKAStateIDs: []string{}}, + &Persist{DisallowedTKAStateIDs: nil}, + true, + }, } for i, test := range tests { if got := test.a.Equals(test.b); got != test.want { diff --git a/types/persist/persist_view.go b/types/persist/persist_view.go index b961c07c9..15355abf4 100644 --- a/types/persist/persist_view.go +++ b/types/persist/persist_view.go @@ -13,6 +13,7 @@ import ( "tailscale.com/tailcfg" "tailscale.com/types/key" "tailscale.com/types/structs" + "tailscale.com/types/views" ) //go:generate go run tailscale.com/cmd/cloner -clonefunc=false -type=Persist @@ -72,6 +73,9 @@ func (v PersistView) LoginName() string { return v.ж.LoginName func (v PersistView) UserProfile() tailcfg.UserProfile { return v.ж.UserProfile } func (v PersistView) NetworkLockKey() key.NLPrivate { return v.ж.NetworkLockKey } func (v PersistView) NodeID() tailcfg.StableNodeID { return v.ж.NodeID } +func (v PersistView) DisallowedTKAStateIDs() views.Slice[string] { + return views.SliceOf(v.ж.DisallowedTKAStateIDs) +} // A compilation failure here means this code must be regenerated, with the command at the top of this file. var _PersistViewNeedsRegeneration = Persist(struct { @@ -84,4 +88,5 @@ var _PersistViewNeedsRegeneration = Persist(struct { UserProfile tailcfg.UserProfile NetworkLockKey key.NLPrivate NodeID tailcfg.StableNodeID + DisallowedTKAStateIDs []string }{})