ssh/tailssh: handle output matching better in tests (#7799)

release-branch/1.38
Maisem Ali 2023-04-05 08:35:02 -07:00 committed by Shayne Sweeney
parent 61f36aa1cd
commit 00205f0ab6
No known key found for this signature in database
GPG Key ID: 69DA13E86BF403B0
1 changed files with 14 additions and 2 deletions

View File

@ -364,6 +364,8 @@ func TestSSHRecordingCancelsSessionsOnUploadFailure(t *testing.T) {
handler func(w http.ResponseWriter, r *http.Request) handler func(w http.ResponseWriter, r *http.Request)
sshCommand string sshCommand string
wantClientOutput string wantClientOutput string
clientOutputMustNotContain []string
}{ }{
{ {
name: "upload-denied", name: "upload-denied",
@ -372,6 +374,8 @@ func TestSSHRecordingCancelsSessionsOnUploadFailure(t *testing.T) {
}, },
sshCommand: "echo hello", sshCommand: "echo hello",
wantClientOutput: "recording: server responded with 403 Forbidden\r\n", wantClientOutput: "recording: server responded with 403 Forbidden\r\n",
clientOutputMustNotContain: []string{"hello"},
}, },
{ {
name: "upload-fails-after-starting", name: "upload-fails-after-starting",
@ -381,7 +385,9 @@ func TestSSHRecordingCancelsSessionsOnUploadFailure(t *testing.T) {
w.WriteHeader(http.StatusInternalServerError) w.WriteHeader(http.StatusInternalServerError)
}, },
sshCommand: "echo hello && sleep 1 && echo world", sshCommand: "echo hello && sleep 1 && echo world",
wantClientOutput: "hello\n\r\n\r\nrecording server responded with: 500 Internal Server Error\r\n\r\n", wantClientOutput: "\r\n\r\nrecording server responded with: 500 Internal Server Error\r\n\r\n",
clientOutputMustNotContain: []string{"world"},
}, },
} }
@ -415,9 +421,15 @@ func TestSSHRecordingCancelsSessionsOnUploadFailure(t *testing.T) {
} else { } else {
t.Errorf("client did not get kicked out: %q", got) t.Errorf("client did not get kicked out: %q", got)
} }
if string(got) != tt.wantClientOutput { gotStr := string(got)
if !strings.HasSuffix(gotStr, tt.wantClientOutput) {
t.Errorf("client got %q, want %q", got, tt.wantClientOutput) t.Errorf("client got %q, want %q", got, tt.wantClientOutput)
} }
for _, x := range tt.clientOutputMustNotContain {
if strings.Contains(gotStr, x) {
t.Errorf("client output must not contain %q", x)
}
}
}() }()
if err := s.HandleSSHConn(dc); err != nil { if err := s.HandleSSHConn(dc); err != nil {
t.Errorf("unexpected error: %v", err) t.Errorf("unexpected error: %v", err)