diff --git a/ssh/tailssh/incubator.go b/ssh/tailssh/incubator.go index 4de3e2b88..694f4e0dd 100644 --- a/ssh/tailssh/incubator.go +++ b/ssh/tailssh/incubator.go @@ -476,7 +476,7 @@ func (ss *sshSession) launchProcess() error { } go resizeWindow(ptyDup /* arbitrary fd */, winCh) - ss.tty = tty + ss.childPipes = []io.Closer{tty} ss.stdin = pty ss.stdout = os.NewFile(uintptr(ptyDup), pty.Name()) ss.stderr = nil // not available for pty @@ -658,11 +658,16 @@ func (ss *sshSession) startWithPTY() (ptyFile, tty *os.File, err error) { // startWithStdPipes starts cmd with os.Pipe for Stdin, Stdout and Stderr. func (ss *sshSession) startWithStdPipes() (err error) { - var stdin io.WriteCloser - var stdout, stderr io.ReadCloser + var ( + stdin io.WriteCloser + stdout, stderr io.ReadCloser + + cStdin io.ReadCloser + cStdout, cStderr io.WriteCloser + ) defer func() { if err != nil { - for _, c := range []io.Closer{stdin, stdout, stderr} { + for _, c := range []io.Closer{stdin, stdout, stderr, cStdin, cStdout, cStderr} { if c != nil { c.Close() } @@ -673,24 +678,28 @@ func (ss *sshSession) startWithStdPipes() (err error) { if cmd == nil { return errors.New("nil cmd") } - stdin, err = cmd.StdinPipe() + ss.stdin, cStdin, err = os.Pipe() if err != nil { return err } - stdout, err = cmd.StdoutPipe() + ss.stdout, cStdout, err = os.Pipe() if err != nil { + cStdin.Close() return err } - stderr, err = cmd.StderrPipe() + ss.stderr, cStderr, err = os.Pipe() if err != nil { + cStdin.Close() + cStdout.Close() return err } + cmd.Stdin = cStdin + cmd.Stdout = cStdout + cmd.Stderr = cStderr + ss.childPipes = []io.Closer{cStdin, cStdout, cStderr} if err := cmd.Start(); err != nil { return err } - ss.stdin = stdin - ss.stdout = stdout - ss.stderr = stderr return nil } diff --git a/ssh/tailssh/tailssh.go b/ssh/tailssh/tailssh.go index 4253b2471..881cb4de0 100644 --- a/ssh/tailssh/tailssh.go +++ b/ssh/tailssh/tailssh.go @@ -826,9 +826,14 @@ type sshSession struct { cmd *exec.Cmd stdin io.WriteCloser stdout io.ReadCloser - stderr io.Reader // nil for pty sessions - ptyReq *ssh.Pty // non-nil for pty sessions - tty *os.File // non-nil for pty sessions, must be closed after process exits + stderr io.ReadCloser // nil for pty sessions + ptyReq *ssh.Pty // non-nil for pty sessions + + // childPipes is a list of pipes that need to be closed when the + // process exits. + // For pty sessions, this is the tty fd. + // For non-pty sessions, this is the stdin,stdout,stderr fds. + childPipes []io.Closer // We use this sync.Once to ensure that we only terminate the process once, // either it exits itself or is terminated @@ -1146,10 +1151,13 @@ func (ss *sshSession) run() { }() } - if ss.tty != nil { - // If running a tty session, close the tty when the session is done. - defer ss.tty.Close() - } + defer func() { + // It is our responsibility to close the FDs that were passed down to + // the child process. + for _, c := range ss.childPipes { + c.Close() + } + }() err = ss.cmd.Wait() processDone.Store(true) // This will either make the SSH Termination goroutine be a no-op,