From 964d723aba1057280399d37e6e3c4ac7ad5a217c Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Thu, 1 Dec 2022 14:39:03 -0800 Subject: [PATCH] ipn/{ipnserver,localapi}: fix InUseOtherUser handling with WatchIPNBus Updates tailscale/corp#8222 Change-Id: I2d6fa6514c7b8d0f89fded35a2d44e7df27e6fb1 Signed-off-by: Brad Fitzpatrick --- ipn/ipnserver/server.go | 90 ++++++++++++++++++++++++++++++------ ipn/ipnserver/server_test.go | 47 +++++++++++++++++++ ipn/localapi/localapi.go | 27 +++++++++++ 3 files changed, 149 insertions(+), 15 deletions(-) create mode 100644 ipn/ipnserver/server_test.go diff --git a/ipn/ipnserver/server.go b/ipn/ipnserver/server.go index b7bfaee2c..06944371c 100644 --- a/ipn/ipnserver/server.go +++ b/ipn/ipnserver/server.go @@ -52,7 +52,8 @@ type Server struct { mu sync.Mutex lastUserID ipn.WindowsUserID // tracks last userid; on change, Reset state for paranoia activeReqs map[*http.Request]*ipnauth.ConnIdentity - backendWaiter set.HandleSet[context.CancelFunc] // values are wake-up funcs of lb waiters + backendWaiter waiterSet // of LocalBackend waiters + zeroReqWaiter waiterSet // of blockUntilZeroConnections waiters } func (s *Server) mustBackend() *ipnlocal.LocalBackend { @@ -63,22 +64,47 @@ func (s *Server) mustBackend() *ipnlocal.LocalBackend { return lb } +// waiterSet is a set of callers waiting on something. Each item (map value) in +// the set is a func that wakes up that waiter's context. The waiter is responsible +// for removing itself from the set when woken up. The (*waiterSet).add method +// returns a cleanup method which does that removal. The caller than defers that +// cleanup. +// +// TODO(bradfitz): this is a generally useful pattern. Move elsewhere? +type waiterSet set.HandleSet[context.CancelFunc] + +// add registers a new waiter in the set. +// It aquires mu to add the waiter, and does so again when cleanup is called to remove it. +// ready is closed when the waiter is ready (or ctx is done). +func (s *waiterSet) add(mu *sync.Mutex, ctx context.Context) (ready <-chan struct{}, cleanup func()) { + ctx, cancel := context.WithCancel(ctx) + hs := (*set.HandleSet[context.CancelFunc])(s) // change method set + mu.Lock() + h := hs.Add(cancel) + mu.Unlock() + return ctx.Done(), func() { + mu.Lock() + delete(*hs, h) + mu.Unlock() + cancel() + } +} + +// wakeAll wakes up all waiters in the set. +func (w waiterSet) wakeAll() { + for _, cancel := range w { + cancel() // they'll remove themselves + } +} + func (s *Server) awaitBackend(ctx context.Context) (_ *ipnlocal.LocalBackend, ok bool) { lb := s.lb.Load() if lb != nil { return lb, true } - ctx, cancel := context.WithCancel(ctx) - defer cancel() - s.mu.Lock() - h := s.backendWaiter.Add(cancel) - s.mu.Unlock() - defer func() { - s.mu.Lock() - delete(s.backendWaiter, h) - s.mu.Unlock() - }() + ready, cleanup := s.backendWaiter.add(&s.mu, ctx) + defer cleanup() // Try again, now that we've registered, in case there was a // race. @@ -87,7 +113,7 @@ func (s *Server) awaitBackend(ctx context.Context) (_ *ipnlocal.LocalBackend, ok return lb, true } - <-ctx.Done() + <-ready lb = s.lb.Load() return lb, lb != nil } @@ -160,6 +186,11 @@ func (s *Server) serveHTTP(w http.ResponseWriter, r *http.Request) { onDone, err := s.addActiveHTTPRequest(r, ci) if err != nil { + if ou, ok := err.(inUseOtherUserError); ok && localapi.InUseOtherUserIPNStream(w, r, ou.Unwrap()) { + w.(http.Flusher).Flush() + s.blockWhileIdentityInUse(ctx, ci) + return + } http.Error(w, err.Error(), http.StatusUnauthorized) return } @@ -219,6 +250,30 @@ func (s *Server) checkConnIdentityLocked(ci *ipnauth.ConnIdentity) error { return nil } +// blockWhileIdentityInUse blocks while ci can't connect to the server because +// the server is in use by a different user. +// +// This is primarily used for the Windows GUI, to block until one user's done +// controlling the tailscaled process. +func (s *Server) blockWhileIdentityInUse(ctx context.Context, ci *ipnauth.ConnIdentity) error { + inUse := func() bool { + s.mu.Lock() + defer s.mu.Unlock() + _, ok := s.checkConnIdentityLocked(ci).(inUseOtherUserError) + return ok + } + for inUse() { + // Check whenever the connection count drops down to zero. + ready, cleanup := s.zeroReqWaiter.add(&s.mu, ctx) + <-ready + cleanup() + if err := ctx.Err(); err != nil { + return err + } + } + return nil +} + // localAPIPermissions returns the permissions for the given identity accessing // the Tailscale local daemon API. // @@ -340,6 +395,13 @@ func (s *Server) addActiveHTTPRequest(req *http.Request, ci *ipnauth.ConnIdentit lb.ResetForClientDisconnect() } } + + // Wake up callers waiting for the server to be idle: + if remain == 0 { + s.mu.Lock() + s.zeroReqWaiter.wakeAll() + s.mu.Unlock() + } } return onDone, nil @@ -373,9 +435,7 @@ func (s *Server) SetLocalBackend(lb *ipnlocal.LocalBackend) { s.startBackendIfNeeded() s.mu.Lock() - for _, wake := range s.backendWaiter { - wake() // they'll remove themselves when woken - } + s.backendWaiter.wakeAll() s.mu.Unlock() // TODO(bradfitz): send status update to GUI long poller waiter. See diff --git a/ipn/ipnserver/server_test.go b/ipn/ipnserver/server_test.go new file mode 100644 index 000000000..4eba917ab --- /dev/null +++ b/ipn/ipnserver/server_test.go @@ -0,0 +1,47 @@ +// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ipnserver + +import ( + "context" + "sync" + "testing" +) + +func TestWaiterSet(t *testing.T) { + var s waiterSet + + wantLen := func(want int, when string) { + t.Helper() + if got := len(s); got != want { + t.Errorf("%s: len = %v; want %v", when, got, want) + } + } + wantLen(0, "initial") + var mu sync.Mutex + ctx, cancel := context.WithCancel(context.Background()) + + ready, cleanup := s.add(&mu, ctx) + wantLen(1, "after add") + + select { + case <-ready: + t.Fatal("should not be ready") + default: + } + s.wakeAll() + <-ready + + wantLen(1, "after fire") + cleanup() + wantLen(0, "after cleanup") + + // And again but on an already-expired ctx. + cancel() + ready, cleanup = s.add(&mu, ctx) + <-ready // shouldn't block + cleanup() + wantLen(0, "at end") +} diff --git a/ipn/localapi/localapi.go b/ipn/localapi/localapi.go index b5c5a364a..7e4296d2b 100644 --- a/ipn/localapi/localapi.go +++ b/ipn/localapi/localapi.go @@ -40,6 +40,7 @@ import ( "tailscale.com/tka" "tailscale.com/types/key" "tailscale.com/types/logger" + "tailscale.com/types/ptr" "tailscale.com/util/clientmetric" "tailscale.com/util/mak" "tailscale.com/util/strs" @@ -607,6 +608,32 @@ func (h *Handler) serveStatus(w http.ResponseWriter, r *http.Request) { e.Encode(st) } +// InUseOtherUserIPNStream reports whether r is a request for the watch-ipn-bus +// handler. If so, it writes an ipn.Notify InUseOtherUser message to the user +// and returns true. Otherwise it returns false, in which case it doesn't write +// to w. +// +// Unlike the regular watch-ipn-bus handler, this one doesn't block. The caller +// (in ipnserver.Server) provides the blocking until the connection is no longer +// in use. +func InUseOtherUserIPNStream(w http.ResponseWriter, r *http.Request, err error) (handled bool) { + if r.Method != "GET" || r.URL.Path != "/localapi/v0/watch-ipn-bus" { + return false + } + js, err := json.Marshal(&ipn.Notify{ + Version: version.Long, + State: ptr.To(ipn.InUseOtherUser), + ErrMessage: ptr.To(err.Error()), + }) + if err != nil { + return false + } + js = append(js, '\n') + w.Header().Set("Content-Type", "application/json") + w.Write(js) + return true +} + func (h *Handler) serveWatchIPNBus(w http.ResponseWriter, r *http.Request) { if !h.PermitWrite { http.Error(w, "denied", http.StatusForbidden)