diff --git a/ipn/ipnserver/server.go b/ipn/ipnserver/server.go index 4c7f5b561..4ee6a554e 100644 --- a/ipn/ipnserver/server.go +++ b/ipn/ipnserver/server.go @@ -60,6 +60,10 @@ type Options struct { // frontend disconnecting. If true, the server keeps running on // its existing state, and accepts new frontend connections. If // false, the server dumps its state and becomes idle. + // + // To support CLI connections (notably, "tailscale status"), + // the actual definition of "disconnect" is when the + // connection count transitions from 1 to 0. SurviveDisconnects bool // DebugMux, if non-nil, specifies an HTTP ServeMux in which @@ -71,24 +75,84 @@ type Options struct { ErrorMessage string } -func pump(logf logger.Logf, ctx context.Context, bs *ipn.BackendServer, s net.Conn) { - defer logf("Control connection done.") +// server is an IPN backend and its set of 0 or more active connections +// talking to an IPN backend. +type server struct { + resetOnZero bool // call bs.Reset on transition from 1->0 connections - for ctx.Err() == nil && !bs.GotQuit { - msg, err := ipn.ReadMsg(s) + bsMu sync.Mutex // lock order: bsMu, then mu + bs *ipn.BackendServer + + mu sync.Mutex + clients map[net.Conn]bool +} + +func (s *server) serveConn(ctx context.Context, c net.Conn, logf logger.Logf) { + s.addConn(c) + logf("incoming control connection") + defer s.removeAndCloseConn(c) + for ctx.Err() == nil { + msg, err := ipn.ReadMsg(c) if err != nil { - logf("ReadMsg: %v", err) - break + if ctx.Err() == nil { + logf("ReadMsg: %v", err) + } + return } - err = bs.GotCommandMsg(msg) - if err != nil { + s.bsMu.Lock() + if err := s.bs.GotCommandMsg(msg); err != nil { logf("GotCommandMsg: %v", err) - break + } + gotQuit := s.bs.GotQuit + s.bsMu.Unlock() + if gotQuit { + return } } } -func Run(rctx context.Context, logf logger.Logf, logid string, opts Options, e wgengine.Engine) error { +func (s *server) addConn(c net.Conn) { + s.mu.Lock() + defer s.mu.Unlock() + if s.clients == nil { + s.clients = map[net.Conn]bool{} + } + s.clients[c] = true +} + +func (s *server) removeAndCloseConn(c net.Conn) { + s.mu.Lock() + delete(s.clients, c) + remain := len(s.clients) + s.mu.Unlock() + + if remain == 0 && s.resetOnZero { + s.bsMu.Lock() + s.bs.Reset() + s.bsMu.Unlock() + } + c.Close() +} + +func (s *server) stopAll() { + s.mu.Lock() + defer s.mu.Unlock() + for c := range s.clients { + safesocket.ConnCloseRead(c) + safesocket.ConnCloseWrite(c) + } + s.clients = nil +} + +func (s *server) writeToClients(b []byte) { + s.mu.Lock() + defer s.mu.Unlock() + for c := range s.clients { + ipn.WriteMsg(c, b) + } +} + +func Run(ctx context.Context, logf logger.Logf, logid string, opts Options, e wgengine.Engine) error { runDone := make(chan struct{}) defer close(runDone) @@ -97,12 +161,18 @@ func Run(rctx context.Context, logf logger.Logf, logid string, opts Options, e w return fmt.Errorf("safesocket.Listen: %v", err) } - // Go listeners can't take a context, close it instead. + server := &server{ + resetOnZero: !opts.SurviveDisconnects, + } + + // When the context is closed or when we return, whichever is first, close our listner + // and all open connections. go func() { select { - case <-rctx.Done(): + case <-ctx.Done(): case <-runDone: } + server.stopAll() listen.Close() }() logf("Listening on %v", listen.Addr()) @@ -110,11 +180,11 @@ func Run(rctx context.Context, logf logger.Logf, logid string, opts Options, e w bo := backoff.NewBackoff("ipnserver", logf) if opts.ErrorMessage != "" { - for i := 1; rctx.Err() == nil; i++ { + for i := 1; ctx.Err() == nil; i++ { s, err := listen.Accept() if err != nil { logf("%d: Accept: %v", i, err) - bo.BackOff(rctx, err) + bo.BackOff(ctx, err) continue } serverToClient := func(b []byte) { @@ -127,7 +197,7 @@ func Run(rctx context.Context, logf logger.Logf, logid string, opts Options, e w s.Read(make([]byte, 1)) }() } - return rctx.Err() + return ctx.Err() } var store ipn.StateStore @@ -144,6 +214,7 @@ func Run(rctx context.Context, logf logger.Logf, logid string, opts Options, e w if err != nil { return fmt.Errorf("NewLocalBackend: %v", err) } + defer b.Shutdown() b.SetDecompressor(func() (controlclient.Decompressor, error) { return smallzstd.NewDecoder(nil) }) @@ -157,17 +228,10 @@ func Run(rctx context.Context, logf logger.Logf, logid string, opts Options, e w }) } - var s net.Conn - serverToClient := func(b []byte) { - if s != nil { // TODO: racy access to s? - ipn.WriteMsg(s, b) - } - } - - bs := ipn.NewBackendServer(logf, b, serverToClient) + server.bs = ipn.NewBackendServer(logf, b, server.writeToClients) if opts.AutostartStateKey != "" { - bs.GotCommand(&ipn.Command{ + server.bs.GotCommand(&ipn.Command{ Version: version.LONG, Start: &ipn.StartArgs{ Opts: ipn.Options{ @@ -178,54 +242,18 @@ func Run(rctx context.Context, logf logger.Logf, logid string, opts Options, e w }) } - var ( - oldS net.Conn - ctx context.Context - cancel context.CancelFunc - ) - stopAll := func() { - // Currently we only support one client connection at a time. - // Theoretically we could allow multiple clients, by passing - // notifications to all of them and accepting commands from - // any of them, but there doesn't seem to be much need for - // that right now. - if oldS != nil { - cancel() - safesocket.ConnCloseRead(oldS) - safesocket.ConnCloseWrite(oldS) - } - } - - for i := 1; rctx.Err() == nil; i++ { - s, err = listen.Accept() + for i := 1; ctx.Err() == nil; i++ { + c, err := listen.Accept() if err != nil { - logf("%d: Accept: %v", i, err) - bo.BackOff(rctx, err) + if ctx.Err() == nil { + logf("ipnserver: Accept: %v", err) + bo.BackOff(ctx, err) + } continue } - logf("%d: Incoming control connection.", i) - stopAll() - - ctx, cancel = context.WithCancel(rctx) - oldS = s - - go func(ctx context.Context, s net.Conn, i int) { - logf := logger.WithPrefix(logf, fmt.Sprintf("%d: ", i)) - pump(logf, ctx, bs, s) - if !opts.SurviveDisconnects || bs.GotQuit { - bs.Reset() - s.Close() - } - // Quitting not allowed, just keep going. - bs.GotQuit = false - }(ctx, s, i) - - bo.BackOff(ctx, nil) + go server.serveConn(ctx, c, logger.WithPrefix(logf, fmt.Sprintf("ipnserver: conn%d: ", i))) } - stopAll() - - b.Shutdown() - return rctx.Err() + return ctx.Err() } func BabysitProc(ctx context.Context, args []string, logf logger.Logf) {