From 0b95ace349a361ee38f2448e1cab2aac052e52d0 Mon Sep 17 00:00:00 2001 From: Maisem Ali Date: Tue, 20 Jun 2023 18:06:57 -0700 Subject: [PATCH] ssh/tailssh: take explicit ownership of stdin,stderr,stdout We were using `cmd.Std{in,err,out}Pipe` and would then spin off goroutines to copy that data back and forth between the child process and the session. We would then concurrently call `cmd.Wait()` to wait for the process to exit. However, this interaction is documented as unsafe. ``` StdoutPipe returns a pipe that will be connected to the command's standard output when the command starts. Wait will close the pipe after seeing the command exit, so most callers need not close the pipe themselves. It is thus incorrect to call Wait before all reads from the pipe have completed. ``` Instead of doing that, take ownership of the pipe and explicitly close them when the session completes. Fixes #7601 Co-authored-by: Joe Tsai Signed-off-by: Maisem Ali --- ssh/tailssh/incubator.go | 29 +++++++++++++++++++---------- ssh/tailssh/tailssh.go | 22 +++++++++++++++------- 2 files changed, 34 insertions(+), 17 deletions(-) diff --git a/ssh/tailssh/incubator.go b/ssh/tailssh/incubator.go index 4de3e2b88..694f4e0dd 100644 --- a/ssh/tailssh/incubator.go +++ b/ssh/tailssh/incubator.go @@ -476,7 +476,7 @@ func (ss *sshSession) launchProcess() error { } go resizeWindow(ptyDup /* arbitrary fd */, winCh) - ss.tty = tty + ss.childPipes = []io.Closer{tty} ss.stdin = pty ss.stdout = os.NewFile(uintptr(ptyDup), pty.Name()) ss.stderr = nil // not available for pty @@ -658,11 +658,16 @@ func (ss *sshSession) startWithPTY() (ptyFile, tty *os.File, err error) { // startWithStdPipes starts cmd with os.Pipe for Stdin, Stdout and Stderr. func (ss *sshSession) startWithStdPipes() (err error) { - var stdin io.WriteCloser - var stdout, stderr io.ReadCloser + var ( + stdin io.WriteCloser + stdout, stderr io.ReadCloser + + cStdin io.ReadCloser + cStdout, cStderr io.WriteCloser + ) defer func() { if err != nil { - for _, c := range []io.Closer{stdin, stdout, stderr} { + for _, c := range []io.Closer{stdin, stdout, stderr, cStdin, cStdout, cStderr} { if c != nil { c.Close() } @@ -673,24 +678,28 @@ func (ss *sshSession) startWithStdPipes() (err error) { if cmd == nil { return errors.New("nil cmd") } - stdin, err = cmd.StdinPipe() + ss.stdin, cStdin, err = os.Pipe() if err != nil { return err } - stdout, err = cmd.StdoutPipe() + ss.stdout, cStdout, err = os.Pipe() if err != nil { + cStdin.Close() return err } - stderr, err = cmd.StderrPipe() + ss.stderr, cStderr, err = os.Pipe() if err != nil { + cStdin.Close() + cStdout.Close() return err } + cmd.Stdin = cStdin + cmd.Stdout = cStdout + cmd.Stderr = cStderr + ss.childPipes = []io.Closer{cStdin, cStdout, cStderr} if err := cmd.Start(); err != nil { return err } - ss.stdin = stdin - ss.stdout = stdout - ss.stderr = stderr return nil } diff --git a/ssh/tailssh/tailssh.go b/ssh/tailssh/tailssh.go index 4253b2471..881cb4de0 100644 --- a/ssh/tailssh/tailssh.go +++ b/ssh/tailssh/tailssh.go @@ -826,9 +826,14 @@ type sshSession struct { cmd *exec.Cmd stdin io.WriteCloser stdout io.ReadCloser - stderr io.Reader // nil for pty sessions - ptyReq *ssh.Pty // non-nil for pty sessions - tty *os.File // non-nil for pty sessions, must be closed after process exits + stderr io.ReadCloser // nil for pty sessions + ptyReq *ssh.Pty // non-nil for pty sessions + + // childPipes is a list of pipes that need to be closed when the + // process exits. + // For pty sessions, this is the tty fd. + // For non-pty sessions, this is the stdin,stdout,stderr fds. + childPipes []io.Closer // We use this sync.Once to ensure that we only terminate the process once, // either it exits itself or is terminated @@ -1146,10 +1151,13 @@ func (ss *sshSession) run() { }() } - if ss.tty != nil { - // If running a tty session, close the tty when the session is done. - defer ss.tty.Close() - } + defer func() { + // It is our responsibility to close the FDs that were passed down to + // the child process. + for _, c := range ss.childPipes { + c.Close() + } + }() err = ss.cmd.Wait() processDone.Store(true) // This will either make the SSH Termination goroutine be a no-op,