ssh/tailssh: use context.WithCancelCause
It was using a custom implmentation of the context.WithCancelCause, replace usage with stdlib. Signed-off-by: Maisem Ali <maisem@tailscale.com>pull/7518/head
parent
a2be1aabfa
commit
e69682678f
|
@ -1,63 +0,0 @@
|
||||||
// Copyright (c) Tailscale Inc & AUTHORS
|
|
||||||
// SPDX-License-Identifier: BSD-3-Clause
|
|
||||||
|
|
||||||
package tailssh
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
// sshContext is the context.Context implementation we use for SSH
|
|
||||||
// that adds a CloseWithError method. Otherwise it's just a normalish
|
|
||||||
// Context.
|
|
||||||
type sshContext struct {
|
|
||||||
underlying context.Context
|
|
||||||
cancel context.CancelFunc // cancels underlying
|
|
||||||
mu sync.Mutex
|
|
||||||
closed bool
|
|
||||||
err error
|
|
||||||
}
|
|
||||||
|
|
||||||
func newSSHContext(ctx context.Context) *sshContext {
|
|
||||||
ctx, cancel := context.WithCancel(ctx)
|
|
||||||
return &sshContext{underlying: ctx, cancel: cancel}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ctx *sshContext) CloseWithError(err error) {
|
|
||||||
ctx.mu.Lock()
|
|
||||||
defer ctx.mu.Unlock()
|
|
||||||
if ctx.closed {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
ctx.closed = true
|
|
||||||
ctx.err = err
|
|
||||||
ctx.cancel()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ctx *sshContext) Err() error {
|
|
||||||
ctx.mu.Lock()
|
|
||||||
defer ctx.mu.Unlock()
|
|
||||||
return ctx.err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ctx *sshContext) Done() <-chan struct{} { return ctx.underlying.Done() }
|
|
||||||
func (ctx *sshContext) Deadline() (deadline time.Time, ok bool) { return }
|
|
||||||
func (ctx *sshContext) Value(k any) any { return ctx.underlying.Value(k) }
|
|
||||||
|
|
||||||
// userVisibleError is a wrapper around an error that implements
|
|
||||||
// SSHTerminationError, so msg is written to their session.
|
|
||||||
type userVisibleError struct {
|
|
||||||
msg string
|
|
||||||
error
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ue userVisibleError) SSHTerminationMessage() string { return ue.msg }
|
|
||||||
|
|
||||||
// SSHTerminationError is implemented by errors that terminate an SSH
|
|
||||||
// session and should be written to user's sessions.
|
|
||||||
type SSHTerminationError interface {
|
|
||||||
error
|
|
||||||
SSHTerminationMessage() string
|
|
||||||
}
|
|
|
@ -787,7 +787,8 @@ type sshSession struct {
|
||||||
sharedID string // ID that's shared with control
|
sharedID string // ID that's shared with control
|
||||||
logf logger.Logf
|
logf logger.Logf
|
||||||
|
|
||||||
ctx *sshContext // implements context.Context
|
ctx context.Context
|
||||||
|
cancelCtx context.CancelCauseFunc
|
||||||
conn *conn
|
conn *conn
|
||||||
agentListener net.Listener // non-nil if agent-forwarding requested+allowed
|
agentListener net.Listener // non-nil if agent-forwarding requested+allowed
|
||||||
|
|
||||||
|
@ -812,10 +813,12 @@ func (ss *sshSession) vlogf(format string, args ...interface{}) {
|
||||||
func (c *conn) newSSHSession(s ssh.Session) *sshSession {
|
func (c *conn) newSSHSession(s ssh.Session) *sshSession {
|
||||||
sharedID := fmt.Sprintf("sess-%s-%02x", c.srv.now().UTC().Format("20060102T150405"), randBytes(5))
|
sharedID := fmt.Sprintf("sess-%s-%02x", c.srv.now().UTC().Format("20060102T150405"), randBytes(5))
|
||||||
c.logf("starting session: %v", sharedID)
|
c.logf("starting session: %v", sharedID)
|
||||||
|
ctx, cancel := context.WithCancelCause(s.Context())
|
||||||
return &sshSession{
|
return &sshSession{
|
||||||
Session: s,
|
Session: s,
|
||||||
sharedID: sharedID,
|
sharedID: sharedID,
|
||||||
ctx: newSSHContext(s.Context()),
|
ctx: ctx,
|
||||||
|
cancelCtx: cancel,
|
||||||
conn: c,
|
conn: c,
|
||||||
logf: logger.WithPrefix(c.srv.logf, "ssh-session("+sharedID+"): "),
|
logf: logger.WithPrefix(c.srv.logf, "ssh-session("+sharedID+"): "),
|
||||||
}
|
}
|
||||||
|
@ -844,7 +847,7 @@ func (c *conn) checkStillValid() {
|
||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
defer c.mu.Unlock()
|
defer c.mu.Unlock()
|
||||||
for _, s := range c.sessions {
|
for _, s := range c.sessions {
|
||||||
s.ctx.CloseWithError(userVisibleError{
|
s.cancelCtx(userVisibleError{
|
||||||
fmt.Sprintf("Access revoked.\r\n"),
|
fmt.Sprintf("Access revoked.\r\n"),
|
||||||
context.Canceled,
|
context.Canceled,
|
||||||
})
|
})
|
||||||
|
@ -897,7 +900,7 @@ func (ss *sshSession) killProcessOnContextDone() {
|
||||||
// Either the process has already exited, in which case this does nothing.
|
// Either the process has already exited, in which case this does nothing.
|
||||||
// Or, the process is still running in which case this will kill it.
|
// Or, the process is still running in which case this will kill it.
|
||||||
ss.exitOnce.Do(func() {
|
ss.exitOnce.Do(func() {
|
||||||
err := ss.ctx.Err()
|
err := context.Cause(ss.ctx)
|
||||||
if serr, ok := err.(SSHTerminationError); ok {
|
if serr, ok := err.(SSHTerminationError); ok {
|
||||||
msg := serr.SSHTerminationMessage()
|
msg := serr.SSHTerminationMessage()
|
||||||
if msg != "" {
|
if msg != "" {
|
||||||
|
@ -997,7 +1000,7 @@ var recordSSH = envknob.RegisterBool("TS_DEBUG_LOG_SSH")
|
||||||
func (ss *sshSession) run() {
|
func (ss *sshSession) run() {
|
||||||
metricActiveSessions.Add(1)
|
metricActiveSessions.Add(1)
|
||||||
defer metricActiveSessions.Add(-1)
|
defer metricActiveSessions.Add(-1)
|
||||||
defer ss.ctx.CloseWithError(errSessionDone)
|
defer ss.cancelCtx(errSessionDone)
|
||||||
|
|
||||||
if attached := ss.conn.srv.attachSessionToConnIfNotShutdown(ss); !attached {
|
if attached := ss.conn.srv.attachSessionToConnIfNotShutdown(ss); !attached {
|
||||||
fmt.Fprintf(ss, "Tailscale SSH is shutting down\r\n")
|
fmt.Fprintf(ss, "Tailscale SSH is shutting down\r\n")
|
||||||
|
@ -1011,7 +1014,7 @@ func (ss *sshSession) run() {
|
||||||
|
|
||||||
if ss.conn.finalAction.SessionDuration != 0 {
|
if ss.conn.finalAction.SessionDuration != 0 {
|
||||||
t := time.AfterFunc(ss.conn.finalAction.SessionDuration, func() {
|
t := time.AfterFunc(ss.conn.finalAction.SessionDuration, func() {
|
||||||
ss.ctx.CloseWithError(userVisibleError{
|
ss.cancelCtx(userVisibleError{
|
||||||
fmt.Sprintf("Session timeout of %v elapsed.", ss.conn.finalAction.SessionDuration),
|
fmt.Sprintf("Session timeout of %v elapsed.", ss.conn.finalAction.SessionDuration),
|
||||||
context.DeadlineExceeded,
|
context.DeadlineExceeded,
|
||||||
})
|
})
|
||||||
|
@ -1066,7 +1069,7 @@ func (ss *sshSession) run() {
|
||||||
defer ss.stdin.Close()
|
defer ss.stdin.Close()
|
||||||
if _, err := io.Copy(rec.writer("i", ss.stdin), ss); err != nil {
|
if _, err := io.Copy(rec.writer("i", ss.stdin), ss); err != nil {
|
||||||
logf("stdin copy: %v", err)
|
logf("stdin copy: %v", err)
|
||||||
ss.ctx.CloseWithError(err)
|
ss.cancelCtx(err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
var openOutputStreams atomic.Int32
|
var openOutputStreams atomic.Int32
|
||||||
|
@ -1080,7 +1083,7 @@ func (ss *sshSession) run() {
|
||||||
_, err := io.Copy(rec.writer("o", ss), ss.stdout)
|
_, err := io.Copy(rec.writer("o", ss), ss.stdout)
|
||||||
if err != nil && !errors.Is(err, io.EOF) {
|
if err != nil && !errors.Is(err, io.EOF) {
|
||||||
logf("stdout copy: %v", err)
|
logf("stdout copy: %v", err)
|
||||||
ss.ctx.CloseWithError(err)
|
ss.cancelCtx(err)
|
||||||
}
|
}
|
||||||
if openOutputStreams.Add(-1) == 0 {
|
if openOutputStreams.Add(-1) == 0 {
|
||||||
ss.CloseWrite()
|
ss.CloseWrite()
|
||||||
|
@ -1489,3 +1492,19 @@ var (
|
||||||
metricSFTP = clientmetric.NewCounter("ssh_sftp_requests")
|
metricSFTP = clientmetric.NewCounter("ssh_sftp_requests")
|
||||||
metricLocalPortForward = clientmetric.NewCounter("ssh_local_port_forward_requests")
|
metricLocalPortForward = clientmetric.NewCounter("ssh_local_port_forward_requests")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// userVisibleError is a wrapper around an error that implements
|
||||||
|
// SSHTerminationError, so msg is written to their session.
|
||||||
|
type userVisibleError struct {
|
||||||
|
msg string
|
||||||
|
error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ue userVisibleError) SSHTerminationMessage() string { return ue.msg }
|
||||||
|
|
||||||
|
// SSHTerminationError is implemented by errors that terminate an SSH
|
||||||
|
// session and should be written to user's sessions.
|
||||||
|
type SSHTerminationError interface {
|
||||||
|
error
|
||||||
|
SSHTerminationMessage() string
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue