ssh/tailssh: cache public keys fetched from URLs
Updates #3802 Change-Id: I96715bae02bce6ea19f16b1736d1bbcd7bcf3534 Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>pull/4431/head
parent
3ffd88a84a
commit
93221b4535
|
@ -53,10 +53,21 @@ type server struct {
|
||||||
logf logger.Logf
|
logf logger.Logf
|
||||||
tailscaledPath string
|
tailscaledPath string
|
||||||
|
|
||||||
// mu protects activeSessions.
|
pubKeyHTTPClient *http.Client // or nil for http.DefaultClient
|
||||||
|
timeNow func() time.Time // or nil for time.Now
|
||||||
|
|
||||||
|
// mu protects the following
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
activeSessionByH map[string]*sshSession // ssh.SessionID (DH H) => that session
|
activeSessionByH map[string]*sshSession // ssh.SessionID (DH H) => session
|
||||||
activeSessionBySharedID map[string]*sshSession // yyymmddThhmmss-XXXXX => session
|
activeSessionBySharedID map[string]*sshSession // yyymmddThhmmss-XXXXX => session
|
||||||
|
fetchPublicKeysCache map[string]pubKeyCacheEntry // by https URL
|
||||||
|
}
|
||||||
|
|
||||||
|
func (srv *server) now() time.Time {
|
||||||
|
if srv.timeNow != nil {
|
||||||
|
return srv.timeNow()
|
||||||
|
}
|
||||||
|
return time.Now()
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
@ -264,7 +275,7 @@ func (srv *server) evaluatePolicy(sshUser string, localAddr, remoteAddr netaddr.
|
||||||
return nil, nil, "", fmt.Errorf("unknown Tailscale identity from src %v", remoteAddr)
|
return nil, nil, "", fmt.Errorf("unknown Tailscale identity from src %v", remoteAddr)
|
||||||
}
|
}
|
||||||
ci := &sshConnInfo{
|
ci := &sshConnInfo{
|
||||||
now: time.Now(),
|
now: srv.now(),
|
||||||
fetchPublicKeysURL: srv.fetchPublicKeysURL,
|
fetchPublicKeysURL: srv.fetchPublicKeysURL,
|
||||||
sshUser: sshUser,
|
sshUser: sshUser,
|
||||||
src: remoteAddr,
|
src: remoteAddr,
|
||||||
|
@ -280,11 +291,58 @@ func (srv *server) evaluatePolicy(sshUser string, localAddr, remoteAddr netaddr.
|
||||||
return a, ci, localUser, nil
|
return a, ci, localUser, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// pubKeyCacheEntry is the cache value for an HTTPS URL of public keys (like
|
||||||
|
// "https://github.com/foo.keys")
|
||||||
|
type pubKeyCacheEntry struct {
|
||||||
|
lines []string
|
||||||
|
etag string // if sent by server
|
||||||
|
at time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
pubKeyCacheDuration = time.Minute // how long to cache non-empty public keys
|
||||||
|
pubKeyCacheEmptyDuration = 15 * time.Second // how long to cache empty responses
|
||||||
|
)
|
||||||
|
|
||||||
|
func (srv *server) fetchPublicKeysURLCached(url string) (ce pubKeyCacheEntry, ok bool) {
|
||||||
|
srv.mu.Lock()
|
||||||
|
defer srv.mu.Unlock()
|
||||||
|
// Mostly don't care about the size of this cache. Clean rarely.
|
||||||
|
if m := srv.fetchPublicKeysCache; len(m) > 50 {
|
||||||
|
tooOld := srv.now().Add(pubKeyCacheDuration * 10)
|
||||||
|
for k, ce := range m {
|
||||||
|
if ce.at.Before(tooOld) {
|
||||||
|
delete(m, k)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ce, ok = srv.fetchPublicKeysCache[url]
|
||||||
|
if !ok {
|
||||||
|
return ce, false
|
||||||
|
}
|
||||||
|
maxAge := pubKeyCacheDuration
|
||||||
|
if len(ce.lines) == 0 {
|
||||||
|
maxAge = pubKeyCacheEmptyDuration
|
||||||
|
}
|
||||||
|
return ce, srv.now().Sub(ce.at) < maxAge
|
||||||
|
}
|
||||||
|
|
||||||
|
func (srv *server) pubKeyClient() *http.Client {
|
||||||
|
if srv.pubKeyHTTPClient != nil {
|
||||||
|
return srv.pubKeyHTTPClient
|
||||||
|
}
|
||||||
|
return http.DefaultClient
|
||||||
|
}
|
||||||
|
|
||||||
func (srv *server) fetchPublicKeysURL(url string) ([]string, error) {
|
func (srv *server) fetchPublicKeysURL(url string) ([]string, error) {
|
||||||
if !strings.HasPrefix(url, "https://") {
|
if !strings.HasPrefix(url, "https://") {
|
||||||
return nil, errors.New("invalid URL scheme")
|
return nil, errors.New("invalid URL scheme")
|
||||||
}
|
}
|
||||||
// TODO(bradfitz): add caching
|
|
||||||
|
ce, ok := srv.fetchPublicKeysURLCached(url)
|
||||||
|
if ok {
|
||||||
|
return ce.lines, nil
|
||||||
|
}
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
@ -292,16 +350,40 @@ func (srv *server) fetchPublicKeysURL(url string) ([]string, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
res, err := http.DefaultClient.Do(req)
|
if ce.etag != "" {
|
||||||
|
req.Header.Add("If-None-Match", ce.etag)
|
||||||
|
}
|
||||||
|
res, err := srv.pubKeyClient().Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer res.Body.Close()
|
defer res.Body.Close()
|
||||||
if res.StatusCode != http.StatusOK {
|
var lines []string
|
||||||
return nil, errors.New(res.Status)
|
var etag string
|
||||||
|
switch res.StatusCode {
|
||||||
|
default:
|
||||||
|
err = fmt.Errorf("unexpected status %v", res.Status)
|
||||||
|
srv.logf("fetching public keys from %s: %v", url, err)
|
||||||
|
case http.StatusNotModified:
|
||||||
|
lines = ce.lines
|
||||||
|
etag = ce.etag
|
||||||
|
case http.StatusOK:
|
||||||
|
var all []byte
|
||||||
|
all, err = io.ReadAll(io.LimitReader(res.Body, 4<<10))
|
||||||
|
if s := strings.TrimSpace(string(all)); s != "" {
|
||||||
|
lines = strings.Split(s, "\n")
|
||||||
}
|
}
|
||||||
all, err := io.ReadAll(io.LimitReader(res.Body, 4<<10))
|
etag = res.Header.Get("Etag")
|
||||||
return strings.Split(string(all), "\n"), err
|
}
|
||||||
|
|
||||||
|
srv.mu.Lock()
|
||||||
|
defer srv.mu.Unlock()
|
||||||
|
mapSet(&srv.fetchPublicKeysCache, url, pubKeyCacheEntry{
|
||||||
|
at: srv.now(),
|
||||||
|
lines: lines,
|
||||||
|
etag: etag,
|
||||||
|
})
|
||||||
|
return lines, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleSSH is invoked when a new SSH connection attempt is made.
|
// handleSSH is invoked when a new SSH connection attempt is made.
|
||||||
|
@ -523,26 +605,20 @@ func (srv *server) getSessionForContext(sctx ssh.Context) (ss *sshSession, ok bo
|
||||||
func (srv *server) startSession(ss *sshSession) {
|
func (srv *server) startSession(ss *sshSession) {
|
||||||
srv.mu.Lock()
|
srv.mu.Lock()
|
||||||
defer srv.mu.Unlock()
|
defer srv.mu.Unlock()
|
||||||
if srv.activeSessionByH == nil {
|
|
||||||
srv.activeSessionByH = make(map[string]*sshSession)
|
|
||||||
}
|
|
||||||
if srv.activeSessionBySharedID == nil {
|
|
||||||
srv.activeSessionBySharedID = make(map[string]*sshSession)
|
|
||||||
}
|
|
||||||
if ss.idH == "" {
|
if ss.idH == "" {
|
||||||
panic("empty idH")
|
panic("empty idH")
|
||||||
}
|
}
|
||||||
if _, dup := srv.activeSessionByH[ss.idH]; dup {
|
|
||||||
panic("dup idH")
|
|
||||||
}
|
|
||||||
if ss.sharedID == "" {
|
if ss.sharedID == "" {
|
||||||
panic("empty sharedID")
|
panic("empty sharedID")
|
||||||
}
|
}
|
||||||
|
if _, dup := srv.activeSessionByH[ss.idH]; dup {
|
||||||
|
panic("dup idH")
|
||||||
|
}
|
||||||
if _, dup := srv.activeSessionBySharedID[ss.sharedID]; dup {
|
if _, dup := srv.activeSessionBySharedID[ss.sharedID]; dup {
|
||||||
panic("dup sharedID")
|
panic("dup sharedID")
|
||||||
}
|
}
|
||||||
srv.activeSessionByH[ss.idH] = ss
|
mapSet(&srv.activeSessionByH, ss.idH, ss)
|
||||||
srv.activeSessionBySharedID[ss.sharedID] = ss
|
mapSet(&srv.activeSessionBySharedID, ss.sharedID, ss)
|
||||||
}
|
}
|
||||||
|
|
||||||
// endSession unregisters s from the list of active sessions.
|
// endSession unregisters s from the list of active sessions.
|
||||||
|
@ -1057,3 +1133,11 @@ func envEq(a, b string) bool {
|
||||||
}
|
}
|
||||||
return a == b
|
return a == b
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// mapSet assigns m[k] = v, making m if necessary.
|
||||||
|
func mapSet[K comparable, V any](m *map[K]V, k K, v V) {
|
||||||
|
if *m == nil {
|
||||||
|
*m = make(map[K]V)
|
||||||
|
}
|
||||||
|
(*m)[k] = v
|
||||||
|
}
|
||||||
|
|
|
@ -9,13 +9,19 @@ package tailssh
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"crypto/sha256"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"net"
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
"os"
|
"os"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"os/user"
|
"os/user"
|
||||||
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -25,6 +31,7 @@ import (
|
||||||
"tailscale.com/net/tsdial"
|
"tailscale.com/net/tsdial"
|
||||||
"tailscale.com/tailcfg"
|
"tailscale.com/tailcfg"
|
||||||
"tailscale.com/tempfork/gliderlabs/ssh"
|
"tailscale.com/tempfork/gliderlabs/ssh"
|
||||||
|
"tailscale.com/tstest"
|
||||||
"tailscale.com/types/logger"
|
"tailscale.com/types/logger"
|
||||||
"tailscale.com/util/cibuild"
|
"tailscale.com/util/cibuild"
|
||||||
"tailscale.com/util/lineread"
|
"tailscale.com/util/lineread"
|
||||||
|
@ -336,3 +343,63 @@ func parseEnv(out []byte) map[string]string {
|
||||||
})
|
})
|
||||||
return e
|
return e
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestPublicKeyFetching(t *testing.T) {
|
||||||
|
var reqsTotal, reqsIfNoneMatchHit, reqsIfNoneMatchMiss int32
|
||||||
|
ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
atomic.AddInt32((&reqsTotal), 1)
|
||||||
|
etag := fmt.Sprintf("W/%q", sha256.Sum256([]byte(r.URL.Path)))
|
||||||
|
w.Header().Set("Etag", etag)
|
||||||
|
if v := r.Header.Get("If-None-Match"); v != "" {
|
||||||
|
if v == etag {
|
||||||
|
atomic.AddInt32(&reqsIfNoneMatchHit, 1)
|
||||||
|
w.WriteHeader(304)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
atomic.AddInt32(&reqsIfNoneMatchMiss, 1)
|
||||||
|
}
|
||||||
|
io.WriteString(w, "foo\nbar\n"+string(r.URL.Path)+"\n")
|
||||||
|
}))
|
||||||
|
ts.StartTLS()
|
||||||
|
defer ts.Close()
|
||||||
|
keys := ts.URL
|
||||||
|
|
||||||
|
clock := &tstest.Clock{}
|
||||||
|
srv := &server{
|
||||||
|
pubKeyHTTPClient: ts.Client(),
|
||||||
|
timeNow: clock.Now,
|
||||||
|
}
|
||||||
|
for i := 0; i < 2; i++ {
|
||||||
|
got, err := srv.fetchPublicKeysURL(keys + "/alice.keys")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if want := []string{"foo", "bar", "/alice.keys"}; !reflect.DeepEqual(got, want) {
|
||||||
|
t.Errorf("got %q; want %q", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if got, want := atomic.LoadInt32(&reqsTotal), int32(1); got != want {
|
||||||
|
t.Errorf("got %d requests; want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := atomic.LoadInt32(&reqsIfNoneMatchHit), int32(0); got != want {
|
||||||
|
t.Errorf("got %d etag hits; want %d", got, want)
|
||||||
|
}
|
||||||
|
clock.Advance(5 * time.Minute)
|
||||||
|
got, err := srv.fetchPublicKeysURL(keys + "/alice.keys")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if want := []string{"foo", "bar", "/alice.keys"}; !reflect.DeepEqual(got, want) {
|
||||||
|
t.Errorf("got %q; want %q", got, want)
|
||||||
|
}
|
||||||
|
if got, want := atomic.LoadInt32(&reqsTotal), int32(2); got != want {
|
||||||
|
t.Errorf("got %d requests; want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := atomic.LoadInt32(&reqsIfNoneMatchHit), int32(1); got != want {
|
||||||
|
t.Errorf("got %d etag hits; want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := atomic.LoadInt32(&reqsIfNoneMatchMiss), int32(0); got != want {
|
||||||
|
t.Errorf("got %d etag misses; want %d", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue