portlist: Accept Options for NewPoller

This is a follow up on PR #8172 and a breaking change that allows NewPoller to take an options struct.
The issue with the previous PR was that NewPoller immediately initializes the underlying os implementation
and therefore setting IncludeLocalhost as an exported field happened too late and cannot happen early enough.
Using the zero value of Poller was also not an option from outside of the package because we need to set initial
private fields

Fixes #8171

Signed-off-by: Marwan Sulaiman <marwan@tailscale.com>
marwan/polleropts
Marwan Sulaiman 2023-05-24 16:35:00 -04:00
parent e32e5c0d0c
commit aa528bb7bf
3 changed files with 31 additions and 13 deletions

View File

@ -292,7 +292,7 @@ func NewLocalBackend(logf logger.Logf, logID logid.PublicID, sys *tsd.System, lo
osshare.SetFileSharingEnabled(false, logf) osshare.SetFileSharingEnabled(false, logf)
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
portpoll, err := portlist.NewPoller() portpoll, err := portlist.NewPoller(portlist.PollerOptions{})
if err != nil { if err != nil {
logf("skipping portlist: %s", err) logf("skipping portlist: %s", err)
} }

View File

@ -24,11 +24,6 @@ var debugDisablePortlist = envknob.RegisterBool("TS_DEBUG_DISABLE_PORTLIST")
// Poller scans the systems for listening ports periodically and sends // Poller scans the systems for listening ports periodically and sends
// the results to C. // the results to C.
type Poller struct { type Poller struct {
// IncludeLocalhost controls whether services bound to localhost are included.
//
// This field should only be changed before calling Run.
IncludeLocalhost bool
c chan List // unbuffered c chan List // unbuffered
// os, if non-nil, is an OS-specific implementation of the portlist getting // os, if non-nil, is an OS-specific implementation of the portlist getting
@ -50,6 +45,10 @@ type Poller struct {
scratch []Port scratch []Port
prev List // most recent data, not aliasing scratch prev List // most recent data, not aliasing scratch
// caller options fields
includeLocalhost bool
pollInterval time.Duration
} }
// osImpl is the OS-specific implementation of getting the open listening ports. // osImpl is the OS-specific implementation of getting the open listening ports.
@ -71,15 +70,34 @@ var newOSImpl func(includeLocalhost bool) osImpl
var errUnimplemented = errors.New("portlist poller not implemented on " + runtime.GOOS) var errUnimplemented = errors.New("portlist poller not implemented on " + runtime.GOOS)
// PollerOptions for customizing the behavior
// of the Poller. The zero value uses each
// of the options' defaults.
type PollerOptions struct {
// IncludeLocalhost controls whether services bound to localhost are included.
//
// This field should only be changed before calling Run.
IncludeLocalhost bool
// PollInterval sets the interval for checking the underlying OS
// for port updates.
PollInterval time.Duration
}
// NewPoller returns a new portlist Poller. It returns an error // NewPoller returns a new portlist Poller. It returns an error
// if the portlist couldn't be obtained. // if the portlist couldn't be obtained.
func NewPoller() (*Poller, error) { func NewPoller(opts PollerOptions) (*Poller, error) {
if debugDisablePortlist() { if debugDisablePortlist() {
return nil, errors.New("portlist disabled by envknob") return nil, errors.New("portlist disabled by envknob")
} }
if opts.PollInterval == 0 {
opts.PollInterval = pollInterval
}
p := &Poller{ p := &Poller{
c: make(chan List), c: make(chan List),
runDone: make(chan struct{}), runDone: make(chan struct{}),
includeLocalhost: opts.IncludeLocalhost,
pollInterval: opts.PollInterval,
} }
p.closeCtx, p.closeCtxCancel = context.WithCancel(context.Background()) p.closeCtx, p.closeCtxCancel = context.WithCancel(context.Background())
p.osOnce.Do(p.initOSField) p.osOnce.Do(p.initOSField)
@ -105,7 +123,7 @@ func (p *Poller) setPrev(pl List) {
func (p *Poller) initOSField() { func (p *Poller) initOSField() {
if newOSImpl != nil { if newOSImpl != nil {
p.os = newOSImpl(p.IncludeLocalhost) p.os = newOSImpl(p.includeLocalhost)
} }
} }
@ -142,7 +160,7 @@ func (p *Poller) send(ctx context.Context, pl List) (sent bool, err error) {
// //
// Run may only be called once. // Run may only be called once.
func (p *Poller) Run(ctx context.Context) error { func (p *Poller) Run(ctx context.Context) error {
tick := time.NewTicker(pollInterval) tick := time.NewTicker(p.pollInterval)
defer tick.Stop() defer tick.Stop()
return p.runWithTickChan(ctx, tick.C) return p.runWithTickChan(ctx, tick.C)
} }

View File

@ -51,7 +51,7 @@ func TestIgnoreLocallyBoundPorts(t *testing.T) {
func TestChangesOverTime(t *testing.T) { func TestChangesOverTime(t *testing.T) {
var p Poller var p Poller
p.IncludeLocalhost = true p.includeLocalhost = true
get := func(t *testing.T) []Port { get := func(t *testing.T) []Port {
t.Helper() t.Helper()
s, err := p.getList() s, err := p.getList()
@ -176,7 +176,7 @@ func TestEqualLessThan(t *testing.T) {
} }
func TestPoller(t *testing.T) { func TestPoller(t *testing.T) {
p, err := NewPoller() p, err := NewPoller(PollerOptions{})
if err != nil { if err != nil {
t.Skipf("not running test: %v", err) t.Skipf("not running test: %v", err)
} }