diff --git a/control/controlbase/conn_test.go b/control/controlbase/conn_test.go index 827d5b9a1..04f3f69b8 100644 --- a/control/controlbase/conn_test.go +++ b/control/controlbase/conn_test.go @@ -206,7 +206,7 @@ func TestConnStd(t *testing.T) { serverErr := make(chan error, 1) go func() { var err error - c2, err = Server(context.Background(), s2, controlKey, testProtocolVersion, nil) + c2, err = Server(context.Background(), s2, controlKey, nil) serverErr <- err }() c1, err = Client(context.Background(), s1, machineKey, controlKey.Public(), testProtocolVersion) @@ -398,7 +398,7 @@ func pairWithConns(t *testing.T, clientConn, serverConn net.Conn) (*Conn, *Conn) ) go func() { var err error - server, err = Server(context.Background(), serverConn, controlKey, testProtocolVersion, nil) + server, err = Server(context.Background(), serverConn, controlKey, nil) serverErr <- err }() diff --git a/control/controlbase/handshake.go b/control/controlbase/handshake.go index 0fb2859b6..b18e08a37 100644 --- a/control/controlbase/handshake.go +++ b/control/controlbase/handshake.go @@ -193,19 +193,13 @@ func continueClientHandshake(ctx context.Context, conn net.Conn, s *symmetricSta // Server initiates a control server handshake, returning the resulting // control connection. // -// maxSupportedVersion is the highest handshake version the server is -// willing to handshake with. The server will handshake with any -// version from 0 to maxSupportedVersion inclusive, the caller should -// inspect conn.Version() to determine what version of the handshake -// was executed. -// // optionalInit can be the client's initial handshake message as // returned by ClientDeferred, or nil in which case the initial // message is read from conn. // // The context deadline, if any, covers the entire handshaking // process. -func Server(ctx context.Context, conn net.Conn, controlKey key.MachinePrivate, maxSupportedVersion uint16, optionalInit []byte) (*Conn, error) { +func Server(ctx context.Context, conn net.Conn, controlKey key.MachinePrivate, optionalInit []byte) (*Conn, error) { if deadline, ok := ctx.Deadline(); ok { if err := conn.SetDeadline(deadline); err != nil { return nil, fmt.Errorf("setting conn deadline: %w", err) @@ -245,15 +239,11 @@ func Server(ctx context.Context, conn net.Conn, controlKey key.MachinePrivate, m } else if _, err := io.ReadFull(conn, init.Header()); err != nil { return nil, err } - // Currently, these versions exclusively indicate what the upper - // RPC protocol understands, the Noise handshake is exactly the - // same in all versions. If that ever changes, this check will - // need to become more complex to handle different kinds of - // handshake. - if init.Version() > maxSupportedVersion { - return nil, sendErr("unsupported handshake version") - } - // Just a rename to make it more obvious what the value is + // Just a rename to make it more obvious what the value is. In the + // current implementation we don't need to block any protocol + // versions at this layer, it's safe to let the handshake proceed + // and then let the caller make decisions based on the agreed-upon + // protocol version. clientVersion := init.Version() if init.Type() != msgTypeInitiation { return nil, sendErr("unexpected handshake message type") diff --git a/control/controlbase/handshake_test.go b/control/controlbase/handshake_test.go index ce28f12e8..755454c1c 100644 --- a/control/controlbase/handshake_test.go +++ b/control/controlbase/handshake_test.go @@ -26,7 +26,7 @@ func TestHandshake(t *testing.T) { ) go func() { var err error - server, err = Server(context.Background(), serverConn, serverKey, testProtocolVersion, nil) + server, err = Server(context.Background(), serverConn, serverKey, nil) serverErr <- err }() @@ -78,7 +78,7 @@ func TestNoReuse(t *testing.T) { ) go func() { var err error - server, err = Server(context.Background(), serverConn, serverKey, testProtocolVersion, nil) + server, err = Server(context.Background(), serverConn, serverKey, nil) serverErr <- err }() @@ -172,7 +172,7 @@ func TestTampering(t *testing.T) { serverErr = make(chan error, 1) ) go func() { - _, err := Server(context.Background(), serverConn, serverKey, testProtocolVersion, nil) + _, err := Server(context.Background(), serverConn, serverKey, nil) // If the server failed, we have to close the Conn to // unblock the client. if err != nil { @@ -200,7 +200,7 @@ func TestTampering(t *testing.T) { serverErr = make(chan error, 1) ) go func() { - _, err := Server(context.Background(), serverConn, serverKey, testProtocolVersion, nil) + _, err := Server(context.Background(), serverConn, serverKey, nil) serverErr <- err }() @@ -225,7 +225,7 @@ func TestTampering(t *testing.T) { serverErr = make(chan error, 1) ) go func() { - server, err := Server(context.Background(), serverConn, serverKey, testProtocolVersion, nil) + server, err := Server(context.Background(), serverConn, serverKey, nil) serverErr <- err _, err = io.WriteString(server, strings.Repeat("a", 14)) serverErr <- err @@ -266,7 +266,7 @@ func TestTampering(t *testing.T) { serverErr = make(chan error, 1) ) go func() { - server, err := Server(context.Background(), serverConn, serverKey, testProtocolVersion, nil) + server, err := Server(context.Background(), serverConn, serverKey, nil) serverErr <- err var bs [100]byte // The server needs a timeout if the tampering is hitting the length header. diff --git a/control/controlbase/interop_test.go b/control/controlbase/interop_test.go index b7e7d15e8..133db8bc5 100644 --- a/control/controlbase/interop_test.go +++ b/control/controlbase/interop_test.go @@ -29,7 +29,7 @@ func TestInteropClient(t *testing.T) { ) go func() { - server, err := Server(context.Background(), s2, controlKey, testProtocolVersion, nil) + server, err := Server(context.Background(), s2, controlKey, nil) serverErr <- err if err != nil { return diff --git a/control/controlhttp/http_test.go b/control/controlhttp/http_test.go index c4b8ddc36..1d2adf124 100644 --- a/control/controlhttp/http_test.go +++ b/control/controlhttp/http_test.go @@ -107,7 +107,7 @@ func testControlHTTP(t *testing.T, proxy proxy) { const testProtocolVersion = 1 sch := make(chan serverResult, 1) handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := AcceptHTTP(context.Background(), w, r, server, testProtocolVersion) + conn, err := AcceptHTTP(context.Background(), w, r, server) if err != nil { log.Print(err) } diff --git a/control/controlhttp/server.go b/control/controlhttp/server.go index 8d7073ffe..0e38da860 100644 --- a/control/controlhttp/server.go +++ b/control/controlhttp/server.go @@ -21,7 +21,7 @@ import ( // // AcceptHTTP always writes an HTTP response to w. The caller must not // attempt their own response after calling AcceptHTTP. -func AcceptHTTP(ctx context.Context, w http.ResponseWriter, r *http.Request, private key.MachinePrivate, maxSupportedVersion uint16) (*controlbase.Conn, error) { +func AcceptHTTP(ctx context.Context, w http.ResponseWriter, r *http.Request, private key.MachinePrivate) (*controlbase.Conn, error) { next := r.Header.Get("Upgrade") if next == "" { http.Error(w, "missing next protocol", http.StatusBadRequest) @@ -63,7 +63,7 @@ func AcceptHTTP(ctx context.Context, w http.ResponseWriter, r *http.Request, pri } conn = netutil.NewDrainBufConn(conn, brw.Reader) - nc, err := controlbase.Server(ctx, conn, private, maxSupportedVersion, init) + nc, err := controlbase.Server(ctx, conn, private, init) if err != nil { conn.Close() return nil, fmt.Errorf("noise handshake failed: %w", err)