diff --git a/ssh/tailssh/incubator.go b/ssh/tailssh/incubator.go index 5f5423fe9..8ecd279a8 100644 --- a/ssh/tailssh/incubator.go +++ b/ssh/tailssh/incubator.go @@ -310,15 +310,25 @@ func (ss *sshSession) launchProcess() error { if err != nil { return err } - go resizeWindow(pty, winCh) - ss.stdout = pty // no stderr for a pty + + // We need to be able to close stdin and stdout separately later so make a + // dup. + ptyDup, err := syscall.Dup(int(pty.Fd())) + if err != nil { + return err + } + go resizeWindow(ptyDup /* arbitrary fd */, winCh) + ss.stdin = pty + ss.stdout = os.NewFile(uintptr(ptyDup), pty.Name()) + ss.stderr = nil // not available for pty + return nil } -func resizeWindow(f *os.File, winCh <-chan ssh.Window) { +func resizeWindow(fd int, winCh <-chan ssh.Window) { for win := range winCh { - unix.IoctlSetWinsize(int(f.Fd()), syscall.TIOCSWINSZ, &unix.Winsize{ + unix.IoctlSetWinsize(fd, syscall.TIOCSWINSZ, &unix.Winsize{ Row: uint16(win.Height), Col: uint16(win.Width), }) diff --git a/ssh/tailssh/tailssh.go b/ssh/tailssh/tailssh.go index c68fc134b..67ad8acb9 100644 --- a/ssh/tailssh/tailssh.go +++ b/ssh/tailssh/tailssh.go @@ -732,7 +732,7 @@ type sshSession struct { // initialized by launchProcess: cmd *exec.Cmd stdin io.WriteCloser - stdout io.Reader + stdout io.ReadCloser stderr io.Reader // nil for pty sessions ptyReq *ssh.Pty // non-nil for pty sessions @@ -843,6 +843,8 @@ func (ss *sshSession) killProcessOnContextDone() { ss.logf("terminating SSH session from %v: %v", ss.conn.info.src.IP(), err) // We don't need to Process.Wait here, sshSession.run() does // the waiting regardless of termination reason. + + // TODO(maisem): should this be a SIGTERM followed by a SIGKILL? ss.cmd.Process.Kill() }) } @@ -1004,20 +1006,23 @@ func (ss *sshSession) run() { go ss.killProcessOnContextDone() go func() { - _, err := io.Copy(rec.writer("i", ss.stdin), ss) - if err != nil { - // TODO: don't log in the success case. + defer ss.stdin.Close() + if _, err := io.Copy(rec.writer("i", ss.stdin), ss); err != nil { logf("stdin copy: %v", err) + ss.ctx.CloseWithError(err) + } else if ss.ptyReq != nil { + const EOT = 4 // https://en.wikipedia.org/wiki/End-of-Transmission_character + ss.stdin.Write([]byte{EOT}) } - ss.stdin.Close() }() go func() { + defer ss.stdout.Close() _, err := io.Copy(rec.writer("o", ss), ss.stdout) - if err != nil { + if err != nil && !errors.Is(err, io.EOF) { logf("stdout copy: %v", err) - // If we got an error here, it's probably because the client has - // disconnected. ss.ctx.CloseWithError(err) + } else { + ss.CloseWrite() } }() // stderr is nil for ptys. @@ -1029,6 +1034,7 @@ func (ss *sshSession) run() { } }() } + err = ss.cmd.Wait() // This will either make the SSH Termination goroutine be a no-op, // or itself will be a no-op because the process was killed by the @@ -1036,6 +1042,7 @@ func (ss *sshSession) run() { ss.exitOnce.Do(func() {}) if err == nil { + ss.logf("Session complete") ss.Exit(0) return }