Make it even simpler by removing Init

marwan/portlistrefactor
Marwan Sulaiman 2023-05-24 23:57:09 -04:00
parent 4b807ca95e
commit 9bbb2b0911
3 changed files with 39 additions and 50 deletions

View File

@ -146,9 +146,8 @@ type LocalBackend struct {
backendLogID logid.PublicID
unregisterNetMon func()
unregisterHealthWatch func()
portpoll *portlist.Poller // may be nil
portpollOnce sync.Once // guards starting readPoller
gotPortPollRes chan struct{} // closed upon first readPoller result
portpollOnce sync.Once // guards starting readPoller
gotPortPollRes chan struct{} // closed upon first readPoller result
newDecompressor func() (controlclient.Decompressor, error)
varRoot string // or empty if SetVarRoot never called
logFlushFunc func() // or nil if SetLogFlusher wasn't called
@ -292,12 +291,6 @@ func NewLocalBackend(logf logger.Logf, logID logid.PublicID, sys *tsd.System, lo
osshare.SetFileSharingEnabled(false, logf)
ctx, cancel := context.WithCancel(context.Background())
portpoll := new(portlist.Poller)
err = portpoll.Init()
if err != nil {
logf("skipping portlist: %s", err)
portpoll = nil
}
b := &LocalBackend{
ctx: ctx,
@ -312,7 +305,6 @@ func NewLocalBackend(logf logger.Logf, logID logid.PublicID, sys *tsd.System, lo
pm: pm,
backendLogID: logID,
state: ipn.NoState,
portpoll: portpoll,
em: newExpiryManager(logf),
gotPortPollRes: make(chan struct{}),
loginFlags: loginFlags,
@ -1377,30 +1369,32 @@ func (b *LocalBackend) Start(opts ipn.Options) error {
b.updateFilterLocked(nil, ipn.PrefsView{})
b.mu.Unlock()
if b.portpoll != nil {
b.portpollOnce.Do(func() {
updates, err := b.portpoll.Run(b.ctx)
if err != nil {
b.logf("error running port poller: %v", err)
return
}
go b.readPoller(updates)
b.portpollOnce.Do(func() {
var p portlist.Poller
updates, err := p.Run(b.ctx)
if err != nil {
b.logf("skipping portlist: %s", err)
return
}
go func() {
defer p.Close()
b.readPoller(updates)
}()
// Give the poller a second to get results to
// prevent it from restarting our map poll
// HTTP request (via doSetHostinfoFilterServices >
// cli.SetHostinfo). In practice this is very quick.
t0 := time.Now()
timer := time.NewTimer(time.Second)
select {
case <-b.gotPortPollRes:
b.logf("[v1] got initial portlist info in %v", time.Since(t0).Round(time.Millisecond))
timer.Stop()
case <-timer.C:
b.logf("timeout waiting for initial portlist")
}
})
}
// Give the poller a second to get results to
// prevent it from restarting our map poll
// HTTP request (via doSetHostinfoFilterServices >
// cli.SetHostinfo). In practice this is very quick.
t0 := time.Now()
timer := time.NewTimer(time.Second)
select {
case <-b.gotPortPollRes:
b.logf("[v1] got initial portlist info in %v", time.Since(t0).Round(time.Millisecond))
timer.Stop()
case <-timer.C:
b.logf("timeout waiting for initial portlist")
}
})
discoPublic := b.e.DiscoPublicKey()
@ -1818,16 +1812,11 @@ func dnsMapsEqual(new, old *netmap.NetworkMap) bool {
// readPoller is a goroutine that receives service lists from
// b.portpoll and propagates them into the controlclient's HostInfo.
func (b *LocalBackend) readPoller(updates chan portlist.Update) {
defer b.portpoll.Close()
n := 0
for {
update, ok := <-updates
if !ok {
return
}
firstResults := true
for update := range updates {
if update.Error != nil {
b.logf("error polling os ports: %v", update.Error)
return // preserve all behavior, though we can just continue
return // preserve old behavior, though we can just continue and try again?
}
sl := []tailcfg.Service{}
for _, p := range update.List {
@ -1851,8 +1840,8 @@ func (b *LocalBackend) readPoller(updates chan portlist.Update) {
b.doSetHostinfoFilterServices(hi)
n++
if n == 1 {
if firstResults {
firstResults = false
close(b.gotPortPollRes)
}
}

View File

@ -97,20 +97,20 @@ func (p *Poller) setPrev(pl List) {
p.prev = slices.Clone(pl)
}
// Init is an optional method that makes sure the Poller is enabled
// init makes sure the Poller is enabled
// and the undelrying OS implementation is working properly.
//
// An error returned from Init is non-fatal and means
// An error returned from init is non-fatal and means
// that it's been administratively disabled or the underlying
// OS is not implemented.
func (p *Poller) Init() error {
func (p *Poller) init() error {
p.initOnce.Do(func() {
p.initErr = p.init()
p.initErr = p.initWithErr()
})
return p.initErr
}
func (p *Poller) init() error {
func (p *Poller) initWithErr() error {
if debugDisablePortlist() {
return errors.New("portlist disabled by envknob")
}
@ -170,7 +170,7 @@ func (p *Poller) send(ctx context.Context, pl List, plErr error) (sent bool) {
// Run may only be called once.
func (p *Poller) Run(ctx context.Context) (chan Update, error) {
if p.os == nil {
err := p.Init()
err := p.init()
if err != nil {
return nil, fmt.Errorf("error initializing poller: %w", err)
}

View File

@ -179,7 +179,7 @@ func TestEqualLessThan(t *testing.T) {
func TestPoller(t *testing.T) {
var p Poller
err := p.Init()
err := p.init()
if err != nil {
t.Skipf("not running test: %v", err)
}