diff --git a/ipn/ipnlocal/local.go b/ipn/ipnlocal/local.go index 36ab52979..eb5355a30 100644 --- a/ipn/ipnlocal/local.go +++ b/ipn/ipnlocal/local.go @@ -293,7 +293,7 @@ func NewLocalBackend(logf logger.Logf, logID logid.PublicID, sys *tsd.System, lo ctx, cancel := context.WithCancel(context.Background()) portpoll := new(portlist.Poller) - err = portpoll.Check() + err = portpoll.Init() if err != nil { logf("skipping portlist: %s", err) portpoll = nil @@ -1379,8 +1379,12 @@ func (b *LocalBackend) Start(opts ipn.Options) error { if b.portpoll != nil { b.portpollOnce.Do(func() { - go b.portpoll.Run(b.ctx) - go b.readPoller() + updates, err := b.portpoll.Run(b.ctx) + if err != nil { + b.logf("error running port poller: %v", err) + return + } + go b.readPoller(updates) // Give the poller a second to get results to // prevent it from restarting our map poll @@ -1813,15 +1817,20 @@ func dnsMapsEqual(new, old *netmap.NetworkMap) bool { // readPoller is a goroutine that receives service lists from // b.portpoll and propagates them into the controlclient's HostInfo. -func (b *LocalBackend) readPoller() { +func (b *LocalBackend) readPoller(updates chan portlist.Update) { + defer b.portpoll.Close() n := 0 for { - ports, ok := <-b.portpoll.Updates() + update, ok := <-updates if !ok { return } + if update.Error != nil { + b.logf("error polling os ports: %v", update.Error) + return // preserve all behavior, though we can just continue + } sl := []tailcfg.Service{} - for _, p := range ports { + for _, p := range update.List { s := tailcfg.Service{ Proto: tailcfg.ServiceProto(p.Proto), Port: p.Port, diff --git a/portlist/poller.go b/portlist/poller.go index 6e39bb872..6cf6e1d6c 100644 --- a/portlist/poller.go +++ b/portlist/poller.go @@ -9,6 +9,7 @@ package portlist import ( "context" "errors" + "fmt" "runtime" "sync" "time" @@ -29,9 +30,14 @@ type Poller struct { // This field should only be changed before calling Run. IncludeLocalhost bool - c chan List // unbuffered + // Interval sets the polling interval for probing the underlying + // os for port updates. + Interval time.Duration + + c chan Update // unbuffered initOnce sync.Once // guards init of private fields + initErr error // os, if non-nil, is an OS-specific implementation of the portlist getting // code. When non-nil, it's responsible for getting the complete list of @@ -53,6 +59,19 @@ type Poller struct { prev List // most recent data, not aliasing scratch } +// Update is a container for a portlist update event. +// When Poller polls the underlying OS for an update, +// it either returns a new list of open ports, +// or an error that happened in the process. +// +// Note that it is up to the caller to act upon the error, +// such as closing the Poller. Otherwise, the Poller will continue +// to try and get a list for every interval. +type Update struct { + List List + Error error +} + // osImpl is the OS-specific implementation of getting the open listening ports. type osImpl interface { Close() error @@ -78,37 +97,54 @@ func (p *Poller) setPrev(pl List) { p.prev = slices.Clone(pl) } -// init sets the os implementation if exists. It also sets -// all private fields. All exported methods must call this in a -// Once, otherwise they may panic. -func (p *Poller) init() { - if debugDisablePortlist() { - return - } - if newOSImpl != nil { - p.os = newOSImpl(p.IncludeLocalhost) - } - p.closeCtx, p.closeCtxCancel = context.WithCancel(context.Background()) - p.c = make(chan List) - p.runDone = make(chan struct{}) +// Init is an optional method that makes sure the Poller is enabled +// and the undelrying OS implementation is working properly. +// +// An error returned from Init is non-fatal and means +// that it's been administratively disabled or the underlying +// OS is not implemented. +func (p *Poller) Init() error { + p.initOnce.Do(func() { + p.initErr = p.init() + }) + return p.initErr } -// Updates return the channel that receives port list updates. -// -// The channel is closed when the Poller is closed. -func (p *Poller) Updates() <-chan List { - p.initOnce.Do(p.init) - return p.c +func (p *Poller) init() error { + if debugDisablePortlist() { + return errors.New("portlist disabled by envknob") + } + if newOSImpl == nil { + return errUnimplemented + } + p.os = newOSImpl(p.IncludeLocalhost) + + // Do one initial poll synchronously so we can return an error + // early. + if pl, err := p.getList(); err != nil { + return err + } else { + p.setPrev(pl) + } + + if p.Interval == 0 { + p.Interval = pollInterval + } + + p.closeCtx, p.closeCtxCancel = context.WithCancel(context.Background()) + p.c = make(chan Update) + p.runDone = make(chan struct{}) + + return nil } // Close closes the Poller. // Run will return with a nil error. func (p *Poller) Close() error { - p.initOnce.Do(p.init) - p.closeCtxCancel() if p.os == nil { return nil } + p.closeCtxCancel() <-p.runDone // if caller of Close never called Run, this can hang. if p.os != nil { p.os.Close() @@ -117,14 +153,14 @@ func (p *Poller) Close() error { } // send sends pl to p.c and returns whether it was successfully sent. -func (p *Poller) send(ctx context.Context, pl List) (sent bool, err error) { +func (p *Poller) send(ctx context.Context, pl List, plErr error) (sent bool) { select { - case p.c <- pl: - return true, nil + case p.c <- Update{pl, plErr}: + return true case <-ctx.Done(): - return false, ctx.Err() + return false case <-p.closeCtx.Done(): - return false, nil + return false } } @@ -132,45 +168,26 @@ func (p *Poller) send(ctx context.Context, pl List) (sent bool, err error) { // is done, or the Close is called. // // Run may only be called once. -func (p *Poller) Run(ctx context.Context) error { - tick := time.NewTicker(pollInterval) +func (p *Poller) Run(ctx context.Context) (chan Update, error) { + if p.os == nil { + err := p.Init() + if err != nil { + return nil, fmt.Errorf("error initializing poller: %w", err) + } + } + tick := time.NewTicker(p.Interval) defer tick.Stop() - return p.runWithTickChan(ctx, tick.C) + go p.runWithTickChan(ctx, tick.C) + return p.c, nil } -// Check makes sure that the Poller is enabled and -// the undelrying OS implementation is working properly. -// -// An error returned from Check is non-fatal and means -// that it's been administratively disabled or the underlying -// OS is not implemented. -func (p *Poller) Check() error { - p.initOnce.Do(p.init) - if p.os == nil { - return errUnimplemented - } - // Do one initial poll synchronously so we can return an error - // early. - if pl, err := p.getList(); err != nil { - return err - } else { - p.setPrev(pl) - } - return nil -} - -func (p *Poller) runWithTickChan(ctx context.Context, tickChan <-chan time.Time) error { - p.initOnce.Do(p.init) - if p.os == nil { - return errUnimplemented - } - +func (p *Poller) runWithTickChan(ctx context.Context, tickChan <-chan time.Time) { defer close(p.runDone) defer close(p.c) // Send out the pre-generated initial value. - if sent, err := p.send(ctx, p.prev); !sent { - return err + if sent := p.send(ctx, p.prev, nil); !sent { + return } for { @@ -178,28 +195,27 @@ func (p *Poller) runWithTickChan(ctx context.Context, tickChan <-chan time.Time) case <-tickChan: pl, err := p.getList() if err != nil { - return err + if !p.send(ctx, nil, err) { + return + } + continue } if pl.equal(p.prev) { continue } p.setPrev(pl) - if sent, err := p.send(ctx, p.prev); !sent { - return err + if !p.send(ctx, p.prev, nil) { + return } case <-ctx.Done(): - return ctx.Err() + return case <-p.closeCtx.Done(): - return nil + return } } } func (p *Poller) getList() (List, error) { - if debugDisablePortlist() { - return nil, nil - } - p.initOnce.Do(p.init) var err error p.scratch, err = p.os.AppendListeningPorts(p.scratch[:0]) return p.scratch, err diff --git a/portlist/portlist_test.go b/portlist/portlist_test.go index 1910f1a0f..7f7128b48 100644 --- a/portlist/portlist_test.go +++ b/portlist/portlist_test.go @@ -17,6 +17,7 @@ func TestGetList(t *testing.T) { tstest.ResourceCheck(t) var p Poller + p.os = newOSImpl(false) pl, err := p.getList() if err != nil { t.Fatal(err) @@ -38,6 +39,7 @@ func TestIgnoreLocallyBoundPorts(t *testing.T) { ta := ln.Addr().(*net.TCPAddr) port := ta.Port var p Poller + p.os = newOSImpl(false) pl, err := p.getList() if err != nil { t.Fatal(err) @@ -51,7 +53,7 @@ func TestIgnoreLocallyBoundPorts(t *testing.T) { func TestChangesOverTime(t *testing.T) { var p Poller - p.IncludeLocalhost = true + p.os = newOSImpl(true) get := func(t *testing.T) []Port { t.Helper() s, err := p.getList() @@ -177,7 +179,7 @@ func TestEqualLessThan(t *testing.T) { func TestPoller(t *testing.T) { var p Poller - err := p.Check() + err := p.Init() if err != nil { t.Skipf("not running test: %v", err) } @@ -190,10 +192,14 @@ func TestPoller(t *testing.T) { go func() { defer wg.Done() - for pl := range p.Updates() { + for update := range p.c { + if update.Error != nil { + t.Errorf("error polling ports: %v", err) + return + } // Look at all the pl slice memory to maximize // chance of race detector seeing violations. - for _, v := range pl { + for _, v := range update.List { if v == (Port{}) { // Force use panic("empty port") @@ -209,9 +215,7 @@ func TestPoller(t *testing.T) { tick := make(chan time.Time, 16) go func() { defer wg.Done() - if err := p.runWithTickChan(context.Background(), tick); err != nil { - t.Error("runWithTickChan:", err) - } + p.runWithTickChan(context.Background(), tick) }() for i := 0; i < 10; i++ { ln, err := net.Listen("tcp", ":0")