ssh/tailssh: make local port forwarding work w/o shell session
Fixes #5091 Signed-off-by: Maisem Ali <maisem@tailscale.com>pull/5858/head
parent
dd045a3767
commit
457d279ace
|
@ -331,6 +331,17 @@ func (srv *server) newConn() (*conn, error) {
|
|||
// to the specified host and port.
|
||||
// TODO(bradfitz/maisem): should we have more checks on host/port?
|
||||
func (c *conn) mayForwardLocalPortTo(ctx ssh.Context, destinationHost string, destinationPort uint32) bool {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if c.finalAction == nil {
|
||||
// We haven't yet authenticated the user, this probably means that they
|
||||
// are just requesting port forwarding and not a shell.
|
||||
// We also do not have a reader or writer in this case so we can't
|
||||
// read "Ctrl-C" or write messages to the user to prompt for check mode.
|
||||
if _, err := c.authorizeSessionLocked(ctx, nil, nil); err != nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
if c.finalAction != nil && c.finalAction.AllowLocalPortForwarding {
|
||||
metricLocalPortForward.Add(1)
|
||||
return true
|
||||
|
@ -554,31 +565,43 @@ func (srv *server) fetchPublicKeysURL(url string) ([]string, error) {
|
|||
return lines, err
|
||||
}
|
||||
|
||||
func (c *conn) authorizeSession(s ssh.Session) (_ *contextReader, ok bool) {
|
||||
// authorizeSession authorizes the SSH session, returning an error if
|
||||
// authorization fails.
|
||||
// If a reader is provided, it will be monitored for a "Ctrl+C" sequence.
|
||||
// If a writer is provided, it will be used to write any messages from
|
||||
// the authorization process.
|
||||
func (c *conn) authorizeSession(sctx ssh.Context, reader io.Reader, writer io.Writer) (*contextReader, error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
idH := s.Context().(ssh.Context).SessionID()
|
||||
return c.authorizeSessionLocked(sctx, reader, writer)
|
||||
}
|
||||
|
||||
// authorizeSessionLocked is like authorizeSession but requires c.mu to be held.
|
||||
func (c *conn) authorizeSessionLocked(sctx ssh.Context, reader io.Reader, writer io.Writer) (*contextReader, error) {
|
||||
idH := sctx.SessionID()
|
||||
if c.idH == "" {
|
||||
c.idH = idH
|
||||
} else if c.idH != idH {
|
||||
c.logf("ssh: session ID mismatch: %q != %q", c.idH, idH)
|
||||
s.Exit(1)
|
||||
return nil, false
|
||||
return nil, fmt.Errorf("session ID mismatch")
|
||||
}
|
||||
cr := &contextReader{r: s}
|
||||
action, err := c.resolveTerminalActionLocked(s, cr)
|
||||
var cr *contextReader
|
||||
if reader != nil {
|
||||
cr = &contextReader{r: reader}
|
||||
}
|
||||
action, err := c.resolveTerminalActionLocked(sctx, cr, writer)
|
||||
if err != nil {
|
||||
c.logf("resolveTerminalAction: %v", err)
|
||||
io.WriteString(s.Stderr(), "Access Denied: failed during authorization check.\r\n")
|
||||
s.Exit(1)
|
||||
return nil, false
|
||||
if writer != nil {
|
||||
io.WriteString(writer, "Access Denied: failed during authorization check.\r\n")
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
if action.Reject || !action.Accept {
|
||||
c.logf("access denied for %v", c.info.uprof.LoginName)
|
||||
s.Exit(1)
|
||||
return nil, false
|
||||
return nil, err
|
||||
}
|
||||
return cr, true
|
||||
return cr, nil
|
||||
}
|
||||
|
||||
// handleSessionPostSSHAuth runs an SSH session after the SSH-level authentication,
|
||||
|
@ -588,8 +611,9 @@ func (c *conn) handleSessionPostSSHAuth(s ssh.Session) {
|
|||
// Now that we have passed the SSH-level authentication, we can start the
|
||||
// Tailscale-level extra verification. This means that we are going to
|
||||
// evaluate the policy provided by control against the incoming SSH session.
|
||||
cr, ok := c.authorizeSession(s)
|
||||
if !ok {
|
||||
cr, err := c.authorizeSession(s.Context().(ssh.Context), s, s.Stderr())
|
||||
if err != nil {
|
||||
s.Exit(1)
|
||||
return
|
||||
}
|
||||
if cr.HasOutstandingRead() {
|
||||
|
@ -624,18 +648,18 @@ func (c *conn) handleSessionPostSSHAuth(s ssh.Session) {
|
|||
// The returned SSHAction will be either Reject or Accept.
|
||||
//
|
||||
// c.mu must be held.
|
||||
func (c *conn) resolveTerminalActionLocked(s ssh.Session, cr *contextReader) (action *tailcfg.SSHAction, err error) {
|
||||
func (c *conn) resolveTerminalActionLocked(sctx ssh.Context, cr *contextReader, stderr io.Writer) (action *tailcfg.SSHAction, err error) {
|
||||
if c.finalAction != nil || c.finalActionErr != nil {
|
||||
return c.finalAction, c.finalActionErr
|
||||
}
|
||||
|
||||
if s.PublicKey() != nil {
|
||||
if sctx.PublicKey() != nil {
|
||||
metricPublicKeyConnections.Add(1)
|
||||
}
|
||||
defer func() {
|
||||
c.finalAction = action
|
||||
c.finalActionErr = err
|
||||
c.pubKey = s.PublicKey()
|
||||
c.pubKey = sctx.PublicKey()
|
||||
if c.pubKey != nil && action.Accept {
|
||||
metricPublicKeyAccepts.Add(1)
|
||||
}
|
||||
|
@ -644,10 +668,18 @@ func (c *conn) resolveTerminalActionLocked(s ssh.Session, cr *contextReader) (ac
|
|||
|
||||
var awaitReadOnce sync.Once // to start Reads on cr
|
||||
var sawInterrupt atomic.Bool
|
||||
var wg sync.WaitGroup
|
||||
defer wg.Wait() // wait for awaitIntrOnce's goroutine to exit
|
||||
var readError chan error
|
||||
if cr != nil {
|
||||
readError = make(chan error)
|
||||
defer func() {
|
||||
rerr := <-readError
|
||||
if rerr != nil && err == nil {
|
||||
err = rerr
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(s.Context())
|
||||
ctx, cancel := context.WithCancel(sctx)
|
||||
defer cancel()
|
||||
|
||||
// Loop processing/fetching Actions until one reaches a
|
||||
|
@ -657,8 +689,8 @@ func (c *conn) resolveTerminalActionLocked(s ssh.Session, cr *contextReader) (ac
|
|||
// (Which is a long time for somebody to see login
|
||||
// instructions and go to a URL to do something.)
|
||||
for {
|
||||
if action.Message != "" {
|
||||
io.WriteString(s.Stderr(), strings.Replace(action.Message, "\n", "\r\n", -1))
|
||||
if action.Message != "" && stderr != nil {
|
||||
io.WriteString(stderr, strings.Replace(action.Message, "\n", "\r\n", -1))
|
||||
}
|
||||
if action.Accept || action.Reject {
|
||||
if action.Reject {
|
||||
|
@ -675,19 +707,23 @@ func (c *conn) resolveTerminalActionLocked(s ssh.Session, cr *contextReader) (ac
|
|||
}
|
||||
metricHolds.Add(1)
|
||||
awaitReadOnce.Do(func() {
|
||||
wg.Add(1)
|
||||
if cr == nil {
|
||||
return
|
||||
}
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
buf := make([]byte, 1)
|
||||
for {
|
||||
n, err := cr.ReadContext(ctx, buf)
|
||||
if err != nil {
|
||||
readError <- nil
|
||||
return
|
||||
}
|
||||
if n > 0 && buf[0] == 0x03 { // Ctrl-C
|
||||
sawInterrupt.Store(true)
|
||||
s.Stderr().Write([]byte("Canceled.\r\n"))
|
||||
s.Exit(1)
|
||||
if stderr != nil {
|
||||
stderr.Write([]byte("Canceled.\r\n"))
|
||||
}
|
||||
readError <- context.Canceled
|
||||
return
|
||||
}
|
||||
}
|
||||
|
|
|
@ -89,6 +89,8 @@ type Context interface {
|
|||
|
||||
// SetValue allows you to easily write new values into the underlying context.
|
||||
SetValue(key, value interface{})
|
||||
|
||||
PublicKey() PublicKey
|
||||
}
|
||||
|
||||
type sshContext struct {
|
||||
|
@ -139,6 +141,14 @@ func (ctx *sshContext) ServerVersion() string {
|
|||
return ctx.Value(ContextKeyServerVersion).(string)
|
||||
}
|
||||
|
||||
func (ctx *sshContext) PublicKey() PublicKey {
|
||||
sessionkey := ctx.Value(ContextKeyPublicKey)
|
||||
if sessionkey == nil {
|
||||
return nil
|
||||
}
|
||||
return sessionkey.(PublicKey)
|
||||
}
|
||||
|
||||
func (ctx *sshContext) RemoteAddr() net.Addr {
|
||||
if addr, ok := ctx.Value(ContextKeyRemoteAddr).(net.Addr); ok {
|
||||
return addr
|
||||
|
|
Loading…
Reference in New Issue