From 949f968667c8c9b44b6dff6bda17b4a40233e99b Mon Sep 17 00:00:00 2001 From: 4yn Date: Wed, 5 May 2021 11:58:59 +0800 Subject: [PATCH] update uwebsockets version --- src/Vendor/uwebsockets/README.md | 2 +- .../include/{uws => }/libusockets.h | 53 +- src/Vendor/uwebsockets/include/uws/App.h | 341 ++-- .../uwebsockets/include/uws/AsyncSocket.h | 76 +- .../uwebsockets/include/uws/BloomFilter.h | 65 + .../uwebsockets/include/uws/HttpContext.h | 89 +- .../uwebsockets/include/uws/HttpContextData.h | 10 +- .../uwebsockets/include/uws/HttpParser.h | 127 +- .../uwebsockets/include/uws/HttpResponse.h | 208 +- .../include/uws/HttpResponseData.h | 19 +- .../uwebsockets/include/uws/HttpRouter.h | 56 +- src/Vendor/uwebsockets/include/uws/Loop.h | 38 +- src/Vendor/uwebsockets/include/uws/LoopData.h | 15 +- .../uwebsockets/include/uws/MessageParser.h | 64 + .../include/uws/MoveOnlyFunction.h | 377 ++++ .../uwebsockets/include/uws/Multipart.h | 231 +++ .../include/uws/PerMessageDeflate.h | 117 +- .../uwebsockets/include/uws/ProxyParser.h | 163 ++ .../uwebsockets/include/uws/QueryParser.h | 120 ++ .../uwebsockets/include/uws/TopicTree.h | 301 ++- .../uwebsockets/include/uws/Utilities.h | 6 +- .../uwebsockets/include/uws/WebSocket.h | 163 +- .../include/uws/WebSocketContext.h | 134 +- .../include/uws/WebSocketContextData.h | 242 ++- .../uwebsockets/include/uws/WebSocketData.h | 19 +- .../include/uws/WebSocketExtensions.h | 189 +- .../include/uws/WebSocketHandshake.h | 97 +- .../include/uws/WebSocketProtocol.h | 29 +- .../uwebsockets/include/uws/f2/LICENSE.txt | 23 - .../uwebsockets/include/uws/f2/function2.hpp | 1764 ----------------- src/Vendor/uwebsockets/src/bsd.c | 308 +++ src/Vendor/uwebsockets/src/context.c | 111 +- src/Vendor/uwebsockets/src/crypto/openssl.c | 328 ++- .../uwebsockets/src/crypto/sni_tree.cpp | 218 ++ src/Vendor/uwebsockets/src/crypto/wolfssl.c | 6 +- .../uwebsockets/src/eventing/epoll_kqueue.c | 7 +- src/Vendor/uwebsockets/src/eventing/gcd.c | 2 +- src/Vendor/uwebsockets/src/eventing/libuv.c | 58 +- .../src/internal/eventing/epoll_kqueue.h | 2 +- .../uwebsockets/src/internal/eventing/libuv.h | 4 +- .../uwebsockets/src/internal/internal.h | 31 +- .../uwebsockets/src/internal/networking/bsd.h | 260 +-- src/Vendor/uwebsockets/src/loop.c | 87 +- src/Vendor/uwebsockets/src/socket.c | 99 +- src/src/BrokenithmServer.cpp | 49 +- 45 files changed, 3870 insertions(+), 2838 deletions(-) rename src/Vendor/uwebsockets/include/{uws => }/libusockets.h (77%) create mode 100644 src/Vendor/uwebsockets/include/uws/BloomFilter.h create mode 100644 src/Vendor/uwebsockets/include/uws/MessageParser.h create mode 100644 src/Vendor/uwebsockets/include/uws/MoveOnlyFunction.h create mode 100644 src/Vendor/uwebsockets/include/uws/Multipart.h create mode 100644 src/Vendor/uwebsockets/include/uws/ProxyParser.h create mode 100644 src/Vendor/uwebsockets/include/uws/QueryParser.h delete mode 100644 src/Vendor/uwebsockets/include/uws/f2/LICENSE.txt delete mode 100644 src/Vendor/uwebsockets/include/uws/f2/function2.hpp create mode 100644 src/Vendor/uwebsockets/src/bsd.c create mode 100644 src/Vendor/uwebsockets/src/crypto/sni_tree.cpp diff --git a/src/Vendor/uwebsockets/README.md b/src/Vendor/uwebsockets/README.md index 9849e7b..850e2a1 100644 --- a/src/Vendor/uwebsockets/README.md +++ b/src/Vendor/uwebsockets/README.md @@ -1,6 +1,6 @@ # Repackaged uWebsockets library -Source was obtained from [uSockets](https://github.com/uNetworking/uSockets) and [uWebSockets](https://github.com/uNetworking/uWebSockets) and repackaged with a build system. +Source was obtained from [uSockets v0.7.1](https://github.com/uNetworking/uSockets/tree/v0.7.1) and [uWebSockets v19.2.0](https://github.com/uNetworking/uWebSockets/tree/v19.2.0) and repackaged with build system configuration. # Original uWebsockets README.md diff --git a/src/Vendor/uwebsockets/include/uws/libusockets.h b/src/Vendor/uwebsockets/include/libusockets.h similarity index 77% rename from src/Vendor/uwebsockets/include/uws/libusockets.h rename to src/Vendor/uwebsockets/include/libusockets.h index 7c0a2ae..f25762e 100644 --- a/src/Vendor/uwebsockets/include/uws/libusockets.h +++ b/src/Vendor/uwebsockets/include/libusockets.h @@ -29,13 +29,13 @@ /* Define what a socket descriptor is based on platform */ #ifdef _WIN32 +#ifndef NOMINMAX #define NOMINMAX -#include +#endif +#include #define LIBUS_SOCKET_DESCRIPTOR SOCKET #define WIN32_EXPORT __declspec(dllexport) -#define alignas(x) __declspec(align(x)) #else -#include #define LIBUS_SOCKET_DESCRIPTOR int #define WIN32_EXPORT #endif @@ -84,9 +84,20 @@ struct us_socket_context_options_t { const char *passphrase; const char *dh_params_file_name; const char *ca_file_name; - int ssl_prefer_low_memory_usage; + int ssl_prefer_low_memory_usage; /* Todo: rename to prefer_low_memory_usage and apply for TCP as well */ }; +/* Return 15-bit timestamp for this context */ +WIN32_EXPORT unsigned short us_socket_context_timestamp(int ssl, struct us_socket_context_t *context); + +/* Adds SNI domain and cert in asn1 format */ +WIN32_EXPORT void us_socket_context_add_server_name(int ssl, struct us_socket_context_t *context, const char *hostname_pattern, struct us_socket_context_options_t options); +WIN32_EXPORT void us_socket_context_remove_server_name(int ssl, struct us_socket_context_t *context, const char *hostname_pattern); +WIN32_EXPORT void us_socket_context_on_server_name(int ssl, struct us_socket_context_t *context, void (*cb)(struct us_socket_context_t *, const char *hostname)); + +/* Returns the underlying SSL native handle, such as SSL_CTX or nullptr */ +WIN32_EXPORT void *us_socket_context_get_native_handle(int ssl, struct us_socket_context_t *context); + /* A socket context holds shared callbacks and user data extension for associated sockets */ WIN32_EXPORT struct us_socket_context_t *us_create_socket_context(int ssl, struct us_loop_t *loop, int ext_size, struct us_socket_context_options_t options); @@ -98,13 +109,16 @@ WIN32_EXPORT void us_socket_context_free(int ssl, struct us_socket_context_t *co WIN32_EXPORT void us_socket_context_on_open(int ssl, struct us_socket_context_t *context, struct us_socket_t *(*on_open)(struct us_socket_t *s, int is_client, char *ip, int ip_length)); WIN32_EXPORT void us_socket_context_on_close(int ssl, struct us_socket_context_t *context, - struct us_socket_t *(*on_close)(struct us_socket_t *s)); + struct us_socket_t *(*on_close)(struct us_socket_t *s, int code, void *reason)); WIN32_EXPORT void us_socket_context_on_data(int ssl, struct us_socket_context_t *context, struct us_socket_t *(*on_data)(struct us_socket_t *s, char *data, int length)); WIN32_EXPORT void us_socket_context_on_writable(int ssl, struct us_socket_context_t *context, struct us_socket_t *(*on_writable)(struct us_socket_t *s)); WIN32_EXPORT void us_socket_context_on_timeout(int ssl, struct us_socket_context_t *context, struct us_socket_t *(*on_timeout)(struct us_socket_t *s)); +/* This one is only used for when a connecting socket fails in a late stage. */ +WIN32_EXPORT void us_socket_context_on_connect_error(int ssl, struct us_socket_context_t *context, + struct us_socket_t *(*on_connect_error)(struct us_socket_t *s, int code)); /* Emitted when a socket has been half-closed */ WIN32_EXPORT void us_socket_context_on_end(int ssl, struct us_socket_context_t *context, struct us_socket_t *(*on_end)(struct us_socket_t *s)); @@ -119,9 +133,18 @@ WIN32_EXPORT struct us_listen_socket_t *us_socket_context_listen(int ssl, struct /* listen_socket.c/.h */ WIN32_EXPORT void us_listen_socket_close(int ssl, struct us_listen_socket_t *ls); -/* Land in on_open or on_close or return null or return socket */ +/* Land in on_open or on_connection_error or return null or return socket */ WIN32_EXPORT struct us_socket_t *us_socket_context_connect(int ssl, struct us_socket_context_t *context, - const char *host, int port, int options, int socket_ext_size); + const char *host, int port, const char *source_host, int options, int socket_ext_size); + +/* Is this socket established? Can be used to check if a connecting socket has fired the on_open event yet. + * Can also be used to determine if a socket is a listen_socket or not, but you probably know that already. */ +WIN32_EXPORT int us_socket_is_established(int ssl, struct us_socket_t *s); + +/* Cancel a connecting socket. Can be used together with us_socket_timeout to limit connection times. + * Entirely destroys the socket - this function works like us_socket_close but does not trigger on_close event since + * you never got the on_open event first. */ +WIN32_EXPORT struct us_socket_t *us_socket_close_connecting(int ssl, struct us_socket_t *s); /* Returns the loop for this socket context. */ WIN32_EXPORT struct us_loop_t *us_socket_context_loop(int ssl, struct us_socket_context_t *context); @@ -189,6 +212,10 @@ WIN32_EXPORT struct us_poll_t *us_poll_resize(struct us_poll_t *p, struct us_loo /* Public interfaces for sockets */ +/* Returns the underlying native handle for a socket, such as SSL or file descriptor. + * In the case of file descriptor, the value of pointer is fd. */ +WIN32_EXPORT void *us_socket_get_native_handle(int ssl, struct us_socket_t *s); + /* Write up to length bytes of data. Returns actual bytes written. * Will call the on_writable callback of active socket context on failure to write everything off in one go. * Set hint msg_more if you have more immediate data to write. */ @@ -210,6 +237,11 @@ WIN32_EXPORT void us_socket_flush(int ssl, struct us_socket_t *s); /* Shuts down the connection by sending FIN and/or close_notify */ WIN32_EXPORT void us_socket_shutdown(int ssl, struct us_socket_t *s); +/* Shuts down the connection in terms of read, meaning next event loop + * iteration will catch the socket being closed. Can be used to defer closing + * to next event loop iteration. */ +WIN32_EXPORT void us_socket_shutdown_read(int ssl, struct us_socket_t *s); + /* Returns whether the socket has been shut down or not */ WIN32_EXPORT int us_socket_is_shut_down(int ssl, struct us_socket_t *s); @@ -217,7 +249,10 @@ WIN32_EXPORT int us_socket_is_shut_down(int ssl, struct us_socket_t *s); WIN32_EXPORT int us_socket_is_closed(int ssl, struct us_socket_t *s); /* Immediately closes the socket */ -WIN32_EXPORT struct us_socket_t *us_socket_close(int ssl, struct us_socket_t *s); +WIN32_EXPORT struct us_socket_t *us_socket_close(int ssl, struct us_socket_t *s, int code, void *reason); + +/* Returns local port or -1 on failure. */ +WIN32_EXPORT int us_socket_local_port(int ssl, struct us_socket_t *s); /* Copy remote (IP) address of socket, or fail with zero length. */ WIN32_EXPORT void us_socket_remote_address(int ssl, struct us_socket_t *s, char *buf, int *length); @@ -230,7 +265,7 @@ WIN32_EXPORT void us_socket_remote_address(int ssl, struct us_socket_t *s, char #if !defined(LIBUS_USE_EPOLL) && !defined(LIBUS_USE_LIBUV) && !defined(LIBUS_USE_GCD) && !defined(LIBUS_USE_KQUEUE) #if defined(_WIN32) #define LIBUS_USE_LIBUV -#elif defined(__APPLE__) +#elif defined(__APPLE__) || defined(__FreeBSD__) #define LIBUS_USE_KQUEUE #else #define LIBUS_USE_EPOLL diff --git a/src/Vendor/uwebsockets/include/uws/App.h b/src/Vendor/uwebsockets/include/uws/App.h index c80679b..aa1f0a5 100644 --- a/src/Vendor/uwebsockets/include/uws/App.h +++ b/src/Vendor/uwebsockets/include/uws/App.h @@ -1,5 +1,5 @@ /* - * Authored by Alex Hultman, 2018-2019. + * Authored by Alex Hultman, 2018-2020. * Intellectual property of third-party. * Licensed under the Apache License, Version 2.0 (the "License"); @@ -25,42 +25,104 @@ #include "HttpResponse.h" #include "WebSocketContext.h" #include "WebSocket.h" -#include "WebSocketExtensions.h" -#include "WebSocketHandshake.h" +#include "PerMessageDeflate.h" namespace uWS { -/* Compress options (really more like PerMessageDeflateOptions) */ -enum CompressOptions { - /* Compression disabled */ - DISABLED = 0, - /* We compress using a shared non-sliding window. No added memory usage, worse compression. */ - SHARED_COMPRESSOR = 1, - /* We compress using a dedicated sliding window. Major memory usage added, better compression of similarly repeated messages. */ - DEDICATED_COMPRESSOR = 2 -}; + /* This one matches us_socket_context_options_t but has default values */ + struct SocketContextOptions { + const char *key_file_name = nullptr; + const char *cert_file_name = nullptr; + const char *passphrase = nullptr; + const char *dh_params_file_name = nullptr; + const char *ca_file_name = nullptr; + int ssl_prefer_low_memory_usage = 0; + + /* Conversion operator used internally */ + operator struct us_socket_context_options_t() const { + struct us_socket_context_options_t socket_context_options; + memcpy(&socket_context_options, this, sizeof(SocketContextOptions)); + return socket_context_options; + } + }; + + static_assert(sizeof(struct us_socket_context_options_t) == sizeof(SocketContextOptions), "Mismatching uSockets/uWebSockets ABI"); template struct TemplatedApp { private: /* The app always owns at least one http context, but creates websocket contexts on demand */ HttpContext *httpContext; - std::vector *> webSocketContexts; + std::vector *> webSocketContexts; public: + /* Server name */ + TemplatedApp &&addServerName(std::string hostname_pattern, SocketContextOptions options = {}) { + + us_socket_context_add_server_name(SSL, (struct us_socket_context_t *) httpContext, hostname_pattern.c_str(), options); + return std::move(*this); + } + + TemplatedApp &&removeServerName(std::string hostname_pattern) { + + us_socket_context_remove_server_name(SSL, (struct us_socket_context_t *) httpContext, hostname_pattern.c_str()); + return std::move(*this); + } + + TemplatedApp &&missingServerName(MoveOnlyFunction handler) { + + if (!constructorFailed()) { + httpContext->getSocketContextData()->missingServerNameHandler = std::move(handler); + + us_socket_context_on_server_name(SSL, (struct us_socket_context_t *) httpContext, [](struct us_socket_context_t *context, const char *hostname) { + + /* This is the only requirements of being friends with HttpContextData */ + HttpContext *httpContext = (HttpContext *) context; + httpContext->getSocketContextData()->missingServerNameHandler(hostname); + }); + } + + return std::move(*this); + } + + /* Returns the SSL_CTX of this app, or nullptr. */ + void *getNativeHandle() { + return us_socket_context_get_native_handle(SSL, (struct us_socket_context_t *) httpContext); + } + /* Attaches a "filter" function to track socket connections/disconnections */ - void filter(fu2::unique_function *, int)> &&filterHandler) { + void filter(MoveOnlyFunction *, int)> &&filterHandler) { httpContext->filter(std::move(filterHandler)); } - /* Publishes a message to all websocket contexts */ + /* Publishes a message to all websocket contexts - conceptually as if publishing to the one single + * TopicTree of this app (technically there are many TopicTrees, however the concept is that one + * app has one conceptual Topic tree) */ void publish(std::string_view topic, std::string_view message, OpCode opCode, bool compress = false) { for (auto *webSocketContext : webSocketContexts) { webSocketContext->getExt()->publish(topic, message, opCode, compress); } } + /* Returns number of subscribers for this topic, or 0 for failure. + * This function should probably be optimized a lot in future releases, + * it could be O(1) with a hash map of fullnames and their counts. */ + unsigned int numSubscribers(std::string_view topic) { + unsigned int subscribers = 0; + + for (auto *webSocketContext : webSocketContexts) { + auto *webSocketContextData = webSocketContext->getExt(); + + Topic *t = webSocketContextData->topicTree.lookupTopic(topic); + if (t) { + subscribers += t->subs.size(); + } + } + + return subscribers; + } + ~TemplatedApp() { /* Let's just put everything here */ if (httpContext) { @@ -84,42 +146,79 @@ public: webSocketContexts = std::move(other.webSocketContexts); } - TemplatedApp(us_socket_context_options_t options = {}) { - httpContext = uWS::HttpContext::create(uWS::Loop::get(), options); + TemplatedApp(SocketContextOptions options = {}) { + httpContext = HttpContext::create(Loop::get(), options); } bool constructorFailed() { return !httpContext; } + template struct WebSocketBehavior { + /* Disabled compression by default - probably a bad default */ CompressOptions compression = DISABLED; - int maxPayloadLength = 16 * 1024; - int idleTimeout = 120; - int maxBackpressure = 1 * 1024 * 1204; - fu2::unique_function *, HttpRequest *)> open = nullptr; - fu2::unique_function *, std::string_view, uWS::OpCode)> message = nullptr; - fu2::unique_function *)> drain = nullptr; - fu2::unique_function *)> ping = nullptr; - fu2::unique_function *)> pong = nullptr; - fu2::unique_function *, int, std::string_view)> close = nullptr; + /* Maximum message size we can receive */ + unsigned int maxPayloadLength = 16 * 1024; + /* 2 minutes timeout is good */ + unsigned short idleTimeout = 120; + /* 64kb backpressure is probably good */ + unsigned int maxBackpressure = 64 * 1024; + bool closeOnBackpressureLimit = false; + /* This one depends on kernel timeouts and is a bad default */ + bool resetIdleTimeoutOnSend = false; + /* A good default, esp. for newcomers */ + bool sendPingsAutomatically = true; + /* Maximum socket lifetime in seconds before forced closure (defaults to disabled) */ + unsigned short maxLifetime = 0; + MoveOnlyFunction *, HttpRequest *, struct us_socket_context_t *)> upgrade = nullptr; + MoveOnlyFunction *)> open = nullptr; + MoveOnlyFunction *, std::string_view, OpCode)> message = nullptr; + MoveOnlyFunction *)> drain = nullptr; + MoveOnlyFunction *, std::string_view)> ping = nullptr; + MoveOnlyFunction *, std::string_view)> pong = nullptr; + MoveOnlyFunction *, int, std::string_view)> close = nullptr; }; template - TemplatedApp &&ws(std::string pattern, WebSocketBehavior &&behavior) { + TemplatedApp &&ws(std::string pattern, WebSocketBehavior &&behavior) { /* Don't compile if alignment rules cannot be satisfied */ static_assert(alignof(UserData) <= LIBUS_EXT_ALIGNMENT, "µWebSockets cannot satisfy UserData alignment requirements. You need to recompile µSockets with LIBUS_EXT_ALIGNMENT adjusted accordingly."); + if (!httpContext) { + return std::move(*this); + } + + /* Terminate on misleading idleTimeout values */ + if (behavior.idleTimeout && behavior.idleTimeout < 8) { + std::cerr << "Error: idleTimeout must be either 0 or greater than 8!" << std::endl; + std::terminate(); + } + + if (behavior.idleTimeout % 4) { + std::cerr << "Warning: idleTimeout should be a multiple of 4!" << std::endl; + } + /* Every route has its own websocket context with its own behavior and user data type */ - auto *webSocketContext = WebSocketContext::create(Loop::get(), (us_socket_context_t *) httpContext); + auto *webSocketContext = WebSocketContext::create(Loop::get(), (us_socket_context_t *) httpContext); + + /* Add all other WebSocketContextData to this new WebSocketContextData */ + for (WebSocketContext *adjacentWebSocketContext : webSocketContexts) { + webSocketContext->getExt()->adjacentWebSocketContextDatas.push_back(adjacentWebSocketContext->getExt()); + } + + /* Add this WebSocketContextData to all other WebSocketContextData */ + for (WebSocketContext *adjacentWebSocketContext : webSocketContexts) { + adjacentWebSocketContext->getExt()->adjacentWebSocketContextDatas.push_back((WebSocketContextData *) webSocketContext->getExt()); + } /* We need to clear this later on */ - webSocketContexts.push_back(webSocketContext); + webSocketContexts.push_back((WebSocketContext *) webSocketContext); /* Quick fix to disable any compression if set */ #ifdef UWS_NO_ZLIB - behavior.compression = uWS::DISABLED; + behavior.compression = DISABLED; #endif /* If we are the first one to use compression, initialize it */ @@ -130,14 +229,15 @@ public: if (!loopData->zlibContext) { loopData->zlibContext = new ZlibContext; loopData->inflationStream = new InflationStream; - loopData->deflationStream = new DeflationStream; + loopData->deflationStream = new DeflationStream(CompressOptions::DEDICATED_COMPRESSOR); } } /* Copy all handlers */ + webSocketContext->getExt()->openHandler = std::move(behavior.open); webSocketContext->getExt()->messageHandler = std::move(behavior.message); webSocketContext->getExt()->drainHandler = std::move(behavior.drain); - webSocketContext->getExt()->closeHandler = std::move([closeHandler = std::move(behavior.close)](WebSocket *ws, int code, std::string_view message) mutable { + webSocketContext->getExt()->closeHandler = std::move([closeHandler = std::move(behavior.close)](WebSocket *ws, int code, std::string_view message) mutable { if (closeHandler) { closeHandler(ws, code, message); } @@ -150,99 +250,30 @@ public: /* Copy settings */ webSocketContext->getExt()->maxPayloadLength = behavior.maxPayloadLength; - webSocketContext->getExt()->idleTimeout = behavior.idleTimeout; webSocketContext->getExt()->maxBackpressure = behavior.maxBackpressure; + webSocketContext->getExt()->closeOnBackpressureLimit = behavior.closeOnBackpressureLimit; + webSocketContext->getExt()->resetIdleTimeoutOnSend = behavior.resetIdleTimeoutOnSend; + webSocketContext->getExt()->sendPingsAutomatically = behavior.sendPingsAutomatically; + webSocketContext->getExt()->compression = behavior.compression; - httpContext->onHttp("get", pattern, [webSocketContext, httpContext = this->httpContext, behavior = std::move(behavior)](auto *res, auto *req) mutable { + /* Calculate idleTimeoutCompnents */ + webSocketContext->getExt()->calculateIdleTimeoutCompnents(behavior.idleTimeout); + + httpContext->onHttp("get", pattern, [webSocketContext, behavior = std::move(behavior)](auto *res, auto *req) mutable { /* If we have this header set, it's a websocket */ std::string_view secWebSocketKey = req->getHeader("sec-websocket-key"); if (secWebSocketKey.length() == 24) { - /* Note: OpenSSL can be used here to speed this up somewhat */ - char secWebSocketAccept[29] = {}; - WebSocketHandshake::generate(secWebSocketKey.data(), secWebSocketAccept); - res->writeStatus("101 Switching Protocols") - ->writeHeader("Upgrade", "websocket") - ->writeHeader("Connection", "Upgrade") - ->writeHeader("Sec-WebSocket-Accept", secWebSocketAccept); + /* Emit upgrade handler */ + if (behavior.upgrade) { + behavior.upgrade(res, req, (struct us_socket_context_t *) webSocketContext); + } else { + /* Default handler upgrades to WebSocket */ + std::string_view secWebSocketProtocol = req->getHeader("sec-websocket-protocol"); + std::string_view secWebSocketExtensions = req->getHeader("sec-websocket-extensions"); - /* Select first subprotocol if present */ - std::string_view secWebSocketProtocol = req->getHeader("sec-websocket-protocol"); - if (secWebSocketProtocol.length()) { - res->writeHeader("Sec-WebSocket-Protocol", secWebSocketProtocol.substr(0, secWebSocketProtocol.find(','))); - } - - /* Negotiate compression */ - bool perMessageDeflate = false; - bool slidingDeflateWindow = false; - if (behavior.compression != DISABLED) { - std::string_view extensions = req->getHeader("sec-websocket-extensions"); - if (extensions.length()) { - /* We never support client context takeover (the client cannot compress with a sliding window). */ - int wantedOptions = PERMESSAGE_DEFLATE | CLIENT_NO_CONTEXT_TAKEOVER; - - /* Shared compressor is the default */ - if (behavior.compression == SHARED_COMPRESSOR) { - /* Disable per-socket compressor */ - wantedOptions |= SERVER_NO_CONTEXT_TAKEOVER; - } - - /* isServer = true */ - ExtensionsNegotiator extensionsNegotiator(wantedOptions); - extensionsNegotiator.readOffer(extensions); - - /* Todo: remove these mid string copies */ - std::string offer = extensionsNegotiator.generateOffer(); - if (offer.length()) { - res->writeHeader("Sec-WebSocket-Extensions", offer); - } - - /* Did we negotiate permessage-deflate? */ - if (extensionsNegotiator.getNegotiatedOptions() & PERMESSAGE_DEFLATE) { - perMessageDeflate = true; - } - - /* Is the server allowed to compress with a sliding window? */ - if (!(extensionsNegotiator.getNegotiatedOptions() & SERVER_NO_CONTEXT_TAKEOVER)) { - slidingDeflateWindow = true; - } - } - } - - /* This will add our mark */ - res->upgrade(); - - /* Move any backpressure */ - std::string backpressure(std::move(((AsyncSocketData *) res->getHttpResponseData())->buffer)); - - /* Keep any fallback buffer alive until we returned from open event, keeping req valid */ - std::string fallback(std::move(res->getHttpResponseData()->salvageFallbackBuffer())); - - /* Destroy HttpResponseData */ - res->getHttpResponseData()->~HttpResponseData(); - - /* Adopting a socket invalidates it, do not rely on it directly to carry any data */ - WebSocket *webSocket = (WebSocket *) us_socket_context_adopt_socket(SSL, - (us_socket_context_t *) webSocketContext, (us_socket_t *) res, sizeof(WebSocketData) + sizeof(UserData)); - - /* Update corked socket in case we got a new one (assuming we always are corked in handlers). */ - webSocket->AsyncSocket::cork(); - - /* Initialize websocket with any moved backpressure intact */ - httpContext->upgradeToWebSocket( - webSocket->init(perMessageDeflate, slidingDeflateWindow, std::move(backpressure)) - ); - - /* Arm idleTimeout */ - us_socket_timeout(SSL, (us_socket_t *) webSocket, behavior.idleTimeout); - - /* Default construct the UserData right before calling open handler */ - new (webSocket->getUserData()) UserData; - - /* Emit open event and start the timeout */ - if (behavior.open) { - behavior.open(webSocket, req); + res->template upgrade({}, secWebSocketKey, secWebSocketProtocol, secWebSocketExtensions, (struct us_socket_context_t *) webSocketContext); } /* We are going to get uncorked by the Http get return */ @@ -257,84 +288,104 @@ public: return std::move(*this); } - TemplatedApp &&get(std::string pattern, fu2::unique_function *, HttpRequest *)> &&handler) { - httpContext->onHttp("get", pattern, std::move(handler)); + TemplatedApp &&get(std::string pattern, MoveOnlyFunction *, HttpRequest *)> &&handler) { + if (httpContext) { + httpContext->onHttp("get", pattern, std::move(handler)); + } return std::move(*this); } - TemplatedApp &&post(std::string pattern, fu2::unique_function *, HttpRequest *)> &&handler) { - httpContext->onHttp("post", pattern, std::move(handler)); + TemplatedApp &&post(std::string pattern, MoveOnlyFunction *, HttpRequest *)> &&handler) { + if (httpContext) { + httpContext->onHttp("post", pattern, std::move(handler)); + } return std::move(*this); } - TemplatedApp &&options(std::string pattern, fu2::unique_function *, HttpRequest *)> &&handler) { - httpContext->onHttp("options", pattern, std::move(handler)); + TemplatedApp &&options(std::string pattern, MoveOnlyFunction *, HttpRequest *)> &&handler) { + if (httpContext) { + httpContext->onHttp("options", pattern, std::move(handler)); + } return std::move(*this); } - TemplatedApp &&del(std::string pattern, fu2::unique_function *, HttpRequest *)> &&handler) { - httpContext->onHttp("delete", pattern, std::move(handler)); + TemplatedApp &&del(std::string pattern, MoveOnlyFunction *, HttpRequest *)> &&handler) { + if (httpContext) { + httpContext->onHttp("delete", pattern, std::move(handler)); + } return std::move(*this); } - TemplatedApp &&patch(std::string pattern, fu2::unique_function *, HttpRequest *)> &&handler) { - httpContext->onHttp("patch", pattern, std::move(handler)); + TemplatedApp &&patch(std::string pattern, MoveOnlyFunction *, HttpRequest *)> &&handler) { + if (httpContext) { + httpContext->onHttp("patch", pattern, std::move(handler)); + } return std::move(*this); } - TemplatedApp &&put(std::string pattern, fu2::unique_function *, HttpRequest *)> &&handler) { - httpContext->onHttp("put", pattern, std::move(handler)); + TemplatedApp &&put(std::string pattern, MoveOnlyFunction *, HttpRequest *)> &&handler) { + if (httpContext) { + httpContext->onHttp("put", pattern, std::move(handler)); + } return std::move(*this); } - TemplatedApp &&head(std::string pattern, fu2::unique_function *, HttpRequest *)> &&handler) { - httpContext->onHttp("head", pattern, std::move(handler)); + TemplatedApp &&head(std::string pattern, MoveOnlyFunction *, HttpRequest *)> &&handler) { + if (httpContext) { + httpContext->onHttp("head", pattern, std::move(handler)); + } return std::move(*this); } - TemplatedApp &&connect(std::string pattern, fu2::unique_function *, HttpRequest *)> &&handler) { - httpContext->onHttp("connect", pattern, std::move(handler)); + TemplatedApp &&connect(std::string pattern, MoveOnlyFunction *, HttpRequest *)> &&handler) { + if (httpContext) { + httpContext->onHttp("connect", pattern, std::move(handler)); + } return std::move(*this); } - TemplatedApp &&trace(std::string pattern, fu2::unique_function *, HttpRequest *)> &&handler) { - httpContext->onHttp("trace", pattern, std::move(handler)); + TemplatedApp &&trace(std::string pattern, MoveOnlyFunction *, HttpRequest *)> &&handler) { + if (httpContext) { + httpContext->onHttp("trace", pattern, std::move(handler)); + } return std::move(*this); } /* This one catches any method */ - TemplatedApp &&any(std::string pattern, fu2::unique_function *, HttpRequest *)> &&handler) { - httpContext->onHttp("*", pattern, std::move(handler)); + TemplatedApp &&any(std::string pattern, MoveOnlyFunction *, HttpRequest *)> &&handler) { + if (httpContext) { + httpContext->onHttp("*", pattern, std::move(handler)); + } return std::move(*this); } /* Host, port, callback */ - TemplatedApp &&listen(std::string host, int port, fu2::unique_function &&handler) { + TemplatedApp &&listen(std::string host, int port, MoveOnlyFunction &&handler) { if (!host.length()) { return listen(port, std::move(handler)); } - handler(httpContext->listen(host.c_str(), port, 0)); + handler(httpContext ? httpContext->listen(host.c_str(), port, 0) : nullptr); return std::move(*this); } /* Host, port, options, callback */ - TemplatedApp &&listen(std::string host, int port, int options, fu2::unique_function &&handler) { + TemplatedApp &&listen(std::string host, int port, int options, MoveOnlyFunction &&handler) { if (!host.length()) { return listen(port, options, std::move(handler)); } - handler(httpContext->listen(host.c_str(), port, options)); + handler(httpContext ? httpContext->listen(host.c_str(), port, options) : nullptr); return std::move(*this); } /* Port, callback */ - TemplatedApp &&listen(int port, fu2::unique_function &&handler) { - handler(httpContext->listen(nullptr, port, 0)); + TemplatedApp &&listen(int port, MoveOnlyFunction &&handler) { + handler(httpContext ? httpContext->listen(nullptr, port, 0) : nullptr); return std::move(*this); } /* Port, options, callback */ - TemplatedApp &&listen(int port, int options, fu2::unique_function &&handler) { - handler(httpContext->listen(nullptr, port, options)); + TemplatedApp &&listen(int port, int options, MoveOnlyFunction &&handler) { + handler(httpContext ? httpContext->listen(nullptr, port, options) : nullptr); return std::move(*this); } diff --git a/src/Vendor/uwebsockets/include/uws/AsyncSocket.h b/src/Vendor/uwebsockets/include/uws/AsyncSocket.h index 3957fed..743f406 100644 --- a/src/Vendor/uwebsockets/include/uws/AsyncSocket.h +++ b/src/Vendor/uwebsockets/include/uws/AsyncSocket.h @@ -1,5 +1,5 @@ /* - * Authored by Alex Hultman, 2018-2019. + * Authored by Alex Hultman, 2018-2020. * Intellectual property of third-party. * Licensed under the Apache License, Version 2.0 (the "License"); @@ -20,21 +20,30 @@ /* This class implements async socket memory management strategies */ +/* NOTE: Many unsigned/signed conversion warnings could be solved by moving from int length + * to unsigned length for everything to/from uSockets - this would however remove the opportunity + * to signal error with -1 (which is how the entire UNIX syscalling is built). */ + #include "LoopData.h" #include "AsyncSocketData.h" namespace uWS { - template struct WebSocketContext; + template struct WebSocketContext; template struct AsyncSocket { template friend struct HttpContext; - template friend struct WebSocketContext; - template friend struct WebSocketContextData; + template friend struct WebSocketContext; + template friend struct WebSocketContextData; friend struct TopicTree; protected: + /* Returns SSL pointer or FD as pointer */ + void *getNativeHandle() { + return us_socket_get_native_handle(SSL, (us_socket_t *) this); + } + /* Get loop data for socket */ LoopData *getLoopData() { return (LoopData *) us_loop_ext(us_socket_context_loop(SSL, us_socket_context(SSL, (us_socket_t *) this))); @@ -57,7 +66,7 @@ protected: /* Immediately close socket */ us_socket_t *close() { - return us_socket_close(SSL, (us_socket_t *) this); + return us_socket_close(SSL, (us_socket_t *) this, 0, nullptr); } /* Cork this socket. Only one socket may ever be corked per-loop at any given time */ @@ -82,7 +91,7 @@ protected: LoopData *loopData = getLoopData(); if (loopData->corkedSocket == this && loopData->corkOffset + size < LoopData::CORK_BUFFER_SIZE) { char *sendBuffer = loopData->corkBuffer + loopData->corkOffset; - loopData->corkOffset += (int) size; + loopData->corkOffset += (unsigned int) size; return {sendBuffer, false}; } else { /* Slow path for now, we want to always be corked if possible */ @@ -91,8 +100,30 @@ protected: } /* Returns the user space backpressure. */ - int getBufferedAmount() { - return (int) getAsyncSocketData()->buffer.size(); + unsigned int getBufferedAmount() { + return (unsigned int) getAsyncSocketData()->buffer.size(); + } + + /* Returns the text representation of an IPv4 or IPv6 address */ + std::string_view addressAsText(std::string_view binary) { + static thread_local char buf[64]; + int ipLength = 0; + + if (!binary.length()) { + return {}; + } + + unsigned char *b = (unsigned char *) binary.data(); + + if (binary.length() == 4) { + ipLength = sprintf(buf, "%u.%u.%u.%u", b[0], b[1], b[2], b[3]); + } else { + ipLength = sprintf(buf, "%02x%02x:%02x%02x:%02x%02x:%02x%02x:%02x%02x:%02x%02x:%02x%02x:%02x%02x", + b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7], b[8], b[9], b[10], b[11], + b[12], b[13], b[14], b[15]); + } + + return {buf, (unsigned int) ipLength}; } /* Returns the remote IP address or empty string on failure */ @@ -100,7 +131,12 @@ protected: static thread_local char buf[16]; int ipLength = 16; us_socket_remote_address(SSL, (us_socket_t *) this, buf, &ipLength); - return std::string_view(buf, ipLength); + return std::string_view(buf, (unsigned int) ipLength); + } + + /* Returns the text representation of IP */ + std::string_view getRemoteAddressAsText() { + return addressAsText(getRemoteAddress()); } /* Write in three levels of prioritization: cork-buffer, syscall, socket-buffer. Always drain if possible. @@ -124,14 +160,14 @@ protected: if ((unsigned int) written < asyncSocketData->buffer.length()) { /* Update buffering (todo: we can do better here if we keep track of what happens to this guy later on) */ - asyncSocketData->buffer = asyncSocketData->buffer.substr(written); + asyncSocketData->buffer = asyncSocketData->buffer.substr((size_t) written); if (optionally) { /* Thankfully we can exit early here */ return {0, true}; } else { /* This path is horrible and points towards erroneous usage */ - asyncSocketData->buffer.append(src, length); + asyncSocketData->buffer.append(src, (unsigned int) length); return {length, true}; } @@ -144,21 +180,21 @@ protected: if (length) { if (loopData->corkedSocket == this) { /* We are corked */ - if (LoopData::CORK_BUFFER_SIZE - loopData->corkOffset >= length) { + if (LoopData::CORK_BUFFER_SIZE - loopData->corkOffset >= (unsigned int) length) { /* If the entire chunk fits in cork buffer */ - memcpy(loopData->corkBuffer + loopData->corkOffset, src, length); - loopData->corkOffset += length; + memcpy(loopData->corkBuffer + loopData->corkOffset, src, (unsigned int) length); + loopData->corkOffset += (unsigned int) length; /* Fall through to default return */ } else { /* Strategy differences between SSL and non-SSL regarding syscall minimizing */ if constexpr (SSL) { /* Cork up as much as we can */ - int stripped = LoopData::CORK_BUFFER_SIZE - loopData->corkOffset; + unsigned int stripped = LoopData::CORK_BUFFER_SIZE - loopData->corkOffset; memcpy(loopData->corkBuffer + loopData->corkOffset, src, stripped); loopData->corkOffset = LoopData::CORK_BUFFER_SIZE; - auto [written, failed] = uncork(src + stripped, length - stripped, optionally); - return {written + stripped, failed}; + auto [written, failed] = uncork(src + stripped, length - (int) stripped, optionally); + return {written + (int) stripped, failed}; } /* For non-SSL we take the penalty of two syscalls */ @@ -178,11 +214,11 @@ protected: /* Fall back to worst possible case (should be very rare for HTTP) */ /* At least we can reserve room for next chunk if we know it up front */ if (nextLength) { - asyncSocketData->buffer.reserve(asyncSocketData->buffer.length() + length - written + nextLength); + asyncSocketData->buffer.reserve(asyncSocketData->buffer.length() + (size_t) (length - written + nextLength)); } /* Buffer this chunk */ - asyncSocketData->buffer.append(src + written, length - written); + asyncSocketData->buffer.append(src + written, (size_t) (length - written)); /* Return the failure */ return {length, true}; @@ -205,7 +241,7 @@ protected: if (loopData->corkOffset) { /* Corked data is already accounted for via its write call */ - auto [written, failed] = write(loopData->corkBuffer, loopData->corkOffset, false, length); + auto [written, failed] = write(loopData->corkBuffer, (int) loopData->corkOffset, false, length); loopData->corkOffset = 0; if (failed) { diff --git a/src/Vendor/uwebsockets/include/uws/BloomFilter.h b/src/Vendor/uwebsockets/include/uws/BloomFilter.h new file mode 100644 index 0000000..95ced77 --- /dev/null +++ b/src/Vendor/uwebsockets/include/uws/BloomFilter.h @@ -0,0 +1,65 @@ +/* + * Authored by Alex Hultman, 2018-2019. + * Intellectual property of third-party. + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef UWS_BLOOMFILTER_H +#define UWS_BLOOMFILTER_H + +/* This filter has a decently low amount of false positives for the + * standard and non-standard common request headers */ + +#include +#include + +namespace uWS { + +struct BloomFilter { +private: + std::bitset<512> filter; + + unsigned int hash1(std::string_view key) { + return ((size_t)key[key.length() - 1] - (key.length() << 3)) & 511; + } + + unsigned int hash2(std::string_view key) { + return (((size_t)key[0] + (key.length() << 4)) & 511); + } + + unsigned int hash3(std::string_view key) { + return ((unsigned int)key[key.length() - 2] - 97 - (key.length() << 5)) & 511; + } + +public: + bool mightHave(std::string_view key) { + return filter.test(hash1(key)) && filter.test(hash2(key)) && (key.length() < 2 || filter.test(hash3(key))); + } + + void add(std::string_view key) { + filter.set(hash1(key)); + filter.set(hash2(key)); + if (key.length() >= 2) { + filter.set(hash3(key)); + } + } + + void reset() { + filter.reset(); + } +}; + +} + +#endif // UWS_BLOOMFILTER_H \ No newline at end of file diff --git a/src/Vendor/uwebsockets/include/uws/HttpContext.h b/src/Vendor/uwebsockets/include/uws/HttpContext.h index 201d45d..70a74e1 100644 --- a/src/Vendor/uwebsockets/include/uws/HttpContext.h +++ b/src/Vendor/uwebsockets/include/uws/HttpContext.h @@ -1,5 +1,5 @@ /* - * Authored by Alex Hultman, 2018-2019. + * Authored by Alex Hultman, 2018-2020. * Intellectual property of third-party. * Licensed under the Apache License, Version 2.0 (the "License"); @@ -24,10 +24,11 @@ #include "HttpContextData.h" #include "HttpResponseData.h" #include "AsyncSocket.h" +#include "WebSocketData.h" #include #include -#include "f2/function2.hpp" +#include "MoveOnlyFunction.h" namespace uWS { template struct HttpResponse; @@ -35,6 +36,7 @@ template struct HttpResponse; template struct HttpContext { template friend struct TemplatedApp; + template friend struct HttpResponse; private: HttpContext() = delete; @@ -60,7 +62,7 @@ private: /* Init the HttpContext by registering libusockets event handlers */ HttpContext *init() { /* Handle socket connections */ - us_socket_context_on_open(SSL, getSocketContext(), [](us_socket_t *s, int is_client, char *ip, int ip_length) { + us_socket_context_on_open(SSL, getSocketContext(), [](us_socket_t *s, int /*is_client*/, char */*ip*/, int /*ip_length*/) { /* Any connected socket should timeout until it has a request */ us_socket_timeout(SSL, s, HTTP_IDLE_TIMEOUT_S); @@ -77,7 +79,7 @@ private: }); /* Handle socket disconnections */ - us_socket_context_on_close(SSL, getSocketContext(), [](us_socket_t *s) { + us_socket_context_on_close(SSL, getSocketContext(), [](us_socket_t *s, int /*code*/, void */*reason*/) { /* Get socket ext */ HttpResponseData *httpResponseData = (HttpResponseData *) us_socket_ext(SSL, s); @@ -119,11 +121,19 @@ private: /* Cork this socket */ ((AsyncSocket *) s)->cork(); + /* Mark that we are inside the parser now */ + httpContextData->isParsingHttp = true; + // clients need to know the cursor after http parse, not servers! // how far did we read then? we need to know to continue with websocket parsing data? or? + void *proxyParser = nullptr; +#ifdef UWS_WITH_PROXY + proxyParser = &httpResponseData->proxyParser; +#endif + /* The return value is entirely up to us to interpret. The HttpParser only care for whether the returned value is DIFFERENT or not from passed user */ - void *returnedSocket = httpResponseData->consumePostPadded(data, length, s, [httpContextData](void *s, uWS::HttpRequest *httpRequest) -> void * { + void *returnedSocket = httpResponseData->consumePostPadded(data, (unsigned int) length, s, proxyParser, [httpContextData](void *s, HttpRequest *httpRequest) -> void * { /* For every request we reset the timeout and hang until user makes action */ /* Warning: if we are in shutdown state, resetting the timer is a security issue! */ us_socket_timeout(SSL, (us_socket_t *) s, 0); @@ -134,18 +144,23 @@ private: /* Are we not ready for another request yet? Terminate the connection. */ if (httpResponseData->state & HttpResponseData::HTTP_RESPONSE_PENDING) { - us_socket_close(SSL, (us_socket_t *) s); + us_socket_close(SSL, (us_socket_t *) s, 0, nullptr); return nullptr; } /* Mark pending request and emit it */ httpResponseData->state = HttpResponseData::HTTP_RESPONSE_PENDING; + /* Mark this response as connectionClose if ancient or connection: close */ + if (httpRequest->isAncient() || httpRequest->getHeader("connection").length() == 5) { + httpResponseData->state |= HttpResponseData::HTTP_CONNECTION_CLOSE; + } + /* Route the method and URL */ httpContextData->router.getUserData() = {(HttpResponse *) s, httpRequest}; if (!httpContextData->router.route(httpRequest->getMethod(), httpRequest->getUrl())) { /* We have to force close this socket as we have no handler for it */ - us_socket_close(SSL, (us_socket_t *) s); + us_socket_close(SSL, (us_socket_t *) s, 0, nullptr); return nullptr; } @@ -205,7 +220,7 @@ private: if (us_socket_is_shut_down(SSL, (us_socket_t *) user)) { return nullptr; } - + /* If we were given the last data chunk, reset data handler to ensure following * requests on the same socket won't trigger any previously registered behavior */ if (fin) { @@ -215,10 +230,13 @@ private: return user; }, [](void *user) { /* Close any socket on HTTP errors */ - us_socket_close(SSL, (us_socket_t *) user); + us_socket_close(SSL, (us_socket_t *) user, 0, nullptr); return nullptr; }); + /* Mark that we are no longer parsing Http */ + httpContextData->isParsingHttp = false; + /* We need to uncork in all cases, except for nullptr (closed socket, or upgraded socket) */ if (returnedSocket != nullptr) { /* Timeout on uncork failure */ @@ -229,6 +247,18 @@ private: ((AsyncSocket *) s)->timeout(HTTP_IDLE_TIMEOUT_S); } + /* We need to check if we should close this socket here now */ + if (httpResponseData->state & HttpResponseData::HTTP_CONNECTION_CLOSE) { + if ((httpResponseData->state & HttpResponseData::HTTP_RESPONSE_PENDING) == 0) { + if (((AsyncSocket *) s)->getBufferedAmount() == 0) { + ((AsyncSocket *) s)->shutdown(); + /* We need to force close after sending FIN since we want to hinder + * clients from keeping to send their huge data */ + ((AsyncSocket *) s)->close(); + } + } + } + return (us_socket_t *) returnedSocket; } @@ -238,7 +268,16 @@ private: AsyncSocket *asyncSocket = (AsyncSocket *) httpContextData->upgradedWebSocket; /* Uncork here as well (note: what if we failed to uncork and we then pub/sub before we even upgraded?) */ - /*auto [written, failed] = */asyncSocket->uncork(); + auto [written, failed] = asyncSocket->uncork(); + + /* If we succeeded in uncorking, check if we have sent WebSocket FIN */ + if (!failed) { + WebSocketData *webSocketData = (WebSocketData *) asyncSocket->getAsyncSocketData(); + if (webSocketData->isShuttingDown) { + /* In that case, also send TCP FIN (this is similar to what we have in ws drain handler) */ + asyncSocket->shutdown(); + } + } /* Reset upgradedWebSocket before we return */ httpContextData->upgradedWebSocket = nullptr; @@ -283,6 +322,18 @@ private: /* Drain any socket buffer, this might empty our backpressure and thus finish the request */ /*auto [written, failed] = */asyncSocket->write(nullptr, 0, true, 0); + /* Should we close this connection after a response - and is this response really done? */ + if (httpResponseData->state & HttpResponseData::HTTP_CONNECTION_CLOSE) { + if ((httpResponseData->state & HttpResponseData::HTTP_RESPONSE_PENDING) == 0) { + if (asyncSocket->getBufferedAmount() == 0) { + asyncSocket->shutdown(); + /* We need to force close after sending FIN since we want to hinder + * clients from keeping to send their huge data */ + asyncSocket->close(); + } + } + } + /* Expect another writable event, or another request within the timeout */ asyncSocket->timeout(HTTP_IDLE_TIMEOUT_S); @@ -310,13 +361,6 @@ private: return this; } - /* Used by App in its WebSocket handler */ - void upgradeToWebSocket(void *newSocket) { - HttpContextData *httpContextData = getSocketContextData(); - - httpContextData->upgradedWebSocket = newSocket; - } - public: /* Construct a new HttpContext using specified loop */ static HttpContext *create(Loop *loop, us_socket_context_options_t options = {}) { @@ -343,12 +387,12 @@ public: us_socket_context_free(SSL, getSocketContext()); } - void filter(fu2::unique_function *, int)> &&filterHandler) { + void filter(MoveOnlyFunction *, int)> &&filterHandler) { getSocketContextData()->filterHandlers.emplace_back(std::move(filterHandler)); } /* Register an HTTP route handler acording to URL pattern */ - void onHttp(std::string method, std::string pattern, fu2::unique_function *, HttpRequest *)> &&handler, bool upgrade = false) { + void onHttp(std::string method, std::string pattern, MoveOnlyFunction *, HttpRequest *)> &&handler, bool upgrade = false) { HttpContextData *httpContextData = getSocketContextData(); /* Todo: This is ugly, fix */ @@ -363,6 +407,13 @@ public: auto user = r->getUserData(); user.httpRequest->setYield(false); user.httpRequest->setParameters(r->getParameters()); + + /* Middleware? Automatically respond to expectations */ + std::string_view expect = user.httpRequest->getHeader("expect"); + if (expect.length() && expect == "100-continue") { + user.httpResponse->writeContinue(); + } + handler(user.httpResponse, user.httpRequest); /* If any handler yielded, the router will keep looking for a suitable handler. */ diff --git a/src/Vendor/uwebsockets/include/uws/HttpContextData.h b/src/Vendor/uwebsockets/include/uws/HttpContextData.h index bfea9c8..9375994 100644 --- a/src/Vendor/uwebsockets/include/uws/HttpContextData.h +++ b/src/Vendor/uwebsockets/include/uws/HttpContextData.h @@ -1,5 +1,5 @@ /* - * Authored by Alex Hultman, 2018-2019. + * Authored by Alex Hultman, 2018-2020. * Intellectual property of third-party. * Licensed under the Apache License, Version 2.0 (the "License"); @@ -21,7 +21,7 @@ #include "HttpRouter.h" #include -#include "f2/function2.hpp" +#include "MoveOnlyFunction.h" namespace uWS { template struct HttpResponse; @@ -31,8 +31,11 @@ template struct alignas(16) HttpContextData { template friend struct HttpContext; template friend struct HttpResponse; + template friend struct TemplatedApp; private: - std::vector *, int)>> filterHandlers; + std::vector *, int)>> filterHandlers; + + MoveOnlyFunction missingServerNameHandler; struct RouterData { HttpResponse *httpResponse; @@ -41,6 +44,7 @@ private: HttpRouter router; void *upgradedWebSocket = nullptr; + bool isParsingHttp = false; }; } diff --git a/src/Vendor/uwebsockets/include/uws/HttpParser.h b/src/Vendor/uwebsockets/include/uws/HttpParser.h index a1c21e7..8e33335 100644 --- a/src/Vendor/uwebsockets/include/uws/HttpParser.h +++ b/src/Vendor/uwebsockets/include/uws/HttpParser.h @@ -1,5 +1,5 @@ /* - * Authored by Alex Hultman, 2018-2019. + * Authored by Alex Hultman, 2018-2020. * Intellectual property of third-party. * Licensed under the Apache License, Version 2.0 (the "License"); @@ -25,12 +25,16 @@ #include #include #include -#include "f2/function2.hpp" +#include "MoveOnlyFunction.h" + +#include "BloomFilter.h" +#include "ProxyParser.h" +#include "QueryParser.h" namespace uWS { /* We require at least this much post padding */ -static const int MINIMUM_HTTP_POST_PADDING = 32; +static const unsigned int MINIMUM_HTTP_POST_PADDING = 32; struct HttpRequest { @@ -41,12 +45,17 @@ private: struct Header { std::string_view key, value; } headers[MAX_HEADERS]; - int querySeparator; + bool ancientHttp; + unsigned int querySeparator; bool didYield; - + BloomFilter bf; std::pair currentParameters; public: + bool isAncient() { + return ancientHttp; + } + bool getYield() { return didYield; } @@ -87,9 +96,11 @@ public: } std::string_view getHeader(std::string_view lowerCasedHeader) { - for (Header *h = headers; (++h)->key.length(); ) { - if (h->key.length() == lowerCasedHeader.length() && !strncmp(h->key.data(), lowerCasedHeader.data(), lowerCasedHeader.length())) { - return h->value; + if (bf.mightHave(lowerCasedHeader)) { + for (Header *h = headers; (++h)->key.length(); ) { + if (h->key.length() == lowerCasedHeader.length() && !strncmp(h->key.data(), lowerCasedHeader.data(), lowerCasedHeader.length())) { + return h->value; + } } } return std::string_view(nullptr, 0); @@ -103,8 +114,9 @@ public: return std::string_view(headers->key.data(), headers->key.length()); } + /* Returns the raw querystring as a whole, still encoded */ std::string_view getQuery() { - if (querySeparator < (int) headers->value.length()) { + if (querySeparator < headers->value.length()) { /* Strip the initial ? */ return std::string_view(headers->value.data() + querySeparator + 1, headers->value.length() - querySeparator - 1); } else { @@ -112,11 +124,19 @@ public: } } + /* Finds and decodes the URI component. */ + std::string_view getQuery(std::string_view key) { + /* Raw querystring including initial '?' sign */ + std::string_view queryString = std::string_view(headers->value.data() + querySeparator, headers->value.length() - querySeparator); + + return getDecodedQueryValue(key, queryString); + } + void setParameters(std::pair parameters) { currentParameters = parameters; } - std::string_view getParameter(unsigned int index) { + std::string_view getParameter(unsigned short index) { if (currentParameters.first < (int) index) { return {}; } else { @@ -136,15 +156,40 @@ private: static unsigned int toUnsignedInteger(std::string_view str) { unsigned int unsignedIntegerValue = 0; - for (unsigned char c : str) { - unsignedIntegerValue = unsignedIntegerValue * 10 + (c - '0'); + for (char c : str) { + unsignedIntegerValue = unsignedIntegerValue * 10u + ((unsigned int) c - (unsigned int) '0'); } return unsignedIntegerValue; } - static unsigned int getHeaders(char *postPaddedBuffer, char *end, struct HttpRequest::Header *headers) { + static unsigned int getHeaders(char *postPaddedBuffer, char *end, struct HttpRequest::Header *headers, void *reserved) { char *preliminaryKey, *preliminaryValue, *start = postPaddedBuffer; + #ifdef UWS_WITH_PROXY + /* ProxyParser is passed as reserved parameter */ + ProxyParser *pp = (ProxyParser *) reserved; + + /* Parse PROXY protocol */ + auto [done, offset] = pp->parse({start, (size_t) (end - postPaddedBuffer)}); + if (!done) { + /* We do not reset the ProxyParser (on filure) since it is tied to this + * connection, which is really only supposed to ever get one PROXY frame + * anyways. We do however allow multiple PROXY frames to be sent (overwrites former). */ + return 0; + } else { + /* We have consumed this data so skip it */ + start += offset; + } + #else + /* This one is unused */ + (void) reserved; + #endif + + /* It is critical for fallback buffering logic that we only return with success + * if we managed to parse a complete HTTP request (minus data). Returning success + * for PROXY means we can end up succeeding, yet leaving bytes in the fallback buffer + * which is then removed, and our counters to flip due to overflow and we end up with a crash */ + for (unsigned int i = 0; i < HttpRequest::MAX_HEADERS; i++) { for (preliminaryKey = postPaddedBuffer; (*postPaddedBuffer != ':') & (*postPaddedBuffer > 32); *(postPaddedBuffer++) |= 32); if (*postPaddedBuffer == '\r') { @@ -158,7 +203,7 @@ private: headers->key = std::string_view(preliminaryKey, (size_t) (postPaddedBuffer - preliminaryKey)); for (postPaddedBuffer++; (*postPaddedBuffer == ':' || *postPaddedBuffer < 33) && *postPaddedBuffer != '\r'; postPaddedBuffer++); preliminaryValue = postPaddedBuffer; - postPaddedBuffer = (char *) memchr(postPaddedBuffer, '\r', end - postPaddedBuffer); + postPaddedBuffer = (char *) memchr(postPaddedBuffer, '\r', (size_t) (end - postPaddedBuffer)); if (postPaddedBuffer && postPaddedBuffer[1] == '\n') { headers->value = std::string_view(preliminaryValue, (size_t) (postPaddedBuffer - preliminaryValue)); postPaddedBuffer += 2; @@ -173,20 +218,34 @@ private: // the only caller of getHeaders template - std::pair fenceAndConsumePostPadded(char *data, int length, void *user, HttpRequest *req, fu2::unique_function &requestHandler, fu2::unique_function &dataHandler) { - int consumedTotal = 0; + std::pair fenceAndConsumePostPadded(char *data, unsigned int length, void *user, void *reserved, HttpRequest *req, MoveOnlyFunction &requestHandler, MoveOnlyFunction &dataHandler) { + + /* How much data we CONSUMED (to throw away) */ + unsigned int consumedTotal = 0; + + /* Fence one byte past end of our buffer (buffer has post padded margins) */ data[length] = '\r'; - for (int consumed; length && (consumed = getHeaders(data, data + length, req->headers)); ) { + for (unsigned int consumed; length && (consumed = getHeaders(data, data + length, req->headers, reserved)); ) { data += consumed; length -= consumed; consumedTotal += consumed; - req->headers->value = std::string_view(req->headers->value.data(), std::max(0, (int) req->headers->value.length() - 9)); + /* Store HTTP version (ancient 1.0 or 1.1) */ + req->ancientHttp = req->headers->value.length() && (req->headers->value[req->headers->value.length() - 1] == '0'); + + /* Strip away tail of first "header value" aka URL */ + req->headers->value = std::string_view(req->headers->value.data(), (size_t) std::max(0, (int) req->headers->value.length() - 9)); + + /* Add all headers to bloom filter */ + req->bf.reset(); + for (HttpRequest::Header *h = req->headers; (++h)->key.length(); ) { + req->bf.add(h->key); + } /* Parse query */ const char *querySeparatorPtr = (const char *) memchr(req->headers->value.data(), '?', req->headers->value.length()); - req->querySeparator = (int) ((querySeparatorPtr ? querySeparatorPtr : req->headers->value.data() + req->headers->value.length()) - req->headers->value.data()); + req->querySeparator = (unsigned int) ((querySeparatorPtr ? querySeparatorPtr : req->headers->value.data() + req->headers->value.length()) - req->headers->value.data()); /* If returned socket is not what we put in we need * to break here as we either have upgraded to @@ -225,22 +284,18 @@ private: } public: + void *consumePostPadded(char *data, unsigned int length, void *user, void *reserved, MoveOnlyFunction &&requestHandler, MoveOnlyFunction &&dataHandler, MoveOnlyFunction &&errorHandler) { - /* We do this to prolong the validity of parsed headers by keeping only the fallback buffer alive */ - std::string &&salvageFallbackBuffer() { - return std::move(fallback); - } - - void *consumePostPadded(char *data, int length, void *user, fu2::unique_function &&requestHandler, fu2::unique_function &&dataHandler, fu2::unique_function &&errorHandler) { - + /* This resets BloomFilter by construction, but later we also reset it again. + * Optimize this to skip resetting twice (req could be made global) */ HttpRequest req; if (remainingStreamingBytes) { // this is exactly the same as below! // todo: refactor this - if (remainingStreamingBytes >= (unsigned int) length) { - void *returnedUser = dataHandler(user, std::string_view(data, length), remainingStreamingBytes == (unsigned int) length); + if (remainingStreamingBytes >= length) { + void *returnedUser = dataHandler(user, std::string_view(data, length), remainingStreamingBytes == length); remainingStreamingBytes -= length; return returnedUser; } else { @@ -257,24 +312,26 @@ public: } } else if (fallback.length()) { - int had = (int) fallback.length(); + unsigned int had = (unsigned int) fallback.length(); - int maxCopyDistance = (int) std::min(MAX_FALLBACK_SIZE - fallback.length(), (size_t) length); + size_t maxCopyDistance = std::min(MAX_FALLBACK_SIZE - fallback.length(), (size_t) length); /* We don't want fallback to be short string optimized, since we want to move it */ - fallback.reserve(fallback.length() + maxCopyDistance + std::max(MINIMUM_HTTP_POST_PADDING, sizeof(std::string))); + fallback.reserve(fallback.length() + maxCopyDistance + std::max(MINIMUM_HTTP_POST_PADDING, sizeof(std::string))); fallback.append(data, maxCopyDistance); // break here on break - std::pair consumed = fenceAndConsumePostPadded(fallback.data(), (int) fallback.length(), user, &req, requestHandler, dataHandler); + std::pair consumed = fenceAndConsumePostPadded(fallback.data(), (unsigned int) fallback.length(), user, reserved, &req, requestHandler, dataHandler); if (consumed.second != user) { return consumed.second; } if (consumed.first) { + /* This logic assumes that we consumed everything in fallback buffer. + * This is critically important, as we will get an integer overflow in case + * of "had" being larger than what we consumed, and that we would drop data */ fallback.clear(); - data += consumed.first - had; length -= consumed.first - had; @@ -308,7 +365,7 @@ public: } } - std::pair consumed = fenceAndConsumePostPadded(data, length, user, &req, requestHandler, dataHandler); + std::pair consumed = fenceAndConsumePostPadded(data, length, user, reserved, &req, requestHandler, dataHandler); if (consumed.second != user) { return consumed.second; } @@ -317,7 +374,7 @@ public: length -= consumed.first; if (length) { - if ((unsigned int) length < MAX_FALLBACK_SIZE) { + if (length < MAX_FALLBACK_SIZE) { fallback.append(data, length); } else { return errorHandler(user); diff --git a/src/Vendor/uwebsockets/include/uws/HttpResponse.h b/src/Vendor/uwebsockets/include/uws/HttpResponse.h index 05a0352..3638d85 100644 --- a/src/Vendor/uwebsockets/include/uws/HttpResponse.h +++ b/src/Vendor/uwebsockets/include/uws/HttpResponse.h @@ -1,5 +1,5 @@ /* - * Authored by Alex Hultman, 2018-2019. + * Authored by Alex Hultman, 2018-2020. * Intellectual property of third-party. * Licensed under the Apache License, Version 2.0 (the "License"); @@ -25,7 +25,12 @@ #include "HttpContextData.h" #include "Utilities.h" -#include "f2/function2.hpp" +#include "WebSocketExtensions.h" +#include "WebSocketHandshake.h" +#include "WebSocket.h" +#include "WebSocketContextData.h" + +#include "MoveOnlyFunction.h" /* todo: tryWrite is missing currently, only send smaller segments with write */ @@ -56,10 +61,10 @@ private: Super::write(buf, length); } - /* Write an unsigned 32-bit integer */ - void writeUnsigned(unsigned int value) { - char buf[10]; - int length = utils::u32toa(value, buf); + /* Write an unsigned 64-bit integer */ + void writeUnsigned64(uint64_t value) { + char buf[20]; + int length = utils::u64toa(value, buf); /* For now we do this copy */ Super::write(buf, length); @@ -77,23 +82,43 @@ private: /* Called only once per request */ void writeMark() { + /* You can disable this altogether */ #ifndef UWS_HTTPRESPONSE_NO_WRITEMARK - writeHeader("uWebSockets", "v0.17"); + if (!Super::getLoopData()->noMark) { + /* We only expose major version */ + writeHeader("uWebSockets", "19"); + } #endif } /* Returns true on success, indicating that it might be feasible to write more data. * Will start timeout if stream reaches totalSize or write failure. */ - bool internalEnd(std::string_view data, int totalSize, bool optional, bool allowContentLength = true) { + bool internalEnd(std::string_view data, size_t totalSize, bool optional, bool allowContentLength = true, bool closeConnection = false) { /* Write status if not already done */ writeStatus(HTTP_200_OK); /* If no total size given then assume this chunk is everything */ if (!totalSize) { - totalSize = (int) data.length(); + totalSize = data.length(); } HttpResponseData *httpResponseData = getHttpResponseData(); + + /* In some cases, such as when refusing huge data we want to close the connection when drained */ + if (closeConnection) { + + /* HTTP 1.1 must send this back unless the client already sent it to us. + * It is a connection close when either of the two parties say so but the + * one party must tell the other one so. + * + * This check also serves to limit writing the header only once. */ + if ((httpResponseData->state & HttpResponseData::HTTP_CONNECTION_CLOSE) == 0) { + writeHeader("Connection", "close"); + } + + httpResponseData->state |= HttpResponseData::HTTP_CONNECTION_CLOSE; + } + if (httpResponseData->state & HttpResponseData::HTTP_WRITE_CALLED) { /* We do not have tryWrite-like functionalities, so ignore optional in this path */ @@ -126,7 +151,7 @@ private: if (allowContentLength) { /* Even zero is a valid content-length */ Super::write("Content-Length: ", 16); - writeUnsigned(totalSize); + writeUnsigned64(totalSize); Super::write("\r\n\r\n", 4); } else { Super::write("\r\n", 2); @@ -140,11 +165,20 @@ private: * if it failed to drain any prior failed header writes */ /* Write as much as possible without causing backpressure */ - auto [written, failed] = Super::write(data.data(), (int) data.length(), optional); + size_t written = 0; + bool failed = false; + while (written < data.length() && !failed) { + /* uSockets only deals with int sizes, so pass chunks of max signed int size */ + auto writtenFailed = Super::write(data.data() + written, (int) std::min(data.length() - written, INT_MAX), optional); + + written += (size_t) writtenFailed.first; + failed = writtenFailed.second; + } + httpResponseData->offset += written; /* Success is when we wrote the entire thing without any failures */ - bool success = (unsigned int) written == data.length() && !failed; + bool success = written == data.length() && !failed; /* If we are now at the end, start a timeout. Also start a timeout if we failed. */ if (!success || httpResponseData->offset == totalSize) { @@ -160,20 +194,140 @@ private: } } - /* This call is identical to end, but will never write content-length and is thus suitable for upgrades */ - void upgrade() { - internalEnd({nullptr, 0}, 0, false, false); +public: + /* If we have proxy support; returns the proxed source address as reported by the proxy. */ +#ifdef UWS_WITH_PROXY + std::string_view getProxiedRemoteAddress() { + return getHttpResponseData()->proxyParser.getSourceAddress(); + } + + std::string_view getProxiedRemoteAddressAsText() { + return Super::addressAsText(getProxiedRemoteAddress()); + } +#endif + + /* Manually upgrade to WebSocket. Typically called in upgrade handler. Immediately calls open handler. + * NOTE: Will invalidate 'this' as socket might change location in memory. Throw away aftert use. */ + template + void upgrade(UserData &&userData, std::string_view secWebSocketKey, std::string_view secWebSocketProtocol, + std::string_view secWebSocketExtensions, + struct us_socket_context_t *webSocketContext) { + + /* Extract needed parameters from WebSocketContextData */ + WebSocketContextData *webSocketContextData = (WebSocketContextData *) us_socket_context_ext(SSL, webSocketContext); + + /* Note: OpenSSL can be used here to speed this up somewhat */ + char secWebSocketAccept[29] = {}; + WebSocketHandshake::generate(secWebSocketKey.data(), secWebSocketAccept); + + writeStatus("101 Switching Protocols") + ->writeHeader("Upgrade", "websocket") + ->writeHeader("Connection", "Upgrade") + ->writeHeader("Sec-WebSocket-Accept", secWebSocketAccept); + + /* Select first subprotocol if present */ + if (secWebSocketProtocol.length()) { + writeHeader("Sec-WebSocket-Protocol", secWebSocketProtocol.substr(0, secWebSocketProtocol.find(','))); + } + + /* Negotiate compression */ + bool perMessageDeflate = false; + CompressOptions compressOptions = CompressOptions::DISABLED; + if (secWebSocketExtensions.length() && webSocketContextData->compression != DISABLED) { + + /* We always want shared inflation */ + int wantedInflationWindow = 0; + + /* Map from selected compressor */ + int wantedCompressionWindow = (webSocketContextData->compression & 0xFF00) >> 8; + + auto [negCompression, negCompressionWindow, negInflationWindow, negResponse] = + negotiateCompression(true, wantedCompressionWindow, wantedInflationWindow, + secWebSocketExtensions); + + if (negCompression) { + perMessageDeflate = true; + + /* Map from windowBits to compressor */ + if (negCompressionWindow == 0) { + compressOptions = CompressOptions::SHARED_COMPRESSOR; + } else { + compressOptions = (CompressOptions) ((uint32_t) (negCompressionWindow << 8) + | (uint32_t) (negCompressionWindow - 7)); + + /* If we are dedicated and have the 3kb then correct any 4kb to 3kb, + * (they both share the windowBits = 9) */ + if (webSocketContextData->compression == DEDICATED_COMPRESSOR_3KB) { + compressOptions = DEDICATED_COMPRESSOR_3KB; + } + } + + writeHeader("Sec-WebSocket-Extensions", negResponse); + } + } + + internalEnd({nullptr, 0}, 0, false, false); + + /* Grab the httpContext from res */ + HttpContext *httpContext = (HttpContext *) us_socket_context(SSL, (struct us_socket_t *) this); + + /* Move any backpressure out of HttpResponse */ + std::string backpressure(std::move(((AsyncSocketData *) getHttpResponseData())->buffer)); + + /* Destroy HttpResponseData */ + getHttpResponseData()->~HttpResponseData(); + + /* Before we adopt and potentially change socket, check if we are corked */ + bool wasCorked = Super::isCorked(); + + /* Adopting a socket invalidates it, do not rely on it directly to carry any data */ + WebSocket *webSocket = (WebSocket *) us_socket_context_adopt_socket(SSL, + (us_socket_context_t *) webSocketContext, (us_socket_t *) this, sizeof(WebSocketData) + sizeof(UserData)); + + /* For whatever reason we were corked, update cork to the new socket */ + if (wasCorked) { + webSocket->AsyncSocket::cork(); + } + + /* Initialize websocket with any moved backpressure intact */ + webSocket->init(perMessageDeflate, compressOptions, std::move(backpressure)); + + /* We should only mark this if inside the parser; if upgrading "async" we cannot set this */ + HttpContextData *httpContextData = httpContext->getSocketContextData(); + if (httpContextData->isParsingHttp) { + /* We need to tell the Http parser that we changed socket */ + httpContextData->upgradedWebSocket = webSocket; + } + + /* Arm idleTimeout */ + us_socket_timeout(SSL, (us_socket_t *) webSocket, webSocketContextData->idleTimeoutComponents.first); + + /* Move construct the UserData right before calling open handler */ + new (webSocket->getUserData()) UserData(std::move(userData)); + + /* Emit open event and start the timeout */ + if (webSocketContextData->openHandler) { + webSocketContextData->openHandler(webSocket); + } } -public: /* Immediately terminate this Http response */ using Super::close; + /* See AsyncSocket */ using Super::getRemoteAddress; + using Super::getRemoteAddressAsText; + using Super::getNativeHandle; /* Note: Headers are not checked in regards to timeout. * We only check when you actively push data or end the request */ + /* Write 100 Continue, can be done any amount of times */ + HttpResponse *writeContinue() { + Super::write("HTTP/1.1 100 Continue\r\n\r\n", 25); + return this; + } + /* Write the HTTP status */ HttpResponse *writeStatus(std::string_view status) { HttpResponseData *httpResponseData = getHttpResponseData(); @@ -204,22 +358,24 @@ public: } /* Write an HTTP header with unsigned int value */ - HttpResponse *writeHeader(std::string_view key, unsigned int value) { + HttpResponse *writeHeader(std::string_view key, uint64_t value) { + writeStatus(HTTP_200_OK); + Super::write(key.data(), (int) key.length()); Super::write(": ", 2); - writeUnsigned(value); + writeUnsigned64(value); Super::write("\r\n", 2); return this; } /* End the response with an optional data chunk. Always starts a timeout. */ - void end(std::string_view data = {}) { - internalEnd(data, (int) data.length(), false); + void end(std::string_view data = {}, bool closeConnection = false) { + internalEnd(data, data.length(), false, true, closeConnection); } /* Try and end the response. Returns [true, true] on success. * Starts a timeout in some cases. Returns [ok, hasResponded] */ - std::pair tryEnd(std::string_view data, int totalSize = 0) { + std::pair tryEnd(std::string_view data, size_t totalSize = 0) { return {internalEnd(data, totalSize, true), hasResponded()}; } @@ -257,7 +413,7 @@ public: } /* Get the current byte write offset for this Http response */ - int getWriteOffset() { + size_t getWriteOffset() { HttpResponseData *httpResponseData = getHttpResponseData(); return httpResponseData->offset; @@ -271,7 +427,7 @@ public: } /* Corks the response if possible. Leaves already corked socket be. */ - HttpResponse *cork(fu2::unique_function &&handler) { + HttpResponse *cork(MoveOnlyFunction &&handler) { if (!Super::isCorked() && Super::canCork()) { Super::cork(); handler(); @@ -292,7 +448,7 @@ public: } /* Attach handler for writable HTTP response */ - HttpResponse *onWritable(fu2::unique_function &&handler) { + HttpResponse *onWritable(MoveOnlyFunction &&handler) { HttpResponseData *httpResponseData = getHttpResponseData(); httpResponseData->onWritable = std::move(handler); @@ -300,7 +456,7 @@ public: } /* Attach handler for aborted HTTP request */ - HttpResponse *onAborted(fu2::unique_function &&handler) { + HttpResponse *onAborted(MoveOnlyFunction &&handler) { HttpResponseData *httpResponseData = getHttpResponseData(); httpResponseData->onAborted = std::move(handler); @@ -308,7 +464,7 @@ public: } /* Attach a read handler for data sent. Will be called with FIN set true if last segment. */ - void onData(fu2::unique_function &&handler) { + void onData(MoveOnlyFunction &&handler) { HttpResponseData *data = getHttpResponseData(); data->inStream = std::move(handler); } diff --git a/src/Vendor/uwebsockets/include/uws/HttpResponseData.h b/src/Vendor/uwebsockets/include/uws/HttpResponseData.h index 13540f7..d3483f5 100644 --- a/src/Vendor/uwebsockets/include/uws/HttpResponseData.h +++ b/src/Vendor/uwebsockets/include/uws/HttpResponseData.h @@ -1,5 +1,5 @@ /* - * Authored by Alex Hultman, 2018-2019. + * Authored by Alex Hultman, 2018-2020. * Intellectual property of third-party. * Licensed under the Apache License, Version 2.0 (the "License"); @@ -22,8 +22,9 @@ #include "HttpParser.h" #include "AsyncSocketData.h" +#include "ProxyParser.h" -#include "f2/function2.hpp" +#include "MoveOnlyFunction.h" namespace uWS { @@ -38,18 +39,22 @@ private: HTTP_WRITE_CALLED = 2, // used HTTP_END_CALLED = 4, // used HTTP_RESPONSE_PENDING = 8, // used - HTTP_ENDED_STREAM_OUT = 16 // not used + HTTP_CONNECTION_CLOSE = 16 // used }; /* Per socket event handlers */ - fu2::unique_function onWritable; - fu2::unique_function onAborted; - fu2::unique_function inStream; // onData + MoveOnlyFunction onWritable; + MoveOnlyFunction onAborted; + MoveOnlyFunction inStream; // onData /* Outgoing offset */ - int offset = 0; + size_t offset = 0; /* Current state (content-length sent, status sent, write called, etc */ int state = 0; + +#ifdef UWS_WITH_PROXY + ProxyParser proxyParser; +#endif }; } diff --git a/src/Vendor/uwebsockets/include/uws/HttpRouter.h b/src/Vendor/uwebsockets/include/uws/HttpRouter.h index dd62269..ef83654 100644 --- a/src/Vendor/uwebsockets/include/uws/HttpRouter.h +++ b/src/Vendor/uwebsockets/include/uws/HttpRouter.h @@ -1,5 +1,5 @@ /* - * Authored by Alex Hultman, 2018-2019. + * Authored by Alex Hultman, 2018-2020. * Intellectual property of third-party. * Licensed under the Apache License, Version 2.0 (the "License"); @@ -25,8 +25,9 @@ #include #include #include +#include -#include "f2/function2.hpp" +#include "MoveOnlyFunction.h" namespace uWS { @@ -47,7 +48,7 @@ private: std::map priority; /* List of handlers */ - std::vector> handlers; + std::vector> handlers; /* Current URL cache */ std::string_view currentUrl; @@ -60,6 +61,8 @@ private: std::vector> children; std::vector handlers; bool isHighPriority; + + Node(std::string name) : name(name) {} } root = {"rootNode"}; /* Advance from parent to child, adding child if necessary */ @@ -71,7 +74,7 @@ private: } /* Insert sorted, but keep order if parent is root (we sort methods by priority elsewhere) */ - std::unique_ptr newNode(new Node({child})); + std::unique_ptr newNode(new Node(child)); newNode->isHighPriority = isHighPriority; return parent->children.emplace(std::upper_bound(parent->children.begin(), parent->children.end(), newNode, [parent, this](auto &a, auto &b) { @@ -107,19 +110,25 @@ private: /* Set URL for router. Will reset any URL cache */ inline void setUrl(std::string_view url) { - /* Remove / from input URL */ - currentUrl = url.substr(std::min((unsigned int) url.length(), 1)); + + /* Todo: URL may also start with "http://domain/" or "*", not only "/" */ + + /* We expect to stand on a slash */ + currentUrl = url; urlSegmentTop = -1; } /* Lazily parse or read from cache */ - inline std::string_view getUrlSegment(int urlSegment) { + inline std::pair getUrlSegment(int urlSegment) { if (urlSegment > urlSegmentTop) { - /* Return empty segment if we are out of URL or stack space, but never for first url segment */ + /* Signal as STOP when we have no more URL or stack space */ if (!currentUrl.length() || urlSegment > 99) { - return {}; + return {{}, true}; } + /* We always stand on a slash here, so step over it */ + currentUrl.remove_prefix(1); + auto segmentLength = currentUrl.find('/'); if (segmentLength == std::string::npos) { segmentLength = currentUrl.length(); @@ -136,19 +145,22 @@ private: urlSegmentTop++; /* Update currentUrl */ - currentUrl = currentUrl.substr(segmentLength + 1); + currentUrl = currentUrl.substr(segmentLength); } } /* In any case we return it */ - return urlSegmentVector[urlSegment]; + return {urlSegmentVector[urlSegment], false}; } /* Executes as many handlers it can */ bool executeHandlers(Node *parent, int urlSegment, USERDATA &userData) { - /* If we have no more URL and not on first round, return where we may stand */ - if (urlSegment && !getUrlSegment(urlSegment).length()) { + + auto [segment, isStop] = getUrlSegment(urlSegment); + + /* If we are on STOP, return where we may stand */ + if (isStop) { /* We have reached accross the entire URL with no stoppage, execute */ - for (int handler : parent->handlers) { + for (uint32_t handler : parent->handlers) { if (handlers[handler & HANDLER_MASK](this)) { return true; } @@ -160,19 +172,19 @@ private: for (auto &p : parent->children) { if (p->name.length() && p->name[0] == '*') { /* Wildcard match (can be seen as a shortcut) */ - for (int handler : p->handlers) { + for (uint32_t handler : p->handlers) { if (handlers[handler & HANDLER_MASK](this)) { return true; } } - } else if (p->name.length() && p->name[0] == ':' && getUrlSegment(urlSegment).length()) { + } else if (p->name.length() && p->name[0] == ':' && segment.length()) { /* Parameter match */ - routeParameters.push(getUrlSegment(urlSegment)); + routeParameters.push(segment); if (executeHandlers(p.get(), urlSegment + 1, userData)) { return true; } routeParameters.pop(); - } else if (p->name == getUrlSegment(urlSegment)) { + } else if (p->name == segment) { /* Static match */ if (executeHandlers(p.get(), urlSegment + 1, userData)) { return true; @@ -217,14 +229,14 @@ public: } /* Adds the corresponding entires in matching tree and handler list */ - void add(std::vector methods, std::string pattern, fu2::unique_function &&handler, uint32_t priority = MEDIUM_PRIORITY) { + void add(std::vector methods, std::string pattern, MoveOnlyFunction &&handler, uint32_t priority = MEDIUM_PRIORITY) { for (std::string method : methods) { /* Lookup method */ Node *node = getNode(&root, method, false); /* Iterate over all segments */ setUrl(pattern); - for (int i = 0; getUrlSegment(i).length() || i == 0; i++) { - node = getNode(node, std::string(getUrlSegment(i)), priority == HIGH_PRIORITY); + for (int i = 0; !getUrlSegment(i).second; i++) { + node = getNode(node, std::string(getUrlSegment(i).first), priority == HIGH_PRIORITY); } /* Insert handler in order sorted by priority (most significant 1 byte) */ node->handlers.insert(std::upper_bound(node->handlers.begin(), node->handlers.end(), (uint32_t) (priority | handlers.size())), (uint32_t) (priority | handlers.size())); @@ -237,4 +249,4 @@ public: } -#endif // UWS_HTTPROUTER_HPP \ No newline at end of file +#endif // UWS_HTTPROUTER_HPP diff --git a/src/Vendor/uwebsockets/include/uws/Loop.h b/src/Vendor/uwebsockets/include/uws/Loop.h index 5274dcf..305cc2d 100644 --- a/src/Vendor/uwebsockets/include/uws/Loop.h +++ b/src/Vendor/uwebsockets/include/uws/Loop.h @@ -1,5 +1,5 @@ /* - * Authored by Alex Hultman, 2018-2019. + * Authored by Alex Hultman, 2018-2020. * Intellectual property of third-party. * Licensed under the Apache License, Version 2.0 (the "License"); @@ -18,10 +18,10 @@ #ifndef UWS_LOOP_H #define UWS_LOOP_H -/* The loop is lazily created per-thread and run with uWS::run() */ +/* The loop is lazily created per-thread and run with run() */ #include "LoopData.h" -#include "libusockets.h" +#include namespace uWS { struct Loop { @@ -80,25 +80,29 @@ private: Loop *loop = nullptr; bool cleanMe = false; }; - + + static LoopCleaner &getLazyLoop() { + static thread_local LoopCleaner lazyLoop; + return lazyLoop; + } + public: /* Lazily initializes a per-thread loop and returns it. * Will automatically free all initialized loops at exit. */ static Loop *get(void *existingNativeLoop = nullptr) { - static thread_local LoopCleaner lazyLoop; - if (!lazyLoop.loop) { + if (!getLazyLoop().loop) { /* If we are given a native loop pointer we pass that to uSockets and let it deal with it */ if (existingNativeLoop) { /* Todo: here we want to pass the pointer, not a boolean */ - lazyLoop.loop = create(existingNativeLoop); + getLazyLoop().loop = create(existingNativeLoop); /* We cannot register automatic free here, must be manually done */ } else { - lazyLoop.loop = create(nullptr); - lazyLoop.cleanMe = true; + getLazyLoop().loop = create(nullptr); + getLazyLoop().cleanMe = true; } } - return lazyLoop.loop; + return getLazyLoop().loop; } /* Freeing the default loop should be done once */ @@ -107,9 +111,12 @@ public: loopData->~LoopData(); /* uSockets will track whether this loop is owned by us or a borrowed alien loop */ us_loop_free((us_loop_t *) this); + + /* Reset lazyLoop */ + getLazyLoop().loop = nullptr; } - void addPostHandler(void *key, fu2::unique_function &&handler) { + void addPostHandler(void *key, MoveOnlyFunction &&handler) { LoopData *loopData = (LoopData *) us_loop_ext((us_loop_t *) this); loopData->postHandlers.emplace(key, std::move(handler)); @@ -122,7 +129,7 @@ public: loopData->postHandlers.erase(key); } - void addPreHandler(void *key, fu2::unique_function &&handler) { + void addPreHandler(void *key, MoveOnlyFunction &&handler) { LoopData *loopData = (LoopData *) us_loop_ext((us_loop_t *) this); loopData->preHandlers.emplace(key, std::move(handler)); @@ -136,7 +143,7 @@ public: } /* Defer this callback on Loop's thread of execution */ - void defer(fu2::unique_function &&cb) { + void defer(MoveOnlyFunction &&cb) { LoopData *loopData = (LoopData *) us_loop_ext((us_loop_t *) this); //if (std::thread::get_id() == ) // todo: add fast path for same thread id @@ -157,6 +164,11 @@ public: void integrate() { us_loop_integrate((us_loop_t *) this); } + + /* Dynamically change this */ + void setSilent(bool silent) { + ((LoopData *) us_loop_ext((us_loop_t *) this))->noMark = silent; + } }; /* Can be called from any thread to run the thread local loop */ diff --git a/src/Vendor/uwebsockets/include/uws/LoopData.h b/src/Vendor/uwebsockets/include/uws/LoopData.h index 112bae8..b4cac29 100644 --- a/src/Vendor/uwebsockets/include/uws/LoopData.h +++ b/src/Vendor/uwebsockets/include/uws/LoopData.h @@ -1,5 +1,5 @@ /* - * Authored by Alex Hultman, 2018-2019. + * Authored by Alex Hultman, 2018-2020. * Intellectual property of third-party. * Licensed under the Apache License, Version 2.0 (the "License"); @@ -26,7 +26,7 @@ #include "PerMessageDeflate.h" -#include "f2/function2.hpp" +#include "MoveOnlyFunction.h" namespace uWS { @@ -37,10 +37,10 @@ struct alignas(16) LoopData { private: std::mutex deferMutex; int currentDeferQueue = 0; - std::vector> deferQueues[2]; + std::vector> deferQueues[2]; /* Map from void ptr to handler */ - std::map> postHandlers, preHandlers; + std::map> postHandlers, preHandlers; public: ~LoopData() { @@ -53,12 +53,15 @@ public: delete [] corkBuffer; } + /* Be silent */ + bool noMark = false; + /* Good 16k for SSL perf. */ - static const int CORK_BUFFER_SIZE = 16 * 1024; + static const unsigned int CORK_BUFFER_SIZE = 16 * 1024; /* Cork data */ char *corkBuffer = new char[CORK_BUFFER_SIZE]; - int corkOffset = 0; + unsigned int corkOffset = 0; void *corkedSocket = nullptr; /* Per message deflate data */ diff --git a/src/Vendor/uwebsockets/include/uws/MessageParser.h b/src/Vendor/uwebsockets/include/uws/MessageParser.h new file mode 100644 index 0000000..aa8d455 --- /dev/null +++ b/src/Vendor/uwebsockets/include/uws/MessageParser.h @@ -0,0 +1,64 @@ +/* + * Authored by Alex Hultman, 2018-2020. + * Intellectual property of third-party. + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* Implements the common parser (RFC 822) used in both HTTP and Multipart parsing */ + +#ifndef UWS_MESSAGE_PARSER_H +#define UWS_MESSAGE_PARSER_H + +#include +#include +#include + +/* For now we have this one here */ +#define MAX_HEADERS 10 + +namespace uWS { + + // should be templated on whether it needs at lest one header (http), or not (multipart) + static inline unsigned int getHeaders(char *postPaddedBuffer, char *end, std::pair *headers) { + char *preliminaryKey, *preliminaryValue, *start = postPaddedBuffer; + + for (unsigned int i = 0; i < MAX_HEADERS; i++) { + for (preliminaryKey = postPaddedBuffer; (*postPaddedBuffer != ':') & (*postPaddedBuffer > 32); *(postPaddedBuffer++) |= 32); + if (*postPaddedBuffer == '\r') { + if ((postPaddedBuffer != end) & (postPaddedBuffer[1] == '\n') /* & (i > 0) */) { // multipart does not require any headers like http does + headers->first = std::string_view(nullptr, 0); + return (unsigned int) ((postPaddedBuffer + 2) - start); + } else { + return 0; + } + } else { + headers->first = std::string_view(preliminaryKey, (size_t) (postPaddedBuffer - preliminaryKey)); + for (postPaddedBuffer++; (*postPaddedBuffer == ':' || *postPaddedBuffer < 33) && *postPaddedBuffer != '\r'; postPaddedBuffer++); + preliminaryValue = postPaddedBuffer; + postPaddedBuffer = (char *) memchr(postPaddedBuffer, '\r', end - postPaddedBuffer); + if (postPaddedBuffer && postPaddedBuffer[1] == '\n') { + headers->second = std::string_view(preliminaryValue, (size_t) (postPaddedBuffer - preliminaryValue)); + postPaddedBuffer += 2; + headers++; + } else { + return 0; + } + } + } + return 0; + } + +} + +#endif \ No newline at end of file diff --git a/src/Vendor/uwebsockets/include/uws/MoveOnlyFunction.h b/src/Vendor/uwebsockets/include/uws/MoveOnlyFunction.h new file mode 100644 index 0000000..b1ae785 --- /dev/null +++ b/src/Vendor/uwebsockets/include/uws/MoveOnlyFunction.h @@ -0,0 +1,377 @@ +/* +MIT License + +Copyright (c) 2020 Oleg Fatkhiev + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +*/ + +/* Sources fetched from https://github.com/ofats/any_invocable on 2021-02-19. */ + +#ifndef _ANY_INVOKABLE_H_ +#define _ANY_INVOKABLE_H_ + +#include +#include +#include + +// clang-format off +/* +namespace std { + template class any_invocable; // never defined + + template + class any_invocable { + public: + using result_type = R; + + // SECTION.3, construct/copy/destroy + any_invocable() noexcept; + any_invocable(nullptr_t) noexcept; + any_invocable(any_invocable&&) noexcept; + template any_invocable(F&&); + + template + explicit any_invocable(in_place_type_t, Args&&...); + template + explicit any_invocable(in_place_type_t, initializer_list, Args&&...); + + any_invocable& operator=(any_invocable&&) noexcept; + any_invocable& operator=(nullptr_t) noexcept; + template any_invocable& operator=(F&&); + template any_invocable& operator=(reference_wrapper) noexcept; + + ~any_invocable(); + + // SECTION.4, any_invocable modifiers + void swap(any_invocable&) noexcept; + + // SECTION.5, any_invocable capacity + explicit operator bool() const noexcept; + + // SECTION.6, any_invocable invocation + R operator()(ArgTypes...) cv ref noexcept(noex); + + // SECTION.7, null pointer comparisons + friend bool operator==(const any_invocable&, nullptr_t) noexcept; + + // SECTION.8, specialized algorithms + friend void swap(any_invocable&, any_invocable&) noexcept; + }; +} +*/ +// clang-format on + +namespace ofats { + +namespace any_detail { + +using buffer = std::aligned_storage_t; + +template +inline constexpr bool is_small_object_v = + sizeof(T) <= sizeof(buffer) && alignof(buffer) % alignof(T) == 0 && + std::is_nothrow_move_constructible_v; + +union storage { + void* ptr_ = nullptr; + buffer buf_; +}; + +enum class action { destroy, move }; + +template +struct handler_traits { + template + struct handler_base { + static void handle(action act, storage* current, storage* other = nullptr) { + switch (act) { + case (action::destroy): + Derived::destroy(*current); + break; + case (action::move): + Derived::move(*current, *other); + break; + } + } + }; + + template + struct small_handler : handler_base> { + template + static void create(storage& s, Args&&... args) { + new (static_cast(&s.buf_)) T(std::forward(args)...); + } + + static void destroy(storage& s) noexcept { + T& value = *static_cast(static_cast(&s.buf_)); + value.~T(); + } + + static void move(storage& dst, storage& src) noexcept { + create(dst, std::move(*static_cast(static_cast(&src.buf_)))); + destroy(src); + } + + static R call(storage& s, ArgTypes... args) { + return std::invoke(*static_cast(static_cast(&s.buf_)), + std::forward(args)...); + } + }; + + template + struct large_handler : handler_base> { + template + static void create(storage& s, Args&&... args) { + s.ptr_ = new T(std::forward(args)...); + } + + static void destroy(storage& s) noexcept { delete static_cast(s.ptr_); } + + static void move(storage& dst, storage& src) noexcept { + dst.ptr_ = src.ptr_; + } + + static R call(storage& s, ArgTypes... args) { + return std::invoke(*static_cast(s.ptr_), + std::forward(args)...); + } + }; + + template + using handler = std::conditional_t, small_handler, + large_handler>; +}; + +template +struct is_in_place_type : std::false_type {}; + +template +struct is_in_place_type> : std::true_type {}; + +template +inline constexpr auto is_in_place_type_v = is_in_place_type::value; + +template +class any_invocable_impl { + template + using handler = + typename any_detail::handler_traits::template handler; + + using storage = any_detail::storage; + using action = any_detail::action; + using handle_func = void (*)(any_detail::action, any_detail::storage*, + any_detail::storage*); + using call_func = R (*)(any_detail::storage&, ArgTypes...); + + public: + using result_type = R; + + any_invocable_impl() noexcept = default; + any_invocable_impl(std::nullptr_t) noexcept {} + any_invocable_impl(any_invocable_impl&& rhs) noexcept { + if (rhs.handle_) { + handle_ = rhs.handle_; + handle_(action::move, &storage_, &rhs.storage_); + call_ = rhs.call_; + rhs.handle_ = nullptr; + } + } + + any_invocable_impl& operator=(any_invocable_impl&& rhs) noexcept { + any_invocable_impl{std::move(rhs)}.swap(*this); + return *this; + } + any_invocable_impl& operator=(std::nullptr_t) noexcept { + destroy(); + return *this; + } + + ~any_invocable_impl() { destroy(); } + + void swap(any_invocable_impl& rhs) noexcept { + if (handle_) { + if (rhs.handle_) { + storage tmp; + handle_(action::move, &tmp, &storage_); + rhs.handle_(action::move, &storage_, &rhs.storage_); + handle_(action::move, &rhs.storage_, &tmp); + std::swap(handle_, rhs.handle_); + std::swap(call_, rhs.call_); + } else { + rhs.swap(*this); + } + } else if (rhs.handle_) { + rhs.handle_(action::move, &storage_, &rhs.storage_); + handle_ = rhs.handle_; + call_ = rhs.call_; + rhs.handle_ = nullptr; + } + } + + explicit operator bool() const noexcept { return handle_ != nullptr; } + + protected: + template + void create(Args&&... args) { + using hdl = handler; + hdl::create(storage_, std::forward(args)...); + handle_ = &hdl::handle; + call_ = &hdl::call; + } + + void destroy() noexcept { + if (handle_) { + handle_(action::destroy, &storage_, nullptr); + handle_ = nullptr; + } + } + + R call(ArgTypes... args) noexcept(is_noexcept) { + return call_(storage_, std::forward(args)...); + } + + friend bool operator==(const any_invocable_impl& f, std::nullptr_t) noexcept { + return !f; + } + friend bool operator==(std::nullptr_t, const any_invocable_impl& f) noexcept { + return !f; + } + friend bool operator!=(const any_invocable_impl& f, std::nullptr_t) noexcept { + return static_cast(f); + } + friend bool operator!=(std::nullptr_t, const any_invocable_impl& f) noexcept { + return static_cast(f); + } + + friend void swap(any_invocable_impl& lhs, any_invocable_impl& rhs) noexcept { + lhs.swap(rhs); + } + + private: + storage storage_; + handle_func handle_ = nullptr; + call_func call_; +}; + +template +using remove_cvref_t = std::remove_cv_t>; + +template +using can_convert = std::conjunction< + std::negation, AI>>, + std::negation>>, + std::is_invocable_r, + std::bool_constant<(!noex || + std::is_nothrow_invocable_r_v)>, + std::is_constructible, F>>; + +} // namespace any_detail + +template +class any_invocable; + +#define __OFATS_ANY_INVOCABLE(cv, ref, noex, inv_quals) \ + template \ + class any_invocable \ + : public any_detail::any_invocable_impl { \ + using base_type = any_detail::any_invocable_impl; \ + \ + public: \ + using base_type::base_type; \ + \ + template < \ + class F, \ + class = std::enable_if_t::value>> \ + any_invocable(F&& f) { \ + base_type::template create>(std::forward(f)); \ + } \ + \ + template , \ + class = std::enable_if_t< \ + std::is_move_constructible_v && \ + std::is_constructible_v && \ + std::is_invocable_r_v && \ + (!noex || std::is_nothrow_invocable_r_v)>> \ + explicit any_invocable(std::in_place_type_t, Args&&... args) { \ + base_type::template create(std::forward(args)...); \ + } \ + \ + template < \ + class T, class U, class... Args, class VT = std::decay_t, \ + class = std::enable_if_t< \ + std::is_move_constructible_v && \ + std::is_constructible_v&, Args...> && \ + std::is_invocable_r_v && \ + (!noex || \ + std::is_nothrow_invocable_r_v)>> \ + explicit any_invocable(std::in_place_type_t, \ + std::initializer_list il, Args&&... args) { \ + base_type::template create(il, std::forward(args)...); \ + } \ + \ + template > \ + std::enable_if_t && \ + std::is_move_constructible_v, \ + any_invocable&> \ + operator=(F&& f) { \ + any_invocable{std::forward(f)}.swap(*this); \ + return *this; \ + } \ + template \ + any_invocable& operator=(std::reference_wrapper f) { \ + any_invocable{f}.swap(*this); \ + return *this; \ + } \ + \ + R operator()(ArgTypes... args) cv ref noexcept(noex) { \ + return base_type::call(std::forward(args)...); \ + } \ + }; + +// cv -> {`empty`, const} +// ref -> {`empty`, &, &&} +// noex -> {true, false} +// inv_quals -> (is_empty(ref) ? & : ref) +__OFATS_ANY_INVOCABLE(, , false, &) // 000 +__OFATS_ANY_INVOCABLE(, , true, &) // 001 +__OFATS_ANY_INVOCABLE(, &, false, &) // 010 +__OFATS_ANY_INVOCABLE(, &, true, &) // 011 +__OFATS_ANY_INVOCABLE(, &&, false, &&) // 020 +__OFATS_ANY_INVOCABLE(, &&, true, &&) // 021 +__OFATS_ANY_INVOCABLE(const, , false, const&) // 100 +__OFATS_ANY_INVOCABLE(const, , true, const&) // 101 +__OFATS_ANY_INVOCABLE(const, &, false, const&) // 110 +__OFATS_ANY_INVOCABLE(const, &, true, const&) // 111 +__OFATS_ANY_INVOCABLE(const, &&, false, const&&) // 120 +__OFATS_ANY_INVOCABLE(const, &&, true, const&&) // 121 + +#undef __OFATS_ANY_INVOCABLE + +} // namespace ofats + +/* We, uWebSockets define our own type */ +namespace uWS { + template + using MoveOnlyFunction = ofats::any_invocable; +} + +#endif // _ANY_INVOKABLE_H_ diff --git a/src/Vendor/uwebsockets/include/uws/Multipart.h b/src/Vendor/uwebsockets/include/uws/Multipart.h new file mode 100644 index 0000000..8538d64 --- /dev/null +++ b/src/Vendor/uwebsockets/include/uws/Multipart.h @@ -0,0 +1,231 @@ +/* + * Authored by Alex Hultman, 2018-2020. + * Intellectual property of third-party. + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* Implements the multipart protocol. Builds atop parts of our common http parser (not yet refactored that way). */ +/* https://www.w3.org/Protocols/rfc1341/7_2_Multipart.html */ + +#ifndef UWS_MULTIPART_H +#define UWS_MULTIPART_H + +#include "MessageParser.h" + +#include +#include +#include +#include +#include + +namespace uWS { + + /* This one could possibly be shared with ExtensionsParser to some degree */ + struct ParameterParser { + + /* Takes the line, commonly given as content-disposition header in the multipart */ + ParameterParser(std::string_view line) { + remainingLine = line; + } + + /* Returns next key/value where value can simply be empty. + * If key (first) is empty then we are at the end */ + std::pair getKeyValue() { + auto key = getToken(); + auto op = getToken(); + + if (!op.length()) { + return {key, ""}; + } + + if (op[0] != ';') { + auto value = getToken(); + /* Strip ; or if at end, nothing */ + getToken(); + return {key, value}; + } + + return {key, ""}; + } + + private: + std::string_view remainingLine; + + /* Consumes a token from the line. Will "unquote" strings */ + std::string_view getToken() { + /* Strip whitespace */ + while (remainingLine.length() && isspace(remainingLine[0])) { + remainingLine.remove_prefix(1); + } + + if (!remainingLine.length()) { + /* All we had was space */ + return {}; + } else { + /* Are we at an operator? */ + if (remainingLine[0] == ';' || remainingLine[0] == '=') { + auto op = remainingLine.substr(0, 1); + remainingLine.remove_prefix(1); + return op; + } else { + /* Are we at a quoted string? */ + if (remainingLine[0] == '\"') { + /* Remove first quote and start counting */ + remainingLine.remove_prefix(1); + auto quote = remainingLine; + int quoteLength = 0; + + /* Read anything until other double quote appears */ + while (remainingLine.length() && remainingLine[0] != '\"') { + remainingLine.remove_prefix(1); + quoteLength++; + } + + /* We can't remove_prefix if we have nothing to remove */ + if (!remainingLine.length()) { + return {}; + } + + remainingLine.remove_prefix(1); + return quote.substr(0, quoteLength); + } else { + /* Read anything until ; = space or end */ + std::string_view token = remainingLine; + + int tokenLength = 0; + while (remainingLine.length() && remainingLine[0] != ';' && remainingLine[0] != '=' && !isspace(remainingLine[0])) { + remainingLine.remove_prefix(1); + tokenLength++; + } + + return token.substr(0, tokenLength); + } + } + } + + /* Nothing */ + return ""; + } + }; + + struct MultipartParser { + + /* 2 chars of hyphen + 1 - 70 chars of boundary */ + char prependedBoundaryBuffer[72]; + std::string_view prependedBoundary; + std::string_view remainingBody; + bool first = true; + + /* I think it is more than sane to limit this to 10 per part */ + //static const int MAX_HEADERS = 10; + + /* Construct the parser based on contentType (reads boundary) */ + MultipartParser(std::string_view contentType) { + + /* We expect the form "multipart/something;somethingboundary=something" */ + if (contentType.length() < 10 || contentType.substr(0, 10) != "multipart/") { + return; + } + + /* For now we simply guess boundary will lie between = and end. This is not entirely + * standards compliant as boundary may be expressed with or without " and spaces */ + auto equalToken = contentType.find('=', 10); + if (equalToken != std::string_view::npos) { + + /* Boundary must be less than or equal to 70 chars yet 1 char or longer */ + std::string_view boundary = contentType.substr(equalToken + 1); + if (!boundary.length() || boundary.length() > 70) { + /* Invalid size */ + return; + } + + /* Prepend it with two hyphens */ + prependedBoundaryBuffer[0] = prependedBoundaryBuffer[1] = '-'; + memcpy(&prependedBoundaryBuffer[2], boundary.data(), boundary.length()); + + prependedBoundary = {prependedBoundaryBuffer, boundary.length() + 2}; + } + } + + /* Is this even a valid multipart request? */ + bool isValid() { + return prependedBoundary.length() != 0; + } + + /* Set the body once, before getting any parts */ + void setBody(std::string_view body) { + remainingBody = body; + } + + /* Parse out the next part's data, filling the headers. Returns nullopt on end or error. */ + std::optional getNextPart(std::pair *headers) { + + /* The remaining two hyphens should be shorter than the boundary */ + if (remainingBody.length() < prependedBoundary.length()) { + /* We are done now */ + return std::nullopt; + } + + if (first) { + auto nextBoundary = remainingBody.find(prependedBoundary); + if (nextBoundary == std::string_view::npos) { + /* Cannot parse */ + return std::nullopt; + } + + /* Toss away boundary and anything before it */ + remainingBody.remove_prefix(nextBoundary + prependedBoundary.length()); + first = false; + } + + auto nextEndBoundary = remainingBody.find(prependedBoundary); + if (nextEndBoundary == std::string_view::npos) { + /* Cannot parse (or simply done) */ + return std::nullopt; + } + + std::string_view part = remainingBody.substr(0, nextEndBoundary); + remainingBody.remove_prefix(nextEndBoundary + prependedBoundary.length()); + + /* Also strip rn before and rn after the part */ + if (part.length() < 4) { + /* Cannot strip */ + return std::nullopt; + } + part.remove_prefix(2); + part.remove_suffix(2); + + /* We are allowed to post pad like this because we know the boundary is at least 2 bytes */ + /* This makes parsing a second pass invalid, so you can only iterate over parts once */ + memset((char *) part.data() + part.length(), '\r', 1); + + /* For this to be a valid part, we need to consume at least 4 bytes (\r\n\r\n) */ + int consumed = getHeaders((char *) part.data(), (char *) part.data() + part.length(), headers); + + if (!consumed) { + /* This is an invalid part */ + return std::nullopt; + } + + /* Strip away the headers from the part body data */ + part.remove_prefix(consumed); + + /* Now pass whatever is remaining of the part */ + return part; + } + }; + +} + +#endif diff --git a/src/Vendor/uwebsockets/include/uws/PerMessageDeflate.h b/src/Vendor/uwebsockets/include/uws/PerMessageDeflate.h index e0ed2fb..b6ee0be 100644 --- a/src/Vendor/uwebsockets/include/uws/PerMessageDeflate.h +++ b/src/Vendor/uwebsockets/include/uws/PerMessageDeflate.h @@ -1,5 +1,5 @@ /* - * Authored by Alex Hultman, 2018-2019. + * Authored by Alex Hultman, 2018-2021. * Intellectual property of third-party. * Licensed under the Apache License, Version 2.0 (the "License"); @@ -20,26 +20,53 @@ #ifndef UWS_PERMESSAGEDEFLATE_H #define UWS_PERMESSAGEDEFLATE_H -#ifndef UWS_NO_ZLIB +/* We always define these options no matter if ZLIB is enabled or not */ +namespace uWS { + /* Compressor mode is HIGH8(windowBits), LOW8(memLevel) */ + enum CompressOptions : uint32_t { + DISABLED = 0, + SHARED_COMPRESSOR = 1, + DEDICATED_COMPRESSOR_3KB = 9 << 8 | 1, + DEDICATED_COMPRESSOR_4KB = 9 << 8 | 2, + DEDICATED_COMPRESSOR_8KB = 10 << 8 | 3, + DEDICATED_COMPRESSOR_16KB = 11 << 8 | 4, + DEDICATED_COMPRESSOR_32KB = 12 << 8 | 5, + DEDICATED_COMPRESSOR_64KB = 13 << 8 | 6, + DEDICATED_COMPRESSOR_128KB = 14 << 8 | 7, + DEDICATED_COMPRESSOR_256KB = 15 << 8 | 8, + /* Same as 256kb */ + DEDICATED_COMPRESSOR = 15 << 8 | 8 + }; +} + +#if !defined(UWS_NO_ZLIB) && !defined(UWS_MOCK_ZLIB) #include #endif #include +#include + +#ifdef UWS_USE_LIBDEFLATE +#include "libdeflate.h" +#include +#endif namespace uWS { /* Do not compile this module if we don't want it */ -#ifdef UWS_NO_ZLIB +#if defined(UWS_NO_ZLIB) || defined(UWS_MOCK_ZLIB) struct ZlibContext {}; struct InflationStream { - std::string_view inflate(ZlibContext *zlibContext, std::string_view compressed, size_t maxPayloadLength) { - return compressed; + std::optional inflate(ZlibContext *zlibContext, std::string_view compressed, size_t maxPayloadLength) { + return compressed.substr(0, std::min(maxPayloadLength, compressed.length())); } }; struct DeflationStream { std::string_view deflate(ZlibContext *zlibContext, std::string_view raw, bool reset) { return raw; } + DeflationStream(int compressOptions) { + } }; #else @@ -54,26 +81,65 @@ struct ZlibContext { char *deflationBuffer; char *inflationBuffer; +#ifdef UWS_USE_LIBDEFLATE + libdeflate_decompressor *decompressor; + libdeflate_compressor *compressor; +#endif + ZlibContext() { deflationBuffer = (char *) malloc(LARGE_BUFFER_SIZE); inflationBuffer = (char *) malloc(LARGE_BUFFER_SIZE); + +#ifdef UWS_USE_LIBDEFLATE + decompressor = libdeflate_alloc_decompressor(); + compressor = libdeflate_alloc_compressor(6); +#endif } ~ZlibContext() { free(deflationBuffer); free(inflationBuffer); + +#ifdef UWS_USE_LIBDEFLATE + libdeflate_free_decompressor(decompressor); + libdeflate_free_compressor(compressor); +#endif } }; struct DeflationStream { z_stream deflationStream = {}; - DeflationStream() { - deflateInit2(&deflationStream, 1, Z_DEFLATED, -15, 8, Z_DEFAULT_STRATEGY); + DeflationStream(CompressOptions compressOptions) { + + /* Sliding inflator should be about 44kb by default, less than compressor */ + + /* Memory usage is given by 2 ^ (windowBits + 2) + 2 ^ (memLevel + 9) */ + int windowBits = -(int) ((compressOptions & 0xFF00) >> 8), memLevel = compressOptions & 0x00FF; + + //printf("windowBits: %d, memLevel: %d\n", windowBits, memLevel); + + deflateInit2(&deflationStream, Z_DEFAULT_COMPRESSION, Z_DEFLATED, windowBits, memLevel, Z_DEFAULT_STRATEGY); } - /* Deflate and optionally reset */ + /* Deflate and optionally reset. You must not deflate an empty string. */ std::string_view deflate(ZlibContext *zlibContext, std::string_view raw, bool reset) { + +#ifdef UWS_USE_LIBDEFLATE + /* Run a fast path in case of shared_compressor */ + if (reset) { + size_t written = 0; + static unsigned char buf[1024 + 1]; + + written = libdeflate_deflate_compress(zlibContext->compressor, raw.data(), raw.length(), buf, 1024); + + if (written) { + memcpy(&buf[written], "\x00", 1); + return std::string_view((char *) buf, written + 1); + } + } +#endif + /* Odd place to clear this one, fix */ zlibContext->dynamicDeflationBuffer.clear(); @@ -105,9 +171,11 @@ struct DeflationStream { if (zlibContext->dynamicDeflationBuffer.length()) { zlibContext->dynamicDeflationBuffer.append(zlibContext->deflationBuffer, DEFLATE_OUTPUT_CHUNK - deflationStream.avail_out); - return {(char *) zlibContext->dynamicDeflationBuffer.data(), zlibContext->dynamicDeflationBuffer.length() - 4}; + return std::string_view((char *) zlibContext->dynamicDeflationBuffer.data(), zlibContext->dynamicDeflationBuffer.length() - 4); } + /* Note: We will get an interger overflow resulting in heap buffer overflow if Z_BUF_ERROR is returned + * from passing 0 as avail_in. Therefore we must not deflate an empty string */ return { zlibContext->deflationBuffer, DEFLATE_OUTPUT_CHUNK - deflationStream.avail_out - 4 @@ -130,7 +198,26 @@ struct InflationStream { inflateEnd(&inflationStream); } - std::string_view inflate(ZlibContext *zlibContext, std::string_view compressed, size_t maxPayloadLength) { + /* Zero length inflates are possible and valid */ + std::optional inflate(ZlibContext *zlibContext, std::string_view compressed, size_t maxPayloadLength) { + +#ifdef UWS_USE_LIBDEFLATE + /* Try fast path first */ + size_t written = 0; + static char buf[1024]; + + /* We have to pad 9 bytes and restore those bytes when done since 9 is more than 6 of next WebSocket message */ + char tmp[9]; + memcpy(tmp, (char *) compressed.data() + compressed.length(), 9); + memcpy((char *) compressed.data() + compressed.length(), "\x00\x00\xff\xff\x01\x00\x00\xff\xff", 9); + libdeflate_result res = libdeflate_deflate_decompress(zlibContext->decompressor, compressed.data(), compressed.length() + 9, buf, 1024, &written); + memcpy((char *) compressed.data() + compressed.length(), tmp, 9); + + if (res == 0) { + /* Fast path wins */ + return std::string_view(buf, written); + } +#endif /* We clear this one here, could be done better */ zlibContext->dynamicInflationBuffer.clear(); @@ -156,7 +243,7 @@ struct InflationStream { inflateReset(&inflationStream); if ((err != Z_BUF_ERROR && err != Z_OK) || zlibContext->dynamicInflationBuffer.length() > maxPayloadLength) { - return {nullptr, 0}; + return std::nullopt; } if (zlibContext->dynamicInflationBuffer.length()) { @@ -164,18 +251,18 @@ struct InflationStream { /* Let's be strict about the max size */ if (zlibContext->dynamicInflationBuffer.length() > maxPayloadLength) { - return {nullptr, 0}; + return std::nullopt; } - return {zlibContext->dynamicInflationBuffer.data(), zlibContext->dynamicInflationBuffer.length()}; + return std::string_view(zlibContext->dynamicInflationBuffer.data(), zlibContext->dynamicInflationBuffer.length()); } /* Let's be strict about the max size */ if ((LARGE_BUFFER_SIZE - inflationStream.avail_out) > maxPayloadLength) { - return {nullptr, 0}; + return std::nullopt; } - return {zlibContext->inflationBuffer, LARGE_BUFFER_SIZE - inflationStream.avail_out}; + return std::string_view(zlibContext->inflationBuffer, LARGE_BUFFER_SIZE - inflationStream.avail_out); } }; diff --git a/src/Vendor/uwebsockets/include/uws/ProxyParser.h b/src/Vendor/uwebsockets/include/uws/ProxyParser.h new file mode 100644 index 0000000..95ee3d1 --- /dev/null +++ b/src/Vendor/uwebsockets/include/uws/ProxyParser.h @@ -0,0 +1,163 @@ +/* + * Authored by Alex Hultman, 2018-2020. + * Intellectual property of third-party. + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* This module implements The PROXY Protocol v2 */ + +#ifndef UWS_PROXY_PARSER_H +#define UWS_PROXY_PARSER_H + +#ifdef UWS_WITH_PROXY + +namespace uWS { + +struct proxy_hdr_v2 { + uint8_t sig[12]; /* hex 0D 0A 0D 0A 00 0D 0A 51 55 49 54 0A */ + uint8_t ver_cmd; /* protocol version and command */ + uint8_t fam; /* protocol family and address */ + uint16_t len; /* number of following bytes part of the header */ +}; + +union proxy_addr { + struct { /* for TCP/UDP over IPv4, len = 12 */ + uint32_t src_addr; + uint32_t dst_addr; + uint16_t src_port; + uint16_t dst_port; + } ipv4_addr; + struct { /* for TCP/UDP over IPv6, len = 36 */ + uint8_t src_addr[16]; + uint8_t dst_addr[16]; + uint16_t src_port; + uint16_t dst_port; + } ipv6_addr; +}; + +/* Byte swap for little-endian systems */ +/* Todo: This functions should be shared with the one in WebSocketProtocol.h! */ +template +T _cond_byte_swap(T value) { + uint32_t endian_test = 1; + if (*((char *)&endian_test)) { + union { + T i; + uint8_t b[sizeof(T)]; + } src = { value }, dst; + + for (unsigned int i = 0; i < sizeof(value); i++) { + dst.b[i] = src.b[sizeof(value) - 1 - i]; + } + + return dst.i; + } + return value; +} + +struct ProxyParser { +private: + union proxy_addr addr; + + /* Default family of 0 signals no proxy address */ + uint8_t family = 0; + +public: + /* Returns 4 or 16 bytes source address */ + std::string_view getSourceAddress() { + + // UNSPEC family and protocol + if (family == 0) { + return {}; + } + + if ((family & 0xf0) >> 4 == 1) { + /* Family 1 is INET4 */ + return {(char *) &addr.ipv4_addr.src_addr, 4}; + } else { + /* Family 2 is INET6 */ + return {(char *) &addr.ipv6_addr.src_addr, 16}; + } + } + + /* Returns [done, consumed] where done = false on failure */ + std::pair parse(std::string_view data) { + + /* We require at least four bytes to determine protocol */ + if (data.length() < 4) { + return {false, 0}; + } + + /* HTTP can never start with "\r\n\r\n", but PROXY always does */ + if (memcmp(data.data(), "\r\n\r\n", 4)) { + /* This is HTTP, so be done */ + return {true, 0}; + } + + /* We assume we are parsing PROXY V2 here */ + + /* We require 16 bytes here */ + if (data.length() < 16) { + return {false, 0}; + } + + /* Header is 16 bytes */ + struct proxy_hdr_v2 header; + memcpy(&header, data.data(), 16); + + if (memcmp(header.sig, "\x0D\x0A\x0D\x0A\x00\x0D\x0A\x51\x55\x49\x54\x0A", 12)) { + /* This is not PROXY protocol at all */ + return {false, 0}; + } + + /* We only support version 2 */ + if ((header.ver_cmd & 0xf0) >> 4 != 2) { + return {false, 0}; + } + + //printf("Version: %d\n", (header.ver_cmd & 0xf0) >> 4); + //printf("Command: %d\n", (header.ver_cmd & 0x0f)); + + /* We get length in network byte order (todo: share this function with the rest) */ + uint16_t hostLength = _cond_byte_swap(header.len); + + /* We must have all the data available */ + if (data.length() < 16u + hostLength) { + return {false, 0}; + } + + /* Payload cannot be more than sizeof proxy_addr */ + if (sizeof(proxy_addr) < hostLength) { + return {false, 0}; + } + + //printf("Family: %d\n", (header.fam & 0xf0) >> 4); + //printf("Transport: %d\n", (header.fam & 0x0f)); + + /* We have 0 family by default, and UNSPEC is 0 as well */ + family = header.fam; + + /* Copy payload */ + memcpy(&addr, data.data() + 16, hostLength); + + /* We consumed everything */ + return {true, 16 + hostLength}; + } +}; + +} + +#endif + +#endif // UWS_PROXY_PARSER_H \ No newline at end of file diff --git a/src/Vendor/uwebsockets/include/uws/QueryParser.h b/src/Vendor/uwebsockets/include/uws/QueryParser.h new file mode 100644 index 0000000..552bb56 --- /dev/null +++ b/src/Vendor/uwebsockets/include/uws/QueryParser.h @@ -0,0 +1,120 @@ +/* + * Authored by Alex Hultman, 2018-2020. + * Intellectual property of third-party. + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* This module implements URI query parsing and retrieval of value given key */ + +#ifndef UWS_QUERYPARSER_H +#define UWS_QUERYPARSER_H + +#include + +namespace uWS { + + /* Takes raw query including initial '?' sign. Will inplace decode, so input will mutate */ + static inline std::string_view getDecodedQueryValue(std::string_view key, std::string_view rawQuery) { + + /* Can't have a value without a key */ + if (!key.length()) { + return {}; + } + + /* Start with the whole querystring including initial '?' */ + std::string_view queryString = rawQuery; + + /* List of key, value could be cached for repeated fetches similar to how headers are, todo! */ + while (queryString.length()) { + /* Find boundaries of this statement */ + std::string_view statement = queryString.substr(1, queryString.find('&', 1) - 1); + + /* Only bother if first char of key match (early exit) */ + if (statement.length() && statement[0] == key[0]) { + /* Equal sign must be present and not in the end of statement */ + auto equality = statement.find('='); + if (equality != std::string_view::npos && equality != statement.length() - 1) { + + std::string_view statementKey = statement.substr(0, equality); + std::string_view statementValue = statement.substr(equality + 1); + + /* String comparison */ + if (key == statementKey) { + + /* Decode value inplace, put null at end if before length of original */ + char *in = (char *) statementValue.data(); + + /* Write offset */ + unsigned int out = 0; + + /* Walk over all chars until end or null char, decoding in place */ + for (unsigned int i = 0; i < statementValue.length() && in[i]; i++) { + /* Only bother with '%' */ + if (in[i] == '%') { + /* Do we have enough data for two bytes hex? */ + if (i + 2 >= statementValue.length()) { + return {}; + } + + /* Two bytes hex */ + int hex1 = in[i + 1] - '0'; + if (hex1 > 9) { + hex1 &= 223; + hex1 -= 7; + } + + int hex2 = in[i + 2] - '0'; + if (hex2 > 9) { + hex2 &= 223; + hex2 -= 7; + } + + *((unsigned char *) &in[out]) = (unsigned char) (hex1 * 16 + hex2); + i += 2; + } else { + /* Is this even a rule? */ + if (in[i] == '+') { + in[out] = ' '; + } else { + in[out] = in[i]; + } + } + + /* We always only write one char */ + out++; + } + + /* If decoded string is shorter than original, put null char to stop next read */ + if (out < statementValue.length()) { + in[out] = 0; + } + + return statementValue.substr(0, out); + } + } else { + /* This querystring is invalid, cannot parse it */ + return {}; + } + } + + queryString.remove_prefix(statement.length() + 1); + } + + /* Nothing found */ + return {}; + } + +} + +#endif diff --git a/src/Vendor/uwebsockets/include/uws/TopicTree.h b/src/Vendor/uwebsockets/include/uws/TopicTree.h index 22fba82..7ea06f0 100644 --- a/src/Vendor/uwebsockets/include/uws/TopicTree.h +++ b/src/Vendor/uwebsockets/include/uws/TopicTree.h @@ -1,5 +1,5 @@ /* - * Authored by Alex Hultman, 2018-2019. + * Authored by Alex Hultman, 2018-2021. * Intellectual property of third-party. * Licensed under the Apache License, Version 2.0 (the "License"); @@ -26,6 +26,10 @@ #include #include #include +#include + +/* We use std::function here, not MoveOnlyFunction */ +#include namespace uWS { @@ -57,29 +61,123 @@ struct Topic { /* Terminating wildcard child */ Topic *terminatingWildcardChild = nullptr; - /* What we published */ - std::map messages; + /* What we published, {inflated, deflated} */ + std::map> messages; std::set subs; + + /* Locked or not, used only when iterating over a Subscriber's topics */ + bool locked = false; + + /* Full name is used when iterating topcis */ + std::string fullName; +}; + +struct Hole { + std::pair lengths; + unsigned int messageId; +}; + +struct Intersection { + std::pair dataChannels; + std::vector holes; + + void forSubscriber(std::vector &senderForMessages, std::function, bool)> cb) { + /* How far we already emitted of the two dataChannels */ + std::pair emitted = {}; + + /* Holes are global to the entire topic tree, so we are not guaranteed to find + * holes in this intersection - they are sorted, though */ + unsigned int examinedHoles = 0; + + /* This is a slow path of sorts, most subscribers will be observers, not active senders */ + for (unsigned int id : senderForMessages) { + std::pair toEmit = {}; + std::pair toIgnore = {}; + + /* This linear search is most probably very small - it could be made log2 if every hole + * knows about its previous accumulated length, which is easy to set up. However this + * log2 search will most likely never be a warranted perf. gain */ + for (; examinedHoles < holes.size(); examinedHoles++) { + if (holes[examinedHoles].messageId == id) { + toIgnore.first += holes[examinedHoles].lengths.first; + toIgnore.second += holes[examinedHoles].lengths.second; + examinedHoles++; + break; + } + /* We are not the sender of this message so we should emit it in this segment */ + toEmit.first += holes[examinedHoles].lengths.first; + toEmit.second += holes[examinedHoles].lengths.second; + } + + /* Emit this segment */ + if (toEmit.first || toEmit.second) { + std::pair cutDataChannels = { + std::string_view(dataChannels.first.data() + emitted.first, toEmit.first), + std::string_view(dataChannels.second.data() + emitted.second, toEmit.second), + }; + + /* We only need to test the first data channel for "FIN" */ + cb(cutDataChannels, emitted.first + toEmit.first + toIgnore.first == dataChannels.first.length()); + } + + emitted.first += toEmit.first + toIgnore.first; + emitted.second += toEmit.second + toIgnore.second; + } + + if (emitted.first == dataChannels.first.length() && emitted.second == dataChannels.second.length()) { + return; + } + + std::pair cutDataChannels = { + std::string_view(dataChannels.first.data() + emitted.first, dataChannels.first.length() - emitted.first), + std::string_view(dataChannels.second.data() + emitted.second, dataChannels.second.length() - emitted.second), + }; + + cb(cutDataChannels, true); + } }; struct TopicTree { + /* Returns Topic, or nullptr. Topic can be root if empty string given. */ + Topic *lookupTopic(std::string_view topic) { + /* Lookup exact Topic ptr from string */ + Topic *iterator = root; + for (size_t start = 0, stop = 0; stop != std::string::npos; start = stop + 1) { + stop = topic.find('/', start); + std::string_view segment = topic.substr(start, stop - start); + + std::map::iterator it = iterator->children.find(segment); + if (it == iterator->children.end()) { + /* This topic does not even exist */ + return nullptr; + } + + iterator = it->second; + } + + return iterator; + } + private: - std::function cb; + std::function cb; Topic *root = new Topic; /* Global messageId for deduplication of overlapping topics and ordering between topics */ unsigned int messageId = 0; + /* Sender holes */ + std::map> senderHoles; + /* The triggered topics */ Topic *triggeredTopics[64]; int numTriggeredTopics = 0; Subscriber *min = (Subscriber *) UINTPTR_MAX; - + /* Cull or trim unused Topic nodes from leaf to root */ void trimTree(Topic *topic) { - if (!topic->subs.size() && !topic->children.size() && !topic->terminatingWildcardChild && !topic->wildcardChild) { + while (!topic->subs.size() && !topic->children.size() && !topic->terminatingWildcardChild && !topic->wildcardChild) { Topic *parent = topic->parent; if (topic->length == 1) { @@ -112,43 +210,64 @@ private: delete [] topic->name; delete topic; - if (parent != root) { - trimTree(parent); + if (parent == root) { + break; } + + topic = parent; } } - /* Should be getData and commit? */ - void publish(Topic *iterator, size_t start, size_t stop, std::string_view topic, std::string_view message) { - /* If we already have 64 triggered topics make sure to drain it here */ - if (numTriggeredTopics == 64) { - drain(); - } + /* Publishes to all matching topics and wildcards. Returns whether at least one topic was a match. */ + bool publish(Topic *iterator, size_t start, size_t stop, std::string_view topic, std::pair message) { + /* Whether we matched with at least one topic */ + bool didMatch = false; + + /* Iterate over all segments in given topic */ for (; stop != std::string::npos; start = stop + 1) { stop = topic.find('/', start); std::string_view segment = topic.substr(start, stop - start); + /* It is very important to disallow wildcards when publishing. + * We will not catch EVERY misuse this lazy way, but enough to hinder + * explosive recursion. + * Terminating wildcards MAY still get triggered along the way, if for + * instace the error is found late while iterating the topic segments. */ + if (segment.length() == 1) { + if (segment[0] == '+' || segment[0] == '#') { + /* "Fail" here, but not necessarily for the entire publish */ + return didMatch; + } + } + /* Do we have a terminating wildcard child? */ if (iterator->terminatingWildcardChild) { iterator->terminatingWildcardChild->messages[messageId] = message; /* Add this topic to triggered */ if (!iterator->terminatingWildcardChild->triggered) { + /* If we already have 64 triggered topics make sure to drain it here */ + if (numTriggeredTopics == 64) { + drain(); + } + triggeredTopics[numTriggeredTopics++] = iterator->terminatingWildcardChild; iterator->terminatingWildcardChild->triggered = true; } + + didMatch = true; } /* Do we have a wildcard child? */ if (iterator->wildcardChild) { - publish(iterator->wildcardChild, stop + 1, stop, topic, message); + didMatch |= publish(iterator->wildcardChild, stop + 1, stop, topic, message); } std::map::iterator it = iterator->children.find(segment); if (it == iterator->children.end()) { /* Stop trying to match by exact string */ - return; + return didMatch; } iterator = it->second; @@ -159,14 +278,22 @@ private: /* Add this topic to triggered */ if (!iterator->triggered) { + /* If we already have 64 triggered topics make sure to drain it here */ + if (numTriggeredTopics == 64) { + drain(); + } + triggeredTopics[numTriggeredTopics++] = iterator; iterator->triggered = true; } + + /* We obviously matches exactly here */ + return true; } public: - TopicTree(std::function cb) { + TopicTree(std::function cb) { this->cb = cb; } @@ -174,7 +301,20 @@ public: delete root; } - void subscribe(std::string_view topic, Subscriber *subscriber) { + /* This is part of the fast path, so should be optimal */ + std::vector &getSenderFor(Subscriber *s) { + static thread_local std::vector emptyVector; + + auto it = senderHoles.find(s); + if (it != senderHoles.end()) { + return it->second; + } + + return emptyVector; + } + + /* Returns number of subscribers after the call and whether or not we were successful in subscribing */ + std::pair subscribe(std::string_view topic, Subscriber *subscriber, bool nonStrict = false) { /* Start iterating from the root */ Topic *iterator = root; @@ -196,7 +336,18 @@ public: newTopic->terminatingWildcardChild = nullptr; newTopic->wildcardChild = nullptr; memcpy(newTopic->name, segment.data(), segment.length()); - + + /* Set fullname as parent's name plus our name */ + newTopic->fullName.reserve(newTopic->parent->fullName.length() + 1 + segment.length()); + + /* Only append parent's name if parent is not root */ + if (newTopic->parent != root) { + newTopic->fullName.append(newTopic->parent->fullName); + newTopic->fullName.append("/"); + } + + newTopic->fullName.append(segment); + /* For simplicity we do insert wildcards with text */ iterator->children.insert(lb, {std::string_view(newTopic->name, segment.length()), newTopic}); @@ -216,22 +367,38 @@ public: } } + /* If this topic is triggered, drain the tree before we join */ + if (iterator->triggered) { + if (!nonStrict) { + drain(); + } + } + /* Add socket to Topic's Set */ auto [it, inserted] = iterator->subs.insert(subscriber); /* Add Topic to list of subscriptions only if we weren't already subscribed */ if (inserted) { subscriber->subscriptions.push_back(iterator); + return {(unsigned int) iterator->subs.size(), true}; } + return {(unsigned int) iterator->subs.size(), false}; } - void publish(std::string_view topic, std::string_view message) { - publish(root, 0, 0, topic, message); + bool publish(std::string_view topic, std::pair message, Subscriber *sender = nullptr) { + /* Add a hole for the sender if one */ + if (sender) { + senderHoles[sender].push_back(messageId); + } + + auto ret = publish(root, 0, 0, topic, message); + /* MessageIDs are reset on drain - this should be fine since messages itself are cleared on drain */ messageId++; + return ret; } - /* Returns whether we were subscribed prior */ - bool unsubscribe(std::string_view topic, Subscriber *subscriber) { + /* Returns a pair of numSubscribers after operation, and whether we were subscribed prior */ + std::pair unsubscribe(std::string_view topic, Subscriber *subscriber, bool nonStrict = false) { /* Subscribers are likely to have very few subscriptions (20 or fewer) */ if (subscriber) { /* Lookup exact Topic ptr from string */ @@ -243,32 +410,55 @@ public: std::map::iterator it = iterator->children.find(segment); if (it == iterator->children.end()) { /* This topic does not even exist */ - return false; + return {0, false}; } iterator = it->second; } + /* Is this topic locked? If so, we cannot unsubscribe from it */ + if (iterator->locked) { + return {iterator->subs.size(), false}; + } + /* Try and remove this topic from our list */ for (auto it = subscriber->subscriptions.begin(); it != subscriber->subscriptions.end(); it++) { if (*it == iterator) { + /* If this topic is triggered, drain the tree before we leave */ + if (iterator->triggered) { + if (!nonStrict) { + drain(); + } + } + /* Remove topic ptr from our list */ subscriber->subscriptions.erase(it); /* Remove us from Topic's subs */ iterator->subs.erase(subscriber); + unsigned int numSubscribers = (unsigned int) iterator->subs.size(); trimTree(iterator); - return true; + return {numSubscribers, true}; } } } - return false; + return {0, false}; } /* Can be called with nullptr, ignore it then */ - void unsubscribeAll(Subscriber *subscriber) { + void unsubscribeAll(Subscriber *subscriber, bool mayFlush = true) { if (subscriber) { for (Topic *topic : subscriber->subscriptions) { + + /* We do not want to flush when closing a socket, it makes no sense to do so */ + + /* If this topic is triggered, drain the tree before we leave */ + if (mayFlush && topic->triggered) { + /* Never mind nonStrict here (yet?) */ + drain(); + } + + /* Remove us from the topic's set */ topic->subs.erase(subscriber); trimTree(topic); } @@ -290,11 +480,18 @@ public: for (int i = 0; i < numTriggeredTopics; i++) { if (triggeredTopics[i]->subs.size()) { triggeredTopics[numFilteredTriggeredTopics++] = triggeredTopics[i]; + } else { + /* If we no longer have any subscribers, yet still keep this Topic alive (parent), + * make sure to clear its potential messages. */ + triggeredTopics[i]->messages.clear(); + triggeredTopics[i]->triggered = false; } } numTriggeredTopics = numFilteredTriggeredTopics; if (!numTriggeredTopics) { + senderHoles.clear(); + messageId = 0; return; } @@ -310,7 +507,7 @@ public: if (min != (Subscriber *)UINTPTR_MAX) { /* Up to 64 triggered Topics per batch */ - std::map intersectionCache; + std::map intersectionCache; /* Loop over these here */ std::set::iterator it[64]; @@ -319,14 +516,14 @@ public: it[i] = triggeredTopics[i]->subs.begin(); end[i] = triggeredTopics[i]->subs.end(); } - + /* Empty all sets from unique subscribers */ for (int nonEmpty = numTriggeredTopics; nonEmpty; ) { Subscriber *nextMin = (Subscriber *)UINTPTR_MAX; /* The message sets relevant for this intersection */ - std::map *perSubscriberIntersectingTopicMessages[64]; + std::map> *perSubscriberIntersectingTopicMessages[64]; int numPerSubscriberIntersectingTopicMessages = 0; uint64_t intersection = 0; @@ -357,18 +554,28 @@ public: } /* Generate cache for intersection */ - if (intersectionCache[intersection].length() == 0) { + if (intersectionCache[intersection].dataChannels.first.length() == 0) { /* Build the union in order without duplicates */ - std::map complete; + std::map> complete; for (int i = 0; i < numPerSubscriberIntersectingTopicMessages; i++) { complete.insert(perSubscriberIntersectingTopicMessages[i]->begin(), perSubscriberIntersectingTopicMessages[i]->end()); } - /* Create the linear cache */ - std::string res; + /* Create the linear cache, {inflated, deflated} */ + Intersection res; for (auto &p : complete) { - res.append(p.second); + res.dataChannels.first.append(p.second.first); + res.dataChannels.second.append(p.second.second); + + /* Appends {id, length, length} + * We could possibly append byte offset also, + * if we want to use log2 search later. */ + Hole h; + h.lengths.first = p.second.first.length(); + h.lengths.second = p.second.second.length(); + h.messageId = p.first; + res.holes.push_back(h); } cb(min, intersectionCache[intersection] = std::move(res)); @@ -379,7 +586,6 @@ public: min = nextMin; } - } /* Clear messages of triggered Topics */ @@ -388,27 +594,8 @@ public: triggeredTopics[i]->triggered = false; } numTriggeredTopics = 0; - } - - void print(Topic *root = nullptr, int indentation = 1) { - if (root == nullptr) { - std::cout << "Print of tree:" << std::endl; - root = this->root; - } - - for (auto p : root->children) { - for (int i = 0; i < indentation; i++) { - std::cout << " "; - } - std::cout << std::string_view(p.second->name, p.second->length) << " = " << p.second->messages.size() << " publishes, " << p.second->subs.size() << " subscribers {"; - - for (auto &p : p.second->subs) { - std::cout << p << " referring to socket: " << p->user << ", "; - } - std::cout << "}" << std::endl; - - print(p.second, indentation + 1); - } + senderHoles.clear(); + messageId = 0; } }; diff --git a/src/Vendor/uwebsockets/include/uws/Utilities.h b/src/Vendor/uwebsockets/include/uws/Utilities.h index c84029e..2fbea32 100644 --- a/src/Vendor/uwebsockets/include/uws/Utilities.h +++ b/src/Vendor/uwebsockets/include/uws/Utilities.h @@ -1,5 +1,5 @@ /* - * Authored by Alex Hultman, 2018-2019. + * Authored by Alex Hultman, 2018-2020. * Intellectual property of third-party. * Licensed under the Apache License, Version 2.0 (the "License"); @@ -43,8 +43,8 @@ inline int u32toaHex(uint32_t value, char *dst) { return ret; } -inline int u32toa(uint32_t value, char *dst) { - char temp[10]; +inline int u64toa(uint64_t value, char *dst) { + char temp[20]; char *p = temp; do { *p++ = (char) ((value % 10) + '0'); diff --git a/src/Vendor/uwebsockets/include/uws/WebSocket.h b/src/Vendor/uwebsockets/include/uws/WebSocket.h index ee05b8b..39ca0b2 100644 --- a/src/Vendor/uwebsockets/include/uws/WebSocket.h +++ b/src/Vendor/uwebsockets/include/uws/WebSocket.h @@ -1,5 +1,5 @@ /* - * Authored by Alex Hultman, 2018-2019. + * Authored by Alex Hultman, 2018-2021. * Intellectual property of third-party. * Licensed under the Apache License, Version 2.0 (the "License"); @@ -27,40 +27,63 @@ namespace uWS { -template +template struct WebSocket : AsyncSocket { template friend struct TemplatedApp; + template friend struct HttpResponse; private: typedef AsyncSocket Super; - void *init(bool perMessageDeflate, bool slidingCompression, std::string &&backpressure) { - new (us_socket_ext(SSL, (us_socket_t *) this)) WebSocketData(perMessageDeflate, slidingCompression, std::move(backpressure)); + void *init(bool perMessageDeflate, CompressOptions compressOptions, std::string &&backpressure) { + new (us_socket_ext(SSL, (us_socket_t *) this)) WebSocketData(perMessageDeflate, compressOptions, std::move(backpressure)); return this; } public: /* Returns pointer to the per socket user data */ - void *getUserData() { + USERDATA *getUserData() { WebSocketData *webSocketData = (WebSocketData *) us_socket_ext(SSL, (us_socket_t *) this); /* We just have it overallocated by sizeof type */ - return (webSocketData + 1); + return (USERDATA *) (webSocketData + 1); } /* See AsyncSocket */ using Super::getBufferedAmount; using Super::getRemoteAddress; + using Super::getRemoteAddressAsText; + using Super::getNativeHandle; /* Simple, immediate close of the socket. Emits close event */ using Super::close; - /* Send or buffer a WebSocket frame, compressed or not. Returns false on increased user space backpressure. */ - bool send(std::string_view message, uWS::OpCode opCode = uWS::OpCode::BINARY, bool compress = false) { + enum SendStatus : int { + BACKPRESSURE, + SUCCESS, + DROPPED + }; + + /* Send or buffer a WebSocket frame, compressed or not. Returns BACKPRESSURE on increased user space backpressure, + * DROPPED on dropped message (due to backpressure) or SUCCCESS if you are free to send even more now. */ + SendStatus send(std::string_view message, OpCode opCode = OpCode::BINARY, bool compress = false) { + WebSocketContextData *webSocketContextData = (WebSocketContextData *) us_socket_context_ext(SSL, + (us_socket_context_t *) us_socket_context(SSL, (us_socket_t *) this) + ); + + /* Skip sending and report success if we are over the limit of maxBackpressure */ + if (webSocketContextData->maxBackpressure && webSocketContextData->maxBackpressure < getBufferedAmount()) { + /* Also defer a close if we should */ + if (webSocketContextData->closeOnBackpressureLimit) { + us_socket_shutdown_read(SSL, (us_socket_t *) this); + } + return DROPPED; + } + /* Transform the message to compressed domain if requested */ if (compress) { WebSocketData *webSocketData = (WebSocketData *) Super::getAsyncSocketData(); - /* Check and correct the compress hint */ - if (opCode < 3 && webSocketData->compressionStatus == WebSocketData::ENABLED) { + /* Check and correct the compress hint. It is never valid to compress 0 bytes */ + if (message.length() && opCode < 3 && webSocketData->compressionStatus == WebSocketData::ENABLED) { LoopData *loopData = Super::getLoopData(); /* Compress using either shared or dedicated deflationStream */ if (webSocketData->deflationStream) { @@ -82,7 +105,7 @@ public: /* Get size, alloate size, write if needed */ size_t messageFrameSize = protocol::messageFrameSize(message.length()); - auto[sendBuffer, requiresWrite] = Super::getSendBuffer(messageFrameSize); + auto [sendBuffer, requiresWrite] = Super::getSendBuffer(messageFrameSize); protocol::formatMessage(sendBuffer, message.data(), message.length(), opCode, message.length(), compress); /* This is the slow path, when we couldn't cork for the user */ if (requiresWrite) { @@ -93,7 +116,7 @@ public: if (failed) { /* Return false for failure, skipping to reset the timeout below */ - return false; + return BACKPRESSURE; } } @@ -101,22 +124,24 @@ public: if (automaticallyCorked) { auto [written, failed] = Super::uncork(); if (failed) { - return false; + return BACKPRESSURE; } } /* Every successful send resets the timeout */ - WebSocketContextData *webSocketContextData = (WebSocketContextData *) us_socket_context_ext(SSL, - (us_socket_context_t *) us_socket_context(SSL, (us_socket_t *) this) - ); - AsyncSocket::timeout(webSocketContextData->idleTimeout); + if (webSocketContextData->resetIdleTimeoutOnSend) { + Super::timeout(webSocketContextData->idleTimeoutComponents.first); + WebSocketData *webSocketData = (WebSocketData *) Super::getAsyncSocketData(); + webSocketData->hasTimedOut = false; + } /* Return success */ - return true; + return SUCCESS; } - /* Send websocket close frame, emit close event, send FIN if successful */ - void end(int code, std::string_view message = {}) { + /* Send websocket close frame, emit close event, send FIN if successful. + * Will not append a close reason if code is 0 or 1005. */ + void end(int code = 0, std::string_view message = {}) { /* Check if we already called this one */ WebSocketData *webSocketData = (WebSocketData *) us_socket_ext(SSL, (us_socket_t *) this); if (webSocketData->isShuttingDown) { @@ -128,23 +153,22 @@ public: /* Format and send the close frame */ static const int MAX_CLOSE_PAYLOAD = 123; - int length = (int) std::min(MAX_CLOSE_PAYLOAD, message.length()); + size_t length = std::min(MAX_CLOSE_PAYLOAD, message.length()); char closePayload[MAX_CLOSE_PAYLOAD + 2]; - int closePayloadLength = (int) protocol::formatClosePayload(closePayload, (uint16_t) code, message.data(), length); + size_t closePayloadLength = protocol::formatClosePayload(closePayload, (uint16_t) code, message.data(), length); bool ok = send(std::string_view(closePayload, closePayloadLength), OpCode::CLOSE); /* FIN if we are ok and not corked */ - WebSocket *webSocket = (WebSocket *) this; - if (!webSocket->isCorked()) { + if (!this->isCorked()) { if (ok) { /* If we are not corked, and we just sent off everything, we need to FIN right here. * In all other cases, we need to fin either if uncork was successful, or when drainage is complete. */ - webSocket->shutdown(); + this->shutdown(); } } /* Emit close event */ - WebSocketContextData *webSocketContextData = (WebSocketContextData *) us_socket_context_ext(SSL, + WebSocketContextData *webSocketContextData = (WebSocketContextData *) us_socket_context_ext(SSL, (us_socket_context_t *) us_socket_context(SSL, (us_socket_t *) this) ); if (webSocketContextData->closeHandler) { @@ -152,13 +176,13 @@ public: } /* Make sure to unsubscribe from any pub/sub node at exit */ - webSocketContextData->topicTree.unsubscribeAll(webSocketData->subscriber); + webSocketContextData->topicTree.unsubscribeAll(webSocketData->subscriber, false); delete webSocketData->subscriber; webSocketData->subscriber = nullptr; } /* Corks the response if possible. Leaves already corked socket be. */ - void cork(fu2::unique_function &&handler) { + void cork(MoveOnlyFunction &&handler) { if (!Super::isCorked() && Super::canCork()) { Super::cork(); handler(); @@ -172,9 +196,9 @@ public: } } - /* Subscribe to a topic according to MQTT rules and syntax */ - void subscribe(std::string_view topic) { - WebSocketContextData *webSocketContextData = (WebSocketContextData *) us_socket_context_ext(SSL, + /* Subscribe to a topic according to MQTT rules and syntax. Returns success */ + /*std::pair*/ bool subscribe(std::string_view topic, bool nonStrict = false) { + WebSocketContextData *webSocketContextData = (WebSocketContextData *) us_socket_context_ext(SSL, (us_socket_context_t *) us_socket_context(SSL, (us_socket_t *) this) ); @@ -184,38 +208,87 @@ public: webSocketData->subscriber = new Subscriber(this); } - webSocketContextData->topicTree.subscribe(topic, webSocketData->subscriber); + /* Cannot return numSubscribers as this is only for this particular websocket context */ + return webSocketContextData->topicTree.subscribe(topic, webSocketData->subscriber, nonStrict).second; } - /* Unsubscribe from a topic, returns true if we were subscribed */ - bool unsubscribe(std::string_view topic) { - WebSocketContextData *webSocketContextData = (WebSocketContextData *) us_socket_context_ext(SSL, + /* Unsubscribe from a topic, returns true if we were subscribed. */ + /*std::pair*/ bool unsubscribe(std::string_view topic, bool nonStrict = false) { + WebSocketContextData *webSocketContextData = (WebSocketContextData *) us_socket_context_ext(SSL, (us_socket_context_t *) us_socket_context(SSL, (us_socket_t *) this) ); WebSocketData *webSocketData = (WebSocketData *) us_socket_ext(SSL, (us_socket_t *) this); - return webSocketContextData->topicTree.unsubscribe(topic, webSocketData->subscriber); + /* Cannot return numSubscribers as this is only for this particular websocket context */ + return webSocketContextData->topicTree.unsubscribe(topic, webSocketData->subscriber, nonStrict).second; } - /* Unsubscribe from all topics you might be subscribed to */ - void unsubscribeAll() { - WebSocketContextData *webSocketContextData = (WebSocketContextData *) us_socket_context_ext(SSL, + /* Returns whether this socket is subscribed to the specified topic */ + bool isSubscribed(std::string_view topic) { + WebSocketContextData *webSocketContextData = (WebSocketContextData *) us_socket_context_ext(SSL, (us_socket_context_t *) us_socket_context(SSL, (us_socket_t *) this) ); WebSocketData *webSocketData = (WebSocketData *) us_socket_ext(SSL, (us_socket_t *) this); + if (!webSocketData->subscriber) { + return false; + } - webSocketContextData->topicTree.unsubscribeAll(webSocketData->subscriber); + Topic *t = webSocketContextData->topicTree.lookupTopic(topic); + if (t) { + return t->subs.find(webSocketData->subscriber) != t->subs.end(); + } + + return false; } - /* Publish a message to a topic according to MQTT rules and syntax */ - void publish(std::string_view topic, std::string_view message, OpCode opCode = OpCode::TEXT, bool compress = false) { - WebSocketContextData *webSocketContextData = (WebSocketContextData *) us_socket_context_ext(SSL, + /* Iterates all topics of this WebSocket. Every topic is represented by its full name. + * Can be called in close handler. It is possible to modify the subscription list while + * inside the callback ONLY IF not modifying the topic passed to the callback. + * Topic names are valid only for the duration of the callback. */ + void iterateTopics(MoveOnlyFunction cb) { + WebSocketData *webSocketData = (WebSocketData *) us_socket_ext(SSL, (us_socket_t *) this); + + if (webSocketData->subscriber) { + for (Topic *t : webSocketData->subscriber->subscriptions) { + /* Lock this topic so that nobody may unsubscribe from it during this callback */ + t->locked = true; + + cb(t->fullName/*, (unsigned int) t->subs.size()*/); + + t->locked = false; + } + } + } + + /* Publish a message to a topic according to MQTT rules and syntax. Returns success. + * We, the WebSocket, must be subscribed to the topic itself and if so - no message will be sent to ourselves. + * Use App::publish for an unconditional publish that simply publishes to whomever might be subscribed. */ + bool publish(std::string_view topic, std::string_view message, OpCode opCode = OpCode::TEXT, bool compress = false) { + WebSocketContextData *webSocketContextData = (WebSocketContextData *) us_socket_context_ext(SSL, (us_socket_context_t *) us_socket_context(SSL, (us_socket_t *) this) ); - /* Is the same as publishing per websocket context */ - webSocketContextData->publish(topic, message, opCode, compress); + + /* We cannot be a subscriber of this topic if we are not a subscriber of anything */ + WebSocketData *webSocketData = (WebSocketData *) us_socket_ext(SSL, (us_socket_t *) this); + if (!webSocketData->subscriber) { + /* Failure, but still do return the number of subscribers */ + return false; + } + + /* Publish as sender, does not receive its own messages even if subscribed to relevant topics */ + bool success = webSocketContextData->publish(topic, message, opCode, compress, webSocketData->subscriber); + + /* Loop over all websocket contexts for this App */ + if (success) { + /* Success is really only determined by the first publish. We must be subscribed to the topic. */ + for (auto *adjacentWebSocketContextData : webSocketContextData->adjacentWebSocketContextDatas) { + adjacentWebSocketContextData->publish(topic, message, opCode, compress); + } + } + + return success; } }; diff --git a/src/Vendor/uwebsockets/include/uws/WebSocketContext.h b/src/Vendor/uwebsockets/include/uws/WebSocketContext.h index 080a930..84625a4 100644 --- a/src/Vendor/uwebsockets/include/uws/WebSocketContext.h +++ b/src/Vendor/uwebsockets/include/uws/WebSocketContext.h @@ -1,5 +1,5 @@ /* - * Authored by Alex Hultman, 2018-2019. + * Authored by Alex Hultman, 2018-2020. * Intellectual property of third-party. * Licensed under the Apache License, Version 2.0 (the "License"); @@ -25,7 +25,7 @@ namespace uWS { -template +template struct WebSocketContext { template friend struct TemplatedApp; template friend struct WebSocketProtocol; @@ -36,12 +36,12 @@ private: return (us_socket_context_t *) this; } - WebSocketContextData *getExt() { - return (WebSocketContextData *) us_socket_context_ext(SSL, (us_socket_context_t *) this); + WebSocketContextData *getExt() { + return (WebSocketContextData *) us_socket_context_ext(SSL, (us_socket_context_t *) this); } /* If we have negotiated compression, set this frame compressed */ - static bool setCompressed(uWS::WebSocketState *wState, void *s) { + static bool setCompressed(WebSocketState */*wState*/, void *s) { WebSocketData *webSocketData = (WebSocketData *) us_socket_ext(SSL, (us_socket_t *) s); if (webSocketData->compressionStatus == WebSocketData::CompressionStatus::ENABLED) { @@ -52,14 +52,14 @@ private: } } - static void forceClose(uWS::WebSocketState *wState, void *s) { - us_socket_close(SSL, (us_socket_t *) s); + static void forceClose(WebSocketState */*wState*/, void *s, std::string_view reason = {}) { + us_socket_close(SSL, (us_socket_t *) s, (int) reason.length(), (void *) reason.data()); } /* Returns true on breakage */ - static bool handleFragment(char *data, size_t length, unsigned int remainingBytes, int opCode, bool fin, uWS::WebSocketState *webSocketState, void *s) { + static bool handleFragment(char *data, size_t length, unsigned int remainingBytes, int opCode, bool fin, WebSocketState *webSocketState, void *s) { /* WebSocketData and WebSocketContextData */ - WebSocketContextData *webSocketContextData = (WebSocketContextData *) us_socket_context_ext(SSL, us_socket_context(SSL, (us_socket_t *) s)); + WebSocketContextData *webSocketContextData = (WebSocketContextData *) us_socket_context_ext(SSL, us_socket_context(SSL, (us_socket_t *) s)); WebSocketData *webSocketData = (WebSocketData *) us_socket_ext(SSL, (us_socket_t *) s); /* Is this a non-control frame? */ @@ -72,25 +72,25 @@ private: webSocketData->compressionStatus = WebSocketData::CompressionStatus::ENABLED; LoopData *loopData = (LoopData *) us_loop_ext(us_socket_context_loop(SSL, us_socket_context(SSL, (us_socket_t *) s))); - std::string_view inflatedFrame = loopData->inflationStream->inflate(loopData->zlibContext, {data, length}, webSocketContextData->maxPayloadLength); - if (!inflatedFrame.length()) { - forceClose(webSocketState, s); + auto inflatedFrame = loopData->inflationStream->inflate(loopData->zlibContext, {data, length}, webSocketContextData->maxPayloadLength); + if (!inflatedFrame.has_value()) { + forceClose(webSocketState, s, ERR_TOO_BIG_MESSAGE_INFLATION); return true; } else { - data = (char *) inflatedFrame.data(); - length = inflatedFrame.length(); + data = (char *) inflatedFrame->data(); + length = inflatedFrame->length(); } } /* Check text messages for Utf-8 validity */ if (opCode == 1 && !protocol::isValidUtf8((unsigned char *) data, length)) { - forceClose(webSocketState, s); + forceClose(webSocketState, s, ERR_INVALID_TEXT); return true; } /* Emit message event & break if we are closed or shut down when returning */ if (webSocketContextData->messageHandler) { - webSocketContextData->messageHandler((WebSocket *) s, std::string_view(data, length), (uWS::OpCode) opCode); + webSocketContextData->messageHandler((WebSocket *) s, std::string_view(data, length), (OpCode) opCode); if (us_socket_is_closed(SSL, (us_socket_t *) s) || webSocketData->isShuttingDown) { return true; } @@ -102,7 +102,7 @@ private: } /* Fragments forming a big message are not caught until appending them */ if (refusePayloadLength(length + webSocketData->fragmentBuffer.length(), webSocketState, s)) { - forceClose(webSocketState, s); + forceClose(webSocketState, s, ERR_TOO_BIG_MESSAGE); return true; } webSocketData->fragmentBuffer.append(data, length); @@ -115,8 +115,8 @@ private: if (webSocketData->compressionStatus == WebSocketData::CompressionStatus::COMPRESSED_FRAME) { webSocketData->compressionStatus = WebSocketData::CompressionStatus::ENABLED; - // what's really the story here? - webSocketData->fragmentBuffer.append("...."); + /* 9 bytes of padding for libdeflate */ + webSocketData->fragmentBuffer.append("123456789"); LoopData *loopData = (LoopData *) us_loop_ext( us_socket_context_loop(SSL, @@ -124,13 +124,13 @@ private: ) ); - std::string_view inflatedFrame = loopData->inflationStream->inflate(loopData->zlibContext, {webSocketData->fragmentBuffer.data(), webSocketData->fragmentBuffer.length() - 4}, webSocketContextData->maxPayloadLength); - if (!inflatedFrame.length()) { - forceClose(webSocketState, s); + auto inflatedFrame = loopData->inflationStream->inflate(loopData->zlibContext, {webSocketData->fragmentBuffer.data(), webSocketData->fragmentBuffer.length() - 9}, webSocketContextData->maxPayloadLength); + if (!inflatedFrame.has_value()) { + forceClose(webSocketState, s, ERR_TOO_BIG_MESSAGE_INFLATION); return true; } else { - data = (char *) inflatedFrame.data(); - length = inflatedFrame.length(); + data = (char *) inflatedFrame->data(); + length = inflatedFrame->length(); } @@ -142,13 +142,13 @@ private: /* Check text messages for Utf-8 validity */ if (opCode == 1 && !protocol::isValidUtf8((unsigned char *) data, length)) { - forceClose(webSocketState, s); + forceClose(webSocketState, s, ERR_INVALID_TEXT); return true; } /* Emit message and check for shutdown or close */ if (webSocketContextData->messageHandler) { - webSocketContextData->messageHandler((WebSocket *) s, std::string_view(data, length), (uWS::OpCode) opCode); + webSocketContextData->messageHandler((WebSocket *) s, std::string_view(data, length), (OpCode) opCode); if (us_socket_is_closed(SSL, (us_socket_t *) s) || webSocketData->isShuttingDown) { return true; } @@ -160,7 +160,7 @@ private: } } else { /* Control frames need the websocket to send pings, pongs and close */ - WebSocket *webSocket = (WebSocket *) s; + WebSocket *webSocket = (WebSocket *) s; if (!remainingBytes && fin && !webSocketData->controlTipLength) { if (opCode == CLOSE) { @@ -171,14 +171,14 @@ private: if (opCode == PING) { webSocket->send(std::string_view(data, length), (OpCode) OpCode::PONG); if (webSocketContextData->pingHandler) { - webSocketContextData->pingHandler(webSocket); + webSocketContextData->pingHandler(webSocket, {data, length}); if (us_socket_is_closed(SSL, (us_socket_t *) s) || webSocketData->isShuttingDown) { return true; } } } else if (opCode == PONG) { if (webSocketContextData->pongHandler) { - webSocketContextData->pongHandler(webSocket); + webSocketContextData->pongHandler(webSocket, {data, length}); if (us_socket_is_closed(SSL, (us_socket_t *) s) || webSocketData->isShuttingDown) { return true; } @@ -188,7 +188,7 @@ private: } else { /* Here we never mind any size optimizations as we are in the worst possible path */ webSocketData->fragmentBuffer.append(data, length); - webSocketData->controlTipLength += (int) length; + webSocketData->controlTipLength += (unsigned int) length; if (!remainingBytes && fin) { char *controlBuffer = (char *) webSocketData->fragmentBuffer.data() + webSocketData->fragmentBuffer.length() - webSocketData->controlTipLength; @@ -200,14 +200,14 @@ private: if (opCode == PING) { webSocket->send(std::string_view(controlBuffer, webSocketData->controlTipLength), (OpCode) OpCode::PONG); if (webSocketContextData->pingHandler) { - webSocketContextData->pingHandler(webSocket); + webSocketContextData->pingHandler(webSocket, std::string_view(controlBuffer, webSocketData->controlTipLength)); if (us_socket_is_closed(SSL, (us_socket_t *) s) || webSocketData->isShuttingDown) { return true; } } } else if (opCode == PONG) { if (webSocketContextData->pongHandler) { - webSocketContextData->pongHandler(webSocket); + webSocketContextData->pongHandler(webSocket, std::string_view(controlBuffer, webSocketData->controlTipLength)); if (us_socket_is_closed(SSL, (us_socket_t *) s) || webSocketData->isShuttingDown) { return true; } @@ -224,32 +224,32 @@ private: return false; } - static bool refusePayloadLength(uint64_t length, uWS::WebSocketState *wState, void *s) { - auto *webSocketContextData = (WebSocketContextData *) us_socket_context_ext(SSL, us_socket_context(SSL, (us_socket_t *) s)); + static bool refusePayloadLength(uint64_t length, WebSocketState */*wState*/, void *s) { + auto *webSocketContextData = (WebSocketContextData *) us_socket_context_ext(SSL, us_socket_context(SSL, (us_socket_t *) s)); /* Return true for refuse, false for accept */ return webSocketContextData->maxPayloadLength < length; } - WebSocketContext *init() { + WebSocketContext *init() { /* Adopting a socket does not trigger open event. * We arreive as WebSocket with timeout set and * any backpressure from HTTP state kept. */ /* Handle socket disconnections */ - us_socket_context_on_close(SSL, getSocketContext(), [](auto *s) { + us_socket_context_on_close(SSL, getSocketContext(), [](auto *s, int code, void *reason) { /* For whatever reason, if we already have emitted close event, do not emit it again */ WebSocketData *webSocketData = (WebSocketData *) (us_socket_ext(SSL, s)); if (!webSocketData->isShuttingDown) { /* Emit close event */ - auto *webSocketContextData = (WebSocketContextData *) us_socket_context_ext(SSL, us_socket_context(SSL, (us_socket_t *) s)); + auto *webSocketContextData = (WebSocketContextData *) us_socket_context_ext(SSL, us_socket_context(SSL, (us_socket_t *) s)); if (webSocketContextData->closeHandler) { - webSocketContextData->closeHandler((WebSocket *) s, 1006, {}); + webSocketContextData->closeHandler((WebSocket *) s, 1006, {(char *) reason, (size_t) code}); } /* Make sure to unsubscribe from any pub/sub node at exit */ - webSocketContextData->topicTree.unsubscribeAll(webSocketData->subscriber); + webSocketContextData->topicTree.unsubscribeAll(webSocketData->subscriber, false); delete webSocketData->subscriber; webSocketData->subscriber = nullptr; } @@ -272,17 +272,18 @@ private: return s; } - auto *webSocketContextData = (WebSocketContextData *) us_socket_context_ext(SSL, us_socket_context(SSL, (us_socket_t *) s)); + auto *webSocketContextData = (WebSocketContextData *) us_socket_context_ext(SSL, us_socket_context(SSL, (us_socket_t *) s)); auto *asyncSocket = (AsyncSocket *) s; /* Every time we get data and not in shutdown state we simply reset the timeout */ - asyncSocket->timeout(webSocketContextData->idleTimeout); + asyncSocket->timeout(webSocketContextData->idleTimeoutComponents.first); + webSocketData->hasTimedOut = false; /* We always cork on data */ asyncSocket->cork(); /* This parser has virtually no overhead */ - uWS::WebSocketProtocol>::consume(data, length, (WebSocketState *) webSocketData, s); + WebSocketProtocol>::consume(data, (unsigned int) length, (WebSocketState *) webSocketData, s); /* Uncorking a closed socekt is fine, in fact it is needed */ asyncSocket->uncork(); @@ -302,6 +303,10 @@ private: /* Handle HTTP write out (note: SSL_read may trigger this spuriously, the app need to handle spurious calls) */ us_socket_context_on_writable(SSL, getSocketContext(), [](auto *s) { + /* NOTE: Are we called here corked? If so, the below write code is broken, since + * we will have 0 as getBufferedAmount due to writing to cork buffer, then sending TCP FIN before + * we actually uncorked and sent off things */ + /* It makes sense to check for us_is_shut_down here and return if so, to avoid shutting down twice */ if (us_socket_is_shut_down(SSL, (us_socket_t *) s)) { return s; @@ -310,16 +315,19 @@ private: AsyncSocket *asyncSocket = (AsyncSocket *) s; WebSocketData *webSocketData = (WebSocketData *)(us_socket_ext(SSL, s)); - /* We store old backpressure since it is unclear whether write drained anything */ - int backpressure = asyncSocket->getBufferedAmount(); + /* We store old backpressure since it is unclear whether write drained anything, + * however, in case of coming here with 0 backpressure we still need to emit drain event */ + unsigned int backpressure = asyncSocket->getBufferedAmount(); /* Drain as much as possible */ asyncSocket->write(nullptr, 0); /* Behavior: if we actively drain backpressure, always reset timeout (even if we are in shutdown) */ - if (backpressure < asyncSocket->getBufferedAmount()) { - auto *webSocketContextData = (WebSocketContextData *) us_socket_context_ext(SSL, us_socket_context(SSL, (us_socket_t *) s)); - asyncSocket->timeout(webSocketContextData->idleTimeout); + /* Also reset timeout if we came here with 0 backpressure */ + if (!backpressure || backpressure > asyncSocket->getBufferedAmount()) { + auto *webSocketContextData = (WebSocketContextData *) us_socket_context_ext(SSL, us_socket_context(SSL, (us_socket_t *) s)); + asyncSocket->timeout(webSocketContextData->idleTimeoutComponents.first); + webSocketData->hasTimedOut = false; } /* Are we in (WebSocket) shutdown mode? */ @@ -329,11 +337,11 @@ private: /* Now perform the actual TCP/TLS shutdown which was postponed due to backpressure */ asyncSocket->shutdown(); } - } else if (backpressure > asyncSocket->getBufferedAmount()) { - /* Only call drain if we actually drained backpressure */ - auto *webSocketContextData = (WebSocketContextData *) us_socket_context_ext(SSL, us_socket_context(SSL, (us_socket_t *) s)); + } else if (!backpressure || backpressure > asyncSocket->getBufferedAmount()) { + /* Only call drain if we actually drained backpressure or if we came here with 0 backpressure */ + auto *webSocketContextData = (WebSocketContextData *) us_socket_context_ext(SSL, us_socket_context(SSL, (us_socket_t *) s)); if (webSocketContextData->drainHandler) { - webSocketContextData->drainHandler((WebSocket *) s); + webSocketContextData->drainHandler((WebSocket *) s); } /* No need to check for closed here as we leave the handler immediately*/ } @@ -345,7 +353,7 @@ private: us_socket_context_on_end(SSL, getSocketContext(), [](auto *s) { /* If we get a fin, we just close I guess */ - us_socket_close(SSL, (us_socket_t *) s); + us_socket_close(SSL, (us_socket_t *) s, 0, nullptr); return s; }); @@ -353,8 +361,20 @@ private: /* Handle socket timeouts, simply close them so to not confuse client with FIN */ us_socket_context_on_timeout(SSL, getSocketContext(), [](auto *s) { + auto *webSocketData = (WebSocketData *)(us_socket_ext(SSL, s)); + auto *webSocketContextData = (WebSocketContextData *) us_socket_context_ext(SSL, us_socket_context(SSL, (us_socket_t *) s)); + + if (webSocketContextData->sendPingsAutomatically && !webSocketData->hasTimedOut) { + webSocketData->hasTimedOut = true; + us_socket_timeout(SSL, s, webSocketContextData->idleTimeoutComponents.second); + /* Send ping without being corked */ + ((AsyncSocket *) s)->write("\x89\x00", 2); + return s; + } + /* Timeout is very simple; we just close it */ - us_socket_close(SSL, (us_socket_t *) s); + /* Warning: we happen to know forceClose will not use first parameter so pass nullptr here */ + forceClose(nullptr, s, ERR_WEBSOCKET_TIMEOUT); return s; }); @@ -363,7 +383,7 @@ private: } void free() { - WebSocketContextData *webSocketContextData = (WebSocketContextData *) us_socket_context_ext(SSL, (us_socket_context_t *) this); + WebSocketContextData *webSocketContextData = (WebSocketContextData *) us_socket_context_ext(SSL, (us_socket_context_t *) this); webSocketContextData->~WebSocketContextData(); us_socket_context_free(SSL, (us_socket_context_t *) this); @@ -371,14 +391,14 @@ private: public: /* WebSocket contexts are always child contexts to a HTTP context so no SSL options are needed as they are inherited */ - static WebSocketContext *create(Loop *loop, us_socket_context_t *parentSocketContext) { - WebSocketContext *webSocketContext = (WebSocketContext *) us_create_child_socket_context(SSL, parentSocketContext, sizeof(WebSocketContextData)); + static WebSocketContext *create(Loop */*loop*/, us_socket_context_t *parentSocketContext) { + WebSocketContext *webSocketContext = (WebSocketContext *) us_create_child_socket_context(SSL, parentSocketContext, sizeof(WebSocketContextData)); if (!webSocketContext) { return nullptr; } /* Init socket context data */ - new ((WebSocketContextData *) us_socket_context_ext(SSL, (us_socket_context_t *)webSocketContext)) WebSocketContextData; + new ((WebSocketContextData *) us_socket_context_ext(SSL, (us_socket_context_t *)webSocketContext)) WebSocketContextData; return webSocketContext->init(); } }; diff --git a/src/Vendor/uwebsockets/include/uws/WebSocketContextData.h b/src/Vendor/uwebsockets/include/uws/WebSocketContextData.h index 1f0d1ac..4cac0da 100644 --- a/src/Vendor/uwebsockets/include/uws/WebSocketContextData.h +++ b/src/Vendor/uwebsockets/include/uws/WebSocketContextData.h @@ -1,5 +1,5 @@ /* - * Authored by Alex Hultman, 2018-2019. + * Authored by Alex Hultman, 2018-2020. * Intellectual property of third-party. * Licensed under the Apache License, Version 2.0 (the "License"); @@ -18,83 +18,279 @@ #ifndef UWS_WEBSOCKETCONTEXTDATA_H #define UWS_WEBSOCKETCONTEXTDATA_H -#include "f2/function2.hpp" +#include "MoveOnlyFunction.h" #include +#include #include "WebSocketProtocol.h" #include "TopicTree.h" +#include "WebSocketData.h" namespace uWS { -template struct WebSocket; +template struct WebSocket; /* todo: this looks identical to WebSocketBehavior, why not just std::move that entire thing in? */ -template +template struct WebSocketContextData { +private: + /* Used for prepending unframed messages when using dedicated compressors */ + struct MessageMetadata { + unsigned int length; + OpCode opCode; + bool compress; + /* Undefined init of all members */ + MessageMetadata() {} + MessageMetadata(unsigned int length, OpCode opCode, bool compress) + : length(length), opCode(opCode), compress(compress) {} + }; + +public: + /* All WebSocketContextData holds a list to all other WebSocketContextData in this app. + * We cannot type it USERDATA since different WebSocketContextData can have different USERDATA. */ + std::vector *> adjacentWebSocketContextDatas; + /* The callbacks for this context */ - fu2::unique_function *, std::string_view, uWS::OpCode)> messageHandler = nullptr; - fu2::unique_function *)> drainHandler = nullptr; - fu2::unique_function *, int, std::string_view)> closeHandler = nullptr; + MoveOnlyFunction *)> openHandler = nullptr; + MoveOnlyFunction *, std::string_view, OpCode)> messageHandler = nullptr; + MoveOnlyFunction *)> drainHandler = nullptr; + MoveOnlyFunction *, int, std::string_view)> closeHandler = nullptr; /* Todo: these should take message also; breaking change for v0.18 */ - fu2::unique_function *)> pingHandler = nullptr; - fu2::unique_function *)> pongHandler = nullptr; + MoveOnlyFunction *, std::string_view)> pingHandler = nullptr; + MoveOnlyFunction *, std::string_view)> pongHandler = nullptr; /* Settings for this context */ size_t maxPayloadLength = 0; - int idleTimeout = 0; + + /* We do need these for async upgrade */ + CompressOptions compression; /* There needs to be a maxBackpressure which will force close everything over that limit */ size_t maxBackpressure = 0; + bool closeOnBackpressureLimit; + bool resetIdleTimeoutOnSend; + bool sendPingsAutomatically; + + /* These are calculated on creation */ + std::pair idleTimeoutComponents; /* Each websocket context has a topic tree for pub/sub */ TopicTree topicTree; + /* This is run once on start-up */ + void calculateIdleTimeoutCompnents(unsigned short idleTimeout) { + unsigned short margin = 4; + /* 4, 8 or 16 seconds margin based on idleTimeout */ + while ((int) idleTimeout - margin * 2 >= margin * 2 && margin < 16) { + margin = (unsigned short) (margin << 1); + } + /* We should have no margin if not using sendPingsAutomatically */ + if (!sendPingsAutomatically) { + margin = 0; + } + idleTimeoutComponents = { + idleTimeout - margin, + margin + }; + } + ~WebSocketContextData() { /* We must unregister any loop post handler here */ Loop::get()->removePostHandler(this); Loop::get()->removePreHandler(this); } - WebSocketContextData() : topicTree([this](Subscriber *s, std::string_view data) -> int { + WebSocketContextData() : topicTree([this](Subscriber *s, Intersection &intersection) -> int { + + /* We could potentially be called here even if we have nothing to send, since we can + * be the sender of every single message in this intersection. Also "fin" of a segment is not + * guaranteed to be set, in case remaining segments are all from us. + * Essentially, we cannot make strict assumptions here. Also, we can even come here corked, + * since publish can call drain! */ + /* We rely on writing to regular asyncSockets */ auto *asyncSocket = (AsyncSocket *) s->user; - auto [written, failed] = asyncSocket->write(data.data(), (int) data.length()); - if (!failed) { - asyncSocket->timeout(this->idleTimeout); - } else { - /* Note: this assumes we are not corked, as corking will swallow things and fail later on */ + /* If we are corked, do not uncork - otherwise if we cork in here, uncork before leaving */ + bool wasCorked = asyncSocket->isCorked(); - /* Check if we now have too much backpressure (todo: don't buffer up before check) */ - if ((unsigned int) asyncSocket->getBufferedAmount() > maxBackpressure) { - asyncSocket->close(); + /* Do we even have room for potential data? */ + if (!maxBackpressure || asyncSocket->getBufferedAmount() < maxBackpressure) { + + /* Roll over all our segments */ + intersection.forSubscriber(topicTree.getSenderFor(s), [asyncSocket, this](std::pair data, bool fin) { + + /* We have a segment that is not marked as last ("fin"). + * Cork if not already so (purely for performance reasons). Does not touch "wasCorked". */ + if (!fin && !asyncSocket->isCorked() && asyncSocket->canCork()) { + asyncSocket->cork(); + } + + /* Pick uncompressed data track */ + std::string_view selectedData = data.first; + + /* Are we using compression? Fine, pick the compressed data track */ + WebSocketData *webSocketData = (WebSocketData *) asyncSocket->getAsyncSocketData(); + if (webSocketData->compressionStatus != WebSocketData::CompressionStatus::DISABLED) { + + /* This is used for both shared and dedicated paths */ + selectedData = data.second; + + /* However, dedicated compression has its own path */ + if (compression != SHARED_COMPRESSOR) { + + WebSocket *ws = (WebSocket *) asyncSocket; + + /* For performance reasons we always cork when in dedicated mode. + * Is this really the best? We already kind of cork things in Zlib? + * Right, formatting needs a cork buffer, right. Never mind. */ + if (!ws->isCorked() && ws->canCork()) { + asyncSocket->cork(); + } + + while (selectedData.length()) { + /* Interpret the data like so, because this is how we shoved it in */ + MessageMetadata mm; + memcpy((char *) &mm, selectedData.data(), sizeof(MessageMetadata)); + std::string_view unframedMessage(selectedData.data() + sizeof(MessageMetadata), mm.length); + + /* Skip this message if our backpressure is too high */ + if (maxBackpressure && ws->getBufferedAmount() > maxBackpressure) { + break; + } + + /* Here we perform the actual compression and framing */ + ws->send(unframedMessage, mm.opCode, mm.compress); + + /* Advance until empty */ + selectedData.remove_prefix(sizeof(MessageMetadata) + mm.length); + } + + /* Continue to next segment without executing below path */ + return; + } + } + + /* Common path for SHARED and DISABLED. It is an invalid assumption that we always are + * uncorked here, however the following (invalid) assumption is not critically wrong either way */ + + /* Note: this assumes we are not corked, as corking will swallow things and fail later on */ + auto [written, failed] = asyncSocket->write(selectedData.data(), (int) selectedData.length()); + /* If we want strict check for success, we can ignore this check if corked and repeat below + * when uncorking - however this is too strict as we really care about PROGRESS rather than + * ENTIRE SUCCESS - we need minor API changes to support correct checks */ + if (!failed) { + if (this->resetIdleTimeoutOnSend) { + auto *webSocketData = (WebSocketData *) us_socket_ext(SSL, (us_socket_t *) asyncSocket); + webSocketData->hasTimedOut = false; + asyncSocket->timeout(this->idleTimeoutComponents.first); + } + } + }); + } + + /* We are done sending, for whatever reasons we ended up corked while not starting with "wasCorked", + * we here need to uncork to restore the state we were called in */ + if (!wasCorked && asyncSocket->isCorked()) { + /* Regarding timeout for writes; */ + auto [written, failed] = asyncSocket->uncork(); + /* Again, this check should be more like DID WE PROGRESS rather than DID WE SUCCEED ENTIRELY */ + if (!failed) { + if (this->resetIdleTimeoutOnSend) { + auto *webSocketData = (WebSocketData *) us_socket_ext(SSL, (us_socket_t *) asyncSocket); + webSocketData->hasTimedOut = false; + asyncSocket->timeout(this->idleTimeoutComponents.first); + } } } + /* Defer a close if we now have (or already had) too much backpressure, or simply skip */ + if (maxBackpressure && closeOnBackpressureLimit && asyncSocket->getBufferedAmount() > maxBackpressure) { + /* We must not immediately close the socket, as that could result in stack overflow, + * iterator invalidation and other TopicTree::drain bugs. We may shutdown the reading side of the socket, + * causing next iteration to error-close the socket from that context instead, if we want to */ + us_socket_shutdown_read(SSL, (us_socket_t *) asyncSocket); + } + /* Reserved, unused */ return 0; }) { /* We empty for both pre and post just to make sure */ - Loop::get()->addPostHandler(this, [this](Loop *loop) { + Loop::get()->addPostHandler(this, [this](Loop */*loop*/) { /* Commit pub/sub batches every loop iteration */ topicTree.drain(); }); - Loop::get()->addPreHandler(this, [this](Loop *loop) { + Loop::get()->addPreHandler(this, [this](Loop */*loop*/) { /* Commit pub/sub batches every loop iteration */ topicTree.drain(); }); } /* Helper for topictree publish, common path from app and ws */ - void publish(std::string_view topic, std::string_view message, OpCode opCode, bool compress) { + bool publish(std::string_view topic, std::string_view message, OpCode opCode, bool compress, Subscriber *sender = nullptr) { + bool didMatch = false; + /* We frame the message right here and only pass raw bytes to the pub/subber */ char *dst = (char *) malloc(protocol::messageFrameSize(message.size())); size_t dst_length = protocol::formatMessage(dst, message.data(), message.length(), opCode, message.length(), false); - topicTree.publish(topic, std::string_view(dst, dst_length)); + /* If compression is disabled */ + if (compression == DISABLED) { + /* Leave second field empty as nobody will ever read it */ + didMatch |= topicTree.publish(topic, {std::string_view(dst, dst_length), {}}, sender); + } else { + /* DEDICATED_COMPRESSOR always takes the same path as must always have MessageMetadata as head */ + if (compress || compression != SHARED_COMPRESSOR) { + /* Shared compression mode publishes compressed, framed data */ + if (compression == SHARED_COMPRESSOR) { + /* Loop data holds shared compressor */ + LoopData *loopData = (LoopData *) us_loop_ext((us_loop_t *) Loop::get()); + + /* Compress it */ + std::string_view compressedMessage = loopData->deflationStream->deflate(loopData->zlibContext, message, true); + + /* Frame it */ + char *dst_compressed = (char *) malloc(protocol::messageFrameSize(compressedMessage.size())); + size_t dst_compressed_length = protocol::formatMessage(dst_compressed, compressedMessage.data(), compressedMessage.length(), opCode, compressedMessage.length(), true); + + /* Always publish the shortest one in any case */ + didMatch |= topicTree.publish(topic, {std::string_view(dst, dst_length), dst_compressed_length >= dst_length ? std::string_view(dst, dst_length) : std::string_view(dst_compressed, dst_compressed_length)}, sender); + + /* We don't care for allocation here */ + ::free(dst_compressed); + } else { + /* Dedicated compression mode publishes metadata + unframed uncompressed data */ + char *dst_compressed = (char *) malloc(message.length() + sizeof(MessageMetadata)); + + MessageMetadata mm( + (unsigned int) message.length(), + opCode, + compress + ); + + memcpy(dst_compressed, (char *) &mm, sizeof(MessageMetadata)); + memcpy(dst_compressed + sizeof(MessageMetadata), message.data(), message.length()); + + /* Interpretation of compressed data depends on what compressor we use */ + didMatch |= topicTree.publish(topic, { + std::string_view(dst, dst_length), + std::string_view(dst_compressed, message.length() + sizeof(MessageMetadata)) + }, sender); + + ::free(dst_compressed); + } + } else { + /* If not compressing, put same message on both tracks (only valid for SHARED_COMPRESSOR). + * DEDICATED_COMPRESSOR_xKB must never end up here as we don't put a proper head here. */ + didMatch |= topicTree.publish(topic, {std::string_view(dst, dst_length), std::string_view(dst, dst_length)}, sender); + } + } + ::free(dst); + + return didMatch; } }; diff --git a/src/Vendor/uwebsockets/include/uws/WebSocketData.h b/src/Vendor/uwebsockets/include/uws/WebSocketData.h index 0f2e59d..d0cea6f 100644 --- a/src/Vendor/uwebsockets/include/uws/WebSocketData.h +++ b/src/Vendor/uwebsockets/include/uws/WebSocketData.h @@ -1,5 +1,5 @@ /* - * Authored by Alex Hultman, 2018-2019. + * Authored by Alex Hultman, 2018-2020. * Intellectual property of third-party. * Licensed under the Apache License, Version 2.0 (the "License"); @@ -21,18 +21,23 @@ #include "WebSocketProtocol.h" #include "AsyncSocketData.h" #include "PerMessageDeflate.h" +#include "TopicTree.h" #include namespace uWS { struct WebSocketData : AsyncSocketData, WebSocketState { - template friend struct WebSocketContext; - template friend struct WebSocket; + /* This guy has a lot of friends - why? */ + template friend struct WebSocketContext; + template friend struct WebSocketContextData; + template friend struct WebSocket; + template friend struct HttpContext; private: std::string fragmentBuffer; - int controlTipLength = 0; + unsigned int controlTipLength = 0; bool isShuttingDown = 0; + bool hasTimedOut = false; enum CompressionStatus : char { DISABLED, ENABLED, @@ -45,12 +50,12 @@ private: /* We could be a subscriber */ Subscriber *subscriber = nullptr; public: - WebSocketData(bool perMessageDeflate, bool slidingCompression, std::string &&backpressure) : AsyncSocketData(std::move(backpressure)), WebSocketState() { + WebSocketData(bool perMessageDeflate, CompressOptions compressOptions, std::string &&backpressure) : AsyncSocketData(std::move(backpressure)), WebSocketState() { compressionStatus = perMessageDeflate ? ENABLED : DISABLED; /* Initialize the dedicated sliding window */ - if (perMessageDeflate && slidingCompression) { - deflationStream = new DeflationStream; + if (perMessageDeflate && (compressOptions != CompressOptions::SHARED_COMPRESSOR)) { + deflationStream = new DeflationStream(compressOptions); } } diff --git a/src/Vendor/uwebsockets/include/uws/WebSocketExtensions.h b/src/Vendor/uwebsockets/include/uws/WebSocketExtensions.h index 9368ec9..93fd5df 100644 --- a/src/Vendor/uwebsockets/include/uws/WebSocketExtensions.h +++ b/src/Vendor/uwebsockets/include/uws/WebSocketExtensions.h @@ -1,5 +1,5 @@ /* - * Authored by Alex Hultman, 2018-2019. + * Authored by Alex Hultman, 2018-2021. * Intellectual property of third-party. * Licensed under the Apache License, Version 2.0 (the "License"); @@ -18,26 +18,35 @@ #ifndef UWS_WEBSOCKETEXTENSIONS_H #define UWS_WEBSOCKETEXTENSIONS_H +/* There is a new, huge bug scenario that needs to be fixed: + * pub/sub does not support being in DEDICATED_COMPRESSOR-mode while having + * some clients downgraded to SHARED_COMPRESSOR - we cannot allow the client to + * demand a downgrade to SHARED_COMPRESSOR (yet) until we fix that scenario in pub/sub */ +// #define UWS_ALLOW_SHARED_AND_DEDICATED_COMPRESSOR_MIX + +/* We forbid negotiating 8 windowBits since Zlib has a bug with this */ +// #define UWS_ALLOW_8_WINDOW_BITS + #include +#include +#include #include +#include namespace uWS { -enum Options : unsigned int { - NO_OPTIONS = 0, - PERMESSAGE_DEFLATE = 1, - SERVER_NO_CONTEXT_TAKEOVER = 2, // remove this - CLIENT_NO_CONTEXT_TAKEOVER = 4, // remove this - NO_DELAY = 8, - SLIDING_DEFLATE_WINDOW = 16 -}; - enum ExtensionTokens { + /* Standard permessage-deflate tokens */ TOK_PERMESSAGE_DEFLATE = 1838, TOK_SERVER_NO_CONTEXT_TAKEOVER = 2807, TOK_CLIENT_NO_CONTEXT_TAKEOVER = 2783, TOK_SERVER_MAX_WINDOW_BITS = 2372, - TOK_CLIENT_MAX_WINDOW_BITS = 2348 + TOK_CLIENT_MAX_WINDOW_BITS = 2348, + /* Non-standard alias for Safari */ + TOK_X_WEBKIT_DEFLATE_FRAME = 2149, + TOK_NO_CONTEXT_TAKEOVER = 2049, + TOK_MAX_WINDOW_BITS = 1614 + }; struct ExtensionsParser { @@ -45,12 +54,18 @@ private: int *lastInteger = nullptr; public: + /* Standard */ bool perMessageDeflate = false; bool serverNoContextTakeover = false; bool clientNoContextTakeover = false; int serverMaxWindowBits = 0; int clientMaxWindowBits = 0; + /* Non-standard Safari */ + bool xWebKitDeflateFrame = false; + bool noContextTakeover = false; + int maxWindowBits = 0; + int getToken(const char *&in, const char *stop) { while (in != stop && !isalnum(*in)) { in++; @@ -78,12 +93,28 @@ public: ExtensionsParser(const char *data, size_t length) { const char *stop = data + length; int token = 1; - for (; token && token != TOK_PERMESSAGE_DEFLATE; token = getToken(data, stop)); + /* Ignore anything before permessage-deflate or x-webkit-deflate-frame */ + for (; token && token != TOK_PERMESSAGE_DEFLATE && token != TOK_X_WEBKIT_DEFLATE_FRAME; token = getToken(data, stop)); + + /* What protocol are we going to use? */ perMessageDeflate = (token == TOK_PERMESSAGE_DEFLATE); + xWebKitDeflateFrame = (token == TOK_X_WEBKIT_DEFLATE_FRAME); + while ((token = getToken(data, stop))) { switch (token) { + case TOK_X_WEBKIT_DEFLATE_FRAME: + /* Duplicates not allowed/supported */ + return; + case TOK_NO_CONTEXT_TAKEOVER: + noContextTakeover = true; + break; + case TOK_MAX_WINDOW_BITS: + maxWindowBits = 1; + lastInteger = &maxWindowBits; + break; case TOK_PERMESSAGE_DEFLATE: + /* Duplicates not allowed/supported */ return; case TOK_SERVER_NO_CONTEXT_TAKEOVER: serverNoContextTakeover = true; @@ -109,60 +140,116 @@ public: } }; -template -struct ExtensionsNegotiator { -protected: - int options; +/* Takes what we (the server) wants, returns what we got */ +static inline std::tuple negotiateCompression(bool wantCompression, int wantedCompressionWindow, int wantedInflationWindow, std::string_view offer) { -public: - ExtensionsNegotiator(int wantedOptions) { - options = wantedOptions; + /* If we don't want compression then we are done here */ + if (!wantCompression) { + return {false, 0, 0, ""}; } - std::string generateOffer() { - std::string extensionsOffer; - if (options & Options::PERMESSAGE_DEFLATE) { - extensionsOffer += "permessage-deflate"; + ExtensionsParser ep(offer.data(), offer.length()); - if (options & Options::CLIENT_NO_CONTEXT_TAKEOVER) { - extensionsOffer += "; client_no_context_takeover"; + static thread_local std::string response; + response = ""; + + int compressionWindow = wantedCompressionWindow; + int inflationWindow = wantedInflationWindow; + bool compression = false; + + if (ep.xWebKitDeflateFrame) { + /* We now have compression */ + compression = true; + response = "x-webkit-deflate-frame"; + + /* If the other peer has DEMANDED us no sliding window, + * we cannot compress with anything other than shared compressor */ + if (ep.noContextTakeover) { + /* We must fail here right now (fix pub/sub) */ +#ifndef UWS_ALLOW_SHARED_AND_DEDICATED_COMPRESSOR_MIX + if (wantedCompressionWindow != 0) { + return {false, 0, 0, ""}; } +#endif - /* It is questionable sending this improves anything */ - /*if (options & Options::SERVER_NO_CONTEXT_TAKEOVER) { - extensionsOffer += "; server_no_context_takeover"; - }*/ + compressionWindow = 0; } - return extensionsOffer; - } + /* If the other peer has DEMANDED us to use a limited sliding window, + * we have to limit out compression sliding window */ + if (ep.maxWindowBits && ep.maxWindowBits < compressionWindow) { + compressionWindow = ep.maxWindowBits; +#ifndef UWS_ALLOW_8_WINDOW_BITS + /* We cannot really deny this, so we have to disable compression in this case */ + if (compressionWindow == 8) { + return {false, 0, 0, ""}; + } +#endif + } - void readOffer(std::string_view offer) { - if (isServer) { - ExtensionsParser extensionsParser(offer.data(), offer.length()); - if ((options & PERMESSAGE_DEFLATE) && extensionsParser.perMessageDeflate) { - if (extensionsParser.clientNoContextTakeover || (options & CLIENT_NO_CONTEXT_TAKEOVER)) { - options |= CLIENT_NO_CONTEXT_TAKEOVER; - } - - /* We leave this option for us to read even if the client did not send it */ - if (extensionsParser.serverNoContextTakeover) { - options |= SERVER_NO_CONTEXT_TAKEOVER; - }/* else { - options &= ~SERVER_NO_CONTEXT_TAKEOVER; - }*/ + /* We decide our own inflation sliding window (and their compression sliding window) */ + if (wantedInflationWindow < 15) { + if (!wantedInflationWindow) { + response += "; no_context_takeover"; } else { - options &= ~PERMESSAGE_DEFLATE; + response += "; max_window_bits=" + std::to_string(wantedInflationWindow); + } + } + } else if (ep.perMessageDeflate) { + /* We now have compression */ + compression = true; + response = "permessage-deflate"; + + if (ep.clientNoContextTakeover) { + inflationWindow = 0; + } else if (ep.clientMaxWindowBits && ep.clientMaxWindowBits != 1) { + inflationWindow = std::min(ep.clientMaxWindowBits, inflationWindow); + } + + /* Whatever we have now, write */ + if (inflationWindow < 15) { + if (!inflationWindow || !ep.clientMaxWindowBits) { + response += "; client_no_context_takeover"; + inflationWindow = 0; + } else { + response += "; client_max_window_bits=" + std::to_string(inflationWindow); + } + } + + /* This block basically lets the client lower it */ + if (ep.serverNoContextTakeover) { + /* This is an important (temporary) fix since we haven't allowed + * these two modes to mix, and pub/sub will not handle this case (yet) */ +#ifdef UWS_ALLOW_SHARED_AND_DEDICATED_COMPRESSOR_MIX + compressionWindow = 0; +#endif + } else if (ep.serverMaxWindowBits) { + compressionWindow = std::min(ep.serverMaxWindowBits, compressionWindow); +#ifndef UWS_ALLOW_8_WINDOW_BITS + /* Zlib cannot do windowBits=8, memLevel=1 so we raise it up to 9 minimum */ + if (compressionWindow == 8) { + compressionWindow = 9; + } +#endif + } + + /* Whatever we have now, write */ + if (compressionWindow < 15) { + if (!compressionWindow) { + response += "; server_no_context_takeover"; + } else { + response += "; server_max_window_bits=" + std::to_string(compressionWindow); } - } else { - // todo! } } - int getNegotiatedOptions() { - return options; + /* A final sanity check (this check does not actually catch too high values!) */ + if ((compressionWindow && compressionWindow < 8) || compressionWindow > 15 || (inflationWindow && inflationWindow < 8) || inflationWindow > 15) { + return {false, 0, 0, ""}; } -}; + + return {compression, compressionWindow, inflationWindow, response}; +} } diff --git a/src/Vendor/uwebsockets/include/uws/WebSocketHandshake.h b/src/Vendor/uwebsockets/include/uws/WebSocketHandshake.h index a539c75..808415e 100644 --- a/src/Vendor/uwebsockets/include/uws/WebSocketHandshake.h +++ b/src/Vendor/uwebsockets/include/uws/WebSocketHandshake.h @@ -1,5 +1,5 @@ /* - * Authored by Alex Hultman, 2018-2019. + * Authored by Alex Hultman, 2018-2020. * Intellectual property of third-party. * Licensed under the Apache License, Version 2.0 (the "License"); @@ -34,57 +34,68 @@ struct WebSocketHandshake { template struct static_for<0, T> { - void operator()(uint32_t *a, uint32_t *hash) {} + void operator()(uint32_t */*a*/, uint32_t */*hash*/) {} }; - template - struct Sha1Loop { - static inline uint32_t rol(uint32_t value, size_t bits) {return (value << bits) | (value >> (32 - bits));} - static inline uint32_t blk(uint32_t b[16], size_t i) { - return rol(b[(i + 13) & 15] ^ b[(i + 8) & 15] ^ b[(i + 2) & 15] ^ b[i], 1); - } + static inline uint32_t rol(uint32_t value, size_t bits) {return (value << bits) | (value >> (32 - bits));} + static inline uint32_t blk(uint32_t b[16], size_t i) { + return rol(b[(i + 13) & 15] ^ b[(i + 8) & 15] ^ b[(i + 2) & 15] ^ b[i], 1); + } + struct Sha1Loop1 { template static inline void f(uint32_t *a, uint32_t *b) { - switch (state) { - case 1: - a[i % 5] += ((a[(3 + i) % 5] & (a[(2 + i) % 5] ^ a[(1 + i) % 5])) ^ a[(1 + i) % 5]) + b[i] + 0x5a827999 + rol(a[(4 + i) % 5], 5); - a[(3 + i) % 5] = rol(a[(3 + i) % 5], 30); - break; - case 2: - b[i] = blk(b, i); - a[(1 + i) % 5] += ((a[(4 + i) % 5] & (a[(3 + i) % 5] ^ a[(2 + i) % 5])) ^ a[(2 + i) % 5]) + b[i] + 0x5a827999 + rol(a[(5 + i) % 5], 5); - a[(4 + i) % 5] = rol(a[(4 + i) % 5], 30); - break; - case 3: - b[(i + 4) % 16] = blk(b, (i + 4) % 16); - a[i % 5] += (a[(3 + i) % 5] ^ a[(2 + i) % 5] ^ a[(1 + i) % 5]) + b[(i + 4) % 16] + 0x6ed9eba1 + rol(a[(4 + i) % 5], 5); - a[(3 + i) % 5] = rol(a[(3 + i) % 5], 30); - break; - case 4: - b[(i + 8) % 16] = blk(b, (i + 8) % 16); - a[i % 5] += (((a[(3 + i) % 5] | a[(2 + i) % 5]) & a[(1 + i) % 5]) | (a[(3 + i) % 5] & a[(2 + i) % 5])) + b[(i + 8) % 16] + 0x8f1bbcdc + rol(a[(4 + i) % 5], 5); - a[(3 + i) % 5] = rol(a[(3 + i) % 5], 30); - break; - case 5: - b[(i + 12) % 16] = blk(b, (i + 12) % 16); - a[i % 5] += (a[(3 + i) % 5] ^ a[(2 + i) % 5] ^ a[(1 + i) % 5]) + b[(i + 12) % 16] + 0xca62c1d6 + rol(a[(4 + i) % 5], 5); - a[(3 + i) % 5] = rol(a[(3 + i) % 5], 30); - break; - case 6: - b[i] += a[4 - i]; - } + a[i % 5] += ((a[(3 + i) % 5] & (a[(2 + i) % 5] ^ a[(1 + i) % 5])) ^ a[(1 + i) % 5]) + b[i] + 0x5a827999 + rol(a[(4 + i) % 5], 5); + a[(3 + i) % 5] = rol(a[(3 + i) % 5], 30); + } + }; + struct Sha1Loop2 { + template + static inline void f(uint32_t *a, uint32_t *b) { + b[i] = blk(b, i); + a[(1 + i) % 5] += ((a[(4 + i) % 5] & (a[(3 + i) % 5] ^ a[(2 + i) % 5])) ^ a[(2 + i) % 5]) + b[i] + 0x5a827999 + rol(a[(5 + i) % 5], 5); + a[(4 + i) % 5] = rol(a[(4 + i) % 5], 30); + } + }; + struct Sha1Loop3 { + template + static inline void f(uint32_t *a, uint32_t *b) { + b[(i + 4) % 16] = blk(b, (i + 4) % 16); + a[i % 5] += (a[(3 + i) % 5] ^ a[(2 + i) % 5] ^ a[(1 + i) % 5]) + b[(i + 4) % 16] + 0x6ed9eba1 + rol(a[(4 + i) % 5], 5); + a[(3 + i) % 5] = rol(a[(3 + i) % 5], 30); + } + }; + struct Sha1Loop4 { + template + static inline void f(uint32_t *a, uint32_t *b) { + b[(i + 8) % 16] = blk(b, (i + 8) % 16); + a[i % 5] += (((a[(3 + i) % 5] | a[(2 + i) % 5]) & a[(1 + i) % 5]) | (a[(3 + i) % 5] & a[(2 + i) % 5])) + b[(i + 8) % 16] + 0x8f1bbcdc + rol(a[(4 + i) % 5], 5); + a[(3 + i) % 5] = rol(a[(3 + i) % 5], 30); + } + }; + struct Sha1Loop5 { + template + static inline void f(uint32_t *a, uint32_t *b) { + b[(i + 12) % 16] = blk(b, (i + 12) % 16); + a[i % 5] += (a[(3 + i) % 5] ^ a[(2 + i) % 5] ^ a[(1 + i) % 5]) + b[(i + 12) % 16] + 0xca62c1d6 + rol(a[(4 + i) % 5], 5); + a[(3 + i) % 5] = rol(a[(3 + i) % 5], 30); + } + }; + struct Sha1Loop6 { + template + static inline void f(uint32_t *a, uint32_t *b) { + b[i] += a[4 - i]; } }; static inline void sha1(uint32_t hash[5], uint32_t b[16]) { uint32_t a[5] = {hash[4], hash[3], hash[2], hash[1], hash[0]}; - static_for<16, Sha1Loop<1>>()(a, b); - static_for<4, Sha1Loop<2>>()(a, b); - static_for<20, Sha1Loop<3>>()(a, b); - static_for<20, Sha1Loop<4>>()(a, b); - static_for<20, Sha1Loop<5>>()(a, b); - static_for<5, Sha1Loop<6>>()(a, hash); + static_for<16, Sha1Loop1>()(a, b); + static_for<4, Sha1Loop2>()(a, b); + static_for<20, Sha1Loop3>()(a, b); + static_for<20, Sha1Loop4>()(a, b); + static_for<20, Sha1Loop5>()(a, b); + static_for<5, Sha1Loop6>()(a, hash); } static inline void base64(unsigned char *src, char *dst) { @@ -112,7 +123,7 @@ public: }; for (int i = 0; i < 6; i++) { - b_input[i] = (input[4 * i + 3] & 0xff) | (input[4 * i + 2] & 0xff) << 8 | (input[4 * i + 1] & 0xff) << 16 | (input[4 * i + 0] & 0xff) << 24; + b_input[i] = (uint32_t) ((input[4 * i + 3] & 0xff) | (input[4 * i + 2] & 0xff) << 8 | (input[4 * i + 1] & 0xff) << 16 | (input[4 * i + 0] & 0xff) << 24); } sha1(b_output, b_input); uint32_t last_b[16] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 480}; diff --git a/src/Vendor/uwebsockets/include/uws/WebSocketProtocol.h b/src/Vendor/uwebsockets/include/uws/WebSocketProtocol.h index cb9106a..db282e4 100644 --- a/src/Vendor/uwebsockets/include/uws/WebSocketProtocol.h +++ b/src/Vendor/uwebsockets/include/uws/WebSocketProtocol.h @@ -1,5 +1,5 @@ /* - * Authored by Alex Hultman, 2018-2019. + * Authored by Alex Hultman, 2018-2020. * Intellectual property of third-party. * Licensed under the Apache License, Version 2.0 (the "License"); @@ -21,9 +21,17 @@ #include #include #include +#include namespace uWS { +/* We should not overcomplicate these */ +const std::string_view ERR_TOO_BIG_MESSAGE("Received too big message"); +const std::string_view ERR_WEBSOCKET_TIMEOUT("WebSocket timed out from inactivity"); +const std::string_view ERR_INVALID_TEXT("Received invalid UTF-8"); +const std::string_view ERR_TOO_BIG_MESSAGE_INFLATION("Received too big message, or other inflation error"); +const std::string_view ERR_INVALID_CLOSE_PAYLOAD("Received invalid close payload"); + enum OpCode : unsigned char { TEXT = 1, BINARY = 2, @@ -49,7 +57,7 @@ public: struct State { unsigned int wantsHead : 1; unsigned int spillLength : 4; - int opStack : 2; // -1, 0, 1 + signed int opStack : 2; // -1, 0, 1 unsigned int lastFin : 1; // 15 bytes @@ -151,20 +159,23 @@ struct CloseFrame { }; static inline CloseFrame parseClosePayload(char *src, size_t length) { - CloseFrame cf = {}; + /* If we get no code or message, default to reporting 1005 no status code present */ + CloseFrame cf = {1005, nullptr, 0}; if (length >= 2) { memcpy(&cf.code, src, 2); cf = {cond_byte_swap(cf.code), src + 2, length - 2}; if (cf.code < 1000 || cf.code > 4999 || (cf.code > 1011 && cf.code < 4000) || (cf.code >= 1004 && cf.code <= 1006) || !isValidUtf8((unsigned char *) cf.message, cf.length)) { - return {}; + /* Even though we got a WebSocket close frame, it in itself is abnormal */ + return {1006, nullptr, 0}; } } return cf; } static inline size_t formatClosePayload(char *dst, uint16_t code, const char *message, size_t length) { - if (code) { + /* We could have more strict checks here, but never append code 0 or 1005 or 1006 */ + if (code && code != 1005 && code != 1006) { code = cond_byte_swap(code); memcpy(dst, &code, 2); /* It is invalid to pass nullptr to memcpy, even though length is 0 */ @@ -219,7 +230,7 @@ static inline size_t formatMessage(char *dst, const char *src, size_t length, Op char mask[4]; if (!isServer) { dst[1] |= 0x80; - uint32_t random = rand(); + uint32_t random = (uint32_t) rand(); memcpy(mask, &random, 4); memcpy(dst + headerLength, &random, 4); headerLength += 4; @@ -307,7 +318,7 @@ protected: wState->state.lastFin = isFin(src); if (Impl::refusePayloadLength(payLength, wState, user)) { - Impl::forceClose(wState, user); + Impl::forceClose(wState, user, ERR_TOO_BIG_MESSAGE); return true; } @@ -351,9 +362,9 @@ protected: static inline bool consumeContinuation(char *&src, unsigned int &length, WebSocketState *wState, void *user) { if (wState->remainingBytes <= length) { if (isServer) { - int n = wState->remainingBytes >> 2; + unsigned int n = wState->remainingBytes >> 2; unmaskInplace(src, src + n * 4, wState->mask); - for (int i = 0, s = wState->remainingBytes % 4; i < s; i++) { + for (unsigned int i = 0, s = wState->remainingBytes % 4; i < s; i++) { src[n * 4 + i] ^= wState->mask[i]; } } diff --git a/src/Vendor/uwebsockets/include/uws/f2/LICENSE.txt b/src/Vendor/uwebsockets/include/uws/f2/LICENSE.txt deleted file mode 100644 index 36b7cd9..0000000 --- a/src/Vendor/uwebsockets/include/uws/f2/LICENSE.txt +++ /dev/null @@ -1,23 +0,0 @@ -Boost Software License - Version 1.0 - August 17th, 2003 - -Permission is hereby granted, free of charge, to any person or organization -obtaining a copy of the software and accompanying documentation covered by -this license (the "Software") to use, reproduce, display, distribute, -execute, and transmit the Software, and to prepare derivative works of the -Software, and to permit third-parties to whom the Software is furnished to -do so, all subject to the following: - -The copyright notices in the Software and this entire statement, including -the above license grant, this restriction and the following disclaimer, -must be included in all copies of the Software, in whole or in part, and -all derivative works of the Software, unless such copies or derivative -works are solely in the form of machine-executable object code generated by -a source language processor. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT -SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE -FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE, -ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -DEALINGS IN THE SOFTWARE. diff --git a/src/Vendor/uwebsockets/include/uws/f2/function2.hpp b/src/Vendor/uwebsockets/include/uws/f2/function2.hpp deleted file mode 100644 index dc81c7d..0000000 --- a/src/Vendor/uwebsockets/include/uws/f2/function2.hpp +++ /dev/null @@ -1,1764 +0,0 @@ - -// Copyright 2015-2019 Denis Blank -// Distributed under the Boost Software License, Version 1.0 -// (See accompanying file LICENSE_1_0.txt or copy at -// http://www.boost.org/LICENSE_1_0.txt) - -#ifndef FU2_INCLUDED_FUNCTION2_HPP_ -#define FU2_INCLUDED_FUNCTION2_HPP_ - -#include -#include -#include -#include -#include -#include -#include - -// Defines: -// - FU2_HAS_DISABLED_EXCEPTIONS -#if defined(FU2_WITH_DISABLED_EXCEPTIONS) || \ - defined(FU2_MACRO_DISABLE_EXCEPTIONS) -#define FU2_HAS_DISABLED_EXCEPTIONS -#else // FU2_WITH_DISABLED_EXCEPTIONS -#if defined(_MSC_VER) -#if !defined(_HAS_EXCEPTIONS) || (_HAS_EXCEPTIONS == 0) -#define FU2_HAS_DISABLED_EXCEPTIONS -#endif -#elif defined(__clang__) -#if !(__EXCEPTIONS && __has_feature(cxx_exceptions)) -#define FU2_HAS_DISABLED_EXCEPTIONS -#endif -#elif defined(__GNUC__) -#if !__EXCEPTIONS -#define FU2_HAS_DISABLED_EXCEPTIONS -#endif -#endif -#endif // FU2_WITH_DISABLED_EXCEPTIONS -// - FU2_HAS_NO_FUNCTIONAL_HEADER -#if !defined(FU2_WITH_NO_FUNCTIONAL_HEADER) && \ - !defined(FU2_NO_FUNCTIONAL_HEADER) && \ - !defined(FU2_HAS_DISABLED_EXCEPTIONS) -#include -#else -#define FU2_HAS_NO_FUNCTIONAL_HEADER -#endif -// - FU2_HAS_CXX17_NOEXCEPT_FUNCTION_TYPE -#if defined(FU2_WITH_CXX17_NOEXCEPT_FUNCTION_TYPE) -#define FU2_HAS_CXX17_NOEXCEPT_FUNCTION_TYPE -#else // FU2_WITH_CXX17_NOEXCEPT_FUNCTION_TYPE -#if defined(_MSC_VER) -#if defined(_HAS_CXX17) && _HAS_CXX17 -#define FU2_HAS_CXX17_NOEXCEPT_FUNCTION_TYPE -#endif -#elif defined(__cpp_noexcept_function_type) -#define FU2_HAS_CXX17_NOEXCEPT_FUNCTION_TYPE -#elif defined(__cplusplus) && (__cplusplus >= 201703L) -#define FU2_HAS_CXX17_NOEXCEPT_FUNCTION_TYPE -#endif -#endif // FU2_WITH_CXX17_NOEXCEPT_FUNCTION_TYPE - -// - FU2_HAS_NO_EMPTY_PROPAGATION -#if defined(FU2_WITH_NO_EMPTY_PROPAGATION) -#define FU2_HAS_NO_EMPTY_PROPAGATION -#endif // FU2_WITH_NO_EMPTY_PROPAGATION - -#if !defined(FU2_HAS_DISABLED_EXCEPTIONS) -#include -#endif - -/// Hint for the compiler that this point should be unreachable -#if defined(_MSC_VER) -// NOLINTNEXTLINE(cppcoreguidelines-macro-usage) -#define FU2_DETAIL_UNREACHABLE_INTRINSIC() __assume(false) -#elif defined(__GNUC__) -// NOLINTNEXTLINE(cppcoreguidelines-macro-usage) -#define FU2_DETAIL_UNREACHABLE_INTRINSIC() __builtin_unreachable() -#elif defined(__has_builtin) && __has_builtin(__builtin_unreachable) -// NOLINTNEXTLINE(cppcoreguidelines-macro-usage) -#define FU2_DETAIL_UNREACHABLE_INTRINSIC() __builtin_unreachable() -#else -// NOLINTNEXTLINE(cppcoreguidelines-macro-usage) -#define FU2_DETAIL_UNREACHABLE_INTRINSIC() abort() -#endif - -/// Causes the application to exit abnormally -#if defined(_MSC_VER) -// NOLINTNEXTLINE(cppcoreguidelines-macro-usage) -#define FU2_DETAIL_TRAP() __debugbreak() -#elif defined(__GNUC__) -// NOLINTNEXTLINE(cppcoreguidelines-macro-usage) -#define FU2_DETAIL_TRAP() __builtin_trap() -#elif defined(__has_builtin) && __has_builtin(__builtin_trap) -// NOLINTNEXTLINE(cppcoreguidelines-macro-usage) -#define FU2_DETAIL_TRAP() __builtin_trap() -#else -// NOLINTNEXTLINE(cppcoreguidelines-macro-usage) -#define FU2_DETAIL_TRAP() *(volatile int*)0x11 = 0 -#endif - -#ifndef NDEBUG -// NOLINTNEXTLINE(cppcoreguidelines-macro-usage) -#define FU2_DETAIL_UNREACHABLE() ::fu2::detail::unreachable_debug() -#else -// NOLINTNEXTLINE(cppcoreguidelines-macro-usage) -#define FU2_DETAIL_UNREACHABLE() FU2_DETAIL_UNREACHABLE_INTRINSIC() -#endif - -namespace fu2 { -inline namespace abi_400 { -namespace detail { -template -class function; - -template -struct identity {}; - -// Equivalent to C++17's std::void_t which targets a bug in GCC, -// that prevents correct SFINAE behavior. -// See http://stackoverflow.com/questions/35753920 for details. -template -struct deduce_to_void : std::common_type {}; - -template -using void_t = typename deduce_to_void::type; - -template -using unrefcv_t = std::remove_cv_t>; - -// Copy enabler helper class -template -struct copyable {}; -template <> -struct copyable { - copyable() = default; - ~copyable() = default; - copyable(copyable const&) = delete; - copyable(copyable&&) = default; - copyable& operator=(copyable const&) = delete; - copyable& operator=(copyable&&) = default; -}; - -/// Configuration trait to configure the function_base class. -template -struct config { - // Is true if the function is owning. - static constexpr auto const is_owning = Owning; - - // Is true if the function is copyable. - static constexpr auto const is_copyable = Copyable; - - // The internal capacity of the function - // used in small functor optimization. - // The object shall expose the real capacity through Capacity::capacity - // and the intended alignment through Capacity::alignment. - using capacity = Capacity; -}; - -/// A config which isn't compatible to other configs -template -struct property { - // Is true when the function throws an exception on empty invocation. - static constexpr auto const is_throwing = Throws; - - // Is true when the function throws an exception on empty invocation. - static constexpr auto const is_strong_exception_guaranteed = - HasStrongExceptGuarantee; -}; - -#ifndef NDEBUG -[[noreturn]] inline void unreachable_debug() { - FU2_DETAIL_TRAP(); - std::abort(); -} -#endif - -/// Provides utilities for invocing callable objects -namespace invocation { -/// Invokes the given callable object with the given arguments -template -constexpr auto invoke(Callable&& callable, Args&&... args) noexcept( - noexcept(std::forward(callable)(std::forward(args)...))) - -> decltype(std::forward(callable)(std::forward(args)...)) { - - return std::forward(callable)(std::forward(args)...); -} -/// Invokes the given member function pointer by reference -template -constexpr auto invoke(Type T::*member, Self&& self, Args&&... args) noexcept( - noexcept((std::forward(self).*member)(std::forward(args)...))) - -> decltype((std::forward(self).* - member)(std::forward(args)...)) { - return (std::forward(self).*member)(std::forward(args)...); -} -/// Invokes the given member function pointer by pointer -template -constexpr auto invoke(Type T::*member, Self&& self, Args&&... args) noexcept( - noexcept((std::forward(self)->*member)(std::forward(args)...))) - -> decltype( - (std::forward(self)->*member)(std::forward(args)...)) { - return (std::forward(self)->*member)(std::forward(args)...); -} -/// Invokes the given pointer to a scalar member by reference -template -constexpr auto -invoke(Type T::*member, - Self&& self) noexcept(noexcept(std::forward(self).*member)) - -> decltype(std::forward(self).*member) { - return (std::forward(self).*member); -} -/// Invokes the given pointer to a scalar member by pointer -template -constexpr auto -invoke(Type T::*member, - Self&& self) noexcept(noexcept(std::forward(self)->*member)) - -> decltype(std::forward(self)->*member) { - return std::forward(self)->*member; -} - -/// Deduces to a true type if the callable object can be invoked with -/// the given arguments. -/// We don't use invoke here because MSVC can't evaluate the nested expression -/// SFINAE here. -template -struct can_invoke : std::false_type {}; -template -struct can_invoke, - decltype((void)std::declval()(std::declval()...))> - : std::true_type {}; -template -struct can_invoke, - decltype((void)((std::declval().*std::declval())( - std::declval()...)))> : std::true_type {}; -template -struct can_invoke, - decltype( - (void)((std::declval().*std::declval())( - std::declval()...)))> : std::true_type {}; -template -struct can_invoke, - decltype( - (void)((std::declval()->*std::declval())( - std::declval()...)))> : std::true_type {}; -template -struct can_invoke, - decltype((void)(std::declval().*std::declval()))> - : std::true_type {}; -template -struct can_invoke, - decltype( - (void)(std::declval().*std::declval()))> - : std::true_type {}; -template -struct can_invoke, - decltype( - (void)(std::declval()->*std::declval()))> - : std::true_type {}; - -template -struct is_noexcept_correct : std::true_type {}; -template -struct is_noexcept_correct> - : std::integral_constant(), - std::declval()...))> { -}; -} // end namespace invocation - -namespace overloading { -template -struct overload_impl; -template -struct overload_impl : Current, - overload_impl { - explicit overload_impl(Current current, Next next, Rest... rest) - : Current(std::move(current)), overload_impl( - std::move(next), std::move(rest)...) { - } - - using Current::operator(); - using overload_impl::operator(); -}; -template -struct overload_impl : Current { - explicit overload_impl(Current current) : Current(std::move(current)) { - } - - using Current::operator(); -}; - -template -constexpr auto overload(T&&... callables) { - return overload_impl...>{std::forward(callables)...}; -} -} // namespace overloading - -/// Declares the namespace which provides the functionality to work with a -/// type-erased object. -namespace type_erasure { -/// Specialization to work with addresses of callable objects -template -struct address_taker { - template - static void* take(O&& obj) { - return std::addressof(obj); - } - static T& restore(void* ptr) { - return *static_cast(ptr); - } - static T const& restore(void const* ptr) { - return *static_cast(ptr); - } - static T volatile& restore(void volatile* ptr) { - return *static_cast(ptr); - } - static T const volatile& restore(void const volatile* ptr) { - return *static_cast(ptr); - } -}; -/// Specialization to work with addresses of raw function pointers -template -struct address_taker::value>> { - template - static void* take(O&& obj) { - return reinterpret_cast(obj); - } - template - static T restore(O ptr) { - return reinterpret_cast(const_cast(ptr)); - } -}; - -template -struct box_factory; -/// Store the allocator inside the box -template -struct box : private Allocator { - friend box_factory; - - T value_; - - explicit box(T value, Allocator allocator) - : Allocator(std::move(allocator)), value_(std::move(value)) { - } - - box(box&&) = default; - box(box const&) = default; - box& operator=(box&&) = default; - box& operator=(box const&) = default; - ~box() = default; -}; -template -struct box : private Allocator { - friend box_factory; - - T value_; - - explicit box(T value, Allocator allocator) - : Allocator(std::move(allocator)), value_(std::move(value)) { - } - - box(box&&) = default; - box(box const&) = delete; - box& operator=(box&&) = default; - box& operator=(box const&) = delete; - ~box() = default; -}; - -template -struct box_factory> { - using real_allocator = - typename std::allocator_traits>:: - template rebind_alloc>; - - /// Allocates space through the boxed allocator - static box* - box_allocate(box const* me) { - real_allocator allocator(*static_cast(me)); - - return static_cast*>( - std::allocator_traits::allocate(allocator, 1U)); - } - - /// Destroys the box through the given allocator - static void box_deallocate(box* me) { - real_allocator allocator(*static_cast(me)); - - me->~box(); - std::allocator_traits::deallocate(allocator, me, 1U); - } -}; - -/// Creates a box containing the given value and allocator -template -auto make_box(std::integral_constant, T&& value, - Allocator&& allocator) { - return box, std::decay_t>( - std::forward(value), std::forward(allocator)); -} - -template -struct is_box : std::false_type {}; -template -struct is_box> : std::true_type {}; - -/// Provides access to the pointer to a heal allocated erased object -/// as well to the inplace storage. -union data_accessor { - data_accessor() = default; - explicit constexpr data_accessor(std::nullptr_t) noexcept : ptr_(nullptr) { - } - explicit constexpr data_accessor(void* ptr) noexcept : ptr_(ptr) { - } - - /// The pointer we use if the object is on the heap - void* ptr_; - /// The first field of the inplace storage - std::size_t inplace_storage_; -}; - -/// See opcode::op_fetch_empty -constexpr void write_empty(data_accessor* accessor, bool empty) noexcept { - accessor->inplace_storage_ = std::size_t(empty); -} - -template -using transfer_const_t = - std::conditional_t>::value, - std::add_const_t, To>; -template -using transfer_volatile_t = - std::conditional_t>::value, - std::add_volatile_t, To>; - -/// The retriever when the object is allocated inplace -template -constexpr auto retrieve(std::true_type /*is_inplace*/, Accessor from, - std::size_t from_capacity) { - using type = transfer_const_t>*; - - /// Process the command by using the data inside the internal capacity - auto storage = &(from->inplace_storage_); - auto inplace = const_cast(static_cast(storage)); - return type(std::align(alignof(T), sizeof(T), inplace, from_capacity)); -} - -/// The retriever which is used when the object is allocated -/// through the allocator -template -constexpr auto retrieve(std::false_type /*is_inplace*/, Accessor from, - std::size_t /*from_capacity*/) { - - return from->ptr_; -} - -namespace invocation_table { -#if !defined(FU2_HAS_DISABLED_EXCEPTIONS) -#if defined(FU2_HAS_NO_FUNCTIONAL_HEADER) -struct bad_function_call : std::exception { - bad_function_call() noexcept { - } - - char const* what() const noexcept override { - return "bad function call"; - } -}; -#else -using std::bad_function_call; -#endif -#endif - -#ifdef FU2_HAS_CXX17_NOEXCEPT_FUNCTION_TYPE -#define FU2_DETAIL_EXPAND_QUALIFIERS_NOEXCEPT(F) \ - F(, , noexcept, , &) \ - F(const, , noexcept, , &) \ - F(, volatile, noexcept, , &) \ - F(const, volatile, noexcept, , &) \ - F(, , noexcept, &, &) \ - F(const, , noexcept, &, &) \ - F(, volatile, noexcept, &, &) \ - F(const, volatile, noexcept, &, &) \ - F(, , noexcept, &&, &&) \ - F(const, , noexcept, &&, &&) \ - F(, volatile, noexcept, &&, &&) \ - F(const, volatile, noexcept, &&, &&) -#define FU2_DETAIL_EXPAND_CV_NOEXCEPT(F) \ - F(, , noexcept) \ - F(const, , noexcept) \ - F(, volatile, noexcept) \ - F(const, volatile, noexcept) -#else // FU2_HAS_CXX17_NOEXCEPT_FUNCTION_TYPE -#define FU2_DETAIL_EXPAND_QUALIFIERS_NOEXCEPT(F) -#define FU2_DETAIL_EXPAND_CV_NOEXCEPT(F) -#endif // FU2_HAS_CXX17_NOEXCEPT_FUNCTION_TYPE - -#define FU2_DETAIL_EXPAND_QUALIFIERS(F) \ - F(, , , , &) \ - F(const, , , , &) \ - F(, volatile, , , &) \ - F(const, volatile, , , &) \ - F(, , , &, &) \ - F(const, , , &, &) \ - F(, volatile, , &, &) \ - F(const, volatile, , &, &) \ - F(, , , &&, &&) \ - F(const, , , &&, &&) \ - F(, volatile, , &&, &&) \ - F(const, volatile, , &&, &&) \ - FU2_DETAIL_EXPAND_QUALIFIERS_NOEXCEPT(F) -#define FU2_DETAIL_EXPAND_CV(F) \ - F(, , ) \ - F(const, , ) \ - F(, volatile, ) \ - F(const, volatile, ) \ - FU2_DETAIL_EXPAND_CV_NOEXCEPT(F) - -/// If the function is qualified as noexcept, the call will never throw -template -[[noreturn]] void throw_or_abortnoexcept( - std::integral_constant /*is_throwing*/) noexcept { - std::abort(); -} -/// Calls std::abort on empty function calls -[[noreturn]] inline void -throw_or_abort(std::false_type /*is_throwing*/) noexcept { - std::abort(); -} -/// Throws bad_function_call on empty funciton calls -[[noreturn]] inline void throw_or_abort(std::true_type /*is_throwing*/) { -#ifdef FU2_HAS_DISABLED_EXCEPTIONS - throw_or_abort(std::false_type{}); -#else - throw bad_function_call{}; -#endif -} - -template -struct function_trait; - -using is_noexcept_ = std::false_type; -using is_noexcept_noexcept = std::true_type; - -#define FU2_DEFINE_FUNCTION_TRAIT(CONST, VOLATILE, NOEXCEPT, OVL_REF, REF) \ - template \ - struct function_trait { \ - using pointer_type = Ret (*)(data_accessor CONST VOLATILE*, \ - std::size_t capacity, Args...); \ - template \ - struct internal_invoker { \ - static Ret invoke(data_accessor CONST VOLATILE* data, \ - std::size_t capacity, Args... args) NOEXCEPT { \ - auto obj = retrieve(std::integral_constant{}, \ - data, capacity); \ - auto box = static_cast(obj); \ - return invocation::invoke( \ - static_castvalue_)> CONST VOLATILE \ - REF>(box->value_), \ - std::forward(args)...); \ - } \ - }; \ - \ - template \ - struct view_invoker { \ - static Ret invoke(data_accessor CONST VOLATILE* data, std::size_t, \ - Args... args) NOEXCEPT { \ - \ - auto ptr = static_cast(data->ptr_); \ - return invocation::invoke(address_taker::restore(ptr), \ - std::forward(args)...); \ - } \ - }; \ - \ - template \ - using callable = T CONST VOLATILE REF; \ - \ - using arguments = identity; \ - \ - using is_noexcept = is_noexcept_##NOEXCEPT; \ - \ - template \ - struct empty_invoker { \ - static Ret invoke(data_accessor CONST VOLATILE* /*data*/, \ - std::size_t /*capacity*/, Args... /*args*/) NOEXCEPT { \ - throw_or_abort##NOEXCEPT(std::integral_constant{}); \ - } \ - }; \ - }; - -FU2_DETAIL_EXPAND_QUALIFIERS(FU2_DEFINE_FUNCTION_TRAIT) -#undef FU2_DEFINE_FUNCTION_TRAIT - -/// Deduces to the function pointer to the given signature -template -using function_pointer_of = typename function_trait::pointer_type; - -template -struct invoke_table; - -/// We optimize the vtable_t in case there is a single function overload -template -struct invoke_table { - using type = function_pointer_of; - - /// Return the function pointer itself - template - static constexpr auto fetch(type pointer) noexcept { - static_assert(Index == 0U, "The index should be 0 here!"); - return pointer; - } - - /// Returns the thunk of an single overloaded callable - template - static constexpr type get_invocation_table_of() noexcept { - return &function_trait::template internal_invoker::invoke; - } - /// Returns the thunk of an single overloaded callable - template - static constexpr type get_invocation_view_table_of() noexcept { - return &function_trait::template view_invoker::invoke; - } - /// Returns the thunk of an empty single overloaded callable - template - static constexpr type get_empty_invocation_table() noexcept { - return &function_trait::template empty_invoker::invoke; - } -}; -/// We generate a table in case of multiple function overloads -template -struct invoke_table { - using type = - std::tuple, function_pointer_of, - function_pointer_of...> const*; - - /// Return the function pointer at the particular index - template - static constexpr auto fetch(type table) noexcept { - return std::get(*table); - } - - /// The invocation vtable for a present object - template - struct invocation_vtable : public std::tuple, - function_pointer_of, - function_pointer_of...> { - constexpr invocation_vtable() noexcept - : std::tuple, function_pointer_of, - function_pointer_of...>(std::make_tuple( - &function_trait::template internal_invoker< - T, IsInplace>::invoke, - &function_trait::template internal_invoker< - T, IsInplace>::invoke, - &function_trait::template internal_invoker< - T, IsInplace>::invoke...)) { - } - }; - - /// Returns the thunk of an multi overloaded callable - template - static type get_invocation_table_of() noexcept { - static invocation_vtable const table; - return &table; - } - - /// The invocation vtable for a present object - template - struct invocation_view_vtable - : public std::tuple, - function_pointer_of, - function_pointer_of...> { - constexpr invocation_view_vtable() noexcept - : std::tuple, function_pointer_of, - function_pointer_of...>(std::make_tuple( - &function_trait::template view_invoker::invoke, - &function_trait::template view_invoker::invoke, - &function_trait::template view_invoker::invoke...)) { - } - }; - - /// Returns the thunk of an multi overloaded callable - template - static type get_invocation_view_table_of() noexcept { - static invocation_view_vtable const table; - return &table; - } - - /// The invocation table for an empty wrapper - template - struct empty_vtable : public std::tuple, - function_pointer_of, - function_pointer_of...> { - constexpr empty_vtable() noexcept - : std::tuple, function_pointer_of, - function_pointer_of...>( - std::make_tuple(&function_trait::template empty_invoker< - IsThrowing>::invoke, - &function_trait::template empty_invoker< - IsThrowing>::invoke, - &function_trait::template empty_invoker< - IsThrowing>::invoke...)) { - } - }; - - /// Returns the thunk of an multi single overloaded callable - template - static type get_empty_invocation_table() noexcept { - static empty_vtable const table; - return &table; - } -}; - -template -class operator_impl; - -#define FU2_DEFINE_FUNCTION_TRAIT(CONST, VOLATILE, NOEXCEPT, OVL_REF, REF) \ - template \ - class operator_impl \ - : operator_impl { \ - \ - template \ - friend class operator_impl; \ - \ - protected: \ - operator_impl() = default; \ - ~operator_impl() = default; \ - operator_impl(operator_impl const&) = default; \ - operator_impl(operator_impl&&) = default; \ - operator_impl& operator=(operator_impl const&) = default; \ - operator_impl& operator=(operator_impl&&) = default; \ - \ - using operator_impl::operator(); \ - \ - Ret operator()(Args... args) CONST VOLATILE OVL_REF NOEXCEPT { \ - auto parent = static_cast(this); \ - using erasure_t = std::decay_terasure_)>; \ - \ - /* `std::decay_terasure_)>` is a workaround for a */ \ - /* compiler regression of MSVC 16.3.1, see #29 for details. */ \ - return std::decay_terasure_)>::template invoke( \ - static_cast(parent->erasure_), \ - std::forward(args)...); \ - } \ - }; \ - template \ - class operator_impl, \ - Ret(Args...) CONST VOLATILE OVL_REF NOEXCEPT> \ - : copyable { \ - \ - template \ - friend class operator_impl; \ - \ - protected: \ - operator_impl() = default; \ - ~operator_impl() = default; \ - operator_impl(operator_impl const&) = default; \ - operator_impl(operator_impl&&) = default; \ - operator_impl& operator=(operator_impl const&) = default; \ - operator_impl& operator=(operator_impl&&) = default; \ - \ - Ret operator()(Args... args) CONST VOLATILE OVL_REF NOEXCEPT { \ - auto parent = \ - static_cast CONST VOLATILE*>(this); \ - using erasure_t = std::decay_terasure_)>; \ - \ - /* `std::decay_terasure_)>` is a workaround for a */ \ - /* compiler regression of MSVC 16.3.1, see #29 for details. */ \ - return std::decay_terasure_)>::template invoke( \ - static_cast(parent->erasure_), \ - std::forward(args)...); \ - } \ - }; - -FU2_DETAIL_EXPAND_QUALIFIERS(FU2_DEFINE_FUNCTION_TRAIT) -#undef FU2_DEFINE_FUNCTION_TRAIT -} // namespace invocation_table - -namespace tables { -/// Identifies the action which is dispatched on the erased object -enum class opcode { - op_move, //< Move the object and set the vtable - op_copy, //< Copy the object and set the vtable - op_destroy, //< Destroy the object and reset the vtable - op_weak_destroy, //< Destroy the object without resetting the vtable - op_fetch_empty, //< Stores true or false into the to storage - //< to indicate emptiness -}; - -/// Abstraction for a vtable together with a command table -/// TODO Add optimization for a single formal argument -/// TODO Add optimization to merge both tables if the function is size -/// optimized -template -class vtable; -template -class vtable> { - using command_function_t = void (*)(vtable* /*this*/, opcode /*op*/, - data_accessor* /*from*/, - std::size_t /*from_capacity*/, - data_accessor* /*to*/, - std::size_t /*to_capacity*/); - - using invoke_table_t = invocation_table::invoke_table; - - command_function_t cmd_; - typename invoke_table_t::type vtable_; - - template - struct trait { - static_assert(is_box::value, - "The trait must be specialized with a box!"); - - /// The command table - template - static void process_cmd(vtable* to_table, opcode op, data_accessor* from, - std::size_t from_capacity, data_accessor* to, - std::size_t to_capacity) { - - switch (op) { - case opcode::op_move: { - /// Retrieve the pointer to the object - auto box = static_cast(retrieve( - std::integral_constant{}, from, from_capacity)); - assert(box && "The object must not be over aligned or null!"); - - if (!IsInplace) { - // Just swap both pointers if we allocated on the heap - to->ptr_ = from->ptr_; - -#ifndef NDEBUG - // We don't need to null the pointer since we know that - // we don't own the data anymore through the vtable - // which is set to empty. - from->ptr_ = nullptr; -#endif - - to_table->template set_allocated(); - - } - // The object is allocated inplace - else { - construct(std::true_type{}, std::move(*box), to_table, to, - to_capacity); - box->~T(); - } - return; - } - case opcode::op_copy: { - auto box = static_cast(retrieve( - std::integral_constant{}, from, from_capacity)); - assert(box && "The object must not be over aligned or null!"); - - assert(std::is_copy_constructible::value && - "The box is required to be copyable here!"); - - // Try to allocate the object inplace - construct(std::is_copy_constructible{}, *box, to_table, to, - to_capacity); - return; - } - case opcode::op_destroy: - case opcode::op_weak_destroy: { - - assert(!to && !to_capacity && "Arg overflow!"); - auto box = static_cast(retrieve( - std::integral_constant{}, from, from_capacity)); - - if (IsInplace) { - box->~T(); - } else { - box_factory::box_deallocate(box); - } - - if (op == opcode::op_destroy) { - to_table->set_empty(); - } - return; - } - case opcode::op_fetch_empty: { - write_empty(to, false); - return; - } - } - - FU2_DETAIL_UNREACHABLE(); - } - - template - static void - construct(std::true_type /*apply*/, Box&& box, vtable* to_table, - data_accessor* to, - std::size_t to_capacity) noexcept(HasStrongExceptGuarantee) { - // Try to allocate the object inplace - void* storage = retrieve(std::true_type{}, to, to_capacity); - if (storage) { - to_table->template set_inplace(); - } else { - // Allocate the object through the allocator - to->ptr_ = storage = - box_factory>::box_allocate(std::addressof(box)); - to_table->template set_allocated(); - } - new (storage) T(std::forward(box)); - } - - template - static void - construct(std::false_type /*apply*/, Box&& /*box*/, vtable* /*to_table*/, - data_accessor* /*to*/, - std::size_t /*to_capacity*/) noexcept(HasStrongExceptGuarantee) { - } - }; - - /// The command table - static void empty_cmd(vtable* to_table, opcode op, data_accessor* /*from*/, - std::size_t /*from_capacity*/, data_accessor* to, - std::size_t /*to_capacity*/) { - - switch (op) { - case opcode::op_move: - case opcode::op_copy: { - to_table->set_empty(); - break; - } - case opcode::op_destroy: - case opcode::op_weak_destroy: { - // Do nothing - break; - } - case opcode::op_fetch_empty: { - write_empty(to, true); - break; - } - default: { - FU2_DETAIL_UNREACHABLE(); - } - } - } - -public: - vtable() noexcept = default; - - /// Initialize an object at the given position - template - static void init(vtable& table, T&& object, data_accessor* to, - std::size_t to_capacity) { - - trait>::construct(std::true_type{}, std::forward(object), - &table, to, to_capacity); - } - - /// Moves the object at the given position - void move(vtable& to_table, data_accessor* from, std::size_t from_capacity, - data_accessor* to, - std::size_t to_capacity) noexcept(HasStrongExceptGuarantee) { - cmd_(&to_table, opcode::op_move, from, from_capacity, to, to_capacity); - set_empty(); - } - - /// Destroys the object at the given position - void copy(vtable& to_table, data_accessor const* from, - std::size_t from_capacity, data_accessor* to, - std::size_t to_capacity) const { - cmd_(&to_table, opcode::op_copy, const_cast(from), - from_capacity, to, to_capacity); - } - - /// Destroys the object at the given position - void destroy(data_accessor* from, - std::size_t from_capacity) noexcept(HasStrongExceptGuarantee) { - cmd_(this, opcode::op_destroy, from, from_capacity, nullptr, 0U); - } - - /// Destroys the object at the given position without invalidating the - /// vtable - void - weak_destroy(data_accessor* from, - std::size_t from_capacity) noexcept(HasStrongExceptGuarantee) { - cmd_(this, opcode::op_weak_destroy, from, from_capacity, nullptr, 0U); - } - - /// Returns true when the vtable doesn't hold any erased object - bool empty() const noexcept { - data_accessor data; - cmd_(nullptr, opcode::op_fetch_empty, nullptr, 0U, &data, 0U); - return bool(data.inplace_storage_); - } - - /// Invoke the function at the given index - template - constexpr auto invoke(Args&&... args) const { - auto thunk = invoke_table_t::template fetch(vtable_); - return thunk(std::forward(args)...); - } - /// Invoke the function at the given index - template - constexpr auto invoke(Args&&... args) const volatile { - auto thunk = invoke_table_t::template fetch(vtable_); - return thunk(std::forward(args)...); - } - - template - void set_inplace() noexcept { - using type = std::decay_t; - vtable_ = invoke_table_t::template get_invocation_table_of(); - cmd_ = &trait::template process_cmd; - } - - template - void set_allocated() noexcept { - using type = std::decay_t; - vtable_ = invoke_table_t::template get_invocation_table_of(); - cmd_ = &trait::template process_cmd; - } - - void set_empty() noexcept { - vtable_ = invoke_table_t::template get_empty_invocation_table(); - cmd_ = &empty_cmd; - } -}; -} // namespace tables - -/// A union which makes the pointer to the heap object share the -/// same space with the internal capacity. -/// The storage type is distinguished by multiple versions of the -/// control and vtable. -template -struct internal_capacity { - /// We extend the union through a technique similar to the tail object hack - typedef union { - /// Tag to access the structure in a type-safe way - data_accessor accessor_; - /// The internal capacity we use to allocate in-place - std::aligned_storage_t capacity_; - } type; -}; -template -struct internal_capacity< - Capacity, std::enable_if_t<(Capacity::capacity < sizeof(void*))>> { - typedef struct { - /// Tag to access the structure in a type-safe way - data_accessor accessor_; - } type; -}; - -template -class internal_capacity_holder { - // Tag to access the structure in a type-safe way - typename internal_capacity::type storage_; - -public: - constexpr internal_capacity_holder() = default; - - constexpr data_accessor* opaque_ptr() noexcept { - return &storage_.accessor_; - } - constexpr data_accessor const* opaque_ptr() const noexcept { - return &storage_.accessor_; - } - constexpr data_accessor volatile* opaque_ptr() volatile noexcept { - return &storage_.accessor_; - } - constexpr data_accessor const volatile* opaque_ptr() const volatile noexcept { - return &storage_.accessor_; - } - - static constexpr std::size_t capacity() noexcept { - return sizeof(storage_); - } -}; - -/// An owning erasure -template -class erasure : internal_capacity_holder { - template - friend class erasure; - template - friend class operator_impl; - - using vtable_t = tables::vtable; - - vtable_t vtable_; - -public: - /// Returns the capacity of this erasure - static constexpr std::size_t capacity() noexcept { - return internal_capacity_holder::capacity(); - } - - constexpr erasure() noexcept { - vtable_.set_empty(); - } - - constexpr erasure(std::nullptr_t) noexcept { - vtable_.set_empty(); - } - - constexpr erasure(erasure&& right) noexcept( - Property::is_strong_exception_guaranteed) { - right.vtable_.move(vtable_, right.opaque_ptr(), right.capacity(), - this->opaque_ptr(), capacity()); - } - - constexpr erasure(erasure const& right) { - right.vtable_.copy(vtable_, right.opaque_ptr(), right.capacity(), - this->opaque_ptr(), capacity()); - } - - template - constexpr erasure(erasure right) noexcept( - Property::is_strong_exception_guaranteed) { - right.vtable_.move(vtable_, right.opaque_ptr(), right.capacity(), - this->opaque_ptr(), capacity()); - } - - template >> - constexpr erasure(std::false_type /*use_bool_op*/, T&& callable, - Allocator&& allocator = Allocator{}) { - vtable_t::init(vtable_, - type_erasure::make_box( - std::integral_constant{}, - std::forward(callable), - std::forward(allocator)), - this->opaque_ptr(), capacity()); - } - template >> - constexpr erasure(std::true_type /*use_bool_op*/, T&& callable, - Allocator&& allocator = Allocator{}) { - if (bool(callable)) { - vtable_t::init(vtable_, - type_erasure::make_box( - std::integral_constant{}, - std::forward(callable), - std::forward(allocator)), - this->opaque_ptr(), capacity()); - } else { - vtable_.set_empty(); - } - } - - ~erasure() { - vtable_.weak_destroy(this->opaque_ptr(), capacity()); - } - - constexpr erasure& - operator=(std::nullptr_t) noexcept(Property::is_strong_exception_guaranteed) { - vtable_.destroy(this->opaque_ptr(), capacity()); - return *this; - } - - constexpr erasure& operator=(erasure&& right) noexcept( - Property::is_strong_exception_guaranteed) { - vtable_.weak_destroy(this->opaque_ptr(), capacity()); - right.vtable_.move(vtable_, right.opaque_ptr(), right.capacity(), - this->opaque_ptr(), capacity()); - return *this; - } - - constexpr erasure& operator=(erasure const& right) { - vtable_.weak_destroy(this->opaque_ptr(), capacity()); - right.vtable_.copy(vtable_, right.opaque_ptr(), right.capacity(), - this->opaque_ptr(), capacity()); - return *this; - } - - template - constexpr erasure& - operator=(erasure right) noexcept( - Property::is_strong_exception_guaranteed) { - vtable_.weak_destroy(this->opaque_ptr(), capacity()); - right.vtable_.move(vtable_, right.opaque_ptr(), right.capacity(), - this->opaque_ptr(), capacity()); - return *this; - } - - template >> - void assign(std::false_type /*use_bool_op*/, T&& callable, - Allocator&& allocator = {}) { - vtable_.weak_destroy(this->opaque_ptr(), capacity()); - vtable_t::init(vtable_, - type_erasure::make_box( - std::integral_constant{}, - std::forward(callable), - std::forward(allocator)), - this->opaque_ptr(), capacity()); - } - - template >> - void assign(std::true_type /*use_bool_op*/, T&& callable, - Allocator&& allocator = {}) { - if (bool(callable)) { - assign(std::false_type{}, std::forward(callable), - std::forward(allocator)); - } else { - operator=(nullptr); - } - } - - /// Returns true when the erasure doesn't hold any erased object - constexpr bool empty() const noexcept { - return vtable_.empty(); - } - - /// Invoke the function of the erasure at the given index - /// - /// We define this out of class to be able to forward the qualified - /// erasure correctly. - template - static constexpr auto invoke(Erasure&& erasure, Args&&... args) { - auto const capacity = erasure.capacity(); - return erasure.vtable_.template invoke( - std::forward(erasure).opaque_ptr(), capacity, - std::forward(args)...); - } -}; - -// A non owning erasure -template -class erasure> { - template - friend class erasure; - template - friend class operator_impl; - - using property_t = property; - - using invoke_table_t = invocation_table::invoke_table; - typename invoke_table_t::type invoke_table_; - - /// The internal pointer to the non owned object - data_accessor view_; - -public: - // NOLINTNEXTLINE(cppcoreguidlines-pro-type-member-init) - constexpr erasure() noexcept - : invoke_table_( - invoke_table_t::template get_empty_invocation_table()), - view_(nullptr) { - } - - // NOLINTNEXTLINE(cppcoreguidlines-pro-type-member-init) - constexpr erasure(std::nullptr_t) noexcept - : invoke_table_( - invoke_table_t::template get_empty_invocation_table()), - view_(nullptr) { - } - - // NOLINTNEXTLINE(cppcoreguidlines-pro-type-member-init) - constexpr erasure(erasure&& right) noexcept - : invoke_table_(right.invoke_table_), view_(right.view_) { - } - - constexpr erasure(erasure const& /*right*/) = default; - - template - // NOLINTNEXTLINE(cppcoreguidlines-pro-type-member-init) - constexpr erasure(erasure right) noexcept - : invoke_table_(right.invoke_table_), view_(right.view_) { - } - - template - // NOLINTNEXTLINE(cppcoreguidlines-pro-type-member-init) - constexpr erasure(std::false_type /*use_bool_op*/, T&& object) - : invoke_table_(invoke_table_t::template get_invocation_view_table_of< - std::decay_t>()), - view_(address_taker>::take(std::forward(object))) { - } - template - // NOLINTNEXTLINE(cppcoreguidlines-pro-type-member-init) - constexpr erasure(std::true_type use_bool_op, T&& object) { - this->assign(use_bool_op, std::forward(object)); - } - - ~erasure() = default; - - constexpr erasure& - operator=(std::nullptr_t) noexcept(HasStrongExceptGuarantee) { - invoke_table_ = - invoke_table_t::template get_empty_invocation_table(); - view_.ptr_ = nullptr; - return *this; - } - - constexpr erasure& operator=(erasure&& right) noexcept { - invoke_table_ = right.invoke_table_; - view_ = right.view_; - right = nullptr; - return *this; - } - - constexpr erasure& operator=(erasure const& /*right*/) = default; - - template - constexpr erasure& - operator=(erasure right) noexcept { - invoke_table_ = right.invoke_table_; - view_ = right.view_; - return *this; - } - - template - constexpr void assign(std::false_type /*use_bool_op*/, T&& callable) { - invoke_table_ = invoke_table_t::template get_invocation_view_table_of< - std::decay_t>(); - view_.ptr_ = - address_taker>::take(std::forward(callable)); - } - template - constexpr void assign(std::true_type /*use_bool_op*/, T&& callable) { - if (bool(callable)) { - assign(std::false_type{}, std::forward(callable)); - } else { - operator=(nullptr); - } - } - - /// Returns true when the erasure doesn't hold any erased object - constexpr bool empty() const noexcept { - return view_.ptr_ == nullptr; - } - - template - static constexpr auto invoke(Erasure&& erasure, T&&... args) { - auto thunk = invoke_table_t::template fetch(erasure.invoke_table_); - return thunk(&(erasure.view_), 0UL, std::forward(args)...); - } -}; -} // namespace type_erasure - -/// Deduces to a true_type if the type T provides the given signature and the -/// signature is noexcept correct callable. -template > -struct accepts_one - : std::integral_constant< - bool, invocation::can_invoke, - typename Trait::arguments>::value && - invocation::is_noexcept_correct< - Trait::is_noexcept::value, - typename Trait::template callable, - typename Trait::arguments>::value> {}; - -/// Deduces to a true_type if the type T provides all signatures -template -struct accepts_all : std::false_type {}; -template -struct accepts_all< - T, identity, - void_t::value>...>> - : std::true_type {}; - -/// Deduces to a true_type if the type T is implementing operator bool() -/// or if the type is convertible to bool directly, this also implements an -/// optimizations for function references `void(&)()` which are can never -/// be null and for such a conversion to bool would never return false. -#if defined(FU2_HAS_NO_EMPTY_PROPAGATION) -template -struct use_bool_op : std::false_type {}; -#else -template -struct has_bool_op : std::false_type {}; -template -struct has_bool_op()))>> - : std::true_type { -#ifndef NDEBUG - static_assert(!std::is_pointer::value, - "Missing deduction for function pointer!"); -#endif -}; - -template -struct use_bool_op : has_bool_op {}; - -#define FU2_DEFINE_USE_OP_TRAIT(CONST, VOLATILE, NOEXCEPT) \ - template \ - struct use_bool_op \ - : std::true_type {}; - -FU2_DETAIL_EXPAND_CV(FU2_DEFINE_USE_OP_TRAIT) -#undef FU2_DEFINE_USE_OP_TRAIT - -template -struct use_bool_op : std::false_type {}; - -#if defined(FU2_HAS_CXX17_NOEXCEPT_FUNCTION_TYPE) -template -struct use_bool_op : std::false_type {}; -#endif -#endif // FU2_HAS_NO_EMPTY_PROPAGATION - -template -struct assert_wrong_copy_assign { - static_assert(!Config::is_owning || !Config::is_copyable || - std::is_copy_constructible>::value, - "Can't wrap a non copyable object into a unique function!"); - - using type = void; -}; - -template -struct assert_no_strong_except_guarantee { - static_assert( - !IsStrongExceptGuaranteed || - (std::is_nothrow_move_constructible::value && - std::is_nothrow_destructible::value), - "Can't wrap a object an object that has no strong exception guarantees " - "if this is required by the wrapper!"); - - using type = void; -}; - -/// SFINAES out if the given callable is not copyable correct to the left one. -template -using enable_if_copyable_correct_t = - std::enable_if_t<(!LeftConfig::is_copyable || RightConfig::is_copyable)>; - -template -using is_owning_correct = - std::integral_constant; - -/// SFINAES out if the given function2 is not owning correct to this one -template -using enable_if_owning_correct_t = - std::enable_if_t::value>; - -template -class function> - : type_erasure::invocation_table::operator_impl< - 0U, - function>, - Args...> { - - template - friend class function; - - template - friend class type_erasure::invocation_table::operator_impl; - - using property_t = property; - using erasure_t = - type_erasure::erasure; - - template - using enable_if_can_accept_all_t = - std::enable_if_t, identity>::value>; - - template - struct is_convertible_to_this : std::false_type {}; - template - struct is_convertible_to_this< - function, - void_t, - enable_if_owning_correct_t>> - : std::true_type {}; - - template - using enable_if_not_convertible_to_this = - std::enable_if_t>::value>; - - template - using enable_if_owning_t = - std::enable_if_t::value && Config::is_owning>; - - template - using assert_wrong_copy_assign_t = - typename assert_wrong_copy_assign>::type; - - template - using assert_no_strong_except_guarantee_t = - typename assert_no_strong_except_guarantee>::type; - - erasure_t erasure_; - -public: - /// Default constructor which empty constructs the function - function() = default; - ~function() = default; - - explicit constexpr function(function const& /*right*/) = default; - explicit constexpr function(function&& /*right*/) = default; - - /// Copy construction from another copyable function - template * = nullptr, - enable_if_copyable_correct_t* = nullptr, - enable_if_owning_correct_t* = nullptr> - constexpr function(function const& right) - : erasure_(right.erasure_) { - } - - /// Move construction from another function - template * = nullptr, - enable_if_owning_correct_t* = nullptr> - constexpr function(function&& right) - : erasure_(std::move(right.erasure_)) { - } - - /// Construction from a callable object which overloads the `()` operator - template * = nullptr, - enable_if_can_accept_all_t* = nullptr, - assert_wrong_copy_assign_t* = nullptr, - assert_no_strong_except_guarantee_t* = nullptr> - constexpr function(T&& callable) - : erasure_(use_bool_op>{}, std::forward(callable)) { - } - template * = nullptr, - enable_if_can_accept_all_t* = nullptr, - enable_if_owning_t* = nullptr, - assert_wrong_copy_assign_t* = nullptr, - assert_no_strong_except_guarantee_t* = nullptr> - constexpr function(T&& callable, Allocator&& allocator) - : erasure_(use_bool_op>{}, std::forward(callable), - std::forward(allocator)) { - } - - /// Empty constructs the function - constexpr function(std::nullptr_t np) : erasure_(np) { - } - - function& operator=(function const& /*right*/) = default; - function& operator=(function&& /*right*/) = default; - - /// Copy assigning from another copyable function - template * = nullptr, - enable_if_copyable_correct_t* = nullptr, - enable_if_owning_correct_t* = nullptr> - function& operator=(function const& right) { - erasure_ = right.erasure_; - return *this; - } - - /// Move assigning from another function - template * = nullptr, - enable_if_owning_correct_t* = nullptr> - function& operator=(function&& right) { - erasure_ = std::move(right.erasure_); - return *this; - } - - /// Move assigning from a callable object - template * = nullptr, - enable_if_can_accept_all_t* = nullptr, - assert_wrong_copy_assign_t* = nullptr, - assert_no_strong_except_guarantee_t* = nullptr> - function& operator=(T&& callable) { - erasure_.assign(use_bool_op>{}, std::forward(callable)); - return *this; - } - - /// Clears the function - function& operator=(std::nullptr_t np) { - erasure_ = np; - return *this; - } - - /// Returns true when the function is empty - bool empty() const noexcept { - return erasure_.empty(); - } - - /// Returns true when the function isn't empty - explicit operator bool() const noexcept { - return !empty(); - } - - /// Assigns a new target with an optional allocator - template >, - enable_if_not_convertible_to_this* = nullptr, - enable_if_can_accept_all_t* = nullptr, - assert_wrong_copy_assign_t* = nullptr, - assert_no_strong_except_guarantee_t* = nullptr> - void assign(T&& callable, Allocator&& allocator = Allocator{}) { - erasure_.assign(use_bool_op>{}, std::forward(callable), - std::forward(allocator)); - } - - /// Swaps this function with the given function - void swap(function& other) noexcept(HasStrongExceptGuarantee) { - if (&other == this) { - return; - } - - function cache = std::move(other); - other = std::move(*this); - *this = std::move(cache); - } - - /// Swaps the left function with the right one - friend void swap(function& left, - function& right) noexcept(HasStrongExceptGuarantee) { - left.swap(right); - } - - /// Calls the wrapped callable object - using type_erasure::invocation_table::operator_impl< - 0U, function, Args...>::operator(); -}; - -template -bool operator==(function const& f, std::nullptr_t) { - return !bool(f); -} - -template -bool operator!=(function const& f, std::nullptr_t) { - return bool(f); -} - -template -bool operator==(std::nullptr_t, function const& f) { - return !bool(f); -} - -template -bool operator!=(std::nullptr_t, function const& f) { - return bool(f); -} - -// Default intended object size of the function -using object_size = std::integral_constant; -} // namespace detail -} // namespace abi_400 - -/// Can be passed to function_base as template argument which causes -/// the internal small buffer to be sized according to the given size, -/// and aligned with the given alignment. -template -struct capacity_fixed { - static constexpr std::size_t capacity = Capacity; - static constexpr std::size_t alignment = Alignment; -}; - -/// Default capacity for small functor optimization -struct capacity_default - : capacity_fixed {}; - -/// Can be passed to function_base as template argument which causes -/// the internal small buffer to be removed from the callable wrapper. -/// The owning function_base will then allocate memory for every object -/// it applies a type erasure on. -struct capacity_none : capacity_fixed<0UL> {}; - -/// Can be passed to function_base as template argument which causes -/// the internal small buffer to be sized such that it can hold -/// the given object without allocating memory for an applied type erasure. -template -struct capacity_can_hold { - static constexpr std::size_t capacity = sizeof(T); - static constexpr std::size_t alignment = alignof(T); -}; - -/// An adaptable function wrapper base for arbitrary functional types. -/// -/// \tparam IsOwning Is true when the type erasure shall be owning the object. -/// -/// \tparam IsCopyable Defines whether the function is copyable or not -/// -/// \tparam Capacity Defines the internal capacity of the function -/// for small functor optimization. -/// The size of the whole function object will be the capacity -/// plus the size of two pointers. If the capacity is zero, -/// the size will increase through one additional pointer -/// so the whole object has the size of 3 * sizeof(void*). -/// The type which is passed to the Capacity template parameter -/// shall provide a capacity and alignment member which -/// looks like the following example: -/// ```cpp -/// struct my_capacity { -/// static constexpr std::size_t capacity = sizeof(my_type); -/// static constexpr std::size_t alignment = alignof(my_type); -/// }; -/// ``` -/// -/// \tparam IsThrowing Defines whether the function throws an exception on -/// empty function call, `std::abort` is called otherwise. -/// -/// \tparam HasStrongExceptGuarantee Defines whether all objects satisfy the -/// strong exception guarantees, -/// which means the function type will satisfy -/// the strong exception guarantees too. -/// -/// \tparam Signatures Defines the signature of the callable wrapper -/// -template -using function_base = detail::function< - detail::config, - detail::property>; - -/// An owning copyable function wrapper for arbitrary callable types. -template -using function = function_base; - -/// An owning non copyable function wrapper for arbitrary callable types. -template -using unique_function = function_base; - -/// A non owning copyable function wrapper for arbitrary callable types. -template -using function_view = function_base; - -#if !defined(FU2_HAS_DISABLED_EXCEPTIONS) -/// Exception type that is thrown when invoking empty function objects -/// and exception support isn't disabled. -/// -/// Exception support is enabled if -/// the template parameter 'Throwing' is set to true (default). -/// -/// This type will default to std::bad_function_call if the -/// functional header is used, otherwise the library provides its own type. -/// -/// You may disable the inclusion of the functional header -/// through defining `FU2_WITH_NO_FUNCTIONAL_HEADER`. -/// -using detail::type_erasure::invocation_table::bad_function_call; -#endif - -/// Returns a callable object, which unifies all callable objects -/// that were passed to this function. -/// -/// ```cpp -/// auto overloaded = fu2::overload([](std::true_type) { return true; }, -/// [](std::false_type) { return false; }); -/// ``` -/// -/// \param callables A pack of callable objects with arbitrary signatures. -/// -/// \returns A callable object which exposes the -/// -template -constexpr auto overload(T&&... callables) { - return detail::overloading::overload(std::forward(callables)...); -} -} // namespace fu2 - -#undef FU2_DETAIL_EXPAND_QUALIFIERS -#undef FU2_DETAIL_EXPAND_QUALIFIERS_NOEXCEPT -#undef FU2_DETAIL_EXPAND_CV -#undef FU2_DETAIL_EXPAND_CV_NOEXCEPT -#undef FU2_DETAIL_UNREACHABLE_INTRINSIC -#undef FU2_DETAIL_UNREACHABLE_INTRINSIC -#undef FU2_DETAIL_TRAP - -#endif // FU2_INCLUDED_FUNCTION2_HPP_ - diff --git a/src/Vendor/uwebsockets/src/bsd.c b/src/Vendor/uwebsockets/src/bsd.c new file mode 100644 index 0000000..c525c39 --- /dev/null +++ b/src/Vendor/uwebsockets/src/bsd.c @@ -0,0 +1,308 @@ +/* + * Authored by Alex Hultman, 2018-2021. + * Intellectual property of third-party. + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* Todo: this file should lie in networking/bsd.c */ + +#include "libusockets.h" +#include "internal/internal.h" + +#include + +#ifndef _WIN32 +//#define _GNU_SOURCE +#include +#include +#include +#include +#include +#include +#include +#include +#include +#endif + +LIBUS_SOCKET_DESCRIPTOR apple_no_sigpipe(LIBUS_SOCKET_DESCRIPTOR fd) { +#ifdef __APPLE__ + if (fd != LIBUS_SOCKET_ERROR) { + int no_sigpipe = 1; + setsockopt(fd, SOL_SOCKET, SO_NOSIGPIPE, &no_sigpipe, sizeof(int)); + } +#endif + return fd; +} + +LIBUS_SOCKET_DESCRIPTOR bsd_set_nonblocking(LIBUS_SOCKET_DESCRIPTOR fd) { +#ifdef _WIN32 + /* Libuv will set windows sockets as non-blocking */ +#else + fcntl(fd, F_SETFL, fcntl(fd, F_GETFL, 0) | O_NONBLOCK); +#endif + return fd; +} + +void bsd_socket_nodelay(LIBUS_SOCKET_DESCRIPTOR fd, int enabled) { + setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, (void *) &enabled, sizeof(enabled)); +} + +void bsd_socket_flush(LIBUS_SOCKET_DESCRIPTOR fd) { + // Linux TCP_CORK has the same underlying corking mechanism as with MSG_MORE +#ifdef TCP_CORK + int enabled = 0; + setsockopt(fd, IPPROTO_TCP, TCP_CORK, &enabled, sizeof(int)); +#endif +} + +LIBUS_SOCKET_DESCRIPTOR bsd_create_socket(int domain, int type, int protocol) { + // returns INVALID_SOCKET on error + int flags = 0; +#if defined(SOCK_CLOEXEC) && defined(SOCK_NONBLOCK) + flags = SOCK_CLOEXEC | SOCK_NONBLOCK; +#endif + + LIBUS_SOCKET_DESCRIPTOR created_fd = socket(domain, type | flags, protocol); + + return bsd_set_nonblocking(apple_no_sigpipe(created_fd)); +} + +void bsd_close_socket(LIBUS_SOCKET_DESCRIPTOR fd) { +#ifdef _WIN32 + closesocket(fd); +#else + close(fd); +#endif +} + +void bsd_shutdown_socket(LIBUS_SOCKET_DESCRIPTOR fd) { +#ifdef _WIN32 + shutdown(fd, SD_SEND); +#else + shutdown(fd, SHUT_WR); +#endif +} + +void bsd_shutdown_socket_read(LIBUS_SOCKET_DESCRIPTOR fd) { +#ifdef _WIN32 + shutdown(fd, SD_RECEIVE); +#else + shutdown(fd, SHUT_RD); +#endif +} + +void internal_finalize_bsd_addr(struct bsd_addr_t *addr) { + // parse, so to speak, the address + if (addr->mem.ss_family == AF_INET6) { + addr->ip = (char *) &((struct sockaddr_in6 *) addr)->sin6_addr; + addr->ip_length = sizeof(struct in6_addr); + addr->port = ntohs(((struct sockaddr_in6 *) addr)->sin6_port); + } else if (addr->mem.ss_family == AF_INET) { + addr->ip = (char *) &((struct sockaddr_in *) addr)->sin_addr; + addr->ip_length = sizeof(struct in_addr); + addr->port = ntohs(((struct sockaddr_in *) addr)->sin_port); + } else { + addr->ip_length = 0; + addr->port = -1; + } +} + +int bsd_local_addr(LIBUS_SOCKET_DESCRIPTOR fd, struct bsd_addr_t *addr) { + addr->len = sizeof(addr->mem); + if (getsockname(fd, (struct sockaddr *) &addr->mem, &addr->len)) { + return -1; + } + internal_finalize_bsd_addr(addr); + return 0; +} + +int bsd_remote_addr(LIBUS_SOCKET_DESCRIPTOR fd, struct bsd_addr_t *addr) { + addr->len = sizeof(addr->mem); + if (getpeername(fd, (struct sockaddr *) &addr->mem, &addr->len)) { + return -1; + } + internal_finalize_bsd_addr(addr); + return 0; +} + +char *bsd_addr_get_ip(struct bsd_addr_t *addr) { + return addr->ip; +} + +int bsd_addr_get_ip_length(struct bsd_addr_t *addr) { + return addr->ip_length; +} + +int bsd_addr_get_port(struct bsd_addr_t *addr) { + return addr->port; +} + +// called by dispatch_ready_poll +LIBUS_SOCKET_DESCRIPTOR bsd_accept_socket(LIBUS_SOCKET_DESCRIPTOR fd, struct bsd_addr_t *addr) { + LIBUS_SOCKET_DESCRIPTOR accepted_fd; + addr->len = sizeof(addr->mem); + +#if defined(SOCK_CLOEXEC) && defined(SOCK_NONBLOCK) + // Linux, FreeBSD + accepted_fd = accept4(fd, (struct sockaddr *) addr, &addr->len, SOCK_CLOEXEC | SOCK_NONBLOCK); +#else + // Windows, OS X + accepted_fd = accept(fd, (struct sockaddr *) addr, &addr->len); + +#endif + + /* We cannot rely on addr since it is not initialized if failed */ + if (accepted_fd == LIBUS_SOCKET_ERROR) { + return LIBUS_SOCKET_ERROR; + } + + internal_finalize_bsd_addr(addr); + + return bsd_set_nonblocking(apple_no_sigpipe(accepted_fd)); +} + +int bsd_recv(LIBUS_SOCKET_DESCRIPTOR fd, void *buf, int length, int flags) { + return recv(fd, buf, length, flags); +} + +int bsd_send(LIBUS_SOCKET_DESCRIPTOR fd, const char *buf, int length, int msg_more) { + + // MSG_MORE (Linux), MSG_PARTIAL (Windows), TCP_NOPUSH (BSD) + +#ifndef MSG_NOSIGNAL +#define MSG_NOSIGNAL 0 +#endif + +#ifdef MSG_MORE + + // for Linux we do not want signals + return send(fd, buf, length, (msg_more * MSG_MORE) | MSG_NOSIGNAL); + +#else + + // use TCP_NOPUSH + + return send(fd, buf, length, MSG_NOSIGNAL); + +#endif +} + +int bsd_would_block() { +#ifdef _WIN32 + return WSAGetLastError() == WSAEWOULDBLOCK; +#else + return errno == EWOULDBLOCK;// || errno == EAGAIN; +#endif +} + +// return LIBUS_SOCKET_ERROR or the fd that represents listen socket +// listen both on ipv6 and ipv4 +LIBUS_SOCKET_DESCRIPTOR bsd_create_listen_socket(const char *host, int port, int options) { + struct addrinfo hints, *result; + memset(&hints, 0, sizeof(struct addrinfo)); + + hints.ai_flags = AI_PASSIVE; + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + + char port_string[16]; + snprintf(port_string, 16, "%d", port); + + if (getaddrinfo(host, port_string, &hints, &result)) { + return LIBUS_SOCKET_ERROR; + } + + LIBUS_SOCKET_DESCRIPTOR listenFd = LIBUS_SOCKET_ERROR; + struct addrinfo *listenAddr; + for (struct addrinfo *a = result; a && listenFd == LIBUS_SOCKET_ERROR; a = a->ai_next) { + if (a->ai_family == AF_INET6) { + listenFd = bsd_create_socket(a->ai_family, a->ai_socktype, a->ai_protocol); + listenAddr = a; + } + } + + for (struct addrinfo *a = result; a && listenFd == LIBUS_SOCKET_ERROR; a = a->ai_next) { + if (a->ai_family == AF_INET) { + listenFd = bsd_create_socket(a->ai_family, a->ai_socktype, a->ai_protocol); + listenAddr = a; + } + } + + if (listenFd == LIBUS_SOCKET_ERROR) { + freeaddrinfo(result); + return LIBUS_SOCKET_ERROR; + } + + if (port != 0) { + /* Otherwise, always enable SO_REUSEPORT and SO_REUSEADDR _unless_ options specify otherwise */ +#if /*defined(__linux) &&*/ defined(SO_REUSEPORT) + if (!(options & LIBUS_LISTEN_EXCLUSIVE_PORT)) { + int optval = 1; + setsockopt(listenFd, SOL_SOCKET, SO_REUSEPORT, &optval, sizeof(optval)); + } +#endif + int enabled = 1; + setsockopt(listenFd, SOL_SOCKET, SO_REUSEADDR, (SETSOCKOPT_PTR_TYPE) &enabled, sizeof(enabled)); + } + +#ifdef IPV6_V6ONLY + int disabled = 0; + setsockopt(listenFd, IPPROTO_IPV6, IPV6_V6ONLY, (SETSOCKOPT_PTR_TYPE) &disabled, sizeof(disabled)); +#endif + + if (bind(listenFd, listenAddr->ai_addr, (socklen_t) listenAddr->ai_addrlen) || listen(listenFd, 512)) { + bsd_close_socket(listenFd); + freeaddrinfo(result); + return LIBUS_SOCKET_ERROR; + } + + freeaddrinfo(result); + return listenFd; +} + +LIBUS_SOCKET_DESCRIPTOR bsd_create_connect_socket(const char *host, int port, const char *source_host, int options) { + struct addrinfo hints, *result; + memset(&hints, 0, sizeof(struct addrinfo)); + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + + char port_string[16]; + snprintf(port_string, 16, "%d", port); + + if (getaddrinfo(host, port_string, &hints, &result) != 0) { + return LIBUS_SOCKET_ERROR; + } + + LIBUS_SOCKET_DESCRIPTOR fd = bsd_create_socket(result->ai_family, result->ai_socktype, result->ai_protocol); + if (fd == LIBUS_SOCKET_ERROR) { + freeaddrinfo(result); + return LIBUS_SOCKET_ERROR; + } + + if (source_host) { + struct addrinfo *interface_result; + if (!getaddrinfo(source_host, NULL, NULL, &interface_result)) { + int ret = bind(fd, interface_result->ai_addr, (socklen_t) interface_result->ai_addrlen); + freeaddrinfo(interface_result); + if (ret == LIBUS_SOCKET_ERROR) { + return LIBUS_SOCKET_ERROR; + } + } + } + + connect(fd, result->ai_addr, (socklen_t) result->ai_addrlen); + freeaddrinfo(result); + + return fd; +} diff --git a/src/Vendor/uwebsockets/src/context.c b/src/Vendor/uwebsockets/src/context.c index e63cc0d..860101b 100644 --- a/src/Vendor/uwebsockets/src/context.c +++ b/src/Vendor/uwebsockets/src/context.c @@ -26,6 +26,10 @@ int default_ignore_data_handler(struct us_socket_t *s) { /* Shared with SSL */ +unsigned short us_socket_context_timestamp(int ssl, struct us_socket_context_t *context) { + return context->timestamp; +} + void us_listen_socket_close(int ssl, struct us_listen_socket_t *ls) { /* us_listen_socket_t extends us_socket_t so we close in similar ways */ if (!us_socket_is_closed(0, &ls->s)) { @@ -80,60 +84,92 @@ struct us_loop_t *us_socket_context_loop(int ssl, struct us_socket_context_t *co return context->loop; } -/* Returns the deep copy, to be freed */ -const char *deep_str_copy(const char *src) { - if (!src) { - return src; - } - size_t len = strlen(src) + 1; - char *dst = malloc(len); - memcpy(dst, src, len); - return dst; -} - /* Not shared with SSL */ -struct us_socket_context_t *us_create_socket_context(int ssl, struct us_loop_t *loop, int context_ext_size, struct us_socket_context_options_t options) { - /* For ease of use we copy all passed strings here */ - options.ca_file_name = deep_str_copy(options.ca_file_name); - options.cert_file_name = deep_str_copy(options.cert_file_name); - options.dh_params_file_name = deep_str_copy(options.dh_params_file_name); - options.key_file_name = deep_str_copy(options.key_file_name); - options.passphrase = deep_str_copy(options.passphrase); - +/* Add SNI context */ +void us_socket_context_add_server_name(int ssl, struct us_socket_context_t *context, const char *hostname_pattern, struct us_socket_context_options_t options) { #ifndef LIBUS_NO_SSL if (ssl) { + us_internal_ssl_socket_context_add_server_name((struct us_internal_ssl_socket_context_t *) context, hostname_pattern, options); + } +#endif +} + +/* Remove SNI context */ +void us_socket_context_remove_server_name(int ssl, struct us_socket_context_t *context, const char *hostname_pattern) { +#ifndef LIBUS_NO_SSL + if (ssl) { + us_internal_ssl_socket_context_remove_server_name((struct us_internal_ssl_socket_context_t *) context, hostname_pattern); + } +#endif +} + +/* I don't like this one - maybe rename it to on_missing_server_name? */ + +/* Called when SNI matching fails - not if a match could be made. + * You may modify the context by adding/removing names in this callback. + * If the correct name is added immediately in the callback, it will be used */ +void us_socket_context_on_server_name(int ssl, struct us_socket_context_t *context, void (*cb)(struct us_socket_context_t *, const char *hostname)) { +#ifndef LIBUS_NO_SSL + if (ssl) { + us_internal_ssl_socket_context_on_server_name((struct us_internal_ssl_socket_context_t *) context, (void (*)(struct us_internal_ssl_socket_context_t *, const char *hostname)) cb); + } +#endif +} + +/* Todo: get native context from SNI pattern */ + +void *us_socket_context_get_native_handle(int ssl, struct us_socket_context_t *context) { +#ifndef LIBUS_NO_SSL + if (ssl) { + return us_internal_ssl_socket_context_get_native_handle((struct us_internal_ssl_socket_context_t *) context); + } +#endif + + /* There is no native handle for a non-SSL socket context */ + return 0; +} + +/* Options is currently only applicable for SSL - this will change with time (prefer_low_memory is one example) */ +struct us_socket_context_t *us_create_socket_context(int ssl, struct us_loop_t *loop, int context_ext_size, struct us_socket_context_options_t options) { +#ifndef LIBUS_NO_SSL + if (ssl) { + /* This function will call us, again, with SSL = false and a bigger ext_size */ return (struct us_socket_context_t *) us_internal_create_ssl_socket_context(loop, context_ext_size, options); } #endif + /* This path is taken once either way - always BEFORE whatever SSL may do LATER. + * context_ext_size will however be modified larger in case of SSL, to hold SSL extensions */ + struct us_socket_context_t *context = malloc(sizeof(struct us_socket_context_t) + context_ext_size); context->loop = loop; context->head = 0; context->iterator = 0; context->next = 0; context->ignore_data = default_ignore_data_handler; - context->options = options; + + /* Begin at 0 */ + context->timestamp = 0; us_internal_loop_link(loop, context); + + /* If we are called from within SSL code, SSL code will make further changes to us */ return context; } void us_socket_context_free(int ssl, struct us_socket_context_t *context) { - /* We also simply free every copied string here */ - free((void *) context->options.ca_file_name); - free((void *) context->options.cert_file_name); - free((void *) context->options.dh_params_file_name); - free((void *) context->options.key_file_name); - free((void *) context->options.passphrase); - #ifndef LIBUS_NO_SSL if (ssl) { + /* This function will call us again with SSL=false */ us_internal_ssl_socket_context_free((struct us_internal_ssl_socket_context_t *) context); return; } #endif + /* This path is taken once either way - always AFTER whatever SSL may do BEFORE. + * This is the opposite order compared to when creating the context - SSL code is cleaning up before non-SSL */ + us_internal_loop_unlink(context->loop, context); free(context); } @@ -167,14 +203,14 @@ struct us_listen_socket_t *us_socket_context_listen(int ssl, struct us_socket_co return ls; } -struct us_socket_t *us_socket_context_connect(int ssl, struct us_socket_context_t *context, const char *host, int port, int options, int socket_ext_size) { +struct us_socket_t *us_socket_context_connect(int ssl, struct us_socket_context_t *context, const char *host, int port, const char *source_host, int options, int socket_ext_size) { #ifndef LIBUS_NO_SSL if (ssl) { - return (struct us_socket_t *) us_internal_ssl_socket_context_connect((struct us_internal_ssl_socket_context_t *) context, host, port, options, socket_ext_size); + return (struct us_socket_t *) us_internal_ssl_socket_context_connect((struct us_internal_ssl_socket_context_t *) context, host, port, source_host, options, socket_ext_size); } #endif - LIBUS_SOCKET_DESCRIPTOR connect_socket_fd = bsd_create_connect_socket(host, port, options); + LIBUS_SOCKET_DESCRIPTOR connect_socket_fd = bsd_create_connect_socket(host, port, source_host, options); if (connect_socket_fd == LIBUS_SOCKET_ERROR) { return 0; } @@ -239,10 +275,10 @@ void us_socket_context_on_open(int ssl, struct us_socket_context_t *context, str context->on_open = on_open; } -void us_socket_context_on_close(int ssl, struct us_socket_context_t *context, struct us_socket_t *(*on_close)(struct us_socket_t *s)) { +void us_socket_context_on_close(int ssl, struct us_socket_context_t *context, struct us_socket_t *(*on_close)(struct us_socket_t *s, int code, void *reason)) { #ifndef LIBUS_NO_SSL if (ssl) { - us_internal_ssl_socket_context_on_close((struct us_internal_ssl_socket_context_t *) context, (struct us_internal_ssl_socket_t * (*)(struct us_internal_ssl_socket_t *)) on_close); + us_internal_ssl_socket_context_on_close((struct us_internal_ssl_socket_context_t *) context, (struct us_internal_ssl_socket_t * (*)(struct us_internal_ssl_socket_t *, int code, void *reason)) on_close); return; } #endif @@ -294,6 +330,17 @@ void us_socket_context_on_end(int ssl, struct us_socket_context_t *context, stru context->on_end = on_end; } +void us_socket_context_on_connect_error(int ssl, struct us_socket_context_t *context, struct us_socket_t *(*on_connect_error)(struct us_socket_t *s, int code)) { +#ifndef LIBUS_NO_SSL + if (ssl) { + us_internal_ssl_socket_context_on_connect_error((struct us_internal_ssl_socket_context_t *) context, (struct us_internal_ssl_socket_t * (*)(struct us_internal_ssl_socket_t *, int)) on_connect_error); + return; + } +#endif + + context->on_connect_error = on_connect_error; +} + void *us_socket_context_ext(int ssl, struct us_socket_context_t *context) { #ifndef LIBUS_NO_SSL if (ssl) { diff --git a/src/Vendor/uwebsockets/src/crypto/openssl.c b/src/Vendor/uwebsockets/src/crypto/openssl.c index 2c9133e..bbeb024 100644 --- a/src/Vendor/uwebsockets/src/crypto/openssl.c +++ b/src/Vendor/uwebsockets/src/crypto/openssl.c @@ -17,8 +17,16 @@ #ifdef LIBUS_USE_OPENSSL +/* These are in sni_tree.cpp */ +void *sni_new(); +void sni_free(void *sni, void(*cb)(void *)); +int sni_add(void *sni, const char *hostname, void *user); +void *sni_remove(void *sni, const char *hostname); +void *sni_find(void *sni, const char *hostname); + #include "libusockets.h" #include "internal/internal.h" +#include /* This module contains the entire OpenSSL implementation * of the SSL socket and socket context interfaces. */ @@ -59,10 +67,17 @@ struct us_internal_ssl_socket_context_t { SSL_CTX *ssl_context; int is_parent; - // här måste det vara! + /* These decorate the base implementation */ struct us_internal_ssl_socket_t *(*on_open)(struct us_internal_ssl_socket_t *, int is_client, char *ip, int ip_length); struct us_internal_ssl_socket_t *(*on_data)(struct us_internal_ssl_socket_t *, char *data, int length); - struct us_internal_ssl_socket_t *(*on_close)(struct us_internal_ssl_socket_t *); + struct us_internal_ssl_socket_t *(*on_writable)(struct us_internal_ssl_socket_t *); + struct us_internal_ssl_socket_t *(*on_close)(struct us_internal_ssl_socket_t *, int code, void *reason); + + /* Called for missing SNI hostnames, if not NULL */ + void (*on_server_name)(struct us_internal_ssl_socket_context_t *, const char *hostname); + + /* Pointer to sni tree, created when the context is created and freed likewise when freed */ + void *sni; }; // same here, should or shouldn't it contain s? @@ -70,6 +85,7 @@ struct us_internal_ssl_socket_t { struct us_socket_t s; SSL *ssl; int ssl_write_wants_read; // we use this for now + int ssl_read_wants_write; }; int passphrase_cb(char *buf, int size, int rwflag, void *u) { @@ -103,7 +119,7 @@ int BIO_s_custom_write(BIO *bio, const char *data, int length) { int written = us_socket_write(0, loop_ssl_data->ssl_socket, data, length, loop_ssl_data->last_write_was_msg_more); if (!written) { - BIO_set_flags(bio, BIO_get_flags(bio) | BIO_FLAGS_SHOULD_RETRY | BIO_FLAGS_WRITE); + BIO_set_flags(bio, BIO_FLAGS_SHOULD_RETRY | BIO_FLAGS_WRITE); return -1; } @@ -118,7 +134,7 @@ int BIO_s_custom_read(BIO *bio, char *dst, int length) { //printf("BIO_s_custom_read\n"); if (!loop_ssl_data->ssl_read_input_length) { - BIO_set_flags(bio, BIO_get_flags(bio) | BIO_FLAGS_SHOULD_RETRY | BIO_FLAGS_READ); + BIO_set_flags(bio, BIO_FLAGS_SHOULD_RETRY | BIO_FLAGS_READ); return -1; } @@ -141,6 +157,7 @@ struct us_internal_ssl_socket_t *ssl_on_open(struct us_internal_ssl_socket_t *s, s->ssl = SSL_new(context->ssl_context); s->ssl_write_wants_read = 0; + s->ssl_read_wants_write = 0; SSL_set_bio(s->ssl, loop_ssl_data->shared_rbio, loop_ssl_data->shared_wbio); BIO_up_ref(loop_ssl_data->shared_rbio); @@ -155,19 +172,26 @@ struct us_internal_ssl_socket_t *ssl_on_open(struct us_internal_ssl_socket_t *s, return (struct us_internal_ssl_socket_t *) context->on_open(s, is_client, ip, ip_length); } -struct us_internal_ssl_socket_t *ssl_on_close(struct us_internal_ssl_socket_t *s) { +/* This one is a helper; it is entirely shared with non-SSL so can be removed */ +struct us_internal_ssl_socket_t *us_internal_ssl_socket_close(struct us_internal_ssl_socket_t *s, int code, void *reason) { + return (struct us_internal_ssl_socket_t *) us_socket_close(0, (struct us_socket_t *) s, code, reason); +} + +struct us_internal_ssl_socket_t *ssl_on_close(struct us_internal_ssl_socket_t *s, int code, void *reason) { struct us_internal_ssl_socket_context_t *context = (struct us_internal_ssl_socket_context_t *) us_socket_context(0, &s->s); SSL_free(s->ssl); - return context->on_close(s); + return context->on_close(s, code, reason); } struct us_internal_ssl_socket_t *ssl_on_end(struct us_internal_ssl_socket_t *s) { - struct us_internal_ssl_socket_context_t *context = (struct us_internal_ssl_socket_context_t *) us_socket_context(0, &s->s); + // struct us_internal_ssl_socket_context_t *context = (struct us_internal_ssl_socket_context_t *) us_socket_context(0, &s->s); // whatever state we are in, a TCP FIN is always an answered shutdown - return us_internal_ssl_socket_close(s); + + /* Todo: this should report CLEANLY SHUTDOWN as reason */ + return us_internal_ssl_socket_close(s, 0, NULL); } // this whole function needs a complete clean-up @@ -192,7 +216,8 @@ struct us_internal_ssl_socket_t *ssl_on_data(struct us_internal_ssl_socket_t *s, // two phase shutdown is complete here //printf("Two step SSL shutdown complete\n"); - return us_internal_ssl_socket_close(s); + /* Todo: this should also report some kind of clean shutdown */ + return us_internal_ssl_socket_close(s, 0, NULL); } else if (ret < 0) { int err = SSL_get_error(s->ssl, ret); @@ -226,13 +251,18 @@ struct us_internal_ssl_socket_t *ssl_on_data(struct us_internal_ssl_socket_t *s, } // terminate connection here - return us_internal_ssl_socket_close(s); + return us_internal_ssl_socket_close(s, 0, NULL); } else { // emit the data we have and exit + if (err == SSL_ERROR_WANT_WRITE) { + // here we need to trigger writable event next ssl_read! + s->ssl_read_wants_write = 1; + } + // assume we emptied the input buffer fully or error here as well! if (loop_ssl_data->ssl_read_input_length) { - return us_internal_ssl_socket_close(s); + return us_internal_ssl_socket_close(s, 0, NULL); } // cannot emit zero length to app @@ -240,6 +270,8 @@ struct us_internal_ssl_socket_t *ssl_on_data(struct us_internal_ssl_socket_t *s, break; } + context = (struct us_internal_ssl_socket_context_t *) us_socket_context(0, &s->s); + s = context->on_data(s, loop_ssl_data->ssl_read_output + LIBUS_RECV_BUFFER_PADDING, read); if (us_socket_is_closed(0, &s->s)) { return s; @@ -255,6 +287,8 @@ struct us_internal_ssl_socket_t *ssl_on_data(struct us_internal_ssl_socket_t *s, // at this point we might be full and need to emit the data to application and start over if (read == LIBUS_RECV_BUFFER_LENGTH) { + context = (struct us_internal_ssl_socket_context_t *) us_socket_context(0, &s->s); + // emit data and restart s = context->on_data(s, loop_ssl_data->ssl_read_output + LIBUS_RECV_BUFFER_PADDING, read); if (us_socket_is_closed(0, &s->s)) { @@ -287,7 +321,7 @@ struct us_internal_ssl_socket_t *ssl_on_data(struct us_internal_ssl_socket_t *s, //exit(-2); // not correct anyways! - s = us_internal_ssl_socket_close(s); + s = us_internal_ssl_socket_close(s, 0, NULL); //us_ } @@ -295,6 +329,28 @@ struct us_internal_ssl_socket_t *ssl_on_data(struct us_internal_ssl_socket_t *s, return s; } +struct us_internal_ssl_socket_t *ssl_on_writable(struct us_internal_ssl_socket_t *s) { + + struct us_internal_ssl_socket_context_t *context = (struct us_internal_ssl_socket_context_t *) us_socket_context(0, &s->s); + + // todo: cork here so that we efficiently output both from reading and from writing? + + if (s->ssl_read_wants_write) { + s->ssl_read_wants_write = 0; + + // make sure to update context before we call (context can change if the user adopts the socket!) + context = (struct us_internal_ssl_socket_context_t *) us_socket_context(0, &s->s); + + // if this one fails to write data, it sets ssl_read_wants_write again + s = (struct us_internal_ssl_socket_t *) context->sc.on_data(&s->s, 0, 0); // cast here! + } + + // should this one come before we have read? should it come always? spurious on_writable is okay + s = context->on_writable(s); + + return s; +} + /* Lazily inits loop ssl data first time */ void us_internal_init_loop_ssl_data(struct us_loop_t *loop) { if (!loop->data.ssl_data) { @@ -371,60 +427,77 @@ int ssl_ignore_data(struct us_internal_ssl_socket_t *s) { } /* Per-context functions */ -struct us_internal_ssl_socket_context_t *us_internal_create_child_ssl_socket_context(struct us_internal_ssl_socket_context_t *context, int context_ext_size) { - struct us_socket_context_options_t options = {0}; +void *us_internal_ssl_socket_context_get_native_handle(struct us_internal_ssl_socket_context_t *context) { + return context->ssl_context; +} +struct us_internal_ssl_socket_context_t *us_internal_create_child_ssl_socket_context(struct us_internal_ssl_socket_context_t *context, int context_ext_size) { + /* Create a new non-SSL context */ + struct us_socket_context_options_t options = {0}; struct us_internal_ssl_socket_context_t *child_context = (struct us_internal_ssl_socket_context_t *) us_create_socket_context(0, context->sc.loop, sizeof(struct us_internal_ssl_socket_context_t) + context_ext_size, options); - // I think this is the only thing being shared + /* The only thing we share is SSL_CTX */ child_context->ssl_context = context->ssl_context; child_context->is_parent = 0; return child_context; } -struct us_internal_ssl_socket_context_t *us_internal_create_ssl_socket_context(struct us_loop_t *loop, int context_ext_size, struct us_socket_context_options_t options) { +/* Common function for creating a context from options. + * We must NOT free a SSL_CTX with only SSL_CTX_free! Also free any password */ +void free_ssl_context(SSL_CTX *ssl_context) { + if (!ssl_context) { + return; + } - us_internal_init_loop_ssl_data(loop); + /* If we have set a password string, free it here */ + void *password = SSL_CTX_get_default_passwd_cb_userdata(ssl_context); + /* OpenSSL returns NULL if we have no set password */ + free(password); - struct us_socket_context_options_t no_options = {0}; + SSL_CTX_free(ssl_context); +} - struct us_internal_ssl_socket_context_t *context = (struct us_internal_ssl_socket_context_t *) us_create_socket_context(0, loop, sizeof(struct us_internal_ssl_socket_context_t) + context_ext_size, no_options); +/* This function should take any options and return SSL_CTX - which has to be free'd with + * our destructor function - free_ssl_context() */ +SSL_CTX *create_ssl_context_from_options(struct us_socket_context_options_t options) { + /* Create the context */ + SSL_CTX *ssl_context = SSL_CTX_new(TLS_method()); - context->ssl_context = SSL_CTX_new(TLS_method()); - context->is_parent = 1; - // only parent ssl contexts may need to ignore data - context->sc.ignore_data = (int (*)(struct us_socket_t *)) ssl_ignore_data; + /* Default options we rely on - changing these will break our logic */ + SSL_CTX_set_read_ahead(ssl_context, 1); + SSL_CTX_set_mode(ssl_context, SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER); - // options - SSL_CTX_set_read_ahead(context->ssl_context, 1); - SSL_CTX_set_mode(context->ssl_context, SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER); - //SSL_CTX_set_mode(context->ssl_context, SSL_MODE_ENABLE_PARTIAL_WRITE); + /* Anything below TLS 1.2 is disabled */ + SSL_CTX_set_min_proto_version(ssl_context, TLS1_2_VERSION); - // this lowers performance a bit in benchmarks + /* The following are helpers. You may easily implement whatever you want by using the native handle directly */ + + /* Important option for lowering memory usage, but lowers performance slightly */ if (options.ssl_prefer_low_memory_usage) { - SSL_CTX_set_mode(context->ssl_context, SSL_MODE_RELEASE_BUFFERS); + SSL_CTX_set_mode(ssl_context, SSL_MODE_RELEASE_BUFFERS); } - //SSL_CTX_set_mode(context->ssl_context, SSL_MODE_RELEASE_BUFFERS); - SSL_CTX_set_options(context->ssl_context, SSL_OP_NO_SSLv3); - SSL_CTX_set_options(context->ssl_context, SSL_OP_NO_TLSv1); - - // these are going to be extended if (options.passphrase) { - SSL_CTX_set_default_passwd_cb_userdata(context->ssl_context, (void *) options.passphrase); - SSL_CTX_set_default_passwd_cb(context->ssl_context, passphrase_cb); + /* When freeing the CTX we need to check SSL_CTX_get_default_passwd_cb_userdata and + * free it if set */ + SSL_CTX_set_default_passwd_cb_userdata(ssl_context, (void *) strdup(options.passphrase)); + SSL_CTX_set_default_passwd_cb(ssl_context, passphrase_cb); } + /* This one most probably do not need the cert_file_name string to be kept alive */ if (options.cert_file_name) { - if (SSL_CTX_use_certificate_chain_file(context->ssl_context, options.cert_file_name) != 1) { - return 0; + if (SSL_CTX_use_certificate_chain_file(ssl_context, options.cert_file_name) != 1) { + free_ssl_context(ssl_context); + return NULL; } } + /* Same as above - we can discard this string afterwards I suppose */ if (options.key_file_name) { - if (SSL_CTX_use_PrivateKey_file(context->ssl_context, options.key_file_name, SSL_FILETYPE_PEM) != 1) { - return 0; + if (SSL_CTX_use_PrivateKey_file(ssl_context, options.key_file_name, SSL_FILETYPE_PEM) != 1) { + free_ssl_context(ssl_context); + return NULL; } } @@ -432,13 +505,15 @@ struct us_internal_ssl_socket_context_t *us_internal_create_ssl_socket_context(s STACK_OF(X509_NAME) *ca_list; ca_list = SSL_load_client_CA_file(options.ca_file_name); if(ca_list == NULL) { - return 0; + free_ssl_context(ssl_context); + return NULL; } - SSL_CTX_set_client_CA_list(context->ssl_context, ca_list); - if (SSL_CTX_load_verify_locations(context->ssl_context, options.ca_file_name, NULL) != 1) { - return 0; + SSL_CTX_set_client_CA_list(ssl_context, ca_list); + if (SSL_CTX_load_verify_locations(ssl_context, options.ca_file_name, NULL) != 1) { + free_ssl_context(ssl_context); + return NULL; } - SSL_CTX_set_verify(context->ssl_context, SSL_VERIFY_PEER, NULL); + SSL_CTX_set_verify(ssl_context, SSL_VERIFY_PEER, NULL); } if (options.dh_params_file_name) { @@ -451,29 +526,155 @@ struct us_internal_ssl_socket_context_t *us_internal_create_ssl_socket_context(s dh_2048 = PEM_read_DHparams(paramfile, NULL, NULL, NULL); fclose(paramfile); } else { - return 0; + free_ssl_context(ssl_context); + return NULL; } if (dh_2048 == NULL) { - return 0; + free_ssl_context(ssl_context); + return NULL; } - if (SSL_CTX_set_tmp_dh(context->ssl_context, dh_2048) != 1) { - return 0; + const long set_tmp_dh = SSL_CTX_set_tmp_dh(ssl_context, dh_2048); + DH_free(dh_2048); + + if (set_tmp_dh != 1) { + free_ssl_context(ssl_context); + return NULL; } /* OWASP Cipher String 'A+' (https://www.owasp.org/index.php/TLS_Cipher_String_Cheat_Sheet) */ - if (SSL_CTX_set_cipher_list(context->ssl_context, "DHE-RSA-AES256-GCM-SHA384:DHE-RSA-AES128-GCM-SHA256:ECDHE-RSA-AES256-GCM-SHA384:ECDHE-RSA-AES128-GCM-SHA256") != 1) { - return 0; + if (SSL_CTX_set_cipher_list(ssl_context, "DHE-RSA-AES256-GCM-SHA384:DHE-RSA-AES128-GCM-SHA256:ECDHE-RSA-AES256-GCM-SHA384:ECDHE-RSA-AES128-GCM-SHA256") != 1) { + free_ssl_context(ssl_context); + return NULL; } } + /* This must be free'd with free_ssl_context, not SSL_CTX_free */ + return ssl_context; +} + +/* Todo: return error on failure? */ +void us_internal_ssl_socket_context_add_server_name(struct us_internal_ssl_socket_context_t *context, const char *hostname_pattern, struct us_socket_context_options_t options) { + + /* Try and construct an SSL_CTX from options */ + SSL_CTX *ssl_context = create_ssl_context_from_options(options); + + /* We do not want to hold any nullptr's in our SNI tree */ + if (ssl_context) { + if (sni_add(context->sni, hostname_pattern, ssl_context)) { + /* If we already had that name, ignore */ + free_ssl_context(ssl_context); + } + } +} + +void us_internal_ssl_socket_context_on_server_name(struct us_internal_ssl_socket_context_t *context, void (*cb)(struct us_internal_ssl_socket_context_t *, const char *hostname)) { + context->on_server_name = cb; +} + +void us_internal_ssl_socket_context_remove_server_name(struct us_internal_ssl_socket_context_t *context, const char *hostname_pattern) { + + /* The same thing must happen for sni_free, that's why we have a callback */ + SSL_CTX *sni_node_ssl_context = (SSL_CTX *) sni_remove(context->sni, hostname_pattern); + free_ssl_context(sni_node_ssl_context); +} + +/* Returns NULL or SSL_CTX. May call missing server name callback */ +SSL_CTX *resolve_context(struct us_internal_ssl_socket_context_t *context, const char *hostname) { + + /* Try once first */ + void *user = sni_find(context->sni, hostname); + if (!user) { + /* Emit missing hostname then try again */ + if (!context->on_server_name) { + /* We have no callback registered, so fail */ + return NULL; + } + + context->on_server_name(context, hostname); + + /* Last try */ + user = sni_find(context->sni, hostname); + } + + return user; +} + +// arg is context +int sni_cb(SSL *ssl, int *al, void *arg) { + + if (ssl) { + const char *hostname = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name); + if (hostname && hostname[0]) { + /* Try and resolve (match) required hostname with what we have registered */ + SSL_CTX *resolved_ssl_context = resolve_context((struct us_internal_ssl_socket_context_t *) arg, hostname); + if (resolved_ssl_context) { + //printf("Did find matching SNI context for hostname: <%s>!\n", hostname); + SSL_set_SSL_CTX(ssl, resolved_ssl_context); + } else { + /* Call a blocking callback notifying of missing context */ + } + + } + + return SSL_TLSEXT_ERR_OK; + } + + /* Can we even come here ever? */ + return SSL_TLSEXT_ERR_NOACK; +} + +struct us_internal_ssl_socket_context_t *us_internal_create_ssl_socket_context(struct us_loop_t *loop, int context_ext_size, struct us_socket_context_options_t options) { + /* If we haven't initialized the loop data yet, do so . + * This is needed because loop data holds shared OpenSSL data and + * the function is also responsible for initializing OpenSSL */ + us_internal_init_loop_ssl_data(loop); + + /* First of all we try and create the SSL context from options */ + SSL_CTX *ssl_context = create_ssl_context_from_options(options); + if (!ssl_context) { + /* We simply fail early if we cannot even create the OpenSSL context */ + return NULL; + } + + /* Otherwise ee continue by creating a non-SSL context, but with larger ext to hold our SSL stuff */ + struct us_internal_ssl_socket_context_t *context = (struct us_internal_ssl_socket_context_t *) us_create_socket_context(0, loop, sizeof(struct us_internal_ssl_socket_context_t) + context_ext_size, options); + + /* I guess this is the only optional callback */ + context->on_server_name = NULL; + + /* Then we extend its SSL parts */ + context->ssl_context = ssl_context;//create_ssl_context_from_options(options); + context->is_parent = 1; + + /* We, as parent context, may ignore data */ + context->sc.ignore_data = (int (*)(struct us_socket_t *)) ssl_ignore_data; + + /* Parent contexts may use SNI */ + SSL_CTX_set_tlsext_servername_callback(context->ssl_context, sni_cb); + SSL_CTX_set_tlsext_servername_arg(context->ssl_context, context); + + /* Also create the SNI tree */ + context->sni = sni_new(); + return context; } +/* Our destructor for hostnames, used below */ +void sni_hostname_destructor(void *user) { + /* Some nodes hold null, so this one must ignore this case */ + free_ssl_context((SSL_CTX *) user); +} + void us_internal_ssl_socket_context_free(struct us_internal_ssl_socket_context_t *context) { + /* If we are parent then we need to free our OpenSSL context */ if (context->is_parent) { - SSL_CTX_free(context->ssl_context); + free_ssl_context(context->ssl_context); + + /* Here we need to register a temporary callback for all still-existing hostnames + * and their contexts. Only parents have an SNI tree */ + sni_free(context->sni, sni_hostname_destructor); } us_socket_context_free(0, &context->sc); @@ -483,8 +684,8 @@ struct us_listen_socket_t *us_internal_ssl_socket_context_listen(struct us_inter return us_socket_context_listen(0, &context->sc, host, port, options, sizeof(struct us_internal_ssl_socket_t) - sizeof(struct us_socket_t) + socket_ext_size); } -struct us_internal_ssl_socket_t *us_internal_ssl_socket_context_connect(struct us_internal_ssl_socket_context_t *context, const char *host, int port, int options, int socket_ext_size) { - return (struct us_internal_ssl_socket_t *) us_socket_context_connect(0, &context->sc, host, port, options, sizeof(struct us_internal_ssl_socket_t) - sizeof(struct us_socket_t) + socket_ext_size); +struct us_internal_ssl_socket_t *us_internal_ssl_socket_context_connect(struct us_internal_ssl_socket_context_t *context, const char *host, int port, const char *source_host, int options, int socket_ext_size) { + return (struct us_internal_ssl_socket_t *) us_socket_context_connect(0, &context->sc, host, port, source_host, options, sizeof(struct us_internal_ssl_socket_t) - sizeof(struct us_socket_t) + socket_ext_size); } void us_internal_ssl_socket_context_on_open(struct us_internal_ssl_socket_context_t *context, struct us_internal_ssl_socket_t *(*on_open)(struct us_internal_ssl_socket_t *s, int is_client, char *ip, int ip_length)) { @@ -492,8 +693,8 @@ void us_internal_ssl_socket_context_on_open(struct us_internal_ssl_socket_contex context->on_open = on_open; } -void us_internal_ssl_socket_context_on_close(struct us_internal_ssl_socket_context_t *context, struct us_internal_ssl_socket_t *(*on_close)(struct us_internal_ssl_socket_t *s)) { - us_socket_context_on_close(0, (struct us_socket_context_t *) context, (struct us_socket_t *(*)(struct us_socket_t *)) ssl_on_close); +void us_internal_ssl_socket_context_on_close(struct us_internal_ssl_socket_context_t *context, struct us_internal_ssl_socket_t *(*on_close)(struct us_internal_ssl_socket_t *s, int code, void *reason)) { + us_socket_context_on_close(0, (struct us_socket_context_t *) context, (struct us_socket_t *(*)(struct us_socket_t *, int, void *)) ssl_on_close); context->on_close = on_close; } @@ -503,22 +704,31 @@ void us_internal_ssl_socket_context_on_data(struct us_internal_ssl_socket_contex } void us_internal_ssl_socket_context_on_writable(struct us_internal_ssl_socket_context_t *context, struct us_internal_ssl_socket_t *(*on_writable)(struct us_internal_ssl_socket_t *s)) { - us_socket_context_on_writable(0, (struct us_socket_context_t *) context, (struct us_socket_t *(*)(struct us_socket_t *)) on_writable); + us_socket_context_on_writable(0, (struct us_socket_context_t *) context, (struct us_socket_t *(*)(struct us_socket_t *)) ssl_on_writable); + context->on_writable = on_writable; } void us_internal_ssl_socket_context_on_timeout(struct us_internal_ssl_socket_context_t *context, struct us_internal_ssl_socket_t *(*on_timeout)(struct us_internal_ssl_socket_t *s)) { us_socket_context_on_timeout(0, (struct us_socket_context_t *) context, (struct us_socket_t *(*)(struct us_socket_t *)) on_timeout); } +/* We do not really listen to passed FIN-handler, we entirely override it with our handler since SSL doesn't really have support for half-closed sockets */ void us_internal_ssl_socket_context_on_end(struct us_internal_ssl_socket_context_t *context, struct us_internal_ssl_socket_t *(*on_end)(struct us_internal_ssl_socket_t *)) { us_socket_context_on_end(0, (struct us_socket_context_t *) context, (struct us_socket_t *(*)(struct us_socket_t *)) ssl_on_end); } +void us_internal_ssl_socket_context_on_connect_error(struct us_internal_ssl_socket_context_t *context, struct us_internal_ssl_socket_t *(*on_connect_error)(struct us_internal_ssl_socket_t *, int code)) { + us_socket_context_on_connect_error(0, (struct us_socket_context_t *) context, (struct us_socket_t *(*)(struct us_socket_t *, int)) on_connect_error); +} + void *us_internal_ssl_socket_context_ext(struct us_internal_ssl_socket_context_t *context) { return context + 1; } /* Per socket functions */ +void *us_internal_ssl_socket_get_native_handle(struct us_internal_ssl_socket_t *s) { + return s->ssl; +} int us_internal_ssl_socket_write(struct us_internal_ssl_socket_t *s, const char *data, int length, int msg_more) { if (us_socket_is_closed(0, &s->s) || us_internal_ssl_socket_is_shut_down(s)) { @@ -617,10 +827,6 @@ void us_internal_ssl_socket_shutdown(struct us_internal_ssl_socket_t *s) { } } -struct us_internal_ssl_socket_t *us_internal_ssl_socket_close(struct us_internal_ssl_socket_t *s) { - return (struct us_internal_ssl_socket_t *) us_socket_close(0, (struct us_socket_t *) s); -} - struct us_internal_ssl_socket_t *us_internal_ssl_socket_context_adopt_socket(struct us_internal_ssl_socket_context_t *context, struct us_internal_ssl_socket_t *s, int ext_size) { // todo: this is completely untested return (struct us_internal_ssl_socket_t *) us_socket_context_adopt_socket(0, &context->sc, &s->s, sizeof(struct us_internal_ssl_socket_t) - sizeof(struct us_socket_t) + ext_size); diff --git a/src/Vendor/uwebsockets/src/crypto/sni_tree.cpp b/src/Vendor/uwebsockets/src/crypto/sni_tree.cpp new file mode 100644 index 0000000..fbe0d38 --- /dev/null +++ b/src/Vendor/uwebsockets/src/crypto/sni_tree.cpp @@ -0,0 +1,218 @@ +/* + * Authored by Alex Hultman, 2018-2020. + * Intellectual property of third-party. + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* This Server Name Indication hostname tree is written in C++ but could be ported to C. + * Overall it looks like crap, but has no memory allocations in fast path and is O(log n). */ + +#ifndef SNI_TREE_H +#define SNI_TREE_H + +#ifndef LIBUS_NO_SSL + +#include +#include +#include +#include +#include +#include + +/* We only handle a maximum of 10 labels per hostname */ +#define MAX_LABELS 10 + +/* This cannot be shared */ +thread_local void (*sni_free_cb)(void *); + +struct sni_node { + /* Empty nodes must always hold null */ + void *user = nullptr; + std::map> children; + + ~sni_node() { + for (auto &p : children) { + /* The data of our string_views are managed by malloc */ + free((void *) p.first.data()); + + /* Call destructor passed to sni_free only if we hold data. + * This is important since sni_remove does not have sni_free_cb set */ + if (p.second.get()->user) { + sni_free_cb(p.second.get()->user); + } + } + } +}; + +// this can only delete ONE single node, but may cull "empty nodes with null as data" +void *removeUser(struct sni_node *root, unsigned int label, std::string_view *labels, unsigned int numLabels) { + + /* If we are in the bottom (past bottom by one), there is nothing to remove */ + if (label == numLabels) { + void *user = root->user; + /* Mark us for culling on the way up */ + root->user = nullptr; + return user; + } + + /* Is this label a child of root? */ + auto it = root->children.find(labels[label]); + if (it == root->children.end()) { + /* We cannot continue */ + return nullptr; + } + + void *removedUser = removeUser(it->second.get(), label + 1, labels, numLabels); + + /* On the way back up, we may cull empty nodes with no children. + * This ends up being where we remove all nodes */ + if (it->second.get()->children.empty() && it->second.get()->user == nullptr) { + + /* The data of our string_views are managed by malloc */ + free((void *) it->first.data()); + + /* This can only happen with user set to null, otherwise we use sni_free_cb which is unset by sni_remove */ + root->children.erase(it); + } + + return removedUser; +} + +void *getUser(struct sni_node *root, unsigned int label, std::string_view *labels, unsigned int numLabels) { + + /* Do we have labels to match? Otherwise, return where we stand */ + if (label == numLabels) { + return root->user; + } + + /* Try and match by our label */ + auto it = root->children.find(labels[label]); + if (it != root->children.end()) { + void *user = getUser(it->second.get(), label + 1, labels, numLabels); + if (user) { + return user; + } + } + + /* Try and match by wildcard */ + it = root->children.find("*"); + if (it == root->children.end()) { + /* Matching has failed for both label and wildcard */ + return nullptr; + } + + /* We matched by wildcard */ + return getUser(it->second.get(), label + 1, labels, numLabels); +} + +extern "C" { + + void *sni_new() { + return new sni_node; + } + + void sni_free(void *sni, void (*cb)(void *)) { + /* We want to run this callback for every remaining name */ + sni_free_cb = cb; + + delete (sni_node *) sni; + } + + /* Returns non-null if this name already exists */ + int sni_add(void *sni, const char *hostname, void *user) { + struct sni_node *root = (struct sni_node *) sni; + + /* Traverse all labels in hostname */ + for (std::string_view view(hostname, strlen(hostname)), label; + view.length(); view.remove_prefix(std::min(view.length(), label.length() + 1))) { + /* Label is the token separated by dot */ + label = view.substr(0, view.find('.', 0)); + + auto it = root->children.find(label); + if (it == root->children.end()) { + /* Duplicate this label for our kept string_view of it */ + void *labelString = malloc(label.length()); + memcpy(labelString, label.data(), label.length()); + + it = root->children.emplace(std::string_view((char *) labelString, label.length()), + std::make_unique()).first; + } + + root = it->second.get(); + } + + /* We must never add multiple contexts for the same name, as that would overwrite and leak */ + if (root->user) { + return 1; + } + + root->user = user; + + return 0; + } + + /* Removes the exact match. Wildcards are treated as the verbatim asterisk char, not as an actual wildcard */ + void *sni_remove(void *sni, const char *hostname) { + struct sni_node *root = (struct sni_node *) sni; + + /* I guess 10 labels is an okay limit */ + std::string_view labels[10]; + unsigned int numLabels = 0; + + /* We traverse all labels first of all */ + for (std::string_view view(hostname, strlen(hostname)), label; + view.length(); view.remove_prefix(std::min(view.length(), label.length() + 1))) { + /* Label is the token separated by dot */ + label = view.substr(0, view.find('.', 0)); + + /* Anything longer than 10 labels is forbidden */ + if (numLabels == 10) { + return nullptr; + } + + labels[numLabels++] = label; + } + + return removeUser(root, 0, labels, numLabels); + } + + void *sni_find(void *sni, const char *hostname) { + struct sni_node *root = (struct sni_node *) sni; + + /* I guess 10 labels is an okay limit */ + std::string_view labels[10]; + unsigned int numLabels = 0; + + /* We traverse all labels first of all */ + for (std::string_view view(hostname, strlen(hostname)), label; + view.length(); view.remove_prefix(std::min(view.length(), label.length() + 1))) { + /* Label is the token separated by dot */ + label = view.substr(0, view.find('.', 0)); + + /* Anything longer than 10 labels is forbidden */ + if (numLabels == 10) { + return nullptr; + } + + labels[numLabels++] = label; + } + + return getUser(root, 0, labels, numLabels); + } + +} + +#endif + +#endif \ No newline at end of file diff --git a/src/Vendor/uwebsockets/src/crypto/wolfssl.c b/src/Vendor/uwebsockets/src/crypto/wolfssl.c index 874d573..748bc91 100644 --- a/src/Vendor/uwebsockets/src/crypto/wolfssl.c +++ b/src/Vendor/uwebsockets/src/crypto/wolfssl.c @@ -369,7 +369,7 @@ struct us_internal_ssl_socket_context_t *us_internal_create_ssl_socket_context(s // options wolfSSL_CTX_set_read_ahead(context->ssl_context, 1); wolfSSL_CTX_set_mode(context->ssl_context, WOLFSSL_MODE_ACCEPT_MOVING_WRITE_BUFFER); - + // this lowers performance a bit in benchmarks if (options.ssl_prefer_low_memory_usage) { //SSL_CTX_set_mode(context->ssl_context, SSL_MODE_RELEASE_BUFFERS); @@ -423,8 +423,8 @@ struct us_listen_socket_t *us_internal_ssl_socket_context_listen(struct us_inter return us_socket_context_listen(0, &context->sc, host, port, options, sizeof(struct us_internal_ssl_socket_t) - sizeof(struct us_socket_t) + socket_ext_size); } -struct us_internal_ssl_socket_t *us_internal_ssl_socket_context_connect(struct us_internal_ssl_socket_context_t *context, const char *host, int port, int options, int socket_ext_size) { - return (struct us_internal_ssl_socket_t *) us_socket_context_connect(0, &context->sc, host, port, options, sizeof(struct us_internal_ssl_socket_t) - sizeof(struct us_socket_t) + socket_ext_size); +struct us_internal_ssl_socket_t *us_internal_ssl_socket_context_connect(struct us_internal_ssl_socket_context_t *context, const char *host, int port, const char *source_host, int options, int socket_ext_size) { + return (struct us_internal_ssl_socket_t *) us_socket_context_connect(0, &context->sc, host, port, source_host, options, sizeof(struct us_internal_ssl_socket_t) - sizeof(struct us_socket_t) + socket_ext_size); } void us_internal_ssl_socket_context_on_open(struct us_internal_ssl_socket_context_t *context, struct us_internal_ssl_socket_t *(*on_open)(struct us_internal_ssl_socket_t *s, int is_client, char *ip, int ip_length)) { diff --git a/src/Vendor/uwebsockets/src/eventing/epoll_kqueue.c b/src/Vendor/uwebsockets/src/eventing/epoll_kqueue.c index 7861662..671db7e 100644 --- a/src/Vendor/uwebsockets/src/eventing/epoll_kqueue.c +++ b/src/Vendor/uwebsockets/src/eventing/epoll_kqueue.c @@ -21,6 +21,9 @@ #if defined(LIBUS_USE_EPOLL) || defined(LIBUS_USE_KQUEUE) +/* Cannot include this one on Windows */ +#include + #ifdef LIBUS_USE_EPOLL #define GET_READY_POLL(loop, index) (struct us_poll_t *) loop->ready_polls[index].data.ptr #define SET_READY_POLL(loop, index, poll) loop->ready_polls[index].data.ptr = poll @@ -329,8 +332,8 @@ void us_timer_set(struct us_timer_t *t, void (*cb)(struct us_timer_t *t), int ms internal_cb->cb = (void (*)(struct us_internal_callback_t *)) cb; struct itimerspec timer_spec = { - {repeat_ms / 1000, ((long)repeat_ms * 1000000) % 1000000000}, - {ms / 1000, ((long)ms * 1000000) % 1000000000} + {repeat_ms / 1000, (long) (repeat_ms % 1000) * (long) 1000000}, + {ms / 1000, (long) (ms % 1000) * (long) 1000000} }; timerfd_settime(us_poll_fd((struct us_poll_t *) t), 0, &timer_spec, NULL); diff --git a/src/Vendor/uwebsockets/src/eventing/gcd.c b/src/Vendor/uwebsockets/src/eventing/gcd.c index 120eec4..4aa3fec 100644 --- a/src/Vendor/uwebsockets/src/eventing/gcd.c +++ b/src/Vendor/uwebsockets/src/eventing/gcd.c @@ -145,7 +145,7 @@ void *us_poll_ext(struct us_poll_t *p) { } unsigned int us_internal_accept_poll_event(struct us_poll_t *p) { - printf("us_internal_accept_poll_event\n"); + //printf("us_internal_accept_poll_event\n"); return 0; } diff --git a/src/Vendor/uwebsockets/src/eventing/libuv.c b/src/Vendor/uwebsockets/src/eventing/libuv.c index 615749c..765137f 100644 --- a/src/Vendor/uwebsockets/src/eventing/libuv.c +++ b/src/Vendor/uwebsockets/src/eventing/libuv.c @@ -1,5 +1,5 @@ /* - * Authored by Alex Hultman, 2018-2019. + * Authored by Alex Hultman, 2018-2021. * Intellectual property of third-party. * Licensed under the Apache License, Version 2.0 (the "License"); @@ -21,9 +21,9 @@ #ifdef LIBUS_USE_LIBUV -// poll dispatch +/* uv_poll_t->data always (except for most times after calling us_poll_stop) points to the us_poll_t */ static void poll_cb(uv_poll_t *p, int status, int events) { - us_internal_dispatch_ready_poll((struct us_poll_t *) p, status < 0, events); + us_internal_dispatch_ready_poll((struct us_poll_t *) p->data, status < 0, events); } static void prepare_cb(uv_prepare_t *p) { @@ -37,10 +37,21 @@ static void check_cb(uv_check_t *p) { us_internal_loop_post(loop); } +/* Not used for polls, since polls need two frees */ static void close_cb_free(uv_handle_t *h) { free(h->data); } +/* This one is different for polls, since we need two frees here */ +static void close_cb_free_poll(uv_handle_t *h) { + /* It is only in case we called us_poll_stop then quickly us_poll_free that we enter this. + * Most of the time, actual freeing is done by us_poll_free. */ + if (h->data) { + free(h->data); + free(h); + } +} + static void timer_cb(uv_timer_t *t) { struct us_internal_callback_t *cb = t->data; cb->cb(cb); @@ -59,9 +70,14 @@ void us_poll_init(struct us_poll_t *p, LIBUS_SOCKET_DESCRIPTOR fd, int poll_type } void us_poll_free(struct us_poll_t *p, struct us_loop_t *loop) { - if (uv_is_closing((uv_handle_t *) &p->uv_p)) { - p->uv_p.data = p; + /* The idea here is like so; in us_poll_stop we call uv_close after setting data of uv-poll to 0. + * This means that in close_cb_free we call free on 0 with does nothing, since us_poll_stop should + * not really free the poll. HOWEVER, if we then call us_poll_free while still closing the uv-poll, + * we simply change back the data to point to our structure so that we actually do free it like we should. */ + if (uv_is_closing((uv_handle_t *) p->uv_p)) { + p->uv_p->data = p; } else { + free(p->uv_p); free(p); } } @@ -69,24 +85,26 @@ void us_poll_free(struct us_poll_t *p, struct us_loop_t *loop) { void us_poll_start(struct us_poll_t *p, struct us_loop_t *loop, int events) { p->poll_type = us_internal_poll_type(p) | ((events & LIBUS_SOCKET_READABLE) ? POLL_TYPE_POLLING_IN : 0) | ((events & LIBUS_SOCKET_WRITABLE) ? POLL_TYPE_POLLING_OUT : 0); - uv_poll_init_socket(loop->uv_loop, &p->uv_p, p->fd); - uv_poll_start(&p->uv_p, events, poll_cb); + uv_poll_init_socket(loop->uv_loop, p->uv_p, p->fd); + uv_poll_start(p->uv_p, events, poll_cb); } void us_poll_change(struct us_poll_t *p, struct us_loop_t *loop, int events) { if (us_poll_events(p) != events) { p->poll_type = us_internal_poll_type(p) | ((events & LIBUS_SOCKET_READABLE) ? POLL_TYPE_POLLING_IN : 0) | ((events & LIBUS_SOCKET_WRITABLE) ? POLL_TYPE_POLLING_OUT : 0); - uv_poll_start(&p->uv_p, events, poll_cb); + uv_poll_start(p->uv_p, events, poll_cb); } } void us_poll_stop(struct us_poll_t *p, struct us_loop_t *loop) { - uv_poll_stop(&p->uv_p); + uv_poll_stop(p->uv_p); - // close but not free is needed here - p->uv_p.data = 0; - uv_close((uv_handle_t *) &p->uv_p, close_cb_free); // needed here + /* We normally only want to close the poll here, not free it. But if we stop it, then quickly "free" it with + * us_poll_free, we postpone the actual freeing to close_cb_free_poll whenever it triggers. + * That's why we set data to null here, so that us_poll_free can reset it if needed */ + p->uv_p->data = 0; + uv_close((uv_handle_t *) p->uv_p, close_cb_free_poll); } int us_poll_events(struct us_poll_t *p) { @@ -171,19 +189,17 @@ void us_loop_run(struct us_loop_t *loop) { } struct us_poll_t *us_create_poll(struct us_loop_t *loop, int fallthrough, unsigned int ext_size) { - return malloc(sizeof(struct us_poll_t) + ext_size); + struct us_poll_t *p = (struct us_poll_t *) malloc(sizeof(struct us_poll_t) + ext_size); + p->uv_p = malloc(sizeof(uv_poll_t)); + p->uv_p->data = p; + return p; } -// this one is broken, us_poll needs to hold a pointer to uv_poll_t for it to work (bad anyways) +/* If we update our block position we have to updarte the uv_poll data to point to us */ struct us_poll_t *us_poll_resize(struct us_poll_t *p, struct us_loop_t *loop, unsigned int ext_size) { - // do not support it yet - return p; - struct us_poll_t *new_p = realloc(p, sizeof(struct us_poll_t) + ext_size); - if (p != new_p) { - new_p->uv_p.data = new_p; - } + new_p->uv_p->data = new_p; return new_p; } @@ -207,7 +223,7 @@ struct us_timer_t *us_create_timer(struct us_loop_t *loop, int fallthrough, unsi } void *us_timer_ext(struct us_timer_t *timer) { - return ((struct us_internal_callback_t *) timer) + 1; + return ((char *) timer) + sizeof(struct us_internal_callback_t) + sizeof(uv_timer_t); } void us_timer_close(struct us_timer_t *t) { diff --git a/src/Vendor/uwebsockets/src/internal/eventing/epoll_kqueue.h b/src/Vendor/uwebsockets/src/internal/eventing/epoll_kqueue.h index 87e9348..b065354 100644 --- a/src/Vendor/uwebsockets/src/internal/eventing/epoll_kqueue.h +++ b/src/Vendor/uwebsockets/src/internal/eventing/epoll_kqueue.h @@ -59,7 +59,7 @@ struct us_loop_t { struct us_poll_t { alignas(LIBUS_EXT_ALIGNMENT) struct { - int fd : 28; + signed int fd : 28; // we could have this unsigned if we wanted to, -1 should never be used unsigned int poll_type : 4; } state; }; diff --git a/src/Vendor/uwebsockets/src/internal/eventing/libuv.h b/src/Vendor/uwebsockets/src/internal/eventing/libuv.h index 7488212..590878b 100644 --- a/src/Vendor/uwebsockets/src/internal/eventing/libuv.h +++ b/src/Vendor/uwebsockets/src/internal/eventing/libuv.h @@ -34,8 +34,10 @@ struct us_loop_t { uv_check_t *uv_check; }; +// it is no longer valid to cast a pointer to us_poll_t to a pointer of uv_poll_t struct us_poll_t { - uv_poll_t uv_p; + /* We need to hold a pointer to this uv_poll_t since we need to be able to resize our block */ + uv_poll_t *uv_p; LIBUS_SOCKET_DESCRIPTOR fd; unsigned char poll_type; }; diff --git a/src/Vendor/uwebsockets/src/internal/internal.h b/src/Vendor/uwebsockets/src/internal/internal.h index 459cd0b..ac54b0f 100644 --- a/src/Vendor/uwebsockets/src/internal/internal.h +++ b/src/Vendor/uwebsockets/src/internal/internal.h @@ -18,6 +18,12 @@ #ifndef INTERNAL_H #define INTERNAL_H +#if defined(_MSC_VER) +#define alignas(x) __declspec(align(x)) +#else +#include +#endif + /* We only have one networking implementation so far */ #include "internal/networking/bsd.h" @@ -100,7 +106,7 @@ struct us_listen_socket_t { struct us_socket_context_t { alignas(LIBUS_EXT_ALIGNMENT) struct us_loop_t *loop; - //unsigned short timeout; + unsigned short timestamp; struct us_socket_t *head; struct us_socket_t *iterator; struct us_socket_context_t *prev, *next; @@ -108,14 +114,12 @@ struct us_socket_context_t { struct us_socket_t *(*on_open)(struct us_socket_t *, int is_client, char *ip, int ip_length); struct us_socket_t *(*on_data)(struct us_socket_t *, char *data, int length); struct us_socket_t *(*on_writable)(struct us_socket_t *); - struct us_socket_t *(*on_close)(struct us_socket_t *); + struct us_socket_t *(*on_close)(struct us_socket_t *, int code, void *reason); //void (*on_timeout)(struct us_socket_context *); struct us_socket_t *(*on_socket_timeout)(struct us_socket_t *); struct us_socket_t *(*on_end)(struct us_socket_t *); + struct us_socket_t *(*on_connect_error)(struct us_socket_t *, int code); int (*ignore_data)(struct us_socket_t *); - - /* All contexts hold references to their own copied options */ - struct us_socket_context_options_t options; }; /* Internal SSL interface */ @@ -124,6 +128,14 @@ struct us_socket_context_t { struct us_internal_ssl_socket_context_t; struct us_internal_ssl_socket_t; +/* SNI functions */ +void us_internal_ssl_socket_context_add_server_name(struct us_internal_ssl_socket_context_t *context, const char *hostname_pattern, struct us_socket_context_options_t options); +void us_internal_ssl_socket_context_remove_server_name(struct us_internal_ssl_socket_context_t *context, const char *hostname_pattern); +void us_internal_ssl_socket_context_on_server_name(struct us_internal_ssl_socket_context_t *context, void (*cb)(struct us_internal_ssl_socket_context_t *, const char *)); + +void *us_internal_ssl_socket_get_native_handle(struct us_internal_ssl_socket_t *s); +void *us_internal_ssl_socket_context_get_native_handle(struct us_internal_ssl_socket_context_t *context); + struct us_internal_ssl_socket_context_t *us_internal_create_ssl_socket_context(struct us_loop_t *loop, int context_ext_size, struct us_socket_context_options_t options); @@ -132,7 +144,7 @@ void us_internal_ssl_socket_context_on_open(struct us_internal_ssl_socket_contex struct us_internal_ssl_socket_t *(*on_open)(struct us_internal_ssl_socket_t *s, int is_client, char *ip, int ip_length)); void us_internal_ssl_socket_context_on_close(struct us_internal_ssl_socket_context_t *context, - struct us_internal_ssl_socket_t *(*on_close)(struct us_internal_ssl_socket_t *s)); + struct us_internal_ssl_socket_t *(*on_close)(struct us_internal_ssl_socket_t *s, int code, void *reason)); void us_internal_ssl_socket_context_on_data(struct us_internal_ssl_socket_context_t *context, struct us_internal_ssl_socket_t *(*on_data)(struct us_internal_ssl_socket_t *s, char *data, int length)); @@ -146,11 +158,14 @@ void us_internal_ssl_socket_context_on_timeout(struct us_internal_ssl_socket_con void us_internal_ssl_socket_context_on_end(struct us_internal_ssl_socket_context_t *context, struct us_internal_ssl_socket_t *(*on_end)(struct us_internal_ssl_socket_t *s)); +void us_internal_ssl_socket_context_on_connect_error(struct us_internal_ssl_socket_context_t *context, + struct us_internal_ssl_socket_t *(*on_connect_error)(struct us_internal_ssl_socket_t *s, int code)); + struct us_listen_socket_t *us_internal_ssl_socket_context_listen(struct us_internal_ssl_socket_context_t *context, const char *host, int port, int options, int socket_ext_size); struct us_internal_ssl_socket_t *us_internal_ssl_socket_context_connect(struct us_internal_ssl_socket_context_t *context, - const char *host, int port, int options, int socket_ext_size); + const char *host, int port, const char *source_host, int options, int socket_ext_size); int us_internal_ssl_socket_write(struct us_internal_ssl_socket_t *s, const char *data, int length, int msg_more); void us_internal_ssl_socket_timeout(struct us_internal_ssl_socket_t *s, unsigned int seconds); @@ -159,7 +174,7 @@ struct us_internal_ssl_socket_context_t *us_internal_ssl_socket_get_context(stru void *us_internal_ssl_socket_ext(struct us_internal_ssl_socket_t *s); int us_internal_ssl_socket_is_shut_down(struct us_internal_ssl_socket_t *s); void us_internal_ssl_socket_shutdown(struct us_internal_ssl_socket_t *s); -struct us_internal_ssl_socket_t *us_internal_ssl_socket_close(struct us_internal_ssl_socket_t *s); + struct us_internal_ssl_socket_t *us_internal_ssl_socket_context_adopt_socket(struct us_internal_ssl_socket_context_t *context, struct us_internal_ssl_socket_t *s, int ext_size); diff --git a/src/Vendor/uwebsockets/src/internal/networking/bsd.h b/src/Vendor/uwebsockets/src/internal/networking/bsd.h index 9c91ff7..23b4fb7 100644 --- a/src/Vendor/uwebsockets/src/internal/networking/bsd.h +++ b/src/Vendor/uwebsockets/src/internal/networking/bsd.h @@ -24,265 +24,61 @@ // here everything about the syscalls are inline-wrapped and included #ifdef _WIN32 +#ifndef NOMINMAX #define NOMINMAX -#include -#include +#endif +#include +#include #pragma comment(lib, "ws2_32.lib") -#include #define SETSOCKOPT_PTR_TYPE const char * #define LIBUS_SOCKET_ERROR INVALID_SOCKET #else #define _GNU_SOURCE -#include +/* For socklen_t */ #include -#include -#include -#include -#include -#include -#include -#include -#include #define SETSOCKOPT_PTR_TYPE int * #define LIBUS_SOCKET_ERROR -1 #endif -static inline LIBUS_SOCKET_DESCRIPTOR apple_no_sigpipe(LIBUS_SOCKET_DESCRIPTOR fd) { -#ifdef __APPLE__ - if (fd != LIBUS_SOCKET_ERROR) { - int no_sigpipe = 1; - setsockopt(fd, SOL_SOCKET, SO_NOSIGPIPE, &no_sigpipe, sizeof(int)); - } -#endif - return fd; -} - -static inline LIBUS_SOCKET_DESCRIPTOR bsd_set_nonblocking(LIBUS_SOCKET_DESCRIPTOR fd) { -#ifdef _WIN32 - /* Libuv will set windows sockets as non-blocking */ -#else - fcntl(fd, F_SETFL, fcntl(fd, F_GETFL, 0) | O_NONBLOCK); -#endif - return fd; -} - -static inline void bsd_socket_nodelay(LIBUS_SOCKET_DESCRIPTOR fd, int enabled) { - setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, (void *) &enabled, sizeof(enabled)); -} - -static inline void bsd_socket_flush(LIBUS_SOCKET_DESCRIPTOR fd) { - // Linux TCP_CORK has the same underlying corking mechanism as with MSG_MORE -#ifdef TCP_CORK - int enabled = 0; - setsockopt(fd, IPPROTO_TCP, TCP_CORK, &enabled, sizeof(int)); -#endif -} - -static inline LIBUS_SOCKET_DESCRIPTOR bsd_create_socket(int domain, int type, int protocol) { - // returns INVALID_SOCKET on error - int flags = 0; -#if defined(SOCK_CLOEXEC) && defined(SOCK_NONBLOCK) - flags = SOCK_CLOEXEC | SOCK_NONBLOCK; -#endif - - LIBUS_SOCKET_DESCRIPTOR created_fd = socket(domain, type | flags, protocol); - - return bsd_set_nonblocking(apple_no_sigpipe(created_fd)); -} - -static inline void bsd_close_socket(LIBUS_SOCKET_DESCRIPTOR fd) { -#ifdef _WIN32 - closesocket(fd); -#else - close(fd); -#endif -} - -static inline void bsd_shutdown_socket(LIBUS_SOCKET_DESCRIPTOR fd) { -#ifdef _WIN32 - shutdown(fd, SD_SEND); -#else - shutdown(fd, SHUT_WR); -#endif -} - struct bsd_addr_t { struct sockaddr_storage mem; socklen_t len; char *ip; int ip_length; + int port; }; -static inline void internal_finalize_bsd_addr(struct bsd_addr_t *addr) { - // parse, so to speak, the address - if (addr->mem.ss_family == AF_INET6) { - addr->ip = (char *) &((struct sockaddr_in6 *) addr)->sin6_addr; - addr->ip_length = sizeof(struct in6_addr); - } else if (addr->mem.ss_family == AF_INET) { - addr->ip = (char *) &((struct sockaddr_in *) addr)->sin_addr; - addr->ip_length = sizeof(struct in_addr); - } else { - addr->ip_length = 0; - } -} +LIBUS_SOCKET_DESCRIPTOR apple_no_sigpipe(LIBUS_SOCKET_DESCRIPTOR fd); +LIBUS_SOCKET_DESCRIPTOR bsd_set_nonblocking(LIBUS_SOCKET_DESCRIPTOR fd); +void bsd_socket_nodelay(LIBUS_SOCKET_DESCRIPTOR fd, int enabled); +void bsd_socket_flush(LIBUS_SOCKET_DESCRIPTOR fd); +LIBUS_SOCKET_DESCRIPTOR bsd_create_socket(int domain, int type, int protocol); -static inline int bsd_socket_addr(LIBUS_SOCKET_DESCRIPTOR fd, struct bsd_addr_t *addr) { - addr->len = sizeof(addr->mem); - if (getpeername(fd, (struct sockaddr *) &addr->mem, &addr->len)) { - return -1; - } - internal_finalize_bsd_addr(addr); - return 0; -} +void bsd_close_socket(LIBUS_SOCKET_DESCRIPTOR fd); +void bsd_shutdown_socket(LIBUS_SOCKET_DESCRIPTOR fd); +void bsd_shutdown_socket_read(LIBUS_SOCKET_DESCRIPTOR fd); -static inline char *bsd_addr_get_ip(struct bsd_addr_t *addr) { - return addr->ip; -} +void internal_finalize_bsd_addr(struct bsd_addr_t *addr); -static inline int bsd_addr_get_ip_length(struct bsd_addr_t *addr) { - return addr->ip_length; -} +int bsd_local_addr(LIBUS_SOCKET_DESCRIPTOR fd, struct bsd_addr_t *addr); +int bsd_remote_addr(LIBUS_SOCKET_DESCRIPTOR fd, struct bsd_addr_t *addr); + +char *bsd_addr_get_ip(struct bsd_addr_t *addr); +int bsd_addr_get_ip_length(struct bsd_addr_t *addr); + +int bsd_addr_get_port(struct bsd_addr_t *addr); // called by dispatch_ready_poll -static inline LIBUS_SOCKET_DESCRIPTOR bsd_accept_socket(LIBUS_SOCKET_DESCRIPTOR fd, struct bsd_addr_t *addr) { - LIBUS_SOCKET_DESCRIPTOR accepted_fd; - addr->len = sizeof(addr->mem); +LIBUS_SOCKET_DESCRIPTOR bsd_accept_socket(LIBUS_SOCKET_DESCRIPTOR fd, struct bsd_addr_t *addr); -#if defined(SOCK_CLOEXEC) && defined(SOCK_NONBLOCK) - // Linux, FreeBSD - accepted_fd = accept4(fd, (struct sockaddr *) addr, &addr->len, SOCK_CLOEXEC | SOCK_NONBLOCK); -#else - // Windows, OS X - accepted_fd = accept(fd, (struct sockaddr *) addr, &addr->len); - -#endif - - internal_finalize_bsd_addr(addr); - - return bsd_set_nonblocking(apple_no_sigpipe(accepted_fd)); -} - -static inline int bsd_recv(LIBUS_SOCKET_DESCRIPTOR fd, void *buf, int length, int flags) { - return recv(fd, buf, length, flags); -} - -static inline int bsd_send(LIBUS_SOCKET_DESCRIPTOR fd, const char *buf, int length, int msg_more) { - - // MSG_MORE (Linux), MSG_PARTIAL (Windows), TCP_NOPUSH (BSD) - -#ifndef MSG_NOSIGNAL -#define MSG_NOSIGNAL 0 -#endif - -#ifdef MSG_MORE - - // for Linux we do not want signals - return send(fd, buf, length, (msg_more * MSG_MORE) | MSG_NOSIGNAL); - -#else - - // use TCP_NOPUSH - - return send(fd, buf, length, MSG_NOSIGNAL); - -#endif -} - -static inline int bsd_would_block() { -#ifdef _WIN32 - return WSAGetLastError() == WSAEWOULDBLOCK; -#else - return errno == EWOULDBLOCK;// || errno == EAGAIN; -#endif -} +int bsd_recv(LIBUS_SOCKET_DESCRIPTOR fd, void *buf, int length, int flags); +int bsd_send(LIBUS_SOCKET_DESCRIPTOR fd, const char *buf, int length, int msg_more); +int bsd_would_block(); // return LIBUS_SOCKET_ERROR or the fd that represents listen socket // listen both on ipv6 and ipv4 -static inline LIBUS_SOCKET_DESCRIPTOR bsd_create_listen_socket(const char *host, int port, int options) { - struct addrinfo hints, *result; - memset(&hints, 0, sizeof(struct addrinfo)); +LIBUS_SOCKET_DESCRIPTOR bsd_create_listen_socket(const char *host, int port, int options); - hints.ai_flags = AI_PASSIVE; - hints.ai_family = AF_UNSPEC; - hints.ai_socktype = SOCK_STREAM; - - char port_string[16]; - snprintf(port_string, 16, "%d", port); - - if (getaddrinfo(host, port_string, &hints, &result)) { - return LIBUS_SOCKET_ERROR; - } - - LIBUS_SOCKET_DESCRIPTOR listenFd = LIBUS_SOCKET_ERROR; - struct addrinfo *listenAddr; - for (struct addrinfo *a = result; a && listenFd == LIBUS_SOCKET_ERROR; a = a->ai_next) { - if (a->ai_family == AF_INET6) { - listenFd = bsd_create_socket(a->ai_family, a->ai_socktype, a->ai_protocol); - listenAddr = a; - } - } - - for (struct addrinfo *a = result; a && listenFd == LIBUS_SOCKET_ERROR; a = a->ai_next) { - if (a->ai_family == AF_INET) { - listenFd = bsd_create_socket(a->ai_family, a->ai_socktype, a->ai_protocol); - listenAddr = a; - } - } - - if (listenFd == LIBUS_SOCKET_ERROR) { - freeaddrinfo(result); - return LIBUS_SOCKET_ERROR; - } - - /* Always enable SO_REUSEPORT and SO_REUSEADDR _unless_ options specify otherwise */ -#if defined(__linux) && defined(SO_REUSEPORT) - if (!(options & LIBUS_LISTEN_EXCLUSIVE_PORT)) { - int optval = 1; - setsockopt(listenFd, SOL_SOCKET, SO_REUSEPORT, &optval, sizeof(optval)); - } -#endif - - int enabled = 1; - setsockopt(listenFd, SOL_SOCKET, SO_REUSEADDR, (SETSOCKOPT_PTR_TYPE) &enabled, sizeof(enabled)); - -#ifdef IPV6_V6ONLY - int disabled = 0; - setsockopt(listenFd, IPPROTO_IPV6, IPV6_V6ONLY, (SETSOCKOPT_PTR_TYPE) &disabled, sizeof(disabled)); -#endif - - if (bind(listenFd, listenAddr->ai_addr, (socklen_t) listenAddr->ai_addrlen) || listen(listenFd, 512)) { - bsd_close_socket(listenFd); - freeaddrinfo(result); - return LIBUS_SOCKET_ERROR; - } - - freeaddrinfo(result); - return listenFd; -} - -static inline LIBUS_SOCKET_DESCRIPTOR bsd_create_connect_socket(const char *host, int port, int options) { - struct addrinfo hints, *result; - memset(&hints, 0, sizeof(struct addrinfo)); - hints.ai_family = AF_UNSPEC; - hints.ai_socktype = SOCK_STREAM; - - char port_string[16]; - snprintf(port_string, 16, "%d", port); - - if (getaddrinfo(host, port_string, &hints, &result) != 0) { - return LIBUS_SOCKET_ERROR; - } - - LIBUS_SOCKET_DESCRIPTOR fd = bsd_create_socket(result->ai_family, result->ai_socktype, result->ai_protocol); - if (fd == LIBUS_SOCKET_ERROR) { - freeaddrinfo(result); - return LIBUS_SOCKET_ERROR; - } - - connect(fd, result->ai_addr, (socklen_t) result->ai_addrlen); - freeaddrinfo(result); - - return fd; -} +LIBUS_SOCKET_DESCRIPTOR bsd_create_connect_socket(const char *host, int port, const char *source_host, int options); #endif // BSD_H diff --git a/src/Vendor/uwebsockets/src/loop.c b/src/Vendor/uwebsockets/src/loop.c index 4fc09aa..ae34fa8 100644 --- a/src/Vendor/uwebsockets/src/loop.c +++ b/src/Vendor/uwebsockets/src/loop.c @@ -1,5 +1,5 @@ /* - * Authored by Alex Hultman, 2018-2019. + * Authored by Alex Hultman, 2018-2021. * Intellectual property of third-party. * Licensed under the Apache License, Version 2.0 (the "License"); @@ -80,24 +80,50 @@ void us_internal_loop_unlink(struct us_loop_t *loop, struct us_socket_context_t /* This functions should never run recursively */ void us_internal_timer_sweep(struct us_loop_t *loop) { struct us_internal_loop_data_t *loop_data = &loop->data; + /* For all socket contexts in this loop */ for (loop_data->iterator = loop_data->head; loop_data->iterator; loop_data->iterator = loop_data->iterator->next) { struct us_socket_context_t *context = loop_data->iterator; - for (context->iterator = context->head; context->iterator; ) { - struct us_socket_t *s = context->iterator; - if (s->timeout && --(s->timeout) == 0) { + /* Update this context's 15-bit timestamp */ + context->timestamp = (context->timestamp + 1) & 0x7fff; - context->on_socket_timeout(s); + /* Update our 16-bit full timestamp (the needle in the haystack) */ + unsigned short needle = 0x8000 | context->timestamp; - /* Check for unlink / link */ - if (s == context->iterator) { - context->iterator = s->next; + /* Begin at head */ + struct us_socket_t *s = context->head; + while (s) { + /* Seek until end or timeout found (tightest loop) */ + while (1) { + /* We only read from 1 random cache line here */ + if (needle == s->timeout) { + break; } + + /* Did we reach the end without a find? */ + if ((s = s->next) == 0) { + goto next_context; + } + } + + /* Here we have a timeout to emit (slow path) */ + s->timeout = 0; + context->iterator = s; + + context->on_socket_timeout(s); + + /* Check for unlink / link (if the event handler did not modify the chain, we step 1) */ + if (s == context->iterator) { + s = s->next; } else { - context->iterator = s->next; + /* The iterator was changed by event handler */ + s = context->iterator; } } + /* We always store a 0 to context->iterator here since we are no longer iterating this context */ + next_context: + context->iterator = 0; } } @@ -136,7 +162,10 @@ void us_internal_loop_post(struct us_loop_t *loop) { void us_internal_dispatch_ready_poll(struct us_poll_t *p, int error, int events) { switch (us_internal_poll_type(p)) { case POLL_TYPE_CALLBACK: { + /* Let's just do this to clear the CodeQL alert */ + #ifndef LIBUS_USE_LIBUV us_internal_accept_poll_event(p); + #endif struct us_internal_callback_t *cb = (struct us_internal_callback_t *) p; cb->cb(cb->cb_expects_the_loop ? (struct us_internal_callback_t *) cb->loop : (struct us_internal_callback_t *) &cb->p); } @@ -147,15 +176,26 @@ void us_internal_dispatch_ready_poll(struct us_poll_t *p, int error, int events) if (us_poll_events(p) == LIBUS_SOCKET_WRITABLE) { struct us_socket_t *s = (struct us_socket_t *) p; - us_poll_change(p, s->context->loop, LIBUS_SOCKET_READABLE); + /* It is perfectly possible to come here with an error */ + if (error) { + /* Emit error, close without emitting on_close */ + s->context->on_connect_error(s, 0); + us_socket_close_connecting(0, s); + } else { + /* All sockets poll for readable */ + us_poll_change(p, s->context->loop, LIBUS_SOCKET_READABLE); - /* We always use nodelay */ - bsd_socket_nodelay(us_poll_fd(p), 1); + /* We always use nodelay */ + bsd_socket_nodelay(us_poll_fd(p), 1); - /* We are now a proper socket */ - us_internal_poll_set_type(p, POLL_TYPE_SOCKET); + /* We are now a proper socket */ + us_internal_poll_set_type(p, POLL_TYPE_SOCKET); - s->context->on_open(s, 1, 0, 0); + /* If we used a connection timeout we have to reset it here */ + us_socket_timeout(0, s, 0); + + s->context->on_open(s, 1, 0, 0); + } } else { struct us_listen_socket_t *listen_socket = (struct us_listen_socket_t *) p; struct bsd_addr_t addr; @@ -169,11 +209,11 @@ void us_internal_dispatch_ready_poll(struct us_poll_t *p, int error, int events) /* Todo: stop timer if any */ do { - struct us_poll_t *p = us_create_poll(us_socket_context(0, &listen_socket->s)->loop, 0, sizeof(struct us_socket_t) - sizeof(struct us_poll_t) + listen_socket->socket_ext_size); - us_poll_init(p, client_fd, POLL_TYPE_SOCKET); - us_poll_start(p, listen_socket->s.context->loop, LIBUS_SOCKET_READABLE); + struct us_poll_t *accepted_p = us_create_poll(us_socket_context(0, &listen_socket->s)->loop, 0, sizeof(struct us_socket_t) - sizeof(struct us_poll_t) + listen_socket->socket_ext_size); + us_poll_init(accepted_p, client_fd, POLL_TYPE_SOCKET); + us_poll_start(accepted_p, listen_socket->s.context->loop, LIBUS_SOCKET_READABLE); - struct us_socket_t *s = (struct us_socket_t *) p; + struct us_socket_t *s = (struct us_socket_t *) accepted_p; s->context = listen_socket->s.context; @@ -201,7 +241,8 @@ void us_internal_dispatch_ready_poll(struct us_poll_t *p, int error, int events) /* Such as epollerr epollhup */ if (error) { - s = us_socket_close(0, s); + /* Todo: decide what code we give here */ + s = us_socket_close(0, s, 0, NULL); return; } @@ -235,14 +276,16 @@ void us_internal_dispatch_ready_poll(struct us_poll_t *p, int error, int events) } else if (!length) { if (us_socket_is_shut_down(0, s)) { /* We got FIN back after sending it */ - s = us_socket_close(0, s); + /* Todo: We should give "CLEAN SHUTDOWN" as reason here */ + s = us_socket_close(0, s, 0, NULL); } else { /* We got FIN, so stop polling for readable */ us_poll_change(&s->p, us_socket_context(0, s)->loop, us_poll_events(&s->p) & LIBUS_SOCKET_WRITABLE); s = s->context->on_end(s); } } else if (length == LIBUS_SOCKET_ERROR && !bsd_would_block()) { - s = us_socket_close(0, s); + /* Todo: decide also here what kind of reason we should give */ + s = us_socket_close(0, s, 0, NULL); } } } diff --git a/src/Vendor/uwebsockets/src/socket.c b/src/Vendor/uwebsockets/src/socket.c index 919f4a8..683be7c 100644 --- a/src/Vendor/uwebsockets/src/socket.c +++ b/src/Vendor/uwebsockets/src/socket.c @@ -1,5 +1,5 @@ /* - * Authored by Alex Hultman, 2018-2019. + * Authored by Alex Hultman, 2018-2021. * Intellectual property of third-party. * Licensed under the Apache License, Version 2.0 (the "License"); @@ -18,12 +18,27 @@ #include "libusockets.h" #include "internal/internal.h" #include +#include /* Shared with SSL */ +int us_socket_local_port(int ssl, struct us_socket_t *s) { + struct bsd_addr_t addr; + if (bsd_local_addr(us_poll_fd(&s->p), &addr)) { + return -1; + } else { + return bsd_addr_get_port(&addr); + } +} + +void us_socket_shutdown_read(int ssl, struct us_socket_t *s) { + /* This syscall is idempotent so no extra check is needed */ + bsd_shutdown_socket_read(us_poll_fd((struct us_poll_t *) s)); +} + void us_socket_remote_address(int ssl, struct us_socket_t *s, char *buf, int *length) { struct bsd_addr_t addr; - if (bsd_socket_addr(us_poll_fd(&s->p), &addr) || *length < bsd_addr_get_ip_length(&addr)) { + if (bsd_remote_addr(us_poll_fd(&s->p), &addr) || *length < bsd_addr_get_ip_length(&addr)) { *length = 0; } else { *length = bsd_addr_get_ip_length(&addr); @@ -37,8 +52,7 @@ struct us_socket_context_t *us_socket_context(int ssl, struct us_socket_t *s) { void us_socket_timeout(int ssl, struct us_socket_t *s, unsigned int seconds) { if (seconds) { - unsigned short timeout_sweeps = (unsigned short) (0.5f + ((float) seconds) / ((float) LIBUS_TIMEOUT_GRANULARITY)); - s->timeout = timeout_sweeps ? timeout_sweeps : 1; + s->timeout = 0x8000 | (s->context->timestamp + (seconds >> 2)); } else { s->timeout = 0; } @@ -54,8 +68,61 @@ int us_socket_is_closed(int ssl, struct us_socket_t *s) { return s->prev == (struct us_socket_t *) s->context; } +int us_socket_is_established(int ssl, struct us_socket_t *s) { + /* Everything that is not POLL_TYPE_SEMI_SOCKET is established */ + return us_internal_poll_type((struct us_poll_t *) s) != POLL_TYPE_SEMI_SOCKET; +} + +/* Exactly the same as us_socket_close but does not emit on_close event */ +struct us_socket_t *us_socket_close_connecting(int ssl, struct us_socket_t *s) { + if (!us_socket_is_closed(0, s)) { + us_internal_socket_context_unlink(s->context, s); + us_poll_stop((struct us_poll_t *) s, s->context->loop); + bsd_close_socket(us_poll_fd((struct us_poll_t *) s)); + + /* Link this socket to the close-list and let it be deleted after this iteration */ + s->next = s->context->loop->data.closed_head; + s->context->loop->data.closed_head = s; + + /* Any socket with prev = context is marked as closed */ + s->prev = (struct us_socket_t *) s->context; + + //return s->context->on_close(s, code, reason); + } + return s; +} + +/* Same as above but emits on_close */ +struct us_socket_t *us_socket_close(int ssl, struct us_socket_t *s, int code, void *reason) { + if (!us_socket_is_closed(0, s)) { + us_internal_socket_context_unlink(s->context, s); + us_poll_stop((struct us_poll_t *) s, s->context->loop); + bsd_close_socket(us_poll_fd((struct us_poll_t *) s)); + + /* Link this socket to the close-list and let it be deleted after this iteration */ + s->next = s->context->loop->data.closed_head; + s->context->loop->data.closed_head = s; + + /* Any socket with prev = context is marked as closed */ + s->prev = (struct us_socket_t *) s->context; + + return s->context->on_close(s, code, reason); + } + return s; +} + /* Not shared with SSL */ +void *us_socket_get_native_handle(int ssl, struct us_socket_t *s) { +#ifndef LIBUS_NO_SSL + if (ssl) { + return us_internal_ssl_socket_get_native_handle((struct us_internal_ssl_socket_t *) s); + } +#endif + + return (void *) (uintptr_t) us_poll_fd((struct us_poll_t *) s); +} + int us_socket_write(int ssl, struct us_socket_t *s, const char *data, int length, int msg_more) { #ifndef LIBUS_NO_SSL if (ssl) { @@ -86,30 +153,6 @@ void *us_socket_ext(int ssl, struct us_socket_t *s) { return s + 1; } -struct us_socket_t *us_socket_close(int ssl, struct us_socket_t *s) { -#ifndef LIBUS_NO_SSL - if (ssl) { - return (struct us_socket_t *) us_internal_ssl_socket_close((struct us_internal_ssl_socket_t *) s); - } -#endif - - if (!us_socket_is_closed(0, s)) { - us_internal_socket_context_unlink(s->context, s); - us_poll_stop((struct us_poll_t *) s, s->context->loop); - bsd_close_socket(us_poll_fd((struct us_poll_t *) s)); - - /* Link this socket to the close-list and let it be deleted after this iteration */ - s->next = s->context->loop->data.closed_head; - s->context->loop->data.closed_head = s; - - /* Any socket with prev = context is marked as closed */ - s->prev = (struct us_socket_t *) s->context; - - return s->context->on_close(s); - } - return s; -} - int us_socket_is_shut_down(int ssl, struct us_socket_t *s) { #ifndef LIBUS_NO_SSL if (ssl) { diff --git a/src/src/BrokenithmServer.cpp b/src/src/BrokenithmServer.cpp index 6522a47..f256141 100644 --- a/src/src/BrokenithmServer.cpp +++ b/src/src/BrokenithmServer.cpp @@ -1,3 +1,5 @@ +#pragma warning(disable : 4267 4138) + #include "BrokenithmServer.hpp" #include @@ -53,11 +55,12 @@ uint64_t BrokenithmServer::get_controller_state() struct ConnectionData { + typedef uWS::WebSocket ConnectionDataSocket; static int s_connection_counter; static std::vector s_connections; int m_uid; - uWS::WebSocket *m_websocket; + ConnectionDataSocket *m_websocket; void static close_all_connections(); @@ -75,7 +78,7 @@ struct ConnectionData s_connections[m_uid] = nullptr; } - void save_socket(uWS::WebSocket *websocket) + void save_socket(ConnectionDataSocket *websocket) { m_websocket = websocket; } @@ -149,12 +152,17 @@ void BrokenithmServer::Impl::start_server() }) .ws( "/ws", - {uWS::DISABLED, - 16 * 1024 * 1024, - 10, - 16 * 1024 * 1024, + {uWS::DISABLED, // compression + 16 * 1024 * 1024, // maxPayloadLength + 16, // idleTimeout + 16 * 1024 * 1024, // maxBackpressure + false, // closeOnBackpressureLimit + false, // resetIdleTimeoutOnSend + true, // sendPingsAutomatically + 0, // maxLifetime + nullptr, // upgrade // Open handler - [](auto *ws, auto *req) { + [](auto *ws) { spdlog::info("Controller ID {} connected", ((ConnectionData *)ws->getUserData())->m_uid); ((ConnectionData *)ws->getUserData())->save_socket(ws); }, @@ -181,26 +189,21 @@ void BrokenithmServer::Impl::start_server() } } }, - // Drain handler - [](auto *ws) {}, - // Ping handler - [](auto *ws) {}, - // Pong handler - [](auto *ws) {}, + nullptr, // Drain handler + nullptr, // Ping handler + nullptr, // Pong handler // Close handler [](auto *ws, int code, std::string_view message) { spdlog::info("Controller ID {} disconnected", ((ConnectionData *)ws->getUserData())->m_uid); }}) - .listen( - m_port, - [&](auto *token) { - if (token) - { - spdlog::info("Server listening at port {}", m_port); - m_running = true; - m_uws_socket_token = token; - } - }) + .listen(m_port, [&](auto *token) { + if (token) + { + spdlog::info("Server listening at port {}", m_port); + m_running = true; + m_uws_socket_token = token; + } + }) .run(); m_running = false;