From 9bbb2b0911f3067920b53811a30145393d6e73b1 Mon Sep 17 00:00:00 2001 From: Marwan Sulaiman Date: Wed, 24 May 2023 23:57:09 -0400 Subject: [PATCH] Make it even simpler by removing Init --- ipn/ipnlocal/local.go | 75 +++++++++++++++++---------------------- portlist/poller.go | 12 +++---- portlist/portlist_test.go | 2 +- 3 files changed, 39 insertions(+), 50 deletions(-) diff --git a/ipn/ipnlocal/local.go b/ipn/ipnlocal/local.go index eb5355a30..fe7bd9df3 100644 --- a/ipn/ipnlocal/local.go +++ b/ipn/ipnlocal/local.go @@ -146,9 +146,8 @@ type LocalBackend struct { backendLogID logid.PublicID unregisterNetMon func() unregisterHealthWatch func() - portpoll *portlist.Poller // may be nil - portpollOnce sync.Once // guards starting readPoller - gotPortPollRes chan struct{} // closed upon first readPoller result + portpollOnce sync.Once // guards starting readPoller + gotPortPollRes chan struct{} // closed upon first readPoller result newDecompressor func() (controlclient.Decompressor, error) varRoot string // or empty if SetVarRoot never called logFlushFunc func() // or nil if SetLogFlusher wasn't called @@ -292,12 +291,6 @@ func NewLocalBackend(logf logger.Logf, logID logid.PublicID, sys *tsd.System, lo osshare.SetFileSharingEnabled(false, logf) ctx, cancel := context.WithCancel(context.Background()) - portpoll := new(portlist.Poller) - err = portpoll.Init() - if err != nil { - logf("skipping portlist: %s", err) - portpoll = nil - } b := &LocalBackend{ ctx: ctx, @@ -312,7 +305,6 @@ func NewLocalBackend(logf logger.Logf, logID logid.PublicID, sys *tsd.System, lo pm: pm, backendLogID: logID, state: ipn.NoState, - portpoll: portpoll, em: newExpiryManager(logf), gotPortPollRes: make(chan struct{}), loginFlags: loginFlags, @@ -1377,30 +1369,32 @@ func (b *LocalBackend) Start(opts ipn.Options) error { b.updateFilterLocked(nil, ipn.PrefsView{}) b.mu.Unlock() - if b.portpoll != nil { - b.portpollOnce.Do(func() { - updates, err := b.portpoll.Run(b.ctx) - if err != nil { - b.logf("error running port poller: %v", err) - return - } - go b.readPoller(updates) + b.portpollOnce.Do(func() { + var p portlist.Poller + updates, err := p.Run(b.ctx) + if err != nil { + b.logf("skipping portlist: %s", err) + return + } + go func() { + defer p.Close() + b.readPoller(updates) + }() - // Give the poller a second to get results to - // prevent it from restarting our map poll - // HTTP request (via doSetHostinfoFilterServices > - // cli.SetHostinfo). In practice this is very quick. - t0 := time.Now() - timer := time.NewTimer(time.Second) - select { - case <-b.gotPortPollRes: - b.logf("[v1] got initial portlist info in %v", time.Since(t0).Round(time.Millisecond)) - timer.Stop() - case <-timer.C: - b.logf("timeout waiting for initial portlist") - } - }) - } + // Give the poller a second to get results to + // prevent it from restarting our map poll + // HTTP request (via doSetHostinfoFilterServices > + // cli.SetHostinfo). In practice this is very quick. + t0 := time.Now() + timer := time.NewTimer(time.Second) + select { + case <-b.gotPortPollRes: + b.logf("[v1] got initial portlist info in %v", time.Since(t0).Round(time.Millisecond)) + timer.Stop() + case <-timer.C: + b.logf("timeout waiting for initial portlist") + } + }) discoPublic := b.e.DiscoPublicKey() @@ -1818,16 +1812,11 @@ func dnsMapsEqual(new, old *netmap.NetworkMap) bool { // readPoller is a goroutine that receives service lists from // b.portpoll and propagates them into the controlclient's HostInfo. func (b *LocalBackend) readPoller(updates chan portlist.Update) { - defer b.portpoll.Close() - n := 0 - for { - update, ok := <-updates - if !ok { - return - } + firstResults := true + for update := range updates { if update.Error != nil { b.logf("error polling os ports: %v", update.Error) - return // preserve all behavior, though we can just continue + return // preserve old behavior, though we can just continue and try again? } sl := []tailcfg.Service{} for _, p := range update.List { @@ -1851,8 +1840,8 @@ func (b *LocalBackend) readPoller(updates chan portlist.Update) { b.doSetHostinfoFilterServices(hi) - n++ - if n == 1 { + if firstResults { + firstResults = false close(b.gotPortPollRes) } } diff --git a/portlist/poller.go b/portlist/poller.go index 6cf6e1d6c..944560b7f 100644 --- a/portlist/poller.go +++ b/portlist/poller.go @@ -97,20 +97,20 @@ func (p *Poller) setPrev(pl List) { p.prev = slices.Clone(pl) } -// Init is an optional method that makes sure the Poller is enabled +// init makes sure the Poller is enabled // and the undelrying OS implementation is working properly. // -// An error returned from Init is non-fatal and means +// An error returned from init is non-fatal and means // that it's been administratively disabled or the underlying // OS is not implemented. -func (p *Poller) Init() error { +func (p *Poller) init() error { p.initOnce.Do(func() { - p.initErr = p.init() + p.initErr = p.initWithErr() }) return p.initErr } -func (p *Poller) init() error { +func (p *Poller) initWithErr() error { if debugDisablePortlist() { return errors.New("portlist disabled by envknob") } @@ -170,7 +170,7 @@ func (p *Poller) send(ctx context.Context, pl List, plErr error) (sent bool) { // Run may only be called once. func (p *Poller) Run(ctx context.Context) (chan Update, error) { if p.os == nil { - err := p.Init() + err := p.init() if err != nil { return nil, fmt.Errorf("error initializing poller: %w", err) } diff --git a/portlist/portlist_test.go b/portlist/portlist_test.go index 7f7128b48..4e4b966e2 100644 --- a/portlist/portlist_test.go +++ b/portlist/portlist_test.go @@ -179,7 +179,7 @@ func TestEqualLessThan(t *testing.T) { func TestPoller(t *testing.T) { var p Poller - err := p.Init() + err := p.init() if err != nil { t.Skipf("not running test: %v", err) }