Compare commits

...

1 Commits

Author SHA1 Message Date
Maisem Ali 457d279ace ssh/tailssh: make local port forwarding work w/o shell session
Fixes #5091

Signed-off-by: Maisem Ali <maisem@tailscale.com>
2022-10-06 11:30:03 -07:00
2 changed files with 72 additions and 26 deletions

View File

@ -331,6 +331,17 @@ func (srv *server) newConn() (*conn, error) {
// to the specified host and port.
// TODO(bradfitz/maisem): should we have more checks on host/port?
func (c *conn) mayForwardLocalPortTo(ctx ssh.Context, destinationHost string, destinationPort uint32) bool {
c.mu.Lock()
defer c.mu.Unlock()
if c.finalAction == nil {
// We haven't yet authenticated the user, this probably means that they
// are just requesting port forwarding and not a shell.
// We also do not have a reader or writer in this case so we can't
// read "Ctrl-C" or write messages to the user to prompt for check mode.
if _, err := c.authorizeSessionLocked(ctx, nil, nil); err != nil {
return false
}
}
if c.finalAction != nil && c.finalAction.AllowLocalPortForwarding {
metricLocalPortForward.Add(1)
return true
@ -554,31 +565,43 @@ func (srv *server) fetchPublicKeysURL(url string) ([]string, error) {
return lines, err
}
func (c *conn) authorizeSession(s ssh.Session) (_ *contextReader, ok bool) {
// authorizeSession authorizes the SSH session, returning an error if
// authorization fails.
// If a reader is provided, it will be monitored for a "Ctrl+C" sequence.
// If a writer is provided, it will be used to write any messages from
// the authorization process.
func (c *conn) authorizeSession(sctx ssh.Context, reader io.Reader, writer io.Writer) (*contextReader, error) {
c.mu.Lock()
defer c.mu.Unlock()
idH := s.Context().(ssh.Context).SessionID()
return c.authorizeSessionLocked(sctx, reader, writer)
}
// authorizeSessionLocked is like authorizeSession but requires c.mu to be held.
func (c *conn) authorizeSessionLocked(sctx ssh.Context, reader io.Reader, writer io.Writer) (*contextReader, error) {
idH := sctx.SessionID()
if c.idH == "" {
c.idH = idH
} else if c.idH != idH {
c.logf("ssh: session ID mismatch: %q != %q", c.idH, idH)
s.Exit(1)
return nil, false
return nil, fmt.Errorf("session ID mismatch")
}
cr := &contextReader{r: s}
action, err := c.resolveTerminalActionLocked(s, cr)
var cr *contextReader
if reader != nil {
cr = &contextReader{r: reader}
}
action, err := c.resolveTerminalActionLocked(sctx, cr, writer)
if err != nil {
c.logf("resolveTerminalAction: %v", err)
io.WriteString(s.Stderr(), "Access Denied: failed during authorization check.\r\n")
s.Exit(1)
return nil, false
if writer != nil {
io.WriteString(writer, "Access Denied: failed during authorization check.\r\n")
}
return nil, err
}
if action.Reject || !action.Accept {
c.logf("access denied for %v", c.info.uprof.LoginName)
s.Exit(1)
return nil, false
return nil, err
}
return cr, true
return cr, nil
}
// handleSessionPostSSHAuth runs an SSH session after the SSH-level authentication,
@ -588,8 +611,9 @@ func (c *conn) handleSessionPostSSHAuth(s ssh.Session) {
// Now that we have passed the SSH-level authentication, we can start the
// Tailscale-level extra verification. This means that we are going to
// evaluate the policy provided by control against the incoming SSH session.
cr, ok := c.authorizeSession(s)
if !ok {
cr, err := c.authorizeSession(s.Context().(ssh.Context), s, s.Stderr())
if err != nil {
s.Exit(1)
return
}
if cr.HasOutstandingRead() {
@ -624,18 +648,18 @@ func (c *conn) handleSessionPostSSHAuth(s ssh.Session) {
// The returned SSHAction will be either Reject or Accept.
//
// c.mu must be held.
func (c *conn) resolveTerminalActionLocked(s ssh.Session, cr *contextReader) (action *tailcfg.SSHAction, err error) {
func (c *conn) resolveTerminalActionLocked(sctx ssh.Context, cr *contextReader, stderr io.Writer) (action *tailcfg.SSHAction, err error) {
if c.finalAction != nil || c.finalActionErr != nil {
return c.finalAction, c.finalActionErr
}
if s.PublicKey() != nil {
if sctx.PublicKey() != nil {
metricPublicKeyConnections.Add(1)
}
defer func() {
c.finalAction = action
c.finalActionErr = err
c.pubKey = s.PublicKey()
c.pubKey = sctx.PublicKey()
if c.pubKey != nil && action.Accept {
metricPublicKeyAccepts.Add(1)
}
@ -644,10 +668,18 @@ func (c *conn) resolveTerminalActionLocked(s ssh.Session, cr *contextReader) (ac
var awaitReadOnce sync.Once // to start Reads on cr
var sawInterrupt atomic.Bool
var wg sync.WaitGroup
defer wg.Wait() // wait for awaitIntrOnce's goroutine to exit
var readError chan error
if cr != nil {
readError = make(chan error)
defer func() {
rerr := <-readError
if rerr != nil && err == nil {
err = rerr
}
}()
}
ctx, cancel := context.WithCancel(s.Context())
ctx, cancel := context.WithCancel(sctx)
defer cancel()
// Loop processing/fetching Actions until one reaches a
@ -657,8 +689,8 @@ func (c *conn) resolveTerminalActionLocked(s ssh.Session, cr *contextReader) (ac
// (Which is a long time for somebody to see login
// instructions and go to a URL to do something.)
for {
if action.Message != "" {
io.WriteString(s.Stderr(), strings.Replace(action.Message, "\n", "\r\n", -1))
if action.Message != "" && stderr != nil {
io.WriteString(stderr, strings.Replace(action.Message, "\n", "\r\n", -1))
}
if action.Accept || action.Reject {
if action.Reject {
@ -675,19 +707,23 @@ func (c *conn) resolveTerminalActionLocked(s ssh.Session, cr *contextReader) (ac
}
metricHolds.Add(1)
awaitReadOnce.Do(func() {
wg.Add(1)
if cr == nil {
return
}
go func() {
defer wg.Done()
buf := make([]byte, 1)
for {
n, err := cr.ReadContext(ctx, buf)
if err != nil {
readError <- nil
return
}
if n > 0 && buf[0] == 0x03 { // Ctrl-C
sawInterrupt.Store(true)
s.Stderr().Write([]byte("Canceled.\r\n"))
s.Exit(1)
if stderr != nil {
stderr.Write([]byte("Canceled.\r\n"))
}
readError <- context.Canceled
return
}
}

View File

@ -89,6 +89,8 @@ type Context interface {
// SetValue allows you to easily write new values into the underlying context.
SetValue(key, value interface{})
PublicKey() PublicKey
}
type sshContext struct {
@ -139,6 +141,14 @@ func (ctx *sshContext) ServerVersion() string {
return ctx.Value(ContextKeyServerVersion).(string)
}
func (ctx *sshContext) PublicKey() PublicKey {
sessionkey := ctx.Value(ContextKeyPublicKey)
if sessionkey == nil {
return nil
}
return sessionkey.(PublicKey)
}
func (ctx *sshContext) RemoteAddr() net.Addr {
if addr, ok := ctx.Value(ContextKeyRemoteAddr).(net.Addr); ok {
return addr