update uwebsockets version

gh-actions
4yn 2021-05-05 11:58:59 +08:00
parent 5531880481
commit 949f968667
45 changed files with 3870 additions and 2838 deletions

View File

@ -1,6 +1,6 @@
# Repackaged uWebsockets library
Source was obtained from [uSockets](https://github.com/uNetworking/uSockets) and [uWebSockets](https://github.com/uNetworking/uWebSockets) and repackaged with a build system.
Source was obtained from [uSockets v0.7.1](https://github.com/uNetworking/uSockets/tree/v0.7.1) and [uWebSockets v19.2.0](https://github.com/uNetworking/uWebSockets/tree/v19.2.0) and repackaged with build system configuration.
# Original uWebsockets README.md

View File

@ -29,13 +29,13 @@
/* Define what a socket descriptor is based on platform */
#ifdef _WIN32
#ifndef NOMINMAX
#define NOMINMAX
#include <WinSock2.h>
#endif
#include <winsock2.h>
#define LIBUS_SOCKET_DESCRIPTOR SOCKET
#define WIN32_EXPORT __declspec(dllexport)
#define alignas(x) __declspec(align(x))
#else
#include <stdalign.h>
#define LIBUS_SOCKET_DESCRIPTOR int
#define WIN32_EXPORT
#endif
@ -84,9 +84,20 @@ struct us_socket_context_options_t {
const char *passphrase;
const char *dh_params_file_name;
const char *ca_file_name;
int ssl_prefer_low_memory_usage;
int ssl_prefer_low_memory_usage; /* Todo: rename to prefer_low_memory_usage and apply for TCP as well */
};
/* Return 15-bit timestamp for this context */
WIN32_EXPORT unsigned short us_socket_context_timestamp(int ssl, struct us_socket_context_t *context);
/* Adds SNI domain and cert in asn1 format */
WIN32_EXPORT void us_socket_context_add_server_name(int ssl, struct us_socket_context_t *context, const char *hostname_pattern, struct us_socket_context_options_t options);
WIN32_EXPORT void us_socket_context_remove_server_name(int ssl, struct us_socket_context_t *context, const char *hostname_pattern);
WIN32_EXPORT void us_socket_context_on_server_name(int ssl, struct us_socket_context_t *context, void (*cb)(struct us_socket_context_t *, const char *hostname));
/* Returns the underlying SSL native handle, such as SSL_CTX or nullptr */
WIN32_EXPORT void *us_socket_context_get_native_handle(int ssl, struct us_socket_context_t *context);
/* A socket context holds shared callbacks and user data extension for associated sockets */
WIN32_EXPORT struct us_socket_context_t *us_create_socket_context(int ssl, struct us_loop_t *loop,
int ext_size, struct us_socket_context_options_t options);
@ -98,13 +109,16 @@ WIN32_EXPORT void us_socket_context_free(int ssl, struct us_socket_context_t *co
WIN32_EXPORT void us_socket_context_on_open(int ssl, struct us_socket_context_t *context,
struct us_socket_t *(*on_open)(struct us_socket_t *s, int is_client, char *ip, int ip_length));
WIN32_EXPORT void us_socket_context_on_close(int ssl, struct us_socket_context_t *context,
struct us_socket_t *(*on_close)(struct us_socket_t *s));
struct us_socket_t *(*on_close)(struct us_socket_t *s, int code, void *reason));
WIN32_EXPORT void us_socket_context_on_data(int ssl, struct us_socket_context_t *context,
struct us_socket_t *(*on_data)(struct us_socket_t *s, char *data, int length));
WIN32_EXPORT void us_socket_context_on_writable(int ssl, struct us_socket_context_t *context,
struct us_socket_t *(*on_writable)(struct us_socket_t *s));
WIN32_EXPORT void us_socket_context_on_timeout(int ssl, struct us_socket_context_t *context,
struct us_socket_t *(*on_timeout)(struct us_socket_t *s));
/* This one is only used for when a connecting socket fails in a late stage. */
WIN32_EXPORT void us_socket_context_on_connect_error(int ssl, struct us_socket_context_t *context,
struct us_socket_t *(*on_connect_error)(struct us_socket_t *s, int code));
/* Emitted when a socket has been half-closed */
WIN32_EXPORT void us_socket_context_on_end(int ssl, struct us_socket_context_t *context, struct us_socket_t *(*on_end)(struct us_socket_t *s));
@ -119,9 +133,18 @@ WIN32_EXPORT struct us_listen_socket_t *us_socket_context_listen(int ssl, struct
/* listen_socket.c/.h */
WIN32_EXPORT void us_listen_socket_close(int ssl, struct us_listen_socket_t *ls);
/* Land in on_open or on_close or return null or return socket */
/* Land in on_open or on_connection_error or return null or return socket */
WIN32_EXPORT struct us_socket_t *us_socket_context_connect(int ssl, struct us_socket_context_t *context,
const char *host, int port, int options, int socket_ext_size);
const char *host, int port, const char *source_host, int options, int socket_ext_size);
/* Is this socket established? Can be used to check if a connecting socket has fired the on_open event yet.
* Can also be used to determine if a socket is a listen_socket or not, but you probably know that already. */
WIN32_EXPORT int us_socket_is_established(int ssl, struct us_socket_t *s);
/* Cancel a connecting socket. Can be used together with us_socket_timeout to limit connection times.
* Entirely destroys the socket - this function works like us_socket_close but does not trigger on_close event since
* you never got the on_open event first. */
WIN32_EXPORT struct us_socket_t *us_socket_close_connecting(int ssl, struct us_socket_t *s);
/* Returns the loop for this socket context. */
WIN32_EXPORT struct us_loop_t *us_socket_context_loop(int ssl, struct us_socket_context_t *context);
@ -189,6 +212,10 @@ WIN32_EXPORT struct us_poll_t *us_poll_resize(struct us_poll_t *p, struct us_loo
/* Public interfaces for sockets */
/* Returns the underlying native handle for a socket, such as SSL or file descriptor.
* In the case of file descriptor, the value of pointer is fd. */
WIN32_EXPORT void *us_socket_get_native_handle(int ssl, struct us_socket_t *s);
/* Write up to length bytes of data. Returns actual bytes written.
* Will call the on_writable callback of active socket context on failure to write everything off in one go.
* Set hint msg_more if you have more immediate data to write. */
@ -210,6 +237,11 @@ WIN32_EXPORT void us_socket_flush(int ssl, struct us_socket_t *s);
/* Shuts down the connection by sending FIN and/or close_notify */
WIN32_EXPORT void us_socket_shutdown(int ssl, struct us_socket_t *s);
/* Shuts down the connection in terms of read, meaning next event loop
* iteration will catch the socket being closed. Can be used to defer closing
* to next event loop iteration. */
WIN32_EXPORT void us_socket_shutdown_read(int ssl, struct us_socket_t *s);
/* Returns whether the socket has been shut down or not */
WIN32_EXPORT int us_socket_is_shut_down(int ssl, struct us_socket_t *s);
@ -217,7 +249,10 @@ WIN32_EXPORT int us_socket_is_shut_down(int ssl, struct us_socket_t *s);
WIN32_EXPORT int us_socket_is_closed(int ssl, struct us_socket_t *s);
/* Immediately closes the socket */
WIN32_EXPORT struct us_socket_t *us_socket_close(int ssl, struct us_socket_t *s);
WIN32_EXPORT struct us_socket_t *us_socket_close(int ssl, struct us_socket_t *s, int code, void *reason);
/* Returns local port or -1 on failure. */
WIN32_EXPORT int us_socket_local_port(int ssl, struct us_socket_t *s);
/* Copy remote (IP) address of socket, or fail with zero length. */
WIN32_EXPORT void us_socket_remote_address(int ssl, struct us_socket_t *s, char *buf, int *length);
@ -230,7 +265,7 @@ WIN32_EXPORT void us_socket_remote_address(int ssl, struct us_socket_t *s, char
#if !defined(LIBUS_USE_EPOLL) && !defined(LIBUS_USE_LIBUV) && !defined(LIBUS_USE_GCD) && !defined(LIBUS_USE_KQUEUE)
#if defined(_WIN32)
#define LIBUS_USE_LIBUV
#elif defined(__APPLE__)
#elif defined(__APPLE__) || defined(__FreeBSD__)
#define LIBUS_USE_KQUEUE
#else
#define LIBUS_USE_EPOLL

View File

@ -1,5 +1,5 @@
/*
* Authored by Alex Hultman, 2018-2019.
* Authored by Alex Hultman, 2018-2020.
* Intellectual property of third-party.
* Licensed under the Apache License, Version 2.0 (the "License");
@ -25,42 +25,104 @@
#include "HttpResponse.h"
#include "WebSocketContext.h"
#include "WebSocket.h"
#include "WebSocketExtensions.h"
#include "WebSocketHandshake.h"
#include "PerMessageDeflate.h"
namespace uWS {
/* Compress options (really more like PerMessageDeflateOptions) */
enum CompressOptions {
/* Compression disabled */
DISABLED = 0,
/* We compress using a shared non-sliding window. No added memory usage, worse compression. */
SHARED_COMPRESSOR = 1,
/* We compress using a dedicated sliding window. Major memory usage added, better compression of similarly repeated messages. */
DEDICATED_COMPRESSOR = 2
};
/* This one matches us_socket_context_options_t but has default values */
struct SocketContextOptions {
const char *key_file_name = nullptr;
const char *cert_file_name = nullptr;
const char *passphrase = nullptr;
const char *dh_params_file_name = nullptr;
const char *ca_file_name = nullptr;
int ssl_prefer_low_memory_usage = 0;
/* Conversion operator used internally */
operator struct us_socket_context_options_t() const {
struct us_socket_context_options_t socket_context_options;
memcpy(&socket_context_options, this, sizeof(SocketContextOptions));
return socket_context_options;
}
};
static_assert(sizeof(struct us_socket_context_options_t) == sizeof(SocketContextOptions), "Mismatching uSockets/uWebSockets ABI");
template <bool SSL>
struct TemplatedApp {
private:
/* The app always owns at least one http context, but creates websocket contexts on demand */
HttpContext<SSL> *httpContext;
std::vector<WebSocketContext<SSL, true> *> webSocketContexts;
std::vector<WebSocketContext<SSL, true, int> *> webSocketContexts;
public:
/* Server name */
TemplatedApp &&addServerName(std::string hostname_pattern, SocketContextOptions options = {}) {
us_socket_context_add_server_name(SSL, (struct us_socket_context_t *) httpContext, hostname_pattern.c_str(), options);
return std::move(*this);
}
TemplatedApp &&removeServerName(std::string hostname_pattern) {
us_socket_context_remove_server_name(SSL, (struct us_socket_context_t *) httpContext, hostname_pattern.c_str());
return std::move(*this);
}
TemplatedApp &&missingServerName(MoveOnlyFunction<void(const char *hostname)> handler) {
if (!constructorFailed()) {
httpContext->getSocketContextData()->missingServerNameHandler = std::move(handler);
us_socket_context_on_server_name(SSL, (struct us_socket_context_t *) httpContext, [](struct us_socket_context_t *context, const char *hostname) {
/* This is the only requirements of being friends with HttpContextData */
HttpContext<SSL> *httpContext = (HttpContext<SSL> *) context;
httpContext->getSocketContextData()->missingServerNameHandler(hostname);
});
}
return std::move(*this);
}
/* Returns the SSL_CTX of this app, or nullptr. */
void *getNativeHandle() {
return us_socket_context_get_native_handle(SSL, (struct us_socket_context_t *) httpContext);
}
/* Attaches a "filter" function to track socket connections/disconnections */
void filter(fu2::unique_function<void(HttpResponse<SSL> *, int)> &&filterHandler) {
void filter(MoveOnlyFunction<void(HttpResponse<SSL> *, int)> &&filterHandler) {
httpContext->filter(std::move(filterHandler));
}
/* Publishes a message to all websocket contexts */
/* Publishes a message to all websocket contexts - conceptually as if publishing to the one single
* TopicTree of this app (technically there are many TopicTrees, however the concept is that one
* app has one conceptual Topic tree) */
void publish(std::string_view topic, std::string_view message, OpCode opCode, bool compress = false) {
for (auto *webSocketContext : webSocketContexts) {
webSocketContext->getExt()->publish(topic, message, opCode, compress);
}
}
/* Returns number of subscribers for this topic, or 0 for failure.
* This function should probably be optimized a lot in future releases,
* it could be O(1) with a hash map of fullnames and their counts. */
unsigned int numSubscribers(std::string_view topic) {
unsigned int subscribers = 0;
for (auto *webSocketContext : webSocketContexts) {
auto *webSocketContextData = webSocketContext->getExt();
Topic *t = webSocketContextData->topicTree.lookupTopic(topic);
if (t) {
subscribers += t->subs.size();
}
}
return subscribers;
}
~TemplatedApp() {
/* Let's just put everything here */
if (httpContext) {
@ -84,42 +146,79 @@ public:
webSocketContexts = std::move(other.webSocketContexts);
}
TemplatedApp(us_socket_context_options_t options = {}) {
httpContext = uWS::HttpContext<SSL>::create(uWS::Loop::get(), options);
TemplatedApp(SocketContextOptions options = {}) {
httpContext = HttpContext<SSL>::create(Loop::get(), options);
}
bool constructorFailed() {
return !httpContext;
}
template <typename UserData>
struct WebSocketBehavior {
/* Disabled compression by default - probably a bad default */
CompressOptions compression = DISABLED;
int maxPayloadLength = 16 * 1024;
int idleTimeout = 120;
int maxBackpressure = 1 * 1024 * 1204;
fu2::unique_function<void(uWS::WebSocket<SSL, true> *, HttpRequest *)> open = nullptr;
fu2::unique_function<void(uWS::WebSocket<SSL, true> *, std::string_view, uWS::OpCode)> message = nullptr;
fu2::unique_function<void(uWS::WebSocket<SSL, true> *)> drain = nullptr;
fu2::unique_function<void(uWS::WebSocket<SSL, true> *)> ping = nullptr;
fu2::unique_function<void(uWS::WebSocket<SSL, true> *)> pong = nullptr;
fu2::unique_function<void(uWS::WebSocket<SSL, true> *, int, std::string_view)> close = nullptr;
/* Maximum message size we can receive */
unsigned int maxPayloadLength = 16 * 1024;
/* 2 minutes timeout is good */
unsigned short idleTimeout = 120;
/* 64kb backpressure is probably good */
unsigned int maxBackpressure = 64 * 1024;
bool closeOnBackpressureLimit = false;
/* This one depends on kernel timeouts and is a bad default */
bool resetIdleTimeoutOnSend = false;
/* A good default, esp. for newcomers */
bool sendPingsAutomatically = true;
/* Maximum socket lifetime in seconds before forced closure (defaults to disabled) */
unsigned short maxLifetime = 0;
MoveOnlyFunction<void(HttpResponse<SSL> *, HttpRequest *, struct us_socket_context_t *)> upgrade = nullptr;
MoveOnlyFunction<void(WebSocket<SSL, true, UserData> *)> open = nullptr;
MoveOnlyFunction<void(WebSocket<SSL, true, UserData> *, std::string_view, OpCode)> message = nullptr;
MoveOnlyFunction<void(WebSocket<SSL, true, UserData> *)> drain = nullptr;
MoveOnlyFunction<void(WebSocket<SSL, true, UserData> *, std::string_view)> ping = nullptr;
MoveOnlyFunction<void(WebSocket<SSL, true, UserData> *, std::string_view)> pong = nullptr;
MoveOnlyFunction<void(WebSocket<SSL, true, UserData> *, int, std::string_view)> close = nullptr;
};
template <typename UserData>
TemplatedApp &&ws(std::string pattern, WebSocketBehavior &&behavior) {
TemplatedApp &&ws(std::string pattern, WebSocketBehavior<UserData> &&behavior) {
/* Don't compile if alignment rules cannot be satisfied */
static_assert(alignof(UserData) <= LIBUS_EXT_ALIGNMENT,
"µWebSockets cannot satisfy UserData alignment requirements. You need to recompile µSockets with LIBUS_EXT_ALIGNMENT adjusted accordingly.");
if (!httpContext) {
return std::move(*this);
}
/* Terminate on misleading idleTimeout values */
if (behavior.idleTimeout && behavior.idleTimeout < 8) {
std::cerr << "Error: idleTimeout must be either 0 or greater than 8!" << std::endl;
std::terminate();
}
if (behavior.idleTimeout % 4) {
std::cerr << "Warning: idleTimeout should be a multiple of 4!" << std::endl;
}
/* Every route has its own websocket context with its own behavior and user data type */
auto *webSocketContext = WebSocketContext<SSL, true>::create(Loop::get(), (us_socket_context_t *) httpContext);
auto *webSocketContext = WebSocketContext<SSL, true, UserData>::create(Loop::get(), (us_socket_context_t *) httpContext);
/* Add all other WebSocketContextData to this new WebSocketContextData */
for (WebSocketContext<SSL, true, int> *adjacentWebSocketContext : webSocketContexts) {
webSocketContext->getExt()->adjacentWebSocketContextDatas.push_back(adjacentWebSocketContext->getExt());
}
/* Add this WebSocketContextData to all other WebSocketContextData */
for (WebSocketContext<SSL, true, int> *adjacentWebSocketContext : webSocketContexts) {
adjacentWebSocketContext->getExt()->adjacentWebSocketContextDatas.push_back((WebSocketContextData<SSL, int> *) webSocketContext->getExt());
}
/* We need to clear this later on */
webSocketContexts.push_back(webSocketContext);
webSocketContexts.push_back((WebSocketContext<SSL, true, int> *) webSocketContext);
/* Quick fix to disable any compression if set */
#ifdef UWS_NO_ZLIB
behavior.compression = uWS::DISABLED;
behavior.compression = DISABLED;
#endif
/* If we are the first one to use compression, initialize it */
@ -130,14 +229,15 @@ public:
if (!loopData->zlibContext) {
loopData->zlibContext = new ZlibContext;
loopData->inflationStream = new InflationStream;
loopData->deflationStream = new DeflationStream;
loopData->deflationStream = new DeflationStream(CompressOptions::DEDICATED_COMPRESSOR);
}
}
/* Copy all handlers */
webSocketContext->getExt()->openHandler = std::move(behavior.open);
webSocketContext->getExt()->messageHandler = std::move(behavior.message);
webSocketContext->getExt()->drainHandler = std::move(behavior.drain);
webSocketContext->getExt()->closeHandler = std::move([closeHandler = std::move(behavior.close)](WebSocket<SSL, true> *ws, int code, std::string_view message) mutable {
webSocketContext->getExt()->closeHandler = std::move([closeHandler = std::move(behavior.close)](WebSocket<SSL, true, UserData> *ws, int code, std::string_view message) mutable {
if (closeHandler) {
closeHandler(ws, code, message);
}
@ -150,99 +250,30 @@ public:
/* Copy settings */
webSocketContext->getExt()->maxPayloadLength = behavior.maxPayloadLength;
webSocketContext->getExt()->idleTimeout = behavior.idleTimeout;
webSocketContext->getExt()->maxBackpressure = behavior.maxBackpressure;
webSocketContext->getExt()->closeOnBackpressureLimit = behavior.closeOnBackpressureLimit;
webSocketContext->getExt()->resetIdleTimeoutOnSend = behavior.resetIdleTimeoutOnSend;
webSocketContext->getExt()->sendPingsAutomatically = behavior.sendPingsAutomatically;
webSocketContext->getExt()->compression = behavior.compression;
httpContext->onHttp("get", pattern, [webSocketContext, httpContext = this->httpContext, behavior = std::move(behavior)](auto *res, auto *req) mutable {
/* Calculate idleTimeoutCompnents */
webSocketContext->getExt()->calculateIdleTimeoutCompnents(behavior.idleTimeout);
httpContext->onHttp("get", pattern, [webSocketContext, behavior = std::move(behavior)](auto *res, auto *req) mutable {
/* If we have this header set, it's a websocket */
std::string_view secWebSocketKey = req->getHeader("sec-websocket-key");
if (secWebSocketKey.length() == 24) {
/* Note: OpenSSL can be used here to speed this up somewhat */
char secWebSocketAccept[29] = {};
WebSocketHandshake::generate(secWebSocketKey.data(), secWebSocketAccept);
res->writeStatus("101 Switching Protocols")
->writeHeader("Upgrade", "websocket")
->writeHeader("Connection", "Upgrade")
->writeHeader("Sec-WebSocket-Accept", secWebSocketAccept);
/* Emit upgrade handler */
if (behavior.upgrade) {
behavior.upgrade(res, req, (struct us_socket_context_t *) webSocketContext);
} else {
/* Default handler upgrades to WebSocket */
std::string_view secWebSocketProtocol = req->getHeader("sec-websocket-protocol");
std::string_view secWebSocketExtensions = req->getHeader("sec-websocket-extensions");
/* Select first subprotocol if present */
std::string_view secWebSocketProtocol = req->getHeader("sec-websocket-protocol");
if (secWebSocketProtocol.length()) {
res->writeHeader("Sec-WebSocket-Protocol", secWebSocketProtocol.substr(0, secWebSocketProtocol.find(',')));
}
/* Negotiate compression */
bool perMessageDeflate = false;
bool slidingDeflateWindow = false;
if (behavior.compression != DISABLED) {
std::string_view extensions = req->getHeader("sec-websocket-extensions");
if (extensions.length()) {
/* We never support client context takeover (the client cannot compress with a sliding window). */
int wantedOptions = PERMESSAGE_DEFLATE | CLIENT_NO_CONTEXT_TAKEOVER;
/* Shared compressor is the default */
if (behavior.compression == SHARED_COMPRESSOR) {
/* Disable per-socket compressor */
wantedOptions |= SERVER_NO_CONTEXT_TAKEOVER;
}
/* isServer = true */
ExtensionsNegotiator<true> extensionsNegotiator(wantedOptions);
extensionsNegotiator.readOffer(extensions);
/* Todo: remove these mid string copies */
std::string offer = extensionsNegotiator.generateOffer();
if (offer.length()) {
res->writeHeader("Sec-WebSocket-Extensions", offer);
}
/* Did we negotiate permessage-deflate? */
if (extensionsNegotiator.getNegotiatedOptions() & PERMESSAGE_DEFLATE) {
perMessageDeflate = true;
}
/* Is the server allowed to compress with a sliding window? */
if (!(extensionsNegotiator.getNegotiatedOptions() & SERVER_NO_CONTEXT_TAKEOVER)) {
slidingDeflateWindow = true;
}
}
}
/* This will add our mark */
res->upgrade();
/* Move any backpressure */
std::string backpressure(std::move(((AsyncSocketData<SSL> *) res->getHttpResponseData())->buffer));
/* Keep any fallback buffer alive until we returned from open event, keeping req valid */
std::string fallback(std::move(res->getHttpResponseData()->salvageFallbackBuffer()));
/* Destroy HttpResponseData */
res->getHttpResponseData()->~HttpResponseData();
/* Adopting a socket invalidates it, do not rely on it directly to carry any data */
WebSocket<SSL, true> *webSocket = (WebSocket<SSL, true> *) us_socket_context_adopt_socket(SSL,
(us_socket_context_t *) webSocketContext, (us_socket_t *) res, sizeof(WebSocketData) + sizeof(UserData));
/* Update corked socket in case we got a new one (assuming we always are corked in handlers). */
webSocket->AsyncSocket<SSL>::cork();
/* Initialize websocket with any moved backpressure intact */
httpContext->upgradeToWebSocket(
webSocket->init(perMessageDeflate, slidingDeflateWindow, std::move(backpressure))
);
/* Arm idleTimeout */
us_socket_timeout(SSL, (us_socket_t *) webSocket, behavior.idleTimeout);
/* Default construct the UserData right before calling open handler */
new (webSocket->getUserData()) UserData;
/* Emit open event and start the timeout */
if (behavior.open) {
behavior.open(webSocket, req);
res->template upgrade<UserData>({}, secWebSocketKey, secWebSocketProtocol, secWebSocketExtensions, (struct us_socket_context_t *) webSocketContext);
}
/* We are going to get uncorked by the Http get return */
@ -257,84 +288,104 @@ public:
return std::move(*this);
}
TemplatedApp &&get(std::string pattern, fu2::unique_function<void(HttpResponse<SSL> *, HttpRequest *)> &&handler) {
httpContext->onHttp("get", pattern, std::move(handler));
TemplatedApp &&get(std::string pattern, MoveOnlyFunction<void(HttpResponse<SSL> *, HttpRequest *)> &&handler) {
if (httpContext) {
httpContext->onHttp("get", pattern, std::move(handler));
}
return std::move(*this);
}
TemplatedApp &&post(std::string pattern, fu2::unique_function<void(HttpResponse<SSL> *, HttpRequest *)> &&handler) {
httpContext->onHttp("post", pattern, std::move(handler));
TemplatedApp &&post(std::string pattern, MoveOnlyFunction<void(HttpResponse<SSL> *, HttpRequest *)> &&handler) {
if (httpContext) {
httpContext->onHttp("post", pattern, std::move(handler));
}
return std::move(*this);
}
TemplatedApp &&options(std::string pattern, fu2::unique_function<void(HttpResponse<SSL> *, HttpRequest *)> &&handler) {
httpContext->onHttp("options", pattern, std::move(handler));
TemplatedApp &&options(std::string pattern, MoveOnlyFunction<void(HttpResponse<SSL> *, HttpRequest *)> &&handler) {
if (httpContext) {
httpContext->onHttp("options", pattern, std::move(handler));
}
return std::move(*this);
}
TemplatedApp &&del(std::string pattern, fu2::unique_function<void(HttpResponse<SSL> *, HttpRequest *)> &&handler) {
httpContext->onHttp("delete", pattern, std::move(handler));
TemplatedApp &&del(std::string pattern, MoveOnlyFunction<void(HttpResponse<SSL> *, HttpRequest *)> &&handler) {
if (httpContext) {
httpContext->onHttp("delete", pattern, std::move(handler));
}
return std::move(*this);
}
TemplatedApp &&patch(std::string pattern, fu2::unique_function<void(HttpResponse<SSL> *, HttpRequest *)> &&handler) {
httpContext->onHttp("patch", pattern, std::move(handler));
TemplatedApp &&patch(std::string pattern, MoveOnlyFunction<void(HttpResponse<SSL> *, HttpRequest *)> &&handler) {
if (httpContext) {
httpContext->onHttp("patch", pattern, std::move(handler));
}
return std::move(*this);
}
TemplatedApp &&put(std::string pattern, fu2::unique_function<void(HttpResponse<SSL> *, HttpRequest *)> &&handler) {
httpContext->onHttp("put", pattern, std::move(handler));
TemplatedApp &&put(std::string pattern, MoveOnlyFunction<void(HttpResponse<SSL> *, HttpRequest *)> &&handler) {
if (httpContext) {
httpContext->onHttp("put", pattern, std::move(handler));
}
return std::move(*this);
}
TemplatedApp &&head(std::string pattern, fu2::unique_function<void(HttpResponse<SSL> *, HttpRequest *)> &&handler) {
httpContext->onHttp("head", pattern, std::move(handler));
TemplatedApp &&head(std::string pattern, MoveOnlyFunction<void(HttpResponse<SSL> *, HttpRequest *)> &&handler) {
if (httpContext) {
httpContext->onHttp("head", pattern, std::move(handler));
}
return std::move(*this);
}
TemplatedApp &&connect(std::string pattern, fu2::unique_function<void(HttpResponse<SSL> *, HttpRequest *)> &&handler) {
httpContext->onHttp("connect", pattern, std::move(handler));
TemplatedApp &&connect(std::string pattern, MoveOnlyFunction<void(HttpResponse<SSL> *, HttpRequest *)> &&handler) {
if (httpContext) {
httpContext->onHttp("connect", pattern, std::move(handler));
}
return std::move(*this);
}
TemplatedApp &&trace(std::string pattern, fu2::unique_function<void(HttpResponse<SSL> *, HttpRequest *)> &&handler) {
httpContext->onHttp("trace", pattern, std::move(handler));
TemplatedApp &&trace(std::string pattern, MoveOnlyFunction<void(HttpResponse<SSL> *, HttpRequest *)> &&handler) {
if (httpContext) {
httpContext->onHttp("trace", pattern, std::move(handler));
}
return std::move(*this);
}
/* This one catches any method */
TemplatedApp &&any(std::string pattern, fu2::unique_function<void(HttpResponse<SSL> *, HttpRequest *)> &&handler) {
httpContext->onHttp("*", pattern, std::move(handler));
TemplatedApp &&any(std::string pattern, MoveOnlyFunction<void(HttpResponse<SSL> *, HttpRequest *)> &&handler) {
if (httpContext) {
httpContext->onHttp("*", pattern, std::move(handler));
}
return std::move(*this);
}
/* Host, port, callback */
TemplatedApp &&listen(std::string host, int port, fu2::unique_function<void(us_listen_socket_t *)> &&handler) {
TemplatedApp &&listen(std::string host, int port, MoveOnlyFunction<void(us_listen_socket_t *)> &&handler) {
if (!host.length()) {
return listen(port, std::move(handler));
}
handler(httpContext->listen(host.c_str(), port, 0));
handler(httpContext ? httpContext->listen(host.c_str(), port, 0) : nullptr);
return std::move(*this);
}
/* Host, port, options, callback */
TemplatedApp &&listen(std::string host, int port, int options, fu2::unique_function<void(us_listen_socket_t *)> &&handler) {
TemplatedApp &&listen(std::string host, int port, int options, MoveOnlyFunction<void(us_listen_socket_t *)> &&handler) {
if (!host.length()) {
return listen(port, options, std::move(handler));
}
handler(httpContext->listen(host.c_str(), port, options));
handler(httpContext ? httpContext->listen(host.c_str(), port, options) : nullptr);
return std::move(*this);
}
/* Port, callback */
TemplatedApp &&listen(int port, fu2::unique_function<void(us_listen_socket_t *)> &&handler) {
handler(httpContext->listen(nullptr, port, 0));
TemplatedApp &&listen(int port, MoveOnlyFunction<void(us_listen_socket_t *)> &&handler) {
handler(httpContext ? httpContext->listen(nullptr, port, 0) : nullptr);
return std::move(*this);
}
/* Port, options, callback */
TemplatedApp &&listen(int port, int options, fu2::unique_function<void(us_listen_socket_t *)> &&handler) {
handler(httpContext->listen(nullptr, port, options));
TemplatedApp &&listen(int port, int options, MoveOnlyFunction<void(us_listen_socket_t *)> &&handler) {
handler(httpContext ? httpContext->listen(nullptr, port, options) : nullptr);
return std::move(*this);
}

View File

@ -1,5 +1,5 @@
/*
* Authored by Alex Hultman, 2018-2019.
* Authored by Alex Hultman, 2018-2020.
* Intellectual property of third-party.
* Licensed under the Apache License, Version 2.0 (the "License");
@ -20,21 +20,30 @@
/* This class implements async socket memory management strategies */
/* NOTE: Many unsigned/signed conversion warnings could be solved by moving from int length
* to unsigned length for everything to/from uSockets - this would however remove the opportunity
* to signal error with -1 (which is how the entire UNIX syscalling is built). */
#include "LoopData.h"
#include "AsyncSocketData.h"
namespace uWS {
template <bool, bool> struct WebSocketContext;
template <bool, bool, typename> struct WebSocketContext;
template <bool SSL>
struct AsyncSocket {
template <bool> friend struct HttpContext;
template <bool, bool> friend struct WebSocketContext;
template <bool> friend struct WebSocketContextData;
template <bool, bool, typename> friend struct WebSocketContext;
template <bool, typename> friend struct WebSocketContextData;
friend struct TopicTree;
protected:
/* Returns SSL pointer or FD as pointer */
void *getNativeHandle() {
return us_socket_get_native_handle(SSL, (us_socket_t *) this);
}
/* Get loop data for socket */
LoopData *getLoopData() {
return (LoopData *) us_loop_ext(us_socket_context_loop(SSL, us_socket_context(SSL, (us_socket_t *) this)));
@ -57,7 +66,7 @@ protected:
/* Immediately close socket */
us_socket_t *close() {
return us_socket_close(SSL, (us_socket_t *) this);
return us_socket_close(SSL, (us_socket_t *) this, 0, nullptr);
}
/* Cork this socket. Only one socket may ever be corked per-loop at any given time */
@ -82,7 +91,7 @@ protected:
LoopData *loopData = getLoopData();
if (loopData->corkedSocket == this && loopData->corkOffset + size < LoopData::CORK_BUFFER_SIZE) {
char *sendBuffer = loopData->corkBuffer + loopData->corkOffset;
loopData->corkOffset += (int) size;
loopData->corkOffset += (unsigned int) size;
return {sendBuffer, false};
} else {
/* Slow path for now, we want to always be corked if possible */
@ -91,8 +100,30 @@ protected:
}
/* Returns the user space backpressure. */
int getBufferedAmount() {
return (int) getAsyncSocketData()->buffer.size();
unsigned int getBufferedAmount() {
return (unsigned int) getAsyncSocketData()->buffer.size();
}
/* Returns the text representation of an IPv4 or IPv6 address */
std::string_view addressAsText(std::string_view binary) {
static thread_local char buf[64];
int ipLength = 0;
if (!binary.length()) {
return {};
}
unsigned char *b = (unsigned char *) binary.data();
if (binary.length() == 4) {
ipLength = sprintf(buf, "%u.%u.%u.%u", b[0], b[1], b[2], b[3]);
} else {
ipLength = sprintf(buf, "%02x%02x:%02x%02x:%02x%02x:%02x%02x:%02x%02x:%02x%02x:%02x%02x:%02x%02x",
b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7], b[8], b[9], b[10], b[11],
b[12], b[13], b[14], b[15]);
}
return {buf, (unsigned int) ipLength};
}
/* Returns the remote IP address or empty string on failure */
@ -100,7 +131,12 @@ protected:
static thread_local char buf[16];
int ipLength = 16;
us_socket_remote_address(SSL, (us_socket_t *) this, buf, &ipLength);
return std::string_view(buf, ipLength);
return std::string_view(buf, (unsigned int) ipLength);
}
/* Returns the text representation of IP */
std::string_view getRemoteAddressAsText() {
return addressAsText(getRemoteAddress());
}
/* Write in three levels of prioritization: cork-buffer, syscall, socket-buffer. Always drain if possible.
@ -124,14 +160,14 @@ protected:
if ((unsigned int) written < asyncSocketData->buffer.length()) {
/* Update buffering (todo: we can do better here if we keep track of what happens to this guy later on) */
asyncSocketData->buffer = asyncSocketData->buffer.substr(written);
asyncSocketData->buffer = asyncSocketData->buffer.substr((size_t) written);
if (optionally) {
/* Thankfully we can exit early here */
return {0, true};
} else {
/* This path is horrible and points towards erroneous usage */
asyncSocketData->buffer.append(src, length);
asyncSocketData->buffer.append(src, (unsigned int) length);
return {length, true};
}
@ -144,21 +180,21 @@ protected:
if (length) {
if (loopData->corkedSocket == this) {
/* We are corked */
if (LoopData::CORK_BUFFER_SIZE - loopData->corkOffset >= length) {
if (LoopData::CORK_BUFFER_SIZE - loopData->corkOffset >= (unsigned int) length) {
/* If the entire chunk fits in cork buffer */
memcpy(loopData->corkBuffer + loopData->corkOffset, src, length);
loopData->corkOffset += length;
memcpy(loopData->corkBuffer + loopData->corkOffset, src, (unsigned int) length);
loopData->corkOffset += (unsigned int) length;
/* Fall through to default return */
} else {
/* Strategy differences between SSL and non-SSL regarding syscall minimizing */
if constexpr (SSL) {
/* Cork up as much as we can */
int stripped = LoopData::CORK_BUFFER_SIZE - loopData->corkOffset;
unsigned int stripped = LoopData::CORK_BUFFER_SIZE - loopData->corkOffset;
memcpy(loopData->corkBuffer + loopData->corkOffset, src, stripped);
loopData->corkOffset = LoopData::CORK_BUFFER_SIZE;
auto [written, failed] = uncork(src + stripped, length - stripped, optionally);
return {written + stripped, failed};
auto [written, failed] = uncork(src + stripped, length - (int) stripped, optionally);
return {written + (int) stripped, failed};
}
/* For non-SSL we take the penalty of two syscalls */
@ -178,11 +214,11 @@ protected:
/* Fall back to worst possible case (should be very rare for HTTP) */
/* At least we can reserve room for next chunk if we know it up front */
if (nextLength) {
asyncSocketData->buffer.reserve(asyncSocketData->buffer.length() + length - written + nextLength);
asyncSocketData->buffer.reserve(asyncSocketData->buffer.length() + (size_t) (length - written + nextLength));
}
/* Buffer this chunk */
asyncSocketData->buffer.append(src + written, length - written);
asyncSocketData->buffer.append(src + written, (size_t) (length - written));
/* Return the failure */
return {length, true};
@ -205,7 +241,7 @@ protected:
if (loopData->corkOffset) {
/* Corked data is already accounted for via its write call */
auto [written, failed] = write(loopData->corkBuffer, loopData->corkOffset, false, length);
auto [written, failed] = write(loopData->corkBuffer, (int) loopData->corkOffset, false, length);
loopData->corkOffset = 0;
if (failed) {

View File

@ -0,0 +1,65 @@
/*
* Authored by Alex Hultman, 2018-2019.
* Intellectual property of third-party.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef UWS_BLOOMFILTER_H
#define UWS_BLOOMFILTER_H
/* This filter has a decently low amount of false positives for the
* standard and non-standard common request headers */
#include <string_view>
#include <bitset>
namespace uWS {
struct BloomFilter {
private:
std::bitset<512> filter;
unsigned int hash1(std::string_view key) {
return ((size_t)key[key.length() - 1] - (key.length() << 3)) & 511;
}
unsigned int hash2(std::string_view key) {
return (((size_t)key[0] + (key.length() << 4)) & 511);
}
unsigned int hash3(std::string_view key) {
return ((unsigned int)key[key.length() - 2] - 97 - (key.length() << 5)) & 511;
}
public:
bool mightHave(std::string_view key) {
return filter.test(hash1(key)) && filter.test(hash2(key)) && (key.length() < 2 || filter.test(hash3(key)));
}
void add(std::string_view key) {
filter.set(hash1(key));
filter.set(hash2(key));
if (key.length() >= 2) {
filter.set(hash3(key));
}
}
void reset() {
filter.reset();
}
};
}
#endif // UWS_BLOOMFILTER_H

View File

@ -1,5 +1,5 @@
/*
* Authored by Alex Hultman, 2018-2019.
* Authored by Alex Hultman, 2018-2020.
* Intellectual property of third-party.
* Licensed under the Apache License, Version 2.0 (the "License");
@ -24,10 +24,11 @@
#include "HttpContextData.h"
#include "HttpResponseData.h"
#include "AsyncSocket.h"
#include "WebSocketData.h"
#include <string_view>
#include <iostream>
#include "f2/function2.hpp"
#include "MoveOnlyFunction.h"
namespace uWS {
template<bool> struct HttpResponse;
@ -35,6 +36,7 @@ template<bool> struct HttpResponse;
template <bool SSL>
struct HttpContext {
template<bool> friend struct TemplatedApp;
template<bool> friend struct HttpResponse;
private:
HttpContext() = delete;
@ -60,7 +62,7 @@ private:
/* Init the HttpContext by registering libusockets event handlers */
HttpContext<SSL> *init() {
/* Handle socket connections */
us_socket_context_on_open(SSL, getSocketContext(), [](us_socket_t *s, int is_client, char *ip, int ip_length) {
us_socket_context_on_open(SSL, getSocketContext(), [](us_socket_t *s, int /*is_client*/, char */*ip*/, int /*ip_length*/) {
/* Any connected socket should timeout until it has a request */
us_socket_timeout(SSL, s, HTTP_IDLE_TIMEOUT_S);
@ -77,7 +79,7 @@ private:
});
/* Handle socket disconnections */
us_socket_context_on_close(SSL, getSocketContext(), [](us_socket_t *s) {
us_socket_context_on_close(SSL, getSocketContext(), [](us_socket_t *s, int /*code*/, void */*reason*/) {
/* Get socket ext */
HttpResponseData<SSL> *httpResponseData = (HttpResponseData<SSL> *) us_socket_ext(SSL, s);
@ -119,11 +121,19 @@ private:
/* Cork this socket */
((AsyncSocket<SSL> *) s)->cork();
/* Mark that we are inside the parser now */
httpContextData->isParsingHttp = true;
// clients need to know the cursor after http parse, not servers!
// how far did we read then? we need to know to continue with websocket parsing data? or?
void *proxyParser = nullptr;
#ifdef UWS_WITH_PROXY
proxyParser = &httpResponseData->proxyParser;
#endif
/* The return value is entirely up to us to interpret. The HttpParser only care for whether the returned value is DIFFERENT or not from passed user */
void *returnedSocket = httpResponseData->consumePostPadded(data, length, s, [httpContextData](void *s, uWS::HttpRequest *httpRequest) -> void * {
void *returnedSocket = httpResponseData->consumePostPadded(data, (unsigned int) length, s, proxyParser, [httpContextData](void *s, HttpRequest *httpRequest) -> void * {
/* For every request we reset the timeout and hang until user makes action */
/* Warning: if we are in shutdown state, resetting the timer is a security issue! */
us_socket_timeout(SSL, (us_socket_t *) s, 0);
@ -134,18 +144,23 @@ private:
/* Are we not ready for another request yet? Terminate the connection. */
if (httpResponseData->state & HttpResponseData<SSL>::HTTP_RESPONSE_PENDING) {
us_socket_close(SSL, (us_socket_t *) s);
us_socket_close(SSL, (us_socket_t *) s, 0, nullptr);
return nullptr;
}
/* Mark pending request and emit it */
httpResponseData->state = HttpResponseData<SSL>::HTTP_RESPONSE_PENDING;
/* Mark this response as connectionClose if ancient or connection: close */
if (httpRequest->isAncient() || httpRequest->getHeader("connection").length() == 5) {
httpResponseData->state |= HttpResponseData<SSL>::HTTP_CONNECTION_CLOSE;
}
/* Route the method and URL */
httpContextData->router.getUserData() = {(HttpResponse<SSL> *) s, httpRequest};
if (!httpContextData->router.route(httpRequest->getMethod(), httpRequest->getUrl())) {
/* We have to force close this socket as we have no handler for it */
us_socket_close(SSL, (us_socket_t *) s);
us_socket_close(SSL, (us_socket_t *) s, 0, nullptr);
return nullptr;
}
@ -205,7 +220,7 @@ private:
if (us_socket_is_shut_down(SSL, (us_socket_t *) user)) {
return nullptr;
}
/* If we were given the last data chunk, reset data handler to ensure following
* requests on the same socket won't trigger any previously registered behavior */
if (fin) {
@ -215,10 +230,13 @@ private:
return user;
}, [](void *user) {
/* Close any socket on HTTP errors */
us_socket_close(SSL, (us_socket_t *) user);
us_socket_close(SSL, (us_socket_t *) user, 0, nullptr);
return nullptr;
});
/* Mark that we are no longer parsing Http */
httpContextData->isParsingHttp = false;
/* We need to uncork in all cases, except for nullptr (closed socket, or upgraded socket) */
if (returnedSocket != nullptr) {
/* Timeout on uncork failure */
@ -229,6 +247,18 @@ private:
((AsyncSocket<SSL> *) s)->timeout(HTTP_IDLE_TIMEOUT_S);
}
/* We need to check if we should close this socket here now */
if (httpResponseData->state & HttpResponseData<SSL>::HTTP_CONNECTION_CLOSE) {
if ((httpResponseData->state & HttpResponseData<SSL>::HTTP_RESPONSE_PENDING) == 0) {
if (((AsyncSocket<SSL> *) s)->getBufferedAmount() == 0) {
((AsyncSocket<SSL> *) s)->shutdown();
/* We need to force close after sending FIN since we want to hinder
* clients from keeping to send their huge data */
((AsyncSocket<SSL> *) s)->close();
}
}
}
return (us_socket_t *) returnedSocket;
}
@ -238,7 +268,16 @@ private:
AsyncSocket<SSL> *asyncSocket = (AsyncSocket<SSL> *) httpContextData->upgradedWebSocket;
/* Uncork here as well (note: what if we failed to uncork and we then pub/sub before we even upgraded?) */
/*auto [written, failed] = */asyncSocket->uncork();
auto [written, failed] = asyncSocket->uncork();
/* If we succeeded in uncorking, check if we have sent WebSocket FIN */
if (!failed) {
WebSocketData *webSocketData = (WebSocketData *) asyncSocket->getAsyncSocketData();
if (webSocketData->isShuttingDown) {
/* In that case, also send TCP FIN (this is similar to what we have in ws drain handler) */
asyncSocket->shutdown();
}
}
/* Reset upgradedWebSocket before we return */
httpContextData->upgradedWebSocket = nullptr;
@ -283,6 +322,18 @@ private:
/* Drain any socket buffer, this might empty our backpressure and thus finish the request */
/*auto [written, failed] = */asyncSocket->write(nullptr, 0, true, 0);
/* Should we close this connection after a response - and is this response really done? */
if (httpResponseData->state & HttpResponseData<SSL>::HTTP_CONNECTION_CLOSE) {
if ((httpResponseData->state & HttpResponseData<SSL>::HTTP_RESPONSE_PENDING) == 0) {
if (asyncSocket->getBufferedAmount() == 0) {
asyncSocket->shutdown();
/* We need to force close after sending FIN since we want to hinder
* clients from keeping to send their huge data */
asyncSocket->close();
}
}
}
/* Expect another writable event, or another request within the timeout */
asyncSocket->timeout(HTTP_IDLE_TIMEOUT_S);
@ -310,13 +361,6 @@ private:
return this;
}
/* Used by App in its WebSocket handler */
void upgradeToWebSocket(void *newSocket) {
HttpContextData<SSL> *httpContextData = getSocketContextData();
httpContextData->upgradedWebSocket = newSocket;
}
public:
/* Construct a new HttpContext using specified loop */
static HttpContext *create(Loop *loop, us_socket_context_options_t options = {}) {
@ -343,12 +387,12 @@ public:
us_socket_context_free(SSL, getSocketContext());
}
void filter(fu2::unique_function<void(HttpResponse<SSL> *, int)> &&filterHandler) {
void filter(MoveOnlyFunction<void(HttpResponse<SSL> *, int)> &&filterHandler) {
getSocketContextData()->filterHandlers.emplace_back(std::move(filterHandler));
}
/* Register an HTTP route handler acording to URL pattern */
void onHttp(std::string method, std::string pattern, fu2::unique_function<void(HttpResponse<SSL> *, HttpRequest *)> &&handler, bool upgrade = false) {
void onHttp(std::string method, std::string pattern, MoveOnlyFunction<void(HttpResponse<SSL> *, HttpRequest *)> &&handler, bool upgrade = false) {
HttpContextData<SSL> *httpContextData = getSocketContextData();
/* Todo: This is ugly, fix */
@ -363,6 +407,13 @@ public:
auto user = r->getUserData();
user.httpRequest->setYield(false);
user.httpRequest->setParameters(r->getParameters());
/* Middleware? Automatically respond to expectations */
std::string_view expect = user.httpRequest->getHeader("expect");
if (expect.length() && expect == "100-continue") {
user.httpResponse->writeContinue();
}
handler(user.httpResponse, user.httpRequest);
/* If any handler yielded, the router will keep looking for a suitable handler. */

View File

@ -1,5 +1,5 @@
/*
* Authored by Alex Hultman, 2018-2019.
* Authored by Alex Hultman, 2018-2020.
* Intellectual property of third-party.
* Licensed under the Apache License, Version 2.0 (the "License");
@ -21,7 +21,7 @@
#include "HttpRouter.h"
#include <vector>
#include "f2/function2.hpp"
#include "MoveOnlyFunction.h"
namespace uWS {
template<bool> struct HttpResponse;
@ -31,8 +31,11 @@ template <bool SSL>
struct alignas(16) HttpContextData {
template <bool> friend struct HttpContext;
template <bool> friend struct HttpResponse;
template <bool> friend struct TemplatedApp;
private:
std::vector<fu2::unique_function<void(HttpResponse<SSL> *, int)>> filterHandlers;
std::vector<MoveOnlyFunction<void(HttpResponse<SSL> *, int)>> filterHandlers;
MoveOnlyFunction<void(const char *hostname)> missingServerNameHandler;
struct RouterData {
HttpResponse<SSL> *httpResponse;
@ -41,6 +44,7 @@ private:
HttpRouter<RouterData> router;
void *upgradedWebSocket = nullptr;
bool isParsingHttp = false;
};
}

View File

@ -1,5 +1,5 @@
/*
* Authored by Alex Hultman, 2018-2019.
* Authored by Alex Hultman, 2018-2020.
* Intellectual property of third-party.
* Licensed under the Apache License, Version 2.0 (the "License");
@ -25,12 +25,16 @@
#include <string>
#include <cstring>
#include <algorithm>
#include "f2/function2.hpp"
#include "MoveOnlyFunction.h"
#include "BloomFilter.h"
#include "ProxyParser.h"
#include "QueryParser.h"
namespace uWS {
/* We require at least this much post padding */
static const int MINIMUM_HTTP_POST_PADDING = 32;
static const unsigned int MINIMUM_HTTP_POST_PADDING = 32;
struct HttpRequest {
@ -41,12 +45,17 @@ private:
struct Header {
std::string_view key, value;
} headers[MAX_HEADERS];
int querySeparator;
bool ancientHttp;
unsigned int querySeparator;
bool didYield;
BloomFilter bf;
std::pair<int, std::string_view *> currentParameters;
public:
bool isAncient() {
return ancientHttp;
}
bool getYield() {
return didYield;
}
@ -87,9 +96,11 @@ public:
}
std::string_view getHeader(std::string_view lowerCasedHeader) {
for (Header *h = headers; (++h)->key.length(); ) {
if (h->key.length() == lowerCasedHeader.length() && !strncmp(h->key.data(), lowerCasedHeader.data(), lowerCasedHeader.length())) {
return h->value;
if (bf.mightHave(lowerCasedHeader)) {
for (Header *h = headers; (++h)->key.length(); ) {
if (h->key.length() == lowerCasedHeader.length() && !strncmp(h->key.data(), lowerCasedHeader.data(), lowerCasedHeader.length())) {
return h->value;
}
}
}
return std::string_view(nullptr, 0);
@ -103,8 +114,9 @@ public:
return std::string_view(headers->key.data(), headers->key.length());
}
/* Returns the raw querystring as a whole, still encoded */
std::string_view getQuery() {
if (querySeparator < (int) headers->value.length()) {
if (querySeparator < headers->value.length()) {
/* Strip the initial ? */
return std::string_view(headers->value.data() + querySeparator + 1, headers->value.length() - querySeparator - 1);
} else {
@ -112,11 +124,19 @@ public:
}
}
/* Finds and decodes the URI component. */
std::string_view getQuery(std::string_view key) {
/* Raw querystring including initial '?' sign */
std::string_view queryString = std::string_view(headers->value.data() + querySeparator, headers->value.length() - querySeparator);
return getDecodedQueryValue(key, queryString);
}
void setParameters(std::pair<int, std::string_view *> parameters) {
currentParameters = parameters;
}
std::string_view getParameter(unsigned int index) {
std::string_view getParameter(unsigned short index) {
if (currentParameters.first < (int) index) {
return {};
} else {
@ -136,15 +156,40 @@ private:
static unsigned int toUnsignedInteger(std::string_view str) {
unsigned int unsignedIntegerValue = 0;
for (unsigned char c : str) {
unsignedIntegerValue = unsignedIntegerValue * 10 + (c - '0');
for (char c : str) {
unsignedIntegerValue = unsignedIntegerValue * 10u + ((unsigned int) c - (unsigned int) '0');
}
return unsignedIntegerValue;
}
static unsigned int getHeaders(char *postPaddedBuffer, char *end, struct HttpRequest::Header *headers) {
static unsigned int getHeaders(char *postPaddedBuffer, char *end, struct HttpRequest::Header *headers, void *reserved) {
char *preliminaryKey, *preliminaryValue, *start = postPaddedBuffer;
#ifdef UWS_WITH_PROXY
/* ProxyParser is passed as reserved parameter */
ProxyParser *pp = (ProxyParser *) reserved;
/* Parse PROXY protocol */
auto [done, offset] = pp->parse({start, (size_t) (end - postPaddedBuffer)});
if (!done) {
/* We do not reset the ProxyParser (on filure) since it is tied to this
* connection, which is really only supposed to ever get one PROXY frame
* anyways. We do however allow multiple PROXY frames to be sent (overwrites former). */
return 0;
} else {
/* We have consumed this data so skip it */
start += offset;
}
#else
/* This one is unused */
(void) reserved;
#endif
/* It is critical for fallback buffering logic that we only return with success
* if we managed to parse a complete HTTP request (minus data). Returning success
* for PROXY means we can end up succeeding, yet leaving bytes in the fallback buffer
* which is then removed, and our counters to flip due to overflow and we end up with a crash */
for (unsigned int i = 0; i < HttpRequest::MAX_HEADERS; i++) {
for (preliminaryKey = postPaddedBuffer; (*postPaddedBuffer != ':') & (*postPaddedBuffer > 32); *(postPaddedBuffer++) |= 32);
if (*postPaddedBuffer == '\r') {
@ -158,7 +203,7 @@ private:
headers->key = std::string_view(preliminaryKey, (size_t) (postPaddedBuffer - preliminaryKey));
for (postPaddedBuffer++; (*postPaddedBuffer == ':' || *postPaddedBuffer < 33) && *postPaddedBuffer != '\r'; postPaddedBuffer++);
preliminaryValue = postPaddedBuffer;
postPaddedBuffer = (char *) memchr(postPaddedBuffer, '\r', end - postPaddedBuffer);
postPaddedBuffer = (char *) memchr(postPaddedBuffer, '\r', (size_t) (end - postPaddedBuffer));
if (postPaddedBuffer && postPaddedBuffer[1] == '\n') {
headers->value = std::string_view(preliminaryValue, (size_t) (postPaddedBuffer - preliminaryValue));
postPaddedBuffer += 2;
@ -173,20 +218,34 @@ private:
// the only caller of getHeaders
template <int CONSUME_MINIMALLY>
std::pair<int, void *> fenceAndConsumePostPadded(char *data, int length, void *user, HttpRequest *req, fu2::unique_function<void *(void *, HttpRequest *)> &requestHandler, fu2::unique_function<void *(void *, std::string_view, bool)> &dataHandler) {
int consumedTotal = 0;
std::pair<unsigned int, void *> fenceAndConsumePostPadded(char *data, unsigned int length, void *user, void *reserved, HttpRequest *req, MoveOnlyFunction<void *(void *, HttpRequest *)> &requestHandler, MoveOnlyFunction<void *(void *, std::string_view, bool)> &dataHandler) {
/* How much data we CONSUMED (to throw away) */
unsigned int consumedTotal = 0;
/* Fence one byte past end of our buffer (buffer has post padded margins) */
data[length] = '\r';
for (int consumed; length && (consumed = getHeaders(data, data + length, req->headers)); ) {
for (unsigned int consumed; length && (consumed = getHeaders(data, data + length, req->headers, reserved)); ) {
data += consumed;
length -= consumed;
consumedTotal += consumed;
req->headers->value = std::string_view(req->headers->value.data(), std::max<int>(0, (int) req->headers->value.length() - 9));
/* Store HTTP version (ancient 1.0 or 1.1) */
req->ancientHttp = req->headers->value.length() && (req->headers->value[req->headers->value.length() - 1] == '0');
/* Strip away tail of first "header value" aka URL */
req->headers->value = std::string_view(req->headers->value.data(), (size_t) std::max<int>(0, (int) req->headers->value.length() - 9));
/* Add all headers to bloom filter */
req->bf.reset();
for (HttpRequest::Header *h = req->headers; (++h)->key.length(); ) {
req->bf.add(h->key);
}
/* Parse query */
const char *querySeparatorPtr = (const char *) memchr(req->headers->value.data(), '?', req->headers->value.length());
req->querySeparator = (int) ((querySeparatorPtr ? querySeparatorPtr : req->headers->value.data() + req->headers->value.length()) - req->headers->value.data());
req->querySeparator = (unsigned int) ((querySeparatorPtr ? querySeparatorPtr : req->headers->value.data() + req->headers->value.length()) - req->headers->value.data());
/* If returned socket is not what we put in we need
* to break here as we either have upgraded to
@ -225,22 +284,18 @@ private:
}
public:
void *consumePostPadded(char *data, unsigned int length, void *user, void *reserved, MoveOnlyFunction<void *(void *, HttpRequest *)> &&requestHandler, MoveOnlyFunction<void *(void *, std::string_view, bool)> &&dataHandler, MoveOnlyFunction<void *(void *)> &&errorHandler) {
/* We do this to prolong the validity of parsed headers by keeping only the fallback buffer alive */
std::string &&salvageFallbackBuffer() {
return std::move(fallback);
}
void *consumePostPadded(char *data, int length, void *user, fu2::unique_function<void *(void *, HttpRequest *)> &&requestHandler, fu2::unique_function<void *(void *, std::string_view, bool)> &&dataHandler, fu2::unique_function<void *(void *)> &&errorHandler) {
/* This resets BloomFilter by construction, but later we also reset it again.
* Optimize this to skip resetting twice (req could be made global) */
HttpRequest req;
if (remainingStreamingBytes) {
// this is exactly the same as below!
// todo: refactor this
if (remainingStreamingBytes >= (unsigned int) length) {
void *returnedUser = dataHandler(user, std::string_view(data, length), remainingStreamingBytes == (unsigned int) length);
if (remainingStreamingBytes >= length) {
void *returnedUser = dataHandler(user, std::string_view(data, length), remainingStreamingBytes == length);
remainingStreamingBytes -= length;
return returnedUser;
} else {
@ -257,24 +312,26 @@ public:
}
} else if (fallback.length()) {
int had = (int) fallback.length();
unsigned int had = (unsigned int) fallback.length();
int maxCopyDistance = (int) std::min(MAX_FALLBACK_SIZE - fallback.length(), (size_t) length);
size_t maxCopyDistance = std::min(MAX_FALLBACK_SIZE - fallback.length(), (size_t) length);
/* We don't want fallback to be short string optimized, since we want to move it */
fallback.reserve(fallback.length() + maxCopyDistance + std::max<int>(MINIMUM_HTTP_POST_PADDING, sizeof(std::string)));
fallback.reserve(fallback.length() + maxCopyDistance + std::max<unsigned int>(MINIMUM_HTTP_POST_PADDING, sizeof(std::string)));
fallback.append(data, maxCopyDistance);
// break here on break
std::pair<int, void *> consumed = fenceAndConsumePostPadded<true>(fallback.data(), (int) fallback.length(), user, &req, requestHandler, dataHandler);
std::pair<unsigned int, void *> consumed = fenceAndConsumePostPadded<true>(fallback.data(), (unsigned int) fallback.length(), user, reserved, &req, requestHandler, dataHandler);
if (consumed.second != user) {
return consumed.second;
}
if (consumed.first) {
/* This logic assumes that we consumed everything in fallback buffer.
* This is critically important, as we will get an integer overflow in case
* of "had" being larger than what we consumed, and that we would drop data */
fallback.clear();
data += consumed.first - had;
length -= consumed.first - had;
@ -308,7 +365,7 @@ public:
}
}
std::pair<int, void *> consumed = fenceAndConsumePostPadded<false>(data, length, user, &req, requestHandler, dataHandler);
std::pair<unsigned int, void *> consumed = fenceAndConsumePostPadded<false>(data, length, user, reserved, &req, requestHandler, dataHandler);
if (consumed.second != user) {
return consumed.second;
}
@ -317,7 +374,7 @@ public:
length -= consumed.first;
if (length) {
if ((unsigned int) length < MAX_FALLBACK_SIZE) {
if (length < MAX_FALLBACK_SIZE) {
fallback.append(data, length);
} else {
return errorHandler(user);

View File

@ -1,5 +1,5 @@
/*
* Authored by Alex Hultman, 2018-2019.
* Authored by Alex Hultman, 2018-2020.
* Intellectual property of third-party.
* Licensed under the Apache License, Version 2.0 (the "License");
@ -25,7 +25,12 @@
#include "HttpContextData.h"
#include "Utilities.h"
#include "f2/function2.hpp"
#include "WebSocketExtensions.h"
#include "WebSocketHandshake.h"
#include "WebSocket.h"
#include "WebSocketContextData.h"
#include "MoveOnlyFunction.h"
/* todo: tryWrite is missing currently, only send smaller segments with write */
@ -56,10 +61,10 @@ private:
Super::write(buf, length);
}
/* Write an unsigned 32-bit integer */
void writeUnsigned(unsigned int value) {
char buf[10];
int length = utils::u32toa(value, buf);
/* Write an unsigned 64-bit integer */
void writeUnsigned64(uint64_t value) {
char buf[20];
int length = utils::u64toa(value, buf);
/* For now we do this copy */
Super::write(buf, length);
@ -77,23 +82,43 @@ private:
/* Called only once per request */
void writeMark() {
/* You can disable this altogether */
#ifndef UWS_HTTPRESPONSE_NO_WRITEMARK
writeHeader("uWebSockets", "v0.17");
if (!Super::getLoopData()->noMark) {
/* We only expose major version */
writeHeader("uWebSockets", "19");
}
#endif
}
/* Returns true on success, indicating that it might be feasible to write more data.
* Will start timeout if stream reaches totalSize or write failure. */
bool internalEnd(std::string_view data, int totalSize, bool optional, bool allowContentLength = true) {
bool internalEnd(std::string_view data, size_t totalSize, bool optional, bool allowContentLength = true, bool closeConnection = false) {
/* Write status if not already done */
writeStatus(HTTP_200_OK);
/* If no total size given then assume this chunk is everything */
if (!totalSize) {
totalSize = (int) data.length();
totalSize = data.length();
}
HttpResponseData<SSL> *httpResponseData = getHttpResponseData();
/* In some cases, such as when refusing huge data we want to close the connection when drained */
if (closeConnection) {
/* HTTP 1.1 must send this back unless the client already sent it to us.
* It is a connection close when either of the two parties say so but the
* one party must tell the other one so.
*
* This check also serves to limit writing the header only once. */
if ((httpResponseData->state & HttpResponseData<SSL>::HTTP_CONNECTION_CLOSE) == 0) {
writeHeader("Connection", "close");
}
httpResponseData->state |= HttpResponseData<SSL>::HTTP_CONNECTION_CLOSE;
}
if (httpResponseData->state & HttpResponseData<SSL>::HTTP_WRITE_CALLED) {
/* We do not have tryWrite-like functionalities, so ignore optional in this path */
@ -126,7 +151,7 @@ private:
if (allowContentLength) {
/* Even zero is a valid content-length */
Super::write("Content-Length: ", 16);
writeUnsigned(totalSize);
writeUnsigned64(totalSize);
Super::write("\r\n\r\n", 4);
} else {
Super::write("\r\n", 2);
@ -140,11 +165,20 @@ private:
* if it failed to drain any prior failed header writes */
/* Write as much as possible without causing backpressure */
auto [written, failed] = Super::write(data.data(), (int) data.length(), optional);
size_t written = 0;
bool failed = false;
while (written < data.length() && !failed) {
/* uSockets only deals with int sizes, so pass chunks of max signed int size */
auto writtenFailed = Super::write(data.data() + written, (int) std::min<size_t>(data.length() - written, INT_MAX), optional);
written += (size_t) writtenFailed.first;
failed = writtenFailed.second;
}
httpResponseData->offset += written;
/* Success is when we wrote the entire thing without any failures */
bool success = (unsigned int) written == data.length() && !failed;
bool success = written == data.length() && !failed;
/* If we are now at the end, start a timeout. Also start a timeout if we failed. */
if (!success || httpResponseData->offset == totalSize) {
@ -160,20 +194,140 @@ private:
}
}
/* This call is identical to end, but will never write content-length and is thus suitable for upgrades */
void upgrade() {
internalEnd({nullptr, 0}, 0, false, false);
public:
/* If we have proxy support; returns the proxed source address as reported by the proxy. */
#ifdef UWS_WITH_PROXY
std::string_view getProxiedRemoteAddress() {
return getHttpResponseData()->proxyParser.getSourceAddress();
}
std::string_view getProxiedRemoteAddressAsText() {
return Super::addressAsText(getProxiedRemoteAddress());
}
#endif
/* Manually upgrade to WebSocket. Typically called in upgrade handler. Immediately calls open handler.
* NOTE: Will invalidate 'this' as socket might change location in memory. Throw away aftert use. */
template <typename UserData>
void upgrade(UserData &&userData, std::string_view secWebSocketKey, std::string_view secWebSocketProtocol,
std::string_view secWebSocketExtensions,
struct us_socket_context_t *webSocketContext) {
/* Extract needed parameters from WebSocketContextData */
WebSocketContextData<SSL, UserData> *webSocketContextData = (WebSocketContextData<SSL, UserData> *) us_socket_context_ext(SSL, webSocketContext);
/* Note: OpenSSL can be used here to speed this up somewhat */
char secWebSocketAccept[29] = {};
WebSocketHandshake::generate(secWebSocketKey.data(), secWebSocketAccept);
writeStatus("101 Switching Protocols")
->writeHeader("Upgrade", "websocket")
->writeHeader("Connection", "Upgrade")
->writeHeader("Sec-WebSocket-Accept", secWebSocketAccept);
/* Select first subprotocol if present */
if (secWebSocketProtocol.length()) {
writeHeader("Sec-WebSocket-Protocol", secWebSocketProtocol.substr(0, secWebSocketProtocol.find(',')));
}
/* Negotiate compression */
bool perMessageDeflate = false;
CompressOptions compressOptions = CompressOptions::DISABLED;
if (secWebSocketExtensions.length() && webSocketContextData->compression != DISABLED) {
/* We always want shared inflation */
int wantedInflationWindow = 0;
/* Map from selected compressor */
int wantedCompressionWindow = (webSocketContextData->compression & 0xFF00) >> 8;
auto [negCompression, negCompressionWindow, negInflationWindow, negResponse] =
negotiateCompression(true, wantedCompressionWindow, wantedInflationWindow,
secWebSocketExtensions);
if (negCompression) {
perMessageDeflate = true;
/* Map from windowBits to compressor */
if (negCompressionWindow == 0) {
compressOptions = CompressOptions::SHARED_COMPRESSOR;
} else {
compressOptions = (CompressOptions) ((uint32_t) (negCompressionWindow << 8)
| (uint32_t) (negCompressionWindow - 7));
/* If we are dedicated and have the 3kb then correct any 4kb to 3kb,
* (they both share the windowBits = 9) */
if (webSocketContextData->compression == DEDICATED_COMPRESSOR_3KB) {
compressOptions = DEDICATED_COMPRESSOR_3KB;
}
}
writeHeader("Sec-WebSocket-Extensions", negResponse);
}
}
internalEnd({nullptr, 0}, 0, false, false);
/* Grab the httpContext from res */
HttpContext<SSL> *httpContext = (HttpContext<SSL> *) us_socket_context(SSL, (struct us_socket_t *) this);
/* Move any backpressure out of HttpResponse */
std::string backpressure(std::move(((AsyncSocketData<SSL> *) getHttpResponseData())->buffer));
/* Destroy HttpResponseData */
getHttpResponseData()->~HttpResponseData();
/* Before we adopt and potentially change socket, check if we are corked */
bool wasCorked = Super::isCorked();
/* Adopting a socket invalidates it, do not rely on it directly to carry any data */
WebSocket<SSL, true, UserData> *webSocket = (WebSocket<SSL, true, UserData> *) us_socket_context_adopt_socket(SSL,
(us_socket_context_t *) webSocketContext, (us_socket_t *) this, sizeof(WebSocketData) + sizeof(UserData));
/* For whatever reason we were corked, update cork to the new socket */
if (wasCorked) {
webSocket->AsyncSocket<SSL>::cork();
}
/* Initialize websocket with any moved backpressure intact */
webSocket->init(perMessageDeflate, compressOptions, std::move(backpressure));
/* We should only mark this if inside the parser; if upgrading "async" we cannot set this */
HttpContextData<SSL> *httpContextData = httpContext->getSocketContextData();
if (httpContextData->isParsingHttp) {
/* We need to tell the Http parser that we changed socket */
httpContextData->upgradedWebSocket = webSocket;
}
/* Arm idleTimeout */
us_socket_timeout(SSL, (us_socket_t *) webSocket, webSocketContextData->idleTimeoutComponents.first);
/* Move construct the UserData right before calling open handler */
new (webSocket->getUserData()) UserData(std::move(userData));
/* Emit open event and start the timeout */
if (webSocketContextData->openHandler) {
webSocketContextData->openHandler(webSocket);
}
}
public:
/* Immediately terminate this Http response */
using Super::close;
/* See AsyncSocket */
using Super::getRemoteAddress;
using Super::getRemoteAddressAsText;
using Super::getNativeHandle;
/* Note: Headers are not checked in regards to timeout.
* We only check when you actively push data or end the request */
/* Write 100 Continue, can be done any amount of times */
HttpResponse *writeContinue() {
Super::write("HTTP/1.1 100 Continue\r\n\r\n", 25);
return this;
}
/* Write the HTTP status */
HttpResponse *writeStatus(std::string_view status) {
HttpResponseData<SSL> *httpResponseData = getHttpResponseData();
@ -204,22 +358,24 @@ public:
}
/* Write an HTTP header with unsigned int value */
HttpResponse *writeHeader(std::string_view key, unsigned int value) {
HttpResponse *writeHeader(std::string_view key, uint64_t value) {
writeStatus(HTTP_200_OK);
Super::write(key.data(), (int) key.length());
Super::write(": ", 2);
writeUnsigned(value);
writeUnsigned64(value);
Super::write("\r\n", 2);
return this;
}
/* End the response with an optional data chunk. Always starts a timeout. */
void end(std::string_view data = {}) {
internalEnd(data, (int) data.length(), false);
void end(std::string_view data = {}, bool closeConnection = false) {
internalEnd(data, data.length(), false, true, closeConnection);
}
/* Try and end the response. Returns [true, true] on success.
* Starts a timeout in some cases. Returns [ok, hasResponded] */
std::pair<bool, bool> tryEnd(std::string_view data, int totalSize = 0) {
std::pair<bool, bool> tryEnd(std::string_view data, size_t totalSize = 0) {
return {internalEnd(data, totalSize, true), hasResponded()};
}
@ -257,7 +413,7 @@ public:
}
/* Get the current byte write offset for this Http response */
int getWriteOffset() {
size_t getWriteOffset() {
HttpResponseData<SSL> *httpResponseData = getHttpResponseData();
return httpResponseData->offset;
@ -271,7 +427,7 @@ public:
}
/* Corks the response if possible. Leaves already corked socket be. */
HttpResponse *cork(fu2::unique_function<void()> &&handler) {
HttpResponse *cork(MoveOnlyFunction<void()> &&handler) {
if (!Super::isCorked() && Super::canCork()) {
Super::cork();
handler();
@ -292,7 +448,7 @@ public:
}
/* Attach handler for writable HTTP response */
HttpResponse *onWritable(fu2::unique_function<bool(int)> &&handler) {
HttpResponse *onWritable(MoveOnlyFunction<bool(size_t)> &&handler) {
HttpResponseData<SSL> *httpResponseData = getHttpResponseData();
httpResponseData->onWritable = std::move(handler);
@ -300,7 +456,7 @@ public:
}
/* Attach handler for aborted HTTP request */
HttpResponse *onAborted(fu2::unique_function<void()> &&handler) {
HttpResponse *onAborted(MoveOnlyFunction<void()> &&handler) {
HttpResponseData<SSL> *httpResponseData = getHttpResponseData();
httpResponseData->onAborted = std::move(handler);
@ -308,7 +464,7 @@ public:
}
/* Attach a read handler for data sent. Will be called with FIN set true if last segment. */
void onData(fu2::unique_function<void(std::string_view, bool)> &&handler) {
void onData(MoveOnlyFunction<void(std::string_view, bool)> &&handler) {
HttpResponseData<SSL> *data = getHttpResponseData();
data->inStream = std::move(handler);
}

View File

@ -1,5 +1,5 @@
/*
* Authored by Alex Hultman, 2018-2019.
* Authored by Alex Hultman, 2018-2020.
* Intellectual property of third-party.
* Licensed under the Apache License, Version 2.0 (the "License");
@ -22,8 +22,9 @@
#include "HttpParser.h"
#include "AsyncSocketData.h"
#include "ProxyParser.h"
#include "f2/function2.hpp"
#include "MoveOnlyFunction.h"
namespace uWS {
@ -38,18 +39,22 @@ private:
HTTP_WRITE_CALLED = 2, // used
HTTP_END_CALLED = 4, // used
HTTP_RESPONSE_PENDING = 8, // used
HTTP_ENDED_STREAM_OUT = 16 // not used
HTTP_CONNECTION_CLOSE = 16 // used
};
/* Per socket event handlers */
fu2::unique_function<bool(int)> onWritable;
fu2::unique_function<void()> onAborted;
fu2::unique_function<void(std::string_view, bool)> inStream; // onData
MoveOnlyFunction<bool(size_t)> onWritable;
MoveOnlyFunction<void()> onAborted;
MoveOnlyFunction<void(std::string_view, bool)> inStream; // onData
/* Outgoing offset */
int offset = 0;
size_t offset = 0;
/* Current state (content-length sent, status sent, write called, etc */
int state = 0;
#ifdef UWS_WITH_PROXY
ProxyParser proxyParser;
#endif
};
}

View File

@ -1,5 +1,5 @@
/*
* Authored by Alex Hultman, 2018-2019.
* Authored by Alex Hultman, 2018-2020.
* Intellectual property of third-party.
* Licensed under the Apache License, Version 2.0 (the "License");
@ -25,8 +25,9 @@
#include <string>
#include <algorithm>
#include <memory>
#include <utility>
#include "f2/function2.hpp"
#include "MoveOnlyFunction.h"
namespace uWS {
@ -47,7 +48,7 @@ private:
std::map<std::string, int> priority;
/* List of handlers */
std::vector<fu2::unique_function<bool(HttpRouter *)>> handlers;
std::vector<MoveOnlyFunction<bool(HttpRouter *)>> handlers;
/* Current URL cache */
std::string_view currentUrl;
@ -60,6 +61,8 @@ private:
std::vector<std::unique_ptr<Node>> children;
std::vector<uint32_t> handlers;
bool isHighPriority;
Node(std::string name) : name(name) {}
} root = {"rootNode"};
/* Advance from parent to child, adding child if necessary */
@ -71,7 +74,7 @@ private:
}
/* Insert sorted, but keep order if parent is root (we sort methods by priority elsewhere) */
std::unique_ptr<Node> newNode(new Node({child}));
std::unique_ptr<Node> newNode(new Node(child));
newNode->isHighPriority = isHighPriority;
return parent->children.emplace(std::upper_bound(parent->children.begin(), parent->children.end(), newNode, [parent, this](auto &a, auto &b) {
@ -107,19 +110,25 @@ private:
/* Set URL for router. Will reset any URL cache */
inline void setUrl(std::string_view url) {
/* Remove / from input URL */
currentUrl = url.substr(std::min<unsigned int>((unsigned int) url.length(), 1));
/* Todo: URL may also start with "http://domain/" or "*", not only "/" */
/* We expect to stand on a slash */
currentUrl = url;
urlSegmentTop = -1;
}
/* Lazily parse or read from cache */
inline std::string_view getUrlSegment(int urlSegment) {
inline std::pair<std::string_view, bool> getUrlSegment(int urlSegment) {
if (urlSegment > urlSegmentTop) {
/* Return empty segment if we are out of URL or stack space, but never for first url segment */
/* Signal as STOP when we have no more URL or stack space */
if (!currentUrl.length() || urlSegment > 99) {
return {};
return {{}, true};
}
/* We always stand on a slash here, so step over it */
currentUrl.remove_prefix(1);
auto segmentLength = currentUrl.find('/');
if (segmentLength == std::string::npos) {
segmentLength = currentUrl.length();
@ -136,19 +145,22 @@ private:
urlSegmentTop++;
/* Update currentUrl */
currentUrl = currentUrl.substr(segmentLength + 1);
currentUrl = currentUrl.substr(segmentLength);
}
}
/* In any case we return it */
return urlSegmentVector[urlSegment];
return {urlSegmentVector[urlSegment], false};
}
/* Executes as many handlers it can */
bool executeHandlers(Node *parent, int urlSegment, USERDATA &userData) {
/* If we have no more URL and not on first round, return where we may stand */
if (urlSegment && !getUrlSegment(urlSegment).length()) {
auto [segment, isStop] = getUrlSegment(urlSegment);
/* If we are on STOP, return where we may stand */
if (isStop) {
/* We have reached accross the entire URL with no stoppage, execute */
for (int handler : parent->handlers) {
for (uint32_t handler : parent->handlers) {
if (handlers[handler & HANDLER_MASK](this)) {
return true;
}
@ -160,19 +172,19 @@ private:
for (auto &p : parent->children) {
if (p->name.length() && p->name[0] == '*') {
/* Wildcard match (can be seen as a shortcut) */
for (int handler : p->handlers) {
for (uint32_t handler : p->handlers) {
if (handlers[handler & HANDLER_MASK](this)) {
return true;
}
}
} else if (p->name.length() && p->name[0] == ':' && getUrlSegment(urlSegment).length()) {
} else if (p->name.length() && p->name[0] == ':' && segment.length()) {
/* Parameter match */
routeParameters.push(getUrlSegment(urlSegment));
routeParameters.push(segment);
if (executeHandlers(p.get(), urlSegment + 1, userData)) {
return true;
}
routeParameters.pop();
} else if (p->name == getUrlSegment(urlSegment)) {
} else if (p->name == segment) {
/* Static match */
if (executeHandlers(p.get(), urlSegment + 1, userData)) {
return true;
@ -217,14 +229,14 @@ public:
}
/* Adds the corresponding entires in matching tree and handler list */
void add(std::vector<std::string> methods, std::string pattern, fu2::unique_function<bool(HttpRouter *)> &&handler, uint32_t priority = MEDIUM_PRIORITY) {
void add(std::vector<std::string> methods, std::string pattern, MoveOnlyFunction<bool(HttpRouter *)> &&handler, uint32_t priority = MEDIUM_PRIORITY) {
for (std::string method : methods) {
/* Lookup method */
Node *node = getNode(&root, method, false);
/* Iterate over all segments */
setUrl(pattern);
for (int i = 0; getUrlSegment(i).length() || i == 0; i++) {
node = getNode(node, std::string(getUrlSegment(i)), priority == HIGH_PRIORITY);
for (int i = 0; !getUrlSegment(i).second; i++) {
node = getNode(node, std::string(getUrlSegment(i).first), priority == HIGH_PRIORITY);
}
/* Insert handler in order sorted by priority (most significant 1 byte) */
node->handlers.insert(std::upper_bound(node->handlers.begin(), node->handlers.end(), (uint32_t) (priority | handlers.size())), (uint32_t) (priority | handlers.size()));
@ -237,4 +249,4 @@ public:
}
#endif // UWS_HTTPROUTER_HPP
#endif // UWS_HTTPROUTER_HPP

View File

@ -1,5 +1,5 @@
/*
* Authored by Alex Hultman, 2018-2019.
* Authored by Alex Hultman, 2018-2020.
* Intellectual property of third-party.
* Licensed under the Apache License, Version 2.0 (the "License");
@ -18,10 +18,10 @@
#ifndef UWS_LOOP_H
#define UWS_LOOP_H
/* The loop is lazily created per-thread and run with uWS::run() */
/* The loop is lazily created per-thread and run with run() */
#include "LoopData.h"
#include "libusockets.h"
#include <libusockets.h>
namespace uWS {
struct Loop {
@ -80,25 +80,29 @@ private:
Loop *loop = nullptr;
bool cleanMe = false;
};
static LoopCleaner &getLazyLoop() {
static thread_local LoopCleaner lazyLoop;
return lazyLoop;
}
public:
/* Lazily initializes a per-thread loop and returns it.
* Will automatically free all initialized loops at exit. */
static Loop *get(void *existingNativeLoop = nullptr) {
static thread_local LoopCleaner lazyLoop;
if (!lazyLoop.loop) {
if (!getLazyLoop().loop) {
/* If we are given a native loop pointer we pass that to uSockets and let it deal with it */
if (existingNativeLoop) {
/* Todo: here we want to pass the pointer, not a boolean */
lazyLoop.loop = create(existingNativeLoop);
getLazyLoop().loop = create(existingNativeLoop);
/* We cannot register automatic free here, must be manually done */
} else {
lazyLoop.loop = create(nullptr);
lazyLoop.cleanMe = true;
getLazyLoop().loop = create(nullptr);
getLazyLoop().cleanMe = true;
}
}
return lazyLoop.loop;
return getLazyLoop().loop;
}
/* Freeing the default loop should be done once */
@ -107,9 +111,12 @@ public:
loopData->~LoopData();
/* uSockets will track whether this loop is owned by us or a borrowed alien loop */
us_loop_free((us_loop_t *) this);
/* Reset lazyLoop */
getLazyLoop().loop = nullptr;
}
void addPostHandler(void *key, fu2::unique_function<void(Loop *)> &&handler) {
void addPostHandler(void *key, MoveOnlyFunction<void(Loop *)> &&handler) {
LoopData *loopData = (LoopData *) us_loop_ext((us_loop_t *) this);
loopData->postHandlers.emplace(key, std::move(handler));
@ -122,7 +129,7 @@ public:
loopData->postHandlers.erase(key);
}
void addPreHandler(void *key, fu2::unique_function<void(Loop *)> &&handler) {
void addPreHandler(void *key, MoveOnlyFunction<void(Loop *)> &&handler) {
LoopData *loopData = (LoopData *) us_loop_ext((us_loop_t *) this);
loopData->preHandlers.emplace(key, std::move(handler));
@ -136,7 +143,7 @@ public:
}
/* Defer this callback on Loop's thread of execution */
void defer(fu2::unique_function<void()> &&cb) {
void defer(MoveOnlyFunction<void()> &&cb) {
LoopData *loopData = (LoopData *) us_loop_ext((us_loop_t *) this);
//if (std::thread::get_id() == ) // todo: add fast path for same thread id
@ -157,6 +164,11 @@ public:
void integrate() {
us_loop_integrate((us_loop_t *) this);
}
/* Dynamically change this */
void setSilent(bool silent) {
((LoopData *) us_loop_ext((us_loop_t *) this))->noMark = silent;
}
};
/* Can be called from any thread to run the thread local loop */

View File

@ -1,5 +1,5 @@
/*
* Authored by Alex Hultman, 2018-2019.
* Authored by Alex Hultman, 2018-2020.
* Intellectual property of third-party.
* Licensed under the Apache License, Version 2.0 (the "License");
@ -26,7 +26,7 @@
#include "PerMessageDeflate.h"
#include "f2/function2.hpp"
#include "MoveOnlyFunction.h"
namespace uWS {
@ -37,10 +37,10 @@ struct alignas(16) LoopData {
private:
std::mutex deferMutex;
int currentDeferQueue = 0;
std::vector<fu2::unique_function<void()>> deferQueues[2];
std::vector<MoveOnlyFunction<void()>> deferQueues[2];
/* Map from void ptr to handler */
std::map<void *, fu2::unique_function<void(Loop *)>> postHandlers, preHandlers;
std::map<void *, MoveOnlyFunction<void(Loop *)>> postHandlers, preHandlers;
public:
~LoopData() {
@ -53,12 +53,15 @@ public:
delete [] corkBuffer;
}
/* Be silent */
bool noMark = false;
/* Good 16k for SSL perf. */
static const int CORK_BUFFER_SIZE = 16 * 1024;
static const unsigned int CORK_BUFFER_SIZE = 16 * 1024;
/* Cork data */
char *corkBuffer = new char[CORK_BUFFER_SIZE];
int corkOffset = 0;
unsigned int corkOffset = 0;
void *corkedSocket = nullptr;
/* Per message deflate data */

View File

@ -0,0 +1,64 @@
/*
* Authored by Alex Hultman, 2018-2020.
* Intellectual property of third-party.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/* Implements the common parser (RFC 822) used in both HTTP and Multipart parsing */
#ifndef UWS_MESSAGE_PARSER_H
#define UWS_MESSAGE_PARSER_H
#include <string_view>
#include <utility>
#include <cstring>
/* For now we have this one here */
#define MAX_HEADERS 10
namespace uWS {
// should be templated on whether it needs at lest one header (http), or not (multipart)
static inline unsigned int getHeaders(char *postPaddedBuffer, char *end, std::pair<std::string_view, std::string_view> *headers) {
char *preliminaryKey, *preliminaryValue, *start = postPaddedBuffer;
for (unsigned int i = 0; i < MAX_HEADERS; i++) {
for (preliminaryKey = postPaddedBuffer; (*postPaddedBuffer != ':') & (*postPaddedBuffer > 32); *(postPaddedBuffer++) |= 32);
if (*postPaddedBuffer == '\r') {
if ((postPaddedBuffer != end) & (postPaddedBuffer[1] == '\n') /* & (i > 0) */) { // multipart does not require any headers like http does
headers->first = std::string_view(nullptr, 0);
return (unsigned int) ((postPaddedBuffer + 2) - start);
} else {
return 0;
}
} else {
headers->first = std::string_view(preliminaryKey, (size_t) (postPaddedBuffer - preliminaryKey));
for (postPaddedBuffer++; (*postPaddedBuffer == ':' || *postPaddedBuffer < 33) && *postPaddedBuffer != '\r'; postPaddedBuffer++);
preliminaryValue = postPaddedBuffer;
postPaddedBuffer = (char *) memchr(postPaddedBuffer, '\r', end - postPaddedBuffer);
if (postPaddedBuffer && postPaddedBuffer[1] == '\n') {
headers->second = std::string_view(preliminaryValue, (size_t) (postPaddedBuffer - preliminaryValue));
postPaddedBuffer += 2;
headers++;
} else {
return 0;
}
}
}
return 0;
}
}
#endif

View File

@ -0,0 +1,377 @@
/*
MIT License
Copyright (c) 2020 Oleg Fatkhiev
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
*/
/* Sources fetched from https://github.com/ofats/any_invocable on 2021-02-19. */
#ifndef _ANY_INVOKABLE_H_
#define _ANY_INVOKABLE_H_
#include <functional>
#include <memory>
#include <type_traits>
// clang-format off
/*
namespace std {
template<class Sig> class any_invocable; // never defined
template<class R, class... ArgTypes>
class any_invocable<R(ArgTypes...) cv ref noexcept(noex)> {
public:
using result_type = R;
// SECTION.3, construct/copy/destroy
any_invocable() noexcept;
any_invocable(nullptr_t) noexcept;
any_invocable(any_invocable&&) noexcept;
template<class F> any_invocable(F&&);
template<class T, class... Args>
explicit any_invocable(in_place_type_t<T>, Args&&...);
template<class T, class U, class... Args>
explicit any_invocable(in_place_type_t<T>, initializer_list<U>, Args&&...);
any_invocable& operator=(any_invocable&&) noexcept;
any_invocable& operator=(nullptr_t) noexcept;
template<class F> any_invocable& operator=(F&&);
template<class F> any_invocable& operator=(reference_wrapper<F>) noexcept;
~any_invocable();
// SECTION.4, any_invocable modifiers
void swap(any_invocable&) noexcept;
// SECTION.5, any_invocable capacity
explicit operator bool() const noexcept;
// SECTION.6, any_invocable invocation
R operator()(ArgTypes...) cv ref noexcept(noex);
// SECTION.7, null pointer comparisons
friend bool operator==(const any_invocable&, nullptr_t) noexcept;
// SECTION.8, specialized algorithms
friend void swap(any_invocable&, any_invocable&) noexcept;
};
}
*/
// clang-format on
namespace ofats {
namespace any_detail {
using buffer = std::aligned_storage_t<sizeof(void*) * 2, alignof(void*)>;
template <class T>
inline constexpr bool is_small_object_v =
sizeof(T) <= sizeof(buffer) && alignof(buffer) % alignof(T) == 0 &&
std::is_nothrow_move_constructible_v<T>;
union storage {
void* ptr_ = nullptr;
buffer buf_;
};
enum class action { destroy, move };
template <class R, class... ArgTypes>
struct handler_traits {
template <class Derived>
struct handler_base {
static void handle(action act, storage* current, storage* other = nullptr) {
switch (act) {
case (action::destroy):
Derived::destroy(*current);
break;
case (action::move):
Derived::move(*current, *other);
break;
}
}
};
template <class T>
struct small_handler : handler_base<small_handler<T>> {
template <class... Args>
static void create(storage& s, Args&&... args) {
new (static_cast<void*>(&s.buf_)) T(std::forward<Args>(args)...);
}
static void destroy(storage& s) noexcept {
T& value = *static_cast<T*>(static_cast<void*>(&s.buf_));
value.~T();
}
static void move(storage& dst, storage& src) noexcept {
create(dst, std::move(*static_cast<T*>(static_cast<void*>(&src.buf_))));
destroy(src);
}
static R call(storage& s, ArgTypes... args) {
return std::invoke(*static_cast<T*>(static_cast<void*>(&s.buf_)),
std::forward<ArgTypes>(args)...);
}
};
template <class T>
struct large_handler : handler_base<large_handler<T>> {
template <class... Args>
static void create(storage& s, Args&&... args) {
s.ptr_ = new T(std::forward<Args>(args)...);
}
static void destroy(storage& s) noexcept { delete static_cast<T*>(s.ptr_); }
static void move(storage& dst, storage& src) noexcept {
dst.ptr_ = src.ptr_;
}
static R call(storage& s, ArgTypes... args) {
return std::invoke(*static_cast<T*>(s.ptr_),
std::forward<ArgTypes>(args)...);
}
};
template <class T>
using handler = std::conditional_t<is_small_object_v<T>, small_handler<T>,
large_handler<T>>;
};
template <class T>
struct is_in_place_type : std::false_type {};
template <class T>
struct is_in_place_type<std::in_place_type_t<T>> : std::true_type {};
template <class T>
inline constexpr auto is_in_place_type_v = is_in_place_type<T>::value;
template <class R, bool is_noexcept, class... ArgTypes>
class any_invocable_impl {
template <class T>
using handler =
typename any_detail::handler_traits<R, ArgTypes...>::template handler<T>;
using storage = any_detail::storage;
using action = any_detail::action;
using handle_func = void (*)(any_detail::action, any_detail::storage*,
any_detail::storage*);
using call_func = R (*)(any_detail::storage&, ArgTypes...);
public:
using result_type = R;
any_invocable_impl() noexcept = default;
any_invocable_impl(std::nullptr_t) noexcept {}
any_invocable_impl(any_invocable_impl&& rhs) noexcept {
if (rhs.handle_) {
handle_ = rhs.handle_;
handle_(action::move, &storage_, &rhs.storage_);
call_ = rhs.call_;
rhs.handle_ = nullptr;
}
}
any_invocable_impl& operator=(any_invocable_impl&& rhs) noexcept {
any_invocable_impl{std::move(rhs)}.swap(*this);
return *this;
}
any_invocable_impl& operator=(std::nullptr_t) noexcept {
destroy();
return *this;
}
~any_invocable_impl() { destroy(); }
void swap(any_invocable_impl& rhs) noexcept {
if (handle_) {
if (rhs.handle_) {
storage tmp;
handle_(action::move, &tmp, &storage_);
rhs.handle_(action::move, &storage_, &rhs.storage_);
handle_(action::move, &rhs.storage_, &tmp);
std::swap(handle_, rhs.handle_);
std::swap(call_, rhs.call_);
} else {
rhs.swap(*this);
}
} else if (rhs.handle_) {
rhs.handle_(action::move, &storage_, &rhs.storage_);
handle_ = rhs.handle_;
call_ = rhs.call_;
rhs.handle_ = nullptr;
}
}
explicit operator bool() const noexcept { return handle_ != nullptr; }
protected:
template <class F, class... Args>
void create(Args&&... args) {
using hdl = handler<F>;
hdl::create(storage_, std::forward<Args>(args)...);
handle_ = &hdl::handle;
call_ = &hdl::call;
}
void destroy() noexcept {
if (handle_) {
handle_(action::destroy, &storage_, nullptr);
handle_ = nullptr;
}
}
R call(ArgTypes... args) noexcept(is_noexcept) {
return call_(storage_, std::forward<ArgTypes>(args)...);
}
friend bool operator==(const any_invocable_impl& f, std::nullptr_t) noexcept {
return !f;
}
friend bool operator==(std::nullptr_t, const any_invocable_impl& f) noexcept {
return !f;
}
friend bool operator!=(const any_invocable_impl& f, std::nullptr_t) noexcept {
return static_cast<bool>(f);
}
friend bool operator!=(std::nullptr_t, const any_invocable_impl& f) noexcept {
return static_cast<bool>(f);
}
friend void swap(any_invocable_impl& lhs, any_invocable_impl& rhs) noexcept {
lhs.swap(rhs);
}
private:
storage storage_;
handle_func handle_ = nullptr;
call_func call_;
};
template <class T>
using remove_cvref_t = std::remove_cv_t<std::remove_reference_t<T>>;
template <class AI, class F, bool noex, class R, class FCall, class... ArgTypes>
using can_convert = std::conjunction<
std::negation<std::is_same<remove_cvref_t<F>, AI>>,
std::negation<any_detail::is_in_place_type<remove_cvref_t<F>>>,
std::is_invocable_r<R, FCall, ArgTypes...>,
std::bool_constant<(!noex ||
std::is_nothrow_invocable_r_v<R, FCall, ArgTypes...>)>,
std::is_constructible<std::decay_t<F>, F>>;
} // namespace any_detail
template <class Signature>
class any_invocable;
#define __OFATS_ANY_INVOCABLE(cv, ref, noex, inv_quals) \
template <class R, class... ArgTypes> \
class any_invocable<R(ArgTypes...) cv ref noexcept(noex)> \
: public any_detail::any_invocable_impl<R, noex, ArgTypes...> { \
using base_type = any_detail::any_invocable_impl<R, noex, ArgTypes...>; \
\
public: \
using base_type::base_type; \
\
template < \
class F, \
class = std::enable_if_t<any_detail::can_convert< \
any_invocable, F, noex, R, F inv_quals, ArgTypes...>::value>> \
any_invocable(F&& f) { \
base_type::template create<std::decay_t<F>>(std::forward<F>(f)); \
} \
\
template <class T, class... Args, class VT = std::decay_t<T>, \
class = std::enable_if_t< \
std::is_move_constructible_v<VT> && \
std::is_constructible_v<VT, Args...> && \
std::is_invocable_r_v<R, VT inv_quals, ArgTypes...> && \
(!noex || std::is_nothrow_invocable_r_v<R, VT inv_quals, \
ArgTypes...>)>> \
explicit any_invocable(std::in_place_type_t<T>, Args&&... args) { \
base_type::template create<VT>(std::forward<Args>(args)...); \
} \
\
template < \
class T, class U, class... Args, class VT = std::decay_t<T>, \
class = std::enable_if_t< \
std::is_move_constructible_v<VT> && \
std::is_constructible_v<VT, std::initializer_list<U>&, Args...> && \
std::is_invocable_r_v<R, VT inv_quals, ArgTypes...> && \
(!noex || \
std::is_nothrow_invocable_r_v<R, VT inv_quals, ArgTypes...>)>> \
explicit any_invocable(std::in_place_type_t<T>, \
std::initializer_list<U> il, Args&&... args) { \
base_type::template create<VT>(il, std::forward<Args>(args)...); \
} \
\
template <class F, class FDec = std::decay_t<F>> \
std::enable_if_t<!std::is_same_v<FDec, any_invocable> && \
std::is_move_constructible_v<FDec>, \
any_invocable&> \
operator=(F&& f) { \
any_invocable{std::forward<F>(f)}.swap(*this); \
return *this; \
} \
template <class F> \
any_invocable& operator=(std::reference_wrapper<F> f) { \
any_invocable{f}.swap(*this); \
return *this; \
} \
\
R operator()(ArgTypes... args) cv ref noexcept(noex) { \
return base_type::call(std::forward<ArgTypes>(args)...); \
} \
};
// cv -> {`empty`, const}
// ref -> {`empty`, &, &&}
// noex -> {true, false}
// inv_quals -> (is_empty(ref) ? & : ref)
__OFATS_ANY_INVOCABLE(, , false, &) // 000
__OFATS_ANY_INVOCABLE(, , true, &) // 001
__OFATS_ANY_INVOCABLE(, &, false, &) // 010
__OFATS_ANY_INVOCABLE(, &, true, &) // 011
__OFATS_ANY_INVOCABLE(, &&, false, &&) // 020
__OFATS_ANY_INVOCABLE(, &&, true, &&) // 021
__OFATS_ANY_INVOCABLE(const, , false, const&) // 100
__OFATS_ANY_INVOCABLE(const, , true, const&) // 101
__OFATS_ANY_INVOCABLE(const, &, false, const&) // 110
__OFATS_ANY_INVOCABLE(const, &, true, const&) // 111
__OFATS_ANY_INVOCABLE(const, &&, false, const&&) // 120
__OFATS_ANY_INVOCABLE(const, &&, true, const&&) // 121
#undef __OFATS_ANY_INVOCABLE
} // namespace ofats
/* We, uWebSockets define our own type */
namespace uWS {
template <class T>
using MoveOnlyFunction = ofats::any_invocable<T>;
}
#endif // _ANY_INVOKABLE_H_

View File

@ -0,0 +1,231 @@
/*
* Authored by Alex Hultman, 2018-2020.
* Intellectual property of third-party.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/* Implements the multipart protocol. Builds atop parts of our common http parser (not yet refactored that way). */
/* https://www.w3.org/Protocols/rfc1341/7_2_Multipart.html */
#ifndef UWS_MULTIPART_H
#define UWS_MULTIPART_H
#include "MessageParser.h"
#include <string_view>
#include <optional>
#include <cstring>
#include <utility>
#include <cctype>
namespace uWS {
/* This one could possibly be shared with ExtensionsParser to some degree */
struct ParameterParser {
/* Takes the line, commonly given as content-disposition header in the multipart */
ParameterParser(std::string_view line) {
remainingLine = line;
}
/* Returns next key/value where value can simply be empty.
* If key (first) is empty then we are at the end */
std::pair<std::string_view, std::string_view> getKeyValue() {
auto key = getToken();
auto op = getToken();
if (!op.length()) {
return {key, ""};
}
if (op[0] != ';') {
auto value = getToken();
/* Strip ; or if at end, nothing */
getToken();
return {key, value};
}
return {key, ""};
}
private:
std::string_view remainingLine;
/* Consumes a token from the line. Will "unquote" strings */
std::string_view getToken() {
/* Strip whitespace */
while (remainingLine.length() && isspace(remainingLine[0])) {
remainingLine.remove_prefix(1);
}
if (!remainingLine.length()) {
/* All we had was space */
return {};
} else {
/* Are we at an operator? */
if (remainingLine[0] == ';' || remainingLine[0] == '=') {
auto op = remainingLine.substr(0, 1);
remainingLine.remove_prefix(1);
return op;
} else {
/* Are we at a quoted string? */
if (remainingLine[0] == '\"') {
/* Remove first quote and start counting */
remainingLine.remove_prefix(1);
auto quote = remainingLine;
int quoteLength = 0;
/* Read anything until other double quote appears */
while (remainingLine.length() && remainingLine[0] != '\"') {
remainingLine.remove_prefix(1);
quoteLength++;
}
/* We can't remove_prefix if we have nothing to remove */
if (!remainingLine.length()) {
return {};
}
remainingLine.remove_prefix(1);
return quote.substr(0, quoteLength);
} else {
/* Read anything until ; = space or end */
std::string_view token = remainingLine;
int tokenLength = 0;
while (remainingLine.length() && remainingLine[0] != ';' && remainingLine[0] != '=' && !isspace(remainingLine[0])) {
remainingLine.remove_prefix(1);
tokenLength++;
}
return token.substr(0, tokenLength);
}
}
}
/* Nothing */
return "";
}
};
struct MultipartParser {
/* 2 chars of hyphen + 1 - 70 chars of boundary */
char prependedBoundaryBuffer[72];
std::string_view prependedBoundary;
std::string_view remainingBody;
bool first = true;
/* I think it is more than sane to limit this to 10 per part */
//static const int MAX_HEADERS = 10;
/* Construct the parser based on contentType (reads boundary) */
MultipartParser(std::string_view contentType) {
/* We expect the form "multipart/something;somethingboundary=something" */
if (contentType.length() < 10 || contentType.substr(0, 10) != "multipart/") {
return;
}
/* For now we simply guess boundary will lie between = and end. This is not entirely
* standards compliant as boundary may be expressed with or without " and spaces */
auto equalToken = contentType.find('=', 10);
if (equalToken != std::string_view::npos) {
/* Boundary must be less than or equal to 70 chars yet 1 char or longer */
std::string_view boundary = contentType.substr(equalToken + 1);
if (!boundary.length() || boundary.length() > 70) {
/* Invalid size */
return;
}
/* Prepend it with two hyphens */
prependedBoundaryBuffer[0] = prependedBoundaryBuffer[1] = '-';
memcpy(&prependedBoundaryBuffer[2], boundary.data(), boundary.length());
prependedBoundary = {prependedBoundaryBuffer, boundary.length() + 2};
}
}
/* Is this even a valid multipart request? */
bool isValid() {
return prependedBoundary.length() != 0;
}
/* Set the body once, before getting any parts */
void setBody(std::string_view body) {
remainingBody = body;
}
/* Parse out the next part's data, filling the headers. Returns nullopt on end or error. */
std::optional<std::string_view> getNextPart(std::pair<std::string_view, std::string_view> *headers) {
/* The remaining two hyphens should be shorter than the boundary */
if (remainingBody.length() < prependedBoundary.length()) {
/* We are done now */
return std::nullopt;
}
if (first) {
auto nextBoundary = remainingBody.find(prependedBoundary);
if (nextBoundary == std::string_view::npos) {
/* Cannot parse */
return std::nullopt;
}
/* Toss away boundary and anything before it */
remainingBody.remove_prefix(nextBoundary + prependedBoundary.length());
first = false;
}
auto nextEndBoundary = remainingBody.find(prependedBoundary);
if (nextEndBoundary == std::string_view::npos) {
/* Cannot parse (or simply done) */
return std::nullopt;
}
std::string_view part = remainingBody.substr(0, nextEndBoundary);
remainingBody.remove_prefix(nextEndBoundary + prependedBoundary.length());
/* Also strip rn before and rn after the part */
if (part.length() < 4) {
/* Cannot strip */
return std::nullopt;
}
part.remove_prefix(2);
part.remove_suffix(2);
/* We are allowed to post pad like this because we know the boundary is at least 2 bytes */
/* This makes parsing a second pass invalid, so you can only iterate over parts once */
memset((char *) part.data() + part.length(), '\r', 1);
/* For this to be a valid part, we need to consume at least 4 bytes (\r\n\r\n) */
int consumed = getHeaders((char *) part.data(), (char *) part.data() + part.length(), headers);
if (!consumed) {
/* This is an invalid part */
return std::nullopt;
}
/* Strip away the headers from the part body data */
part.remove_prefix(consumed);
/* Now pass whatever is remaining of the part */
return part;
}
};
}
#endif

View File

@ -1,5 +1,5 @@
/*
* Authored by Alex Hultman, 2018-2019.
* Authored by Alex Hultman, 2018-2021.
* Intellectual property of third-party.
* Licensed under the Apache License, Version 2.0 (the "License");
@ -20,26 +20,53 @@
#ifndef UWS_PERMESSAGEDEFLATE_H
#define UWS_PERMESSAGEDEFLATE_H
#ifndef UWS_NO_ZLIB
/* We always define these options no matter if ZLIB is enabled or not */
namespace uWS {
/* Compressor mode is HIGH8(windowBits), LOW8(memLevel) */
enum CompressOptions : uint32_t {
DISABLED = 0,
SHARED_COMPRESSOR = 1,
DEDICATED_COMPRESSOR_3KB = 9 << 8 | 1,
DEDICATED_COMPRESSOR_4KB = 9 << 8 | 2,
DEDICATED_COMPRESSOR_8KB = 10 << 8 | 3,
DEDICATED_COMPRESSOR_16KB = 11 << 8 | 4,
DEDICATED_COMPRESSOR_32KB = 12 << 8 | 5,
DEDICATED_COMPRESSOR_64KB = 13 << 8 | 6,
DEDICATED_COMPRESSOR_128KB = 14 << 8 | 7,
DEDICATED_COMPRESSOR_256KB = 15 << 8 | 8,
/* Same as 256kb */
DEDICATED_COMPRESSOR = 15 << 8 | 8
};
}
#if !defined(UWS_NO_ZLIB) && !defined(UWS_MOCK_ZLIB)
#include <zlib.h>
#endif
#include <string>
#include <optional>
#ifdef UWS_USE_LIBDEFLATE
#include "libdeflate.h"
#include <cstring>
#endif
namespace uWS {
/* Do not compile this module if we don't want it */
#ifdef UWS_NO_ZLIB
#if defined(UWS_NO_ZLIB) || defined(UWS_MOCK_ZLIB)
struct ZlibContext {};
struct InflationStream {
std::string_view inflate(ZlibContext *zlibContext, std::string_view compressed, size_t maxPayloadLength) {
return compressed;
std::optional<std::string_view> inflate(ZlibContext *zlibContext, std::string_view compressed, size_t maxPayloadLength) {
return compressed.substr(0, std::min(maxPayloadLength, compressed.length()));
}
};
struct DeflationStream {
std::string_view deflate(ZlibContext *zlibContext, std::string_view raw, bool reset) {
return raw;
}
DeflationStream(int compressOptions) {
}
};
#else
@ -54,26 +81,65 @@ struct ZlibContext {
char *deflationBuffer;
char *inflationBuffer;
#ifdef UWS_USE_LIBDEFLATE
libdeflate_decompressor *decompressor;
libdeflate_compressor *compressor;
#endif
ZlibContext() {
deflationBuffer = (char *) malloc(LARGE_BUFFER_SIZE);
inflationBuffer = (char *) malloc(LARGE_BUFFER_SIZE);
#ifdef UWS_USE_LIBDEFLATE
decompressor = libdeflate_alloc_decompressor();
compressor = libdeflate_alloc_compressor(6);
#endif
}
~ZlibContext() {
free(deflationBuffer);
free(inflationBuffer);
#ifdef UWS_USE_LIBDEFLATE
libdeflate_free_decompressor(decompressor);
libdeflate_free_compressor(compressor);
#endif
}
};
struct DeflationStream {
z_stream deflationStream = {};
DeflationStream() {
deflateInit2(&deflationStream, 1, Z_DEFLATED, -15, 8, Z_DEFAULT_STRATEGY);
DeflationStream(CompressOptions compressOptions) {
/* Sliding inflator should be about 44kb by default, less than compressor */
/* Memory usage is given by 2 ^ (windowBits + 2) + 2 ^ (memLevel + 9) */
int windowBits = -(int) ((compressOptions & 0xFF00) >> 8), memLevel = compressOptions & 0x00FF;
//printf("windowBits: %d, memLevel: %d\n", windowBits, memLevel);
deflateInit2(&deflationStream, Z_DEFAULT_COMPRESSION, Z_DEFLATED, windowBits, memLevel, Z_DEFAULT_STRATEGY);
}
/* Deflate and optionally reset */
/* Deflate and optionally reset. You must not deflate an empty string. */
std::string_view deflate(ZlibContext *zlibContext, std::string_view raw, bool reset) {
#ifdef UWS_USE_LIBDEFLATE
/* Run a fast path in case of shared_compressor */
if (reset) {
size_t written = 0;
static unsigned char buf[1024 + 1];
written = libdeflate_deflate_compress(zlibContext->compressor, raw.data(), raw.length(), buf, 1024);
if (written) {
memcpy(&buf[written], "\x00", 1);
return std::string_view((char *) buf, written + 1);
}
}
#endif
/* Odd place to clear this one, fix */
zlibContext->dynamicDeflationBuffer.clear();
@ -105,9 +171,11 @@ struct DeflationStream {
if (zlibContext->dynamicDeflationBuffer.length()) {
zlibContext->dynamicDeflationBuffer.append(zlibContext->deflationBuffer, DEFLATE_OUTPUT_CHUNK - deflationStream.avail_out);
return {(char *) zlibContext->dynamicDeflationBuffer.data(), zlibContext->dynamicDeflationBuffer.length() - 4};
return std::string_view((char *) zlibContext->dynamicDeflationBuffer.data(), zlibContext->dynamicDeflationBuffer.length() - 4);
}
/* Note: We will get an interger overflow resulting in heap buffer overflow if Z_BUF_ERROR is returned
* from passing 0 as avail_in. Therefore we must not deflate an empty string */
return {
zlibContext->deflationBuffer,
DEFLATE_OUTPUT_CHUNK - deflationStream.avail_out - 4
@ -130,7 +198,26 @@ struct InflationStream {
inflateEnd(&inflationStream);
}
std::string_view inflate(ZlibContext *zlibContext, std::string_view compressed, size_t maxPayloadLength) {
/* Zero length inflates are possible and valid */
std::optional<std::string_view> inflate(ZlibContext *zlibContext, std::string_view compressed, size_t maxPayloadLength) {
#ifdef UWS_USE_LIBDEFLATE
/* Try fast path first */
size_t written = 0;
static char buf[1024];
/* We have to pad 9 bytes and restore those bytes when done since 9 is more than 6 of next WebSocket message */
char tmp[9];
memcpy(tmp, (char *) compressed.data() + compressed.length(), 9);
memcpy((char *) compressed.data() + compressed.length(), "\x00\x00\xff\xff\x01\x00\x00\xff\xff", 9);
libdeflate_result res = libdeflate_deflate_decompress(zlibContext->decompressor, compressed.data(), compressed.length() + 9, buf, 1024, &written);
memcpy((char *) compressed.data() + compressed.length(), tmp, 9);
if (res == 0) {
/* Fast path wins */
return std::string_view(buf, written);
}
#endif
/* We clear this one here, could be done better */
zlibContext->dynamicInflationBuffer.clear();
@ -156,7 +243,7 @@ struct InflationStream {
inflateReset(&inflationStream);
if ((err != Z_BUF_ERROR && err != Z_OK) || zlibContext->dynamicInflationBuffer.length() > maxPayloadLength) {
return {nullptr, 0};
return std::nullopt;
}
if (zlibContext->dynamicInflationBuffer.length()) {
@ -164,18 +251,18 @@ struct InflationStream {
/* Let's be strict about the max size */
if (zlibContext->dynamicInflationBuffer.length() > maxPayloadLength) {
return {nullptr, 0};
return std::nullopt;
}
return {zlibContext->dynamicInflationBuffer.data(), zlibContext->dynamicInflationBuffer.length()};
return std::string_view(zlibContext->dynamicInflationBuffer.data(), zlibContext->dynamicInflationBuffer.length());
}
/* Let's be strict about the max size */
if ((LARGE_BUFFER_SIZE - inflationStream.avail_out) > maxPayloadLength) {
return {nullptr, 0};
return std::nullopt;
}
return {zlibContext->inflationBuffer, LARGE_BUFFER_SIZE - inflationStream.avail_out};
return std::string_view(zlibContext->inflationBuffer, LARGE_BUFFER_SIZE - inflationStream.avail_out);
}
};

View File

@ -0,0 +1,163 @@
/*
* Authored by Alex Hultman, 2018-2020.
* Intellectual property of third-party.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/* This module implements The PROXY Protocol v2 */
#ifndef UWS_PROXY_PARSER_H
#define UWS_PROXY_PARSER_H
#ifdef UWS_WITH_PROXY
namespace uWS {
struct proxy_hdr_v2 {
uint8_t sig[12]; /* hex 0D 0A 0D 0A 00 0D 0A 51 55 49 54 0A */
uint8_t ver_cmd; /* protocol version and command */
uint8_t fam; /* protocol family and address */
uint16_t len; /* number of following bytes part of the header */
};
union proxy_addr {
struct { /* for TCP/UDP over IPv4, len = 12 */
uint32_t src_addr;
uint32_t dst_addr;
uint16_t src_port;
uint16_t dst_port;
} ipv4_addr;
struct { /* for TCP/UDP over IPv6, len = 36 */
uint8_t src_addr[16];
uint8_t dst_addr[16];
uint16_t src_port;
uint16_t dst_port;
} ipv6_addr;
};
/* Byte swap for little-endian systems */
/* Todo: This functions should be shared with the one in WebSocketProtocol.h! */
template <typename T>
T _cond_byte_swap(T value) {
uint32_t endian_test = 1;
if (*((char *)&endian_test)) {
union {
T i;
uint8_t b[sizeof(T)];
} src = { value }, dst;
for (unsigned int i = 0; i < sizeof(value); i++) {
dst.b[i] = src.b[sizeof(value) - 1 - i];
}
return dst.i;
}
return value;
}
struct ProxyParser {
private:
union proxy_addr addr;
/* Default family of 0 signals no proxy address */
uint8_t family = 0;
public:
/* Returns 4 or 16 bytes source address */
std::string_view getSourceAddress() {
// UNSPEC family and protocol
if (family == 0) {
return {};
}
if ((family & 0xf0) >> 4 == 1) {
/* Family 1 is INET4 */
return {(char *) &addr.ipv4_addr.src_addr, 4};
} else {
/* Family 2 is INET6 */
return {(char *) &addr.ipv6_addr.src_addr, 16};
}
}
/* Returns [done, consumed] where done = false on failure */
std::pair<bool, unsigned int> parse(std::string_view data) {
/* We require at least four bytes to determine protocol */
if (data.length() < 4) {
return {false, 0};
}
/* HTTP can never start with "\r\n\r\n", but PROXY always does */
if (memcmp(data.data(), "\r\n\r\n", 4)) {
/* This is HTTP, so be done */
return {true, 0};
}
/* We assume we are parsing PROXY V2 here */
/* We require 16 bytes here */
if (data.length() < 16) {
return {false, 0};
}
/* Header is 16 bytes */
struct proxy_hdr_v2 header;
memcpy(&header, data.data(), 16);
if (memcmp(header.sig, "\x0D\x0A\x0D\x0A\x00\x0D\x0A\x51\x55\x49\x54\x0A", 12)) {
/* This is not PROXY protocol at all */
return {false, 0};
}
/* We only support version 2 */
if ((header.ver_cmd & 0xf0) >> 4 != 2) {
return {false, 0};
}
//printf("Version: %d\n", (header.ver_cmd & 0xf0) >> 4);
//printf("Command: %d\n", (header.ver_cmd & 0x0f));
/* We get length in network byte order (todo: share this function with the rest) */
uint16_t hostLength = _cond_byte_swap<uint16_t>(header.len);
/* We must have all the data available */
if (data.length() < 16u + hostLength) {
return {false, 0};
}
/* Payload cannot be more than sizeof proxy_addr */
if (sizeof(proxy_addr) < hostLength) {
return {false, 0};
}
//printf("Family: %d\n", (header.fam & 0xf0) >> 4);
//printf("Transport: %d\n", (header.fam & 0x0f));
/* We have 0 family by default, and UNSPEC is 0 as well */
family = header.fam;
/* Copy payload */
memcpy(&addr, data.data() + 16, hostLength);
/* We consumed everything */
return {true, 16 + hostLength};
}
};
}
#endif
#endif // UWS_PROXY_PARSER_H

View File

@ -0,0 +1,120 @@
/*
* Authored by Alex Hultman, 2018-2020.
* Intellectual property of third-party.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/* This module implements URI query parsing and retrieval of value given key */
#ifndef UWS_QUERYPARSER_H
#define UWS_QUERYPARSER_H
#include <string_view>
namespace uWS {
/* Takes raw query including initial '?' sign. Will inplace decode, so input will mutate */
static inline std::string_view getDecodedQueryValue(std::string_view key, std::string_view rawQuery) {
/* Can't have a value without a key */
if (!key.length()) {
return {};
}
/* Start with the whole querystring including initial '?' */
std::string_view queryString = rawQuery;
/* List of key, value could be cached for repeated fetches similar to how headers are, todo! */
while (queryString.length()) {
/* Find boundaries of this statement */
std::string_view statement = queryString.substr(1, queryString.find('&', 1) - 1);
/* Only bother if first char of key match (early exit) */
if (statement.length() && statement[0] == key[0]) {
/* Equal sign must be present and not in the end of statement */
auto equality = statement.find('=');
if (equality != std::string_view::npos && equality != statement.length() - 1) {
std::string_view statementKey = statement.substr(0, equality);
std::string_view statementValue = statement.substr(equality + 1);
/* String comparison */
if (key == statementKey) {
/* Decode value inplace, put null at end if before length of original */
char *in = (char *) statementValue.data();
/* Write offset */
unsigned int out = 0;
/* Walk over all chars until end or null char, decoding in place */
for (unsigned int i = 0; i < statementValue.length() && in[i]; i++) {
/* Only bother with '%' */
if (in[i] == '%') {
/* Do we have enough data for two bytes hex? */
if (i + 2 >= statementValue.length()) {
return {};
}
/* Two bytes hex */
int hex1 = in[i + 1] - '0';
if (hex1 > 9) {
hex1 &= 223;
hex1 -= 7;
}
int hex2 = in[i + 2] - '0';
if (hex2 > 9) {
hex2 &= 223;
hex2 -= 7;
}
*((unsigned char *) &in[out]) = (unsigned char) (hex1 * 16 + hex2);
i += 2;
} else {
/* Is this even a rule? */
if (in[i] == '+') {
in[out] = ' ';
} else {
in[out] = in[i];
}
}
/* We always only write one char */
out++;
}
/* If decoded string is shorter than original, put null char to stop next read */
if (out < statementValue.length()) {
in[out] = 0;
}
return statementValue.substr(0, out);
}
} else {
/* This querystring is invalid, cannot parse it */
return {};
}
}
queryString.remove_prefix(statement.length() + 1);
}
/* Nothing found */
return {};
}
}
#endif

View File

@ -1,5 +1,5 @@
/*
* Authored by Alex Hultman, 2018-2019.
* Authored by Alex Hultman, 2018-2021.
* Intellectual property of third-party.
* Licensed under the Apache License, Version 2.0 (the "License");
@ -26,6 +26,10 @@
#include <set>
#include <chrono>
#include <list>
#include <cstring>
/* We use std::function here, not MoveOnlyFunction */
#include <functional>
namespace uWS {
@ -57,29 +61,123 @@ struct Topic {
/* Terminating wildcard child */
Topic *terminatingWildcardChild = nullptr;
/* What we published */
std::map<unsigned int, std::string> messages;
/* What we published, {inflated, deflated} */
std::map<unsigned int, std::pair<std::string, std::string>> messages;
std::set<Subscriber *> subs;
/* Locked or not, used only when iterating over a Subscriber's topics */
bool locked = false;
/* Full name is used when iterating topcis */
std::string fullName;
};
struct Hole {
std::pair<size_t, size_t> lengths;
unsigned int messageId;
};
struct Intersection {
std::pair<std::string, std::string> dataChannels;
std::vector<Hole> holes;
void forSubscriber(std::vector<unsigned int> &senderForMessages, std::function<void(std::pair<std::string_view, std::string_view>, bool)> cb) {
/* How far we already emitted of the two dataChannels */
std::pair<size_t, size_t> emitted = {};
/* Holes are global to the entire topic tree, so we are not guaranteed to find
* holes in this intersection - they are sorted, though */
unsigned int examinedHoles = 0;
/* This is a slow path of sorts, most subscribers will be observers, not active senders */
for (unsigned int id : senderForMessages) {
std::pair<size_t, size_t> toEmit = {};
std::pair<size_t, size_t> toIgnore = {};
/* This linear search is most probably very small - it could be made log2 if every hole
* knows about its previous accumulated length, which is easy to set up. However this
* log2 search will most likely never be a warranted perf. gain */
for (; examinedHoles < holes.size(); examinedHoles++) {
if (holes[examinedHoles].messageId == id) {
toIgnore.first += holes[examinedHoles].lengths.first;
toIgnore.second += holes[examinedHoles].lengths.second;
examinedHoles++;
break;
}
/* We are not the sender of this message so we should emit it in this segment */
toEmit.first += holes[examinedHoles].lengths.first;
toEmit.second += holes[examinedHoles].lengths.second;
}
/* Emit this segment */
if (toEmit.first || toEmit.second) {
std::pair<std::string_view, std::string_view> cutDataChannels = {
std::string_view(dataChannels.first.data() + emitted.first, toEmit.first),
std::string_view(dataChannels.second.data() + emitted.second, toEmit.second),
};
/* We only need to test the first data channel for "FIN" */
cb(cutDataChannels, emitted.first + toEmit.first + toIgnore.first == dataChannels.first.length());
}
emitted.first += toEmit.first + toIgnore.first;
emitted.second += toEmit.second + toIgnore.second;
}
if (emitted.first == dataChannels.first.length() && emitted.second == dataChannels.second.length()) {
return;
}
std::pair<std::string_view, std::string_view> cutDataChannels = {
std::string_view(dataChannels.first.data() + emitted.first, dataChannels.first.length() - emitted.first),
std::string_view(dataChannels.second.data() + emitted.second, dataChannels.second.length() - emitted.second),
};
cb(cutDataChannels, true);
}
};
struct TopicTree {
/* Returns Topic, or nullptr. Topic can be root if empty string given. */
Topic *lookupTopic(std::string_view topic) {
/* Lookup exact Topic ptr from string */
Topic *iterator = root;
for (size_t start = 0, stop = 0; stop != std::string::npos; start = stop + 1) {
stop = topic.find('/', start);
std::string_view segment = topic.substr(start, stop - start);
std::map<std::string_view, Topic *>::iterator it = iterator->children.find(segment);
if (it == iterator->children.end()) {
/* This topic does not even exist */
return nullptr;
}
iterator = it->second;
}
return iterator;
}
private:
std::function<int(Subscriber *, std::string_view)> cb;
std::function<int(Subscriber *, Intersection &)> cb;
Topic *root = new Topic;
/* Global messageId for deduplication of overlapping topics and ordering between topics */
unsigned int messageId = 0;
/* Sender holes */
std::map<Subscriber *, std::vector<unsigned int>> senderHoles;
/* The triggered topics */
Topic *triggeredTopics[64];
int numTriggeredTopics = 0;
Subscriber *min = (Subscriber *) UINTPTR_MAX;
/* Cull or trim unused Topic nodes from leaf to root */
void trimTree(Topic *topic) {
if (!topic->subs.size() && !topic->children.size() && !topic->terminatingWildcardChild && !topic->wildcardChild) {
while (!topic->subs.size() && !topic->children.size() && !topic->terminatingWildcardChild && !topic->wildcardChild) {
Topic *parent = topic->parent;
if (topic->length == 1) {
@ -112,43 +210,64 @@ private:
delete [] topic->name;
delete topic;
if (parent != root) {
trimTree(parent);
if (parent == root) {
break;
}
topic = parent;
}
}
/* Should be getData and commit? */
void publish(Topic *iterator, size_t start, size_t stop, std::string_view topic, std::string_view message) {
/* If we already have 64 triggered topics make sure to drain it here */
if (numTriggeredTopics == 64) {
drain();
}
/* Publishes to all matching topics and wildcards. Returns whether at least one topic was a match. */
bool publish(Topic *iterator, size_t start, size_t stop, std::string_view topic, std::pair<std::string_view, std::string_view> message) {
/* Whether we matched with at least one topic */
bool didMatch = false;
/* Iterate over all segments in given topic */
for (; stop != std::string::npos; start = stop + 1) {
stop = topic.find('/', start);
std::string_view segment = topic.substr(start, stop - start);
/* It is very important to disallow wildcards when publishing.
* We will not catch EVERY misuse this lazy way, but enough to hinder
* explosive recursion.
* Terminating wildcards MAY still get triggered along the way, if for
* instace the error is found late while iterating the topic segments. */
if (segment.length() == 1) {
if (segment[0] == '+' || segment[0] == '#') {
/* "Fail" here, but not necessarily for the entire publish */
return didMatch;
}
}
/* Do we have a terminating wildcard child? */
if (iterator->terminatingWildcardChild) {
iterator->terminatingWildcardChild->messages[messageId] = message;
/* Add this topic to triggered */
if (!iterator->terminatingWildcardChild->triggered) {
/* If we already have 64 triggered topics make sure to drain it here */
if (numTriggeredTopics == 64) {
drain();
}
triggeredTopics[numTriggeredTopics++] = iterator->terminatingWildcardChild;
iterator->terminatingWildcardChild->triggered = true;
}
didMatch = true;
}
/* Do we have a wildcard child? */
if (iterator->wildcardChild) {
publish(iterator->wildcardChild, stop + 1, stop, topic, message);
didMatch |= publish(iterator->wildcardChild, stop + 1, stop, topic, message);
}
std::map<std::string_view, Topic *>::iterator it = iterator->children.find(segment);
if (it == iterator->children.end()) {
/* Stop trying to match by exact string */
return;
return didMatch;
}
iterator = it->second;
@ -159,14 +278,22 @@ private:
/* Add this topic to triggered */
if (!iterator->triggered) {
/* If we already have 64 triggered topics make sure to drain it here */
if (numTriggeredTopics == 64) {
drain();
}
triggeredTopics[numTriggeredTopics++] = iterator;
iterator->triggered = true;
}
/* We obviously matches exactly here */
return true;
}
public:
TopicTree(std::function<int(Subscriber *, std::string_view)> cb) {
TopicTree(std::function<int(Subscriber *, Intersection &)> cb) {
this->cb = cb;
}
@ -174,7 +301,20 @@ public:
delete root;
}
void subscribe(std::string_view topic, Subscriber *subscriber) {
/* This is part of the fast path, so should be optimal */
std::vector<unsigned int> &getSenderFor(Subscriber *s) {
static thread_local std::vector<unsigned int> emptyVector;
auto it = senderHoles.find(s);
if (it != senderHoles.end()) {
return it->second;
}
return emptyVector;
}
/* Returns number of subscribers after the call and whether or not we were successful in subscribing */
std::pair<unsigned int, bool> subscribe(std::string_view topic, Subscriber *subscriber, bool nonStrict = false) {
/* Start iterating from the root */
Topic *iterator = root;
@ -196,7 +336,18 @@ public:
newTopic->terminatingWildcardChild = nullptr;
newTopic->wildcardChild = nullptr;
memcpy(newTopic->name, segment.data(), segment.length());
/* Set fullname as parent's name plus our name */
newTopic->fullName.reserve(newTopic->parent->fullName.length() + 1 + segment.length());
/* Only append parent's name if parent is not root */
if (newTopic->parent != root) {
newTopic->fullName.append(newTopic->parent->fullName);
newTopic->fullName.append("/");
}
newTopic->fullName.append(segment);
/* For simplicity we do insert wildcards with text */
iterator->children.insert(lb, {std::string_view(newTopic->name, segment.length()), newTopic});
@ -216,22 +367,38 @@ public:
}
}
/* If this topic is triggered, drain the tree before we join */
if (iterator->triggered) {
if (!nonStrict) {
drain();
}
}
/* Add socket to Topic's Set */
auto [it, inserted] = iterator->subs.insert(subscriber);
/* Add Topic to list of subscriptions only if we weren't already subscribed */
if (inserted) {
subscriber->subscriptions.push_back(iterator);
return {(unsigned int) iterator->subs.size(), true};
}
return {(unsigned int) iterator->subs.size(), false};
}
void publish(std::string_view topic, std::string_view message) {
publish(root, 0, 0, topic, message);
bool publish(std::string_view topic, std::pair<std::string_view, std::string_view> message, Subscriber *sender = nullptr) {
/* Add a hole for the sender if one */
if (sender) {
senderHoles[sender].push_back(messageId);
}
auto ret = publish(root, 0, 0, topic, message);
/* MessageIDs are reset on drain - this should be fine since messages itself are cleared on drain */
messageId++;
return ret;
}
/* Returns whether we were subscribed prior */
bool unsubscribe(std::string_view topic, Subscriber *subscriber) {
/* Returns a pair of numSubscribers after operation, and whether we were subscribed prior */
std::pair<unsigned int, bool> unsubscribe(std::string_view topic, Subscriber *subscriber, bool nonStrict = false) {
/* Subscribers are likely to have very few subscriptions (20 or fewer) */
if (subscriber) {
/* Lookup exact Topic ptr from string */
@ -243,32 +410,55 @@ public:
std::map<std::string_view, Topic *>::iterator it = iterator->children.find(segment);
if (it == iterator->children.end()) {
/* This topic does not even exist */
return false;
return {0, false};
}
iterator = it->second;
}
/* Is this topic locked? If so, we cannot unsubscribe from it */
if (iterator->locked) {
return {iterator->subs.size(), false};
}
/* Try and remove this topic from our list */
for (auto it = subscriber->subscriptions.begin(); it != subscriber->subscriptions.end(); it++) {
if (*it == iterator) {
/* If this topic is triggered, drain the tree before we leave */
if (iterator->triggered) {
if (!nonStrict) {
drain();
}
}
/* Remove topic ptr from our list */
subscriber->subscriptions.erase(it);
/* Remove us from Topic's subs */
iterator->subs.erase(subscriber);
unsigned int numSubscribers = (unsigned int) iterator->subs.size();
trimTree(iterator);
return true;
return {numSubscribers, true};
}
}
}
return false;
return {0, false};
}
/* Can be called with nullptr, ignore it then */
void unsubscribeAll(Subscriber *subscriber) {
void unsubscribeAll(Subscriber *subscriber, bool mayFlush = true) {
if (subscriber) {
for (Topic *topic : subscriber->subscriptions) {
/* We do not want to flush when closing a socket, it makes no sense to do so */
/* If this topic is triggered, drain the tree before we leave */
if (mayFlush && topic->triggered) {
/* Never mind nonStrict here (yet?) */
drain();
}
/* Remove us from the topic's set */
topic->subs.erase(subscriber);
trimTree(topic);
}
@ -290,11 +480,18 @@ public:
for (int i = 0; i < numTriggeredTopics; i++) {
if (triggeredTopics[i]->subs.size()) {
triggeredTopics[numFilteredTriggeredTopics++] = triggeredTopics[i];
} else {
/* If we no longer have any subscribers, yet still keep this Topic alive (parent),
* make sure to clear its potential messages. */
triggeredTopics[i]->messages.clear();
triggeredTopics[i]->triggered = false;
}
}
numTriggeredTopics = numFilteredTriggeredTopics;
if (!numTriggeredTopics) {
senderHoles.clear();
messageId = 0;
return;
}
@ -310,7 +507,7 @@ public:
if (min != (Subscriber *)UINTPTR_MAX) {
/* Up to 64 triggered Topics per batch */
std::map<uint64_t, std::string> intersectionCache;
std::map<uint64_t, Intersection> intersectionCache;
/* Loop over these here */
std::set<Subscriber *>::iterator it[64];
@ -319,14 +516,14 @@ public:
it[i] = triggeredTopics[i]->subs.begin();
end[i] = triggeredTopics[i]->subs.end();
}
/* Empty all sets from unique subscribers */
for (int nonEmpty = numTriggeredTopics; nonEmpty; ) {
Subscriber *nextMin = (Subscriber *)UINTPTR_MAX;
/* The message sets relevant for this intersection */
std::map<unsigned int, std::string> *perSubscriberIntersectingTopicMessages[64];
std::map<unsigned int, std::pair<std::string, std::string>> *perSubscriberIntersectingTopicMessages[64];
int numPerSubscriberIntersectingTopicMessages = 0;
uint64_t intersection = 0;
@ -357,18 +554,28 @@ public:
}
/* Generate cache for intersection */
if (intersectionCache[intersection].length() == 0) {
if (intersectionCache[intersection].dataChannels.first.length() == 0) {
/* Build the union in order without duplicates */
std::map<unsigned int, std::string> complete;
std::map<unsigned int, std::pair<std::string, std::string>> complete;
for (int i = 0; i < numPerSubscriberIntersectingTopicMessages; i++) {
complete.insert(perSubscriberIntersectingTopicMessages[i]->begin(), perSubscriberIntersectingTopicMessages[i]->end());
}
/* Create the linear cache */
std::string res;
/* Create the linear cache, {inflated, deflated} */
Intersection res;
for (auto &p : complete) {
res.append(p.second);
res.dataChannels.first.append(p.second.first);
res.dataChannels.second.append(p.second.second);
/* Appends {id, length, length}
* We could possibly append byte offset also,
* if we want to use log2 search later. */
Hole h;
h.lengths.first = p.second.first.length();
h.lengths.second = p.second.second.length();
h.messageId = p.first;
res.holes.push_back(h);
}
cb(min, intersectionCache[intersection] = std::move(res));
@ -379,7 +586,6 @@ public:
min = nextMin;
}
}
/* Clear messages of triggered Topics */
@ -388,27 +594,8 @@ public:
triggeredTopics[i]->triggered = false;
}
numTriggeredTopics = 0;
}
void print(Topic *root = nullptr, int indentation = 1) {
if (root == nullptr) {
std::cout << "Print of tree:" << std::endl;
root = this->root;
}
for (auto p : root->children) {
for (int i = 0; i < indentation; i++) {
std::cout << " ";
}
std::cout << std::string_view(p.second->name, p.second->length) << " = " << p.second->messages.size() << " publishes, " << p.second->subs.size() << " subscribers {";
for (auto &p : p.second->subs) {
std::cout << p << " referring to socket: " << p->user << ", ";
}
std::cout << "}" << std::endl;
print(p.second, indentation + 1);
}
senderHoles.clear();
messageId = 0;
}
};

View File

@ -1,5 +1,5 @@
/*
* Authored by Alex Hultman, 2018-2019.
* Authored by Alex Hultman, 2018-2020.
* Intellectual property of third-party.
* Licensed under the Apache License, Version 2.0 (the "License");
@ -43,8 +43,8 @@ inline int u32toaHex(uint32_t value, char *dst) {
return ret;
}
inline int u32toa(uint32_t value, char *dst) {
char temp[10];
inline int u64toa(uint64_t value, char *dst) {
char temp[20];
char *p = temp;
do {
*p++ = (char) ((value % 10) + '0');

View File

@ -1,5 +1,5 @@
/*
* Authored by Alex Hultman, 2018-2019.
* Authored by Alex Hultman, 2018-2021.
* Intellectual property of third-party.
* Licensed under the Apache License, Version 2.0 (the "License");
@ -27,40 +27,63 @@
namespace uWS {
template <bool SSL, bool isServer>
template <bool SSL, bool isServer, typename USERDATA>
struct WebSocket : AsyncSocket<SSL> {
template <bool> friend struct TemplatedApp;
template <bool> friend struct HttpResponse;
private:
typedef AsyncSocket<SSL> Super;
void *init(bool perMessageDeflate, bool slidingCompression, std::string &&backpressure) {
new (us_socket_ext(SSL, (us_socket_t *) this)) WebSocketData(perMessageDeflate, slidingCompression, std::move(backpressure));
void *init(bool perMessageDeflate, CompressOptions compressOptions, std::string &&backpressure) {
new (us_socket_ext(SSL, (us_socket_t *) this)) WebSocketData(perMessageDeflate, compressOptions, std::move(backpressure));
return this;
}
public:
/* Returns pointer to the per socket user data */
void *getUserData() {
USERDATA *getUserData() {
WebSocketData *webSocketData = (WebSocketData *) us_socket_ext(SSL, (us_socket_t *) this);
/* We just have it overallocated by sizeof type */
return (webSocketData + 1);
return (USERDATA *) (webSocketData + 1);
}
/* See AsyncSocket */
using Super::getBufferedAmount;
using Super::getRemoteAddress;
using Super::getRemoteAddressAsText;
using Super::getNativeHandle;
/* Simple, immediate close of the socket. Emits close event */
using Super::close;
/* Send or buffer a WebSocket frame, compressed or not. Returns false on increased user space backpressure. */
bool send(std::string_view message, uWS::OpCode opCode = uWS::OpCode::BINARY, bool compress = false) {
enum SendStatus : int {
BACKPRESSURE,
SUCCESS,
DROPPED
};
/* Send or buffer a WebSocket frame, compressed or not. Returns BACKPRESSURE on increased user space backpressure,
* DROPPED on dropped message (due to backpressure) or SUCCCESS if you are free to send even more now. */
SendStatus send(std::string_view message, OpCode opCode = OpCode::BINARY, bool compress = false) {
WebSocketContextData<SSL, USERDATA> *webSocketContextData = (WebSocketContextData<SSL, USERDATA> *) us_socket_context_ext(SSL,
(us_socket_context_t *) us_socket_context(SSL, (us_socket_t *) this)
);
/* Skip sending and report success if we are over the limit of maxBackpressure */
if (webSocketContextData->maxBackpressure && webSocketContextData->maxBackpressure < getBufferedAmount()) {
/* Also defer a close if we should */
if (webSocketContextData->closeOnBackpressureLimit) {
us_socket_shutdown_read(SSL, (us_socket_t *) this);
}
return DROPPED;
}
/* Transform the message to compressed domain if requested */
if (compress) {
WebSocketData *webSocketData = (WebSocketData *) Super::getAsyncSocketData();
/* Check and correct the compress hint */
if (opCode < 3 && webSocketData->compressionStatus == WebSocketData::ENABLED) {
/* Check and correct the compress hint. It is never valid to compress 0 bytes */
if (message.length() && opCode < 3 && webSocketData->compressionStatus == WebSocketData::ENABLED) {
LoopData *loopData = Super::getLoopData();
/* Compress using either shared or dedicated deflationStream */
if (webSocketData->deflationStream) {
@ -82,7 +105,7 @@ public:
/* Get size, alloate size, write if needed */
size_t messageFrameSize = protocol::messageFrameSize(message.length());
auto[sendBuffer, requiresWrite] = Super::getSendBuffer(messageFrameSize);
auto [sendBuffer, requiresWrite] = Super::getSendBuffer(messageFrameSize);
protocol::formatMessage<isServer>(sendBuffer, message.data(), message.length(), opCode, message.length(), compress);
/* This is the slow path, when we couldn't cork for the user */
if (requiresWrite) {
@ -93,7 +116,7 @@ public:
if (failed) {
/* Return false for failure, skipping to reset the timeout below */
return false;
return BACKPRESSURE;
}
}
@ -101,22 +124,24 @@ public:
if (automaticallyCorked) {
auto [written, failed] = Super::uncork();
if (failed) {
return false;
return BACKPRESSURE;
}
}
/* Every successful send resets the timeout */
WebSocketContextData<SSL> *webSocketContextData = (WebSocketContextData<SSL> *) us_socket_context_ext(SSL,
(us_socket_context_t *) us_socket_context(SSL, (us_socket_t *) this)
);
AsyncSocket<SSL>::timeout(webSocketContextData->idleTimeout);
if (webSocketContextData->resetIdleTimeoutOnSend) {
Super::timeout(webSocketContextData->idleTimeoutComponents.first);
WebSocketData *webSocketData = (WebSocketData *) Super::getAsyncSocketData();
webSocketData->hasTimedOut = false;
}
/* Return success */
return true;
return SUCCESS;
}
/* Send websocket close frame, emit close event, send FIN if successful */
void end(int code, std::string_view message = {}) {
/* Send websocket close frame, emit close event, send FIN if successful.
* Will not append a close reason if code is 0 or 1005. */
void end(int code = 0, std::string_view message = {}) {
/* Check if we already called this one */
WebSocketData *webSocketData = (WebSocketData *) us_socket_ext(SSL, (us_socket_t *) this);
if (webSocketData->isShuttingDown) {
@ -128,23 +153,22 @@ public:
/* Format and send the close frame */
static const int MAX_CLOSE_PAYLOAD = 123;
int length = (int) std::min<size_t>(MAX_CLOSE_PAYLOAD, message.length());
size_t length = std::min<size_t>(MAX_CLOSE_PAYLOAD, message.length());
char closePayload[MAX_CLOSE_PAYLOAD + 2];
int closePayloadLength = (int) protocol::formatClosePayload(closePayload, (uint16_t) code, message.data(), length);
size_t closePayloadLength = protocol::formatClosePayload(closePayload, (uint16_t) code, message.data(), length);
bool ok = send(std::string_view(closePayload, closePayloadLength), OpCode::CLOSE);
/* FIN if we are ok and not corked */
WebSocket<SSL, true> *webSocket = (WebSocket<SSL, true> *) this;
if (!webSocket->isCorked()) {
if (!this->isCorked()) {
if (ok) {
/* If we are not corked, and we just sent off everything, we need to FIN right here.
* In all other cases, we need to fin either if uncork was successful, or when drainage is complete. */
webSocket->shutdown();
this->shutdown();
}
}
/* Emit close event */
WebSocketContextData<SSL> *webSocketContextData = (WebSocketContextData<SSL> *) us_socket_context_ext(SSL,
WebSocketContextData<SSL, USERDATA> *webSocketContextData = (WebSocketContextData<SSL, USERDATA> *) us_socket_context_ext(SSL,
(us_socket_context_t *) us_socket_context(SSL, (us_socket_t *) this)
);
if (webSocketContextData->closeHandler) {
@ -152,13 +176,13 @@ public:
}
/* Make sure to unsubscribe from any pub/sub node at exit */
webSocketContextData->topicTree.unsubscribeAll(webSocketData->subscriber);
webSocketContextData->topicTree.unsubscribeAll(webSocketData->subscriber, false);
delete webSocketData->subscriber;
webSocketData->subscriber = nullptr;
}
/* Corks the response if possible. Leaves already corked socket be. */
void cork(fu2::unique_function<void()> &&handler) {
void cork(MoveOnlyFunction<void()> &&handler) {
if (!Super::isCorked() && Super::canCork()) {
Super::cork();
handler();
@ -172,9 +196,9 @@ public:
}
}
/* Subscribe to a topic according to MQTT rules and syntax */
void subscribe(std::string_view topic) {
WebSocketContextData<SSL> *webSocketContextData = (WebSocketContextData<SSL> *) us_socket_context_ext(SSL,
/* Subscribe to a topic according to MQTT rules and syntax. Returns success */
/*std::pair<unsigned int, bool>*/ bool subscribe(std::string_view topic, bool nonStrict = false) {
WebSocketContextData<SSL, USERDATA> *webSocketContextData = (WebSocketContextData<SSL, USERDATA> *) us_socket_context_ext(SSL,
(us_socket_context_t *) us_socket_context(SSL, (us_socket_t *) this)
);
@ -184,38 +208,87 @@ public:
webSocketData->subscriber = new Subscriber(this);
}
webSocketContextData->topicTree.subscribe(topic, webSocketData->subscriber);
/* Cannot return numSubscribers as this is only for this particular websocket context */
return webSocketContextData->topicTree.subscribe(topic, webSocketData->subscriber, nonStrict).second;
}
/* Unsubscribe from a topic, returns true if we were subscribed */
bool unsubscribe(std::string_view topic) {
WebSocketContextData<SSL> *webSocketContextData = (WebSocketContextData<SSL> *) us_socket_context_ext(SSL,
/* Unsubscribe from a topic, returns true if we were subscribed. */
/*std::pair<unsigned int, bool>*/ bool unsubscribe(std::string_view topic, bool nonStrict = false) {
WebSocketContextData<SSL, USERDATA> *webSocketContextData = (WebSocketContextData<SSL, USERDATA> *) us_socket_context_ext(SSL,
(us_socket_context_t *) us_socket_context(SSL, (us_socket_t *) this)
);
WebSocketData *webSocketData = (WebSocketData *) us_socket_ext(SSL, (us_socket_t *) this);
return webSocketContextData->topicTree.unsubscribe(topic, webSocketData->subscriber);
/* Cannot return numSubscribers as this is only for this particular websocket context */
return webSocketContextData->topicTree.unsubscribe(topic, webSocketData->subscriber, nonStrict).second;
}
/* Unsubscribe from all topics you might be subscribed to */
void unsubscribeAll() {
WebSocketContextData<SSL> *webSocketContextData = (WebSocketContextData<SSL> *) us_socket_context_ext(SSL,
/* Returns whether this socket is subscribed to the specified topic */
bool isSubscribed(std::string_view topic) {
WebSocketContextData<SSL, USERDATA> *webSocketContextData = (WebSocketContextData<SSL, USERDATA> *) us_socket_context_ext(SSL,
(us_socket_context_t *) us_socket_context(SSL, (us_socket_t *) this)
);
WebSocketData *webSocketData = (WebSocketData *) us_socket_ext(SSL, (us_socket_t *) this);
if (!webSocketData->subscriber) {
return false;
}
webSocketContextData->topicTree.unsubscribeAll(webSocketData->subscriber);
Topic *t = webSocketContextData->topicTree.lookupTopic(topic);
if (t) {
return t->subs.find(webSocketData->subscriber) != t->subs.end();
}
return false;
}
/* Publish a message to a topic according to MQTT rules and syntax */
void publish(std::string_view topic, std::string_view message, OpCode opCode = OpCode::TEXT, bool compress = false) {
WebSocketContextData<SSL> *webSocketContextData = (WebSocketContextData<SSL> *) us_socket_context_ext(SSL,
/* Iterates all topics of this WebSocket. Every topic is represented by its full name.
* Can be called in close handler. It is possible to modify the subscription list while
* inside the callback ONLY IF not modifying the topic passed to the callback.
* Topic names are valid only for the duration of the callback. */
void iterateTopics(MoveOnlyFunction<void(std::string_view/*, unsigned int*/)> cb) {
WebSocketData *webSocketData = (WebSocketData *) us_socket_ext(SSL, (us_socket_t *) this);
if (webSocketData->subscriber) {
for (Topic *t : webSocketData->subscriber->subscriptions) {
/* Lock this topic so that nobody may unsubscribe from it during this callback */
t->locked = true;
cb(t->fullName/*, (unsigned int) t->subs.size()*/);
t->locked = false;
}
}
}
/* Publish a message to a topic according to MQTT rules and syntax. Returns success.
* We, the WebSocket, must be subscribed to the topic itself and if so - no message will be sent to ourselves.
* Use App::publish for an unconditional publish that simply publishes to whomever might be subscribed. */
bool publish(std::string_view topic, std::string_view message, OpCode opCode = OpCode::TEXT, bool compress = false) {
WebSocketContextData<SSL, USERDATA> *webSocketContextData = (WebSocketContextData<SSL, USERDATA> *) us_socket_context_ext(SSL,
(us_socket_context_t *) us_socket_context(SSL, (us_socket_t *) this)
);
/* Is the same as publishing per websocket context */
webSocketContextData->publish(topic, message, opCode, compress);
/* We cannot be a subscriber of this topic if we are not a subscriber of anything */
WebSocketData *webSocketData = (WebSocketData *) us_socket_ext(SSL, (us_socket_t *) this);
if (!webSocketData->subscriber) {
/* Failure, but still do return the number of subscribers */
return false;
}
/* Publish as sender, does not receive its own messages even if subscribed to relevant topics */
bool success = webSocketContextData->publish(topic, message, opCode, compress, webSocketData->subscriber);
/* Loop over all websocket contexts for this App */
if (success) {
/* Success is really only determined by the first publish. We must be subscribed to the topic. */
for (auto *adjacentWebSocketContextData : webSocketContextData->adjacentWebSocketContextDatas) {
adjacentWebSocketContextData->publish(topic, message, opCode, compress);
}
}
return success;
}
};

View File

@ -1,5 +1,5 @@
/*
* Authored by Alex Hultman, 2018-2019.
* Authored by Alex Hultman, 2018-2020.
* Intellectual property of third-party.
* Licensed under the Apache License, Version 2.0 (the "License");
@ -25,7 +25,7 @@
namespace uWS {
template <bool SSL, bool isServer>
template <bool SSL, bool isServer, typename USERDATA>
struct WebSocketContext {
template <bool> friend struct TemplatedApp;
template <bool, typename> friend struct WebSocketProtocol;
@ -36,12 +36,12 @@ private:
return (us_socket_context_t *) this;
}
WebSocketContextData<SSL> *getExt() {
return (WebSocketContextData<SSL> *) us_socket_context_ext(SSL, (us_socket_context_t *) this);
WebSocketContextData<SSL, USERDATA> *getExt() {
return (WebSocketContextData<SSL, USERDATA> *) us_socket_context_ext(SSL, (us_socket_context_t *) this);
}
/* If we have negotiated compression, set this frame compressed */
static bool setCompressed(uWS::WebSocketState<isServer> *wState, void *s) {
static bool setCompressed(WebSocketState<isServer> */*wState*/, void *s) {
WebSocketData *webSocketData = (WebSocketData *) us_socket_ext(SSL, (us_socket_t *) s);
if (webSocketData->compressionStatus == WebSocketData::CompressionStatus::ENABLED) {
@ -52,14 +52,14 @@ private:
}
}
static void forceClose(uWS::WebSocketState<isServer> *wState, void *s) {
us_socket_close(SSL, (us_socket_t *) s);
static void forceClose(WebSocketState<isServer> */*wState*/, void *s, std::string_view reason = {}) {
us_socket_close(SSL, (us_socket_t *) s, (int) reason.length(), (void *) reason.data());
}
/* Returns true on breakage */
static bool handleFragment(char *data, size_t length, unsigned int remainingBytes, int opCode, bool fin, uWS::WebSocketState<isServer> *webSocketState, void *s) {
static bool handleFragment(char *data, size_t length, unsigned int remainingBytes, int opCode, bool fin, WebSocketState<isServer> *webSocketState, void *s) {
/* WebSocketData and WebSocketContextData */
WebSocketContextData<SSL> *webSocketContextData = (WebSocketContextData<SSL> *) us_socket_context_ext(SSL, us_socket_context(SSL, (us_socket_t *) s));
WebSocketContextData<SSL, USERDATA> *webSocketContextData = (WebSocketContextData<SSL, USERDATA> *) us_socket_context_ext(SSL, us_socket_context(SSL, (us_socket_t *) s));
WebSocketData *webSocketData = (WebSocketData *) us_socket_ext(SSL, (us_socket_t *) s);
/* Is this a non-control frame? */
@ -72,25 +72,25 @@ private:
webSocketData->compressionStatus = WebSocketData::CompressionStatus::ENABLED;
LoopData *loopData = (LoopData *) us_loop_ext(us_socket_context_loop(SSL, us_socket_context(SSL, (us_socket_t *) s)));
std::string_view inflatedFrame = loopData->inflationStream->inflate(loopData->zlibContext, {data, length}, webSocketContextData->maxPayloadLength);
if (!inflatedFrame.length()) {
forceClose(webSocketState, s);
auto inflatedFrame = loopData->inflationStream->inflate(loopData->zlibContext, {data, length}, webSocketContextData->maxPayloadLength);
if (!inflatedFrame.has_value()) {
forceClose(webSocketState, s, ERR_TOO_BIG_MESSAGE_INFLATION);
return true;
} else {
data = (char *) inflatedFrame.data();
length = inflatedFrame.length();
data = (char *) inflatedFrame->data();
length = inflatedFrame->length();
}
}
/* Check text messages for Utf-8 validity */
if (opCode == 1 && !protocol::isValidUtf8((unsigned char *) data, length)) {
forceClose(webSocketState, s);
forceClose(webSocketState, s, ERR_INVALID_TEXT);
return true;
}
/* Emit message event & break if we are closed or shut down when returning */
if (webSocketContextData->messageHandler) {
webSocketContextData->messageHandler((WebSocket<SSL, isServer> *) s, std::string_view(data, length), (uWS::OpCode) opCode);
webSocketContextData->messageHandler((WebSocket<SSL, isServer, USERDATA> *) s, std::string_view(data, length), (OpCode) opCode);
if (us_socket_is_closed(SSL, (us_socket_t *) s) || webSocketData->isShuttingDown) {
return true;
}
@ -102,7 +102,7 @@ private:
}
/* Fragments forming a big message are not caught until appending them */
if (refusePayloadLength(length + webSocketData->fragmentBuffer.length(), webSocketState, s)) {
forceClose(webSocketState, s);
forceClose(webSocketState, s, ERR_TOO_BIG_MESSAGE);
return true;
}
webSocketData->fragmentBuffer.append(data, length);
@ -115,8 +115,8 @@ private:
if (webSocketData->compressionStatus == WebSocketData::CompressionStatus::COMPRESSED_FRAME) {
webSocketData->compressionStatus = WebSocketData::CompressionStatus::ENABLED;
// what's really the story here?
webSocketData->fragmentBuffer.append("....");
/* 9 bytes of padding for libdeflate */
webSocketData->fragmentBuffer.append("123456789");
LoopData *loopData = (LoopData *) us_loop_ext(
us_socket_context_loop(SSL,
@ -124,13 +124,13 @@ private:
)
);
std::string_view inflatedFrame = loopData->inflationStream->inflate(loopData->zlibContext, {webSocketData->fragmentBuffer.data(), webSocketData->fragmentBuffer.length() - 4}, webSocketContextData->maxPayloadLength);
if (!inflatedFrame.length()) {
forceClose(webSocketState, s);
auto inflatedFrame = loopData->inflationStream->inflate(loopData->zlibContext, {webSocketData->fragmentBuffer.data(), webSocketData->fragmentBuffer.length() - 9}, webSocketContextData->maxPayloadLength);
if (!inflatedFrame.has_value()) {
forceClose(webSocketState, s, ERR_TOO_BIG_MESSAGE_INFLATION);
return true;
} else {
data = (char *) inflatedFrame.data();
length = inflatedFrame.length();
data = (char *) inflatedFrame->data();
length = inflatedFrame->length();
}
@ -142,13 +142,13 @@ private:
/* Check text messages for Utf-8 validity */
if (opCode == 1 && !protocol::isValidUtf8((unsigned char *) data, length)) {
forceClose(webSocketState, s);
forceClose(webSocketState, s, ERR_INVALID_TEXT);
return true;
}
/* Emit message and check for shutdown or close */
if (webSocketContextData->messageHandler) {
webSocketContextData->messageHandler((WebSocket<SSL, isServer> *) s, std::string_view(data, length), (uWS::OpCode) opCode);
webSocketContextData->messageHandler((WebSocket<SSL, isServer, USERDATA> *) s, std::string_view(data, length), (OpCode) opCode);
if (us_socket_is_closed(SSL, (us_socket_t *) s) || webSocketData->isShuttingDown) {
return true;
}
@ -160,7 +160,7 @@ private:
}
} else {
/* Control frames need the websocket to send pings, pongs and close */
WebSocket<SSL, isServer> *webSocket = (WebSocket<SSL, isServer> *) s;
WebSocket<SSL, isServer, USERDATA> *webSocket = (WebSocket<SSL, isServer, USERDATA> *) s;
if (!remainingBytes && fin && !webSocketData->controlTipLength) {
if (opCode == CLOSE) {
@ -171,14 +171,14 @@ private:
if (opCode == PING) {
webSocket->send(std::string_view(data, length), (OpCode) OpCode::PONG);
if (webSocketContextData->pingHandler) {
webSocketContextData->pingHandler(webSocket);
webSocketContextData->pingHandler(webSocket, {data, length});
if (us_socket_is_closed(SSL, (us_socket_t *) s) || webSocketData->isShuttingDown) {
return true;
}
}
} else if (opCode == PONG) {
if (webSocketContextData->pongHandler) {
webSocketContextData->pongHandler(webSocket);
webSocketContextData->pongHandler(webSocket, {data, length});
if (us_socket_is_closed(SSL, (us_socket_t *) s) || webSocketData->isShuttingDown) {
return true;
}
@ -188,7 +188,7 @@ private:
} else {
/* Here we never mind any size optimizations as we are in the worst possible path */
webSocketData->fragmentBuffer.append(data, length);
webSocketData->controlTipLength += (int) length;
webSocketData->controlTipLength += (unsigned int) length;
if (!remainingBytes && fin) {
char *controlBuffer = (char *) webSocketData->fragmentBuffer.data() + webSocketData->fragmentBuffer.length() - webSocketData->controlTipLength;
@ -200,14 +200,14 @@ private:
if (opCode == PING) {
webSocket->send(std::string_view(controlBuffer, webSocketData->controlTipLength), (OpCode) OpCode::PONG);
if (webSocketContextData->pingHandler) {
webSocketContextData->pingHandler(webSocket);
webSocketContextData->pingHandler(webSocket, std::string_view(controlBuffer, webSocketData->controlTipLength));
if (us_socket_is_closed(SSL, (us_socket_t *) s) || webSocketData->isShuttingDown) {
return true;
}
}
} else if (opCode == PONG) {
if (webSocketContextData->pongHandler) {
webSocketContextData->pongHandler(webSocket);
webSocketContextData->pongHandler(webSocket, std::string_view(controlBuffer, webSocketData->controlTipLength));
if (us_socket_is_closed(SSL, (us_socket_t *) s) || webSocketData->isShuttingDown) {
return true;
}
@ -224,32 +224,32 @@ private:
return false;
}
static bool refusePayloadLength(uint64_t length, uWS::WebSocketState<isServer> *wState, void *s) {
auto *webSocketContextData = (WebSocketContextData<SSL> *) us_socket_context_ext(SSL, us_socket_context(SSL, (us_socket_t *) s));
static bool refusePayloadLength(uint64_t length, WebSocketState<isServer> */*wState*/, void *s) {
auto *webSocketContextData = (WebSocketContextData<SSL, USERDATA> *) us_socket_context_ext(SSL, us_socket_context(SSL, (us_socket_t *) s));
/* Return true for refuse, false for accept */
return webSocketContextData->maxPayloadLength < length;
}
WebSocketContext<SSL, isServer> *init() {
WebSocketContext<SSL, isServer, USERDATA> *init() {
/* Adopting a socket does not trigger open event.
* We arreive as WebSocket with timeout set and
* any backpressure from HTTP state kept. */
/* Handle socket disconnections */
us_socket_context_on_close(SSL, getSocketContext(), [](auto *s) {
us_socket_context_on_close(SSL, getSocketContext(), [](auto *s, int code, void *reason) {
/* For whatever reason, if we already have emitted close event, do not emit it again */
WebSocketData *webSocketData = (WebSocketData *) (us_socket_ext(SSL, s));
if (!webSocketData->isShuttingDown) {
/* Emit close event */
auto *webSocketContextData = (WebSocketContextData<SSL> *) us_socket_context_ext(SSL, us_socket_context(SSL, (us_socket_t *) s));
auto *webSocketContextData = (WebSocketContextData<SSL, USERDATA> *) us_socket_context_ext(SSL, us_socket_context(SSL, (us_socket_t *) s));
if (webSocketContextData->closeHandler) {
webSocketContextData->closeHandler((WebSocket<SSL, true> *) s, 1006, {});
webSocketContextData->closeHandler((WebSocket<SSL, isServer, USERDATA> *) s, 1006, {(char *) reason, (size_t) code});
}
/* Make sure to unsubscribe from any pub/sub node at exit */
webSocketContextData->topicTree.unsubscribeAll(webSocketData->subscriber);
webSocketContextData->topicTree.unsubscribeAll(webSocketData->subscriber, false);
delete webSocketData->subscriber;
webSocketData->subscriber = nullptr;
}
@ -272,17 +272,18 @@ private:
return s;
}
auto *webSocketContextData = (WebSocketContextData<SSL> *) us_socket_context_ext(SSL, us_socket_context(SSL, (us_socket_t *) s));
auto *webSocketContextData = (WebSocketContextData<SSL, USERDATA> *) us_socket_context_ext(SSL, us_socket_context(SSL, (us_socket_t *) s));
auto *asyncSocket = (AsyncSocket<SSL> *) s;
/* Every time we get data and not in shutdown state we simply reset the timeout */
asyncSocket->timeout(webSocketContextData->idleTimeout);
asyncSocket->timeout(webSocketContextData->idleTimeoutComponents.first);
webSocketData->hasTimedOut = false;
/* We always cork on data */
asyncSocket->cork();
/* This parser has virtually no overhead */
uWS::WebSocketProtocol<isServer, WebSocketContext<SSL, isServer>>::consume(data, length, (WebSocketState<isServer> *) webSocketData, s);
WebSocketProtocol<isServer, WebSocketContext<SSL, isServer, USERDATA>>::consume(data, (unsigned int) length, (WebSocketState<isServer> *) webSocketData, s);
/* Uncorking a closed socekt is fine, in fact it is needed */
asyncSocket->uncork();
@ -302,6 +303,10 @@ private:
/* Handle HTTP write out (note: SSL_read may trigger this spuriously, the app need to handle spurious calls) */
us_socket_context_on_writable(SSL, getSocketContext(), [](auto *s) {
/* NOTE: Are we called here corked? If so, the below write code is broken, since
* we will have 0 as getBufferedAmount due to writing to cork buffer, then sending TCP FIN before
* we actually uncorked and sent off things */
/* It makes sense to check for us_is_shut_down here and return if so, to avoid shutting down twice */
if (us_socket_is_shut_down(SSL, (us_socket_t *) s)) {
return s;
@ -310,16 +315,19 @@ private:
AsyncSocket<SSL> *asyncSocket = (AsyncSocket<SSL> *) s;
WebSocketData *webSocketData = (WebSocketData *)(us_socket_ext(SSL, s));
/* We store old backpressure since it is unclear whether write drained anything */
int backpressure = asyncSocket->getBufferedAmount();
/* We store old backpressure since it is unclear whether write drained anything,
* however, in case of coming here with 0 backpressure we still need to emit drain event */
unsigned int backpressure = asyncSocket->getBufferedAmount();
/* Drain as much as possible */
asyncSocket->write(nullptr, 0);
/* Behavior: if we actively drain backpressure, always reset timeout (even if we are in shutdown) */
if (backpressure < asyncSocket->getBufferedAmount()) {
auto *webSocketContextData = (WebSocketContextData<SSL> *) us_socket_context_ext(SSL, us_socket_context(SSL, (us_socket_t *) s));
asyncSocket->timeout(webSocketContextData->idleTimeout);
/* Also reset timeout if we came here with 0 backpressure */
if (!backpressure || backpressure > asyncSocket->getBufferedAmount()) {
auto *webSocketContextData = (WebSocketContextData<SSL, USERDATA> *) us_socket_context_ext(SSL, us_socket_context(SSL, (us_socket_t *) s));
asyncSocket->timeout(webSocketContextData->idleTimeoutComponents.first);
webSocketData->hasTimedOut = false;
}
/* Are we in (WebSocket) shutdown mode? */
@ -329,11 +337,11 @@ private:
/* Now perform the actual TCP/TLS shutdown which was postponed due to backpressure */
asyncSocket->shutdown();
}
} else if (backpressure > asyncSocket->getBufferedAmount()) {
/* Only call drain if we actually drained backpressure */
auto *webSocketContextData = (WebSocketContextData<SSL> *) us_socket_context_ext(SSL, us_socket_context(SSL, (us_socket_t *) s));
} else if (!backpressure || backpressure > asyncSocket->getBufferedAmount()) {
/* Only call drain if we actually drained backpressure or if we came here with 0 backpressure */
auto *webSocketContextData = (WebSocketContextData<SSL, USERDATA> *) us_socket_context_ext(SSL, us_socket_context(SSL, (us_socket_t *) s));
if (webSocketContextData->drainHandler) {
webSocketContextData->drainHandler((WebSocket<SSL, isServer> *) s);
webSocketContextData->drainHandler((WebSocket<SSL, isServer, USERDATA> *) s);
}
/* No need to check for closed here as we leave the handler immediately*/
}
@ -345,7 +353,7 @@ private:
us_socket_context_on_end(SSL, getSocketContext(), [](auto *s) {
/* If we get a fin, we just close I guess */
us_socket_close(SSL, (us_socket_t *) s);
us_socket_close(SSL, (us_socket_t *) s, 0, nullptr);
return s;
});
@ -353,8 +361,20 @@ private:
/* Handle socket timeouts, simply close them so to not confuse client with FIN */
us_socket_context_on_timeout(SSL, getSocketContext(), [](auto *s) {
auto *webSocketData = (WebSocketData *)(us_socket_ext(SSL, s));
auto *webSocketContextData = (WebSocketContextData<SSL, USERDATA> *) us_socket_context_ext(SSL, us_socket_context(SSL, (us_socket_t *) s));
if (webSocketContextData->sendPingsAutomatically && !webSocketData->hasTimedOut) {
webSocketData->hasTimedOut = true;
us_socket_timeout(SSL, s, webSocketContextData->idleTimeoutComponents.second);
/* Send ping without being corked */
((AsyncSocket<SSL> *) s)->write("\x89\x00", 2);
return s;
}
/* Timeout is very simple; we just close it */
us_socket_close(SSL, (us_socket_t *) s);
/* Warning: we happen to know forceClose will not use first parameter so pass nullptr here */
forceClose(nullptr, s, ERR_WEBSOCKET_TIMEOUT);
return s;
});
@ -363,7 +383,7 @@ private:
}
void free() {
WebSocketContextData<SSL> *webSocketContextData = (WebSocketContextData<SSL> *) us_socket_context_ext(SSL, (us_socket_context_t *) this);
WebSocketContextData<SSL, USERDATA> *webSocketContextData = (WebSocketContextData<SSL, USERDATA> *) us_socket_context_ext(SSL, (us_socket_context_t *) this);
webSocketContextData->~WebSocketContextData();
us_socket_context_free(SSL, (us_socket_context_t *) this);
@ -371,14 +391,14 @@ private:
public:
/* WebSocket contexts are always child contexts to a HTTP context so no SSL options are needed as they are inherited */
static WebSocketContext *create(Loop *loop, us_socket_context_t *parentSocketContext) {
WebSocketContext *webSocketContext = (WebSocketContext *) us_create_child_socket_context(SSL, parentSocketContext, sizeof(WebSocketContextData<SSL>));
static WebSocketContext *create(Loop */*loop*/, us_socket_context_t *parentSocketContext) {
WebSocketContext *webSocketContext = (WebSocketContext *) us_create_child_socket_context(SSL, parentSocketContext, sizeof(WebSocketContextData<SSL, USERDATA>));
if (!webSocketContext) {
return nullptr;
}
/* Init socket context data */
new ((WebSocketContextData<SSL> *) us_socket_context_ext(SSL, (us_socket_context_t *)webSocketContext)) WebSocketContextData<SSL>;
new ((WebSocketContextData<SSL, USERDATA> *) us_socket_context_ext(SSL, (us_socket_context_t *)webSocketContext)) WebSocketContextData<SSL, USERDATA>;
return webSocketContext->init();
}
};

View File

@ -1,5 +1,5 @@
/*
* Authored by Alex Hultman, 2018-2019.
* Authored by Alex Hultman, 2018-2020.
* Intellectual property of third-party.
* Licensed under the Apache License, Version 2.0 (the "License");
@ -18,83 +18,279 @@
#ifndef UWS_WEBSOCKETCONTEXTDATA_H
#define UWS_WEBSOCKETCONTEXTDATA_H
#include "f2/function2.hpp"
#include "MoveOnlyFunction.h"
#include <string_view>
#include <vector>
#include "WebSocketProtocol.h"
#include "TopicTree.h"
#include "WebSocketData.h"
namespace uWS {
template <bool, bool> struct WebSocket;
template <bool, bool, typename> struct WebSocket;
/* todo: this looks identical to WebSocketBehavior, why not just std::move that entire thing in? */
template <bool SSL>
template <bool SSL, typename USERDATA>
struct WebSocketContextData {
private:
/* Used for prepending unframed messages when using dedicated compressors */
struct MessageMetadata {
unsigned int length;
OpCode opCode;
bool compress;
/* Undefined init of all members */
MessageMetadata() {}
MessageMetadata(unsigned int length, OpCode opCode, bool compress)
: length(length), opCode(opCode), compress(compress) {}
};
public:
/* All WebSocketContextData holds a list to all other WebSocketContextData in this app.
* We cannot type it USERDATA since different WebSocketContextData can have different USERDATA. */
std::vector<WebSocketContextData<SSL, int> *> adjacentWebSocketContextDatas;
/* The callbacks for this context */
fu2::unique_function<void(WebSocket<SSL, true> *, std::string_view, uWS::OpCode)> messageHandler = nullptr;
fu2::unique_function<void(WebSocket<SSL, true> *)> drainHandler = nullptr;
fu2::unique_function<void(WebSocket<SSL, true> *, int, std::string_view)> closeHandler = nullptr;
MoveOnlyFunction<void(WebSocket<SSL, true, USERDATA> *)> openHandler = nullptr;
MoveOnlyFunction<void(WebSocket<SSL, true, USERDATA> *, std::string_view, OpCode)> messageHandler = nullptr;
MoveOnlyFunction<void(WebSocket<SSL, true, USERDATA> *)> drainHandler = nullptr;
MoveOnlyFunction<void(WebSocket<SSL, true, USERDATA> *, int, std::string_view)> closeHandler = nullptr;
/* Todo: these should take message also; breaking change for v0.18 */
fu2::unique_function<void(WebSocket<SSL, true> *)> pingHandler = nullptr;
fu2::unique_function<void(WebSocket<SSL, true> *)> pongHandler = nullptr;
MoveOnlyFunction<void(WebSocket<SSL, true, USERDATA> *, std::string_view)> pingHandler = nullptr;
MoveOnlyFunction<void(WebSocket<SSL, true, USERDATA> *, std::string_view)> pongHandler = nullptr;
/* Settings for this context */
size_t maxPayloadLength = 0;
int idleTimeout = 0;
/* We do need these for async upgrade */
CompressOptions compression;
/* There needs to be a maxBackpressure which will force close everything over that limit */
size_t maxBackpressure = 0;
bool closeOnBackpressureLimit;
bool resetIdleTimeoutOnSend;
bool sendPingsAutomatically;
/* These are calculated on creation */
std::pair<unsigned short, unsigned short> idleTimeoutComponents;
/* Each websocket context has a topic tree for pub/sub */
TopicTree topicTree;
/* This is run once on start-up */
void calculateIdleTimeoutCompnents(unsigned short idleTimeout) {
unsigned short margin = 4;
/* 4, 8 or 16 seconds margin based on idleTimeout */
while ((int) idleTimeout - margin * 2 >= margin * 2 && margin < 16) {
margin = (unsigned short) (margin << 1);
}
/* We should have no margin if not using sendPingsAutomatically */
if (!sendPingsAutomatically) {
margin = 0;
}
idleTimeoutComponents = {
idleTimeout - margin,
margin
};
}
~WebSocketContextData() {
/* We must unregister any loop post handler here */
Loop::get()->removePostHandler(this);
Loop::get()->removePreHandler(this);
}
WebSocketContextData() : topicTree([this](Subscriber *s, std::string_view data) -> int {
WebSocketContextData() : topicTree([this](Subscriber *s, Intersection &intersection) -> int {
/* We could potentially be called here even if we have nothing to send, since we can
* be the sender of every single message in this intersection. Also "fin" of a segment is not
* guaranteed to be set, in case remaining segments are all from us.
* Essentially, we cannot make strict assumptions here. Also, we can even come here corked,
* since publish can call drain! */
/* We rely on writing to regular asyncSockets */
auto *asyncSocket = (AsyncSocket<SSL> *) s->user;
auto [written, failed] = asyncSocket->write(data.data(), (int) data.length());
if (!failed) {
asyncSocket->timeout(this->idleTimeout);
} else {
/* Note: this assumes we are not corked, as corking will swallow things and fail later on */
/* If we are corked, do not uncork - otherwise if we cork in here, uncork before leaving */
bool wasCorked = asyncSocket->isCorked();
/* Check if we now have too much backpressure (todo: don't buffer up before check) */
if ((unsigned int) asyncSocket->getBufferedAmount() > maxBackpressure) {
asyncSocket->close();
/* Do we even have room for potential data? */
if (!maxBackpressure || asyncSocket->getBufferedAmount() < maxBackpressure) {
/* Roll over all our segments */
intersection.forSubscriber(topicTree.getSenderFor(s), [asyncSocket, this](std::pair<std::string_view, std::string_view> data, bool fin) {
/* We have a segment that is not marked as last ("fin").
* Cork if not already so (purely for performance reasons). Does not touch "wasCorked". */
if (!fin && !asyncSocket->isCorked() && asyncSocket->canCork()) {
asyncSocket->cork();
}
/* Pick uncompressed data track */
std::string_view selectedData = data.first;
/* Are we using compression? Fine, pick the compressed data track */
WebSocketData *webSocketData = (WebSocketData *) asyncSocket->getAsyncSocketData();
if (webSocketData->compressionStatus != WebSocketData::CompressionStatus::DISABLED) {
/* This is used for both shared and dedicated paths */
selectedData = data.second;
/* However, dedicated compression has its own path */
if (compression != SHARED_COMPRESSOR) {
WebSocket<SSL, true, int> *ws = (WebSocket<SSL, true, int> *) asyncSocket;
/* For performance reasons we always cork when in dedicated mode.
* Is this really the best? We already kind of cork things in Zlib?
* Right, formatting needs a cork buffer, right. Never mind. */
if (!ws->isCorked() && ws->canCork()) {
asyncSocket->cork();
}
while (selectedData.length()) {
/* Interpret the data like so, because this is how we shoved it in */
MessageMetadata mm;
memcpy((char *) &mm, selectedData.data(), sizeof(MessageMetadata));
std::string_view unframedMessage(selectedData.data() + sizeof(MessageMetadata), mm.length);
/* Skip this message if our backpressure is too high */
if (maxBackpressure && ws->getBufferedAmount() > maxBackpressure) {
break;
}
/* Here we perform the actual compression and framing */
ws->send(unframedMessage, mm.opCode, mm.compress);
/* Advance until empty */
selectedData.remove_prefix(sizeof(MessageMetadata) + mm.length);
}
/* Continue to next segment without executing below path */
return;
}
}
/* Common path for SHARED and DISABLED. It is an invalid assumption that we always are
* uncorked here, however the following (invalid) assumption is not critically wrong either way */
/* Note: this assumes we are not corked, as corking will swallow things and fail later on */
auto [written, failed] = asyncSocket->write(selectedData.data(), (int) selectedData.length());
/* If we want strict check for success, we can ignore this check if corked and repeat below
* when uncorking - however this is too strict as we really care about PROGRESS rather than
* ENTIRE SUCCESS - we need minor API changes to support correct checks */
if (!failed) {
if (this->resetIdleTimeoutOnSend) {
auto *webSocketData = (WebSocketData *) us_socket_ext(SSL, (us_socket_t *) asyncSocket);
webSocketData->hasTimedOut = false;
asyncSocket->timeout(this->idleTimeoutComponents.first);
}
}
});
}
/* We are done sending, for whatever reasons we ended up corked while not starting with "wasCorked",
* we here need to uncork to restore the state we were called in */
if (!wasCorked && asyncSocket->isCorked()) {
/* Regarding timeout for writes; */
auto [written, failed] = asyncSocket->uncork();
/* Again, this check should be more like DID WE PROGRESS rather than DID WE SUCCEED ENTIRELY */
if (!failed) {
if (this->resetIdleTimeoutOnSend) {
auto *webSocketData = (WebSocketData *) us_socket_ext(SSL, (us_socket_t *) asyncSocket);
webSocketData->hasTimedOut = false;
asyncSocket->timeout(this->idleTimeoutComponents.first);
}
}
}
/* Defer a close if we now have (or already had) too much backpressure, or simply skip */
if (maxBackpressure && closeOnBackpressureLimit && asyncSocket->getBufferedAmount() > maxBackpressure) {
/* We must not immediately close the socket, as that could result in stack overflow,
* iterator invalidation and other TopicTree::drain bugs. We may shutdown the reading side of the socket,
* causing next iteration to error-close the socket from that context instead, if we want to */
us_socket_shutdown_read(SSL, (us_socket_t *) asyncSocket);
}
/* Reserved, unused */
return 0;
}) {
/* We empty for both pre and post just to make sure */
Loop::get()->addPostHandler(this, [this](Loop *loop) {
Loop::get()->addPostHandler(this, [this](Loop */*loop*/) {
/* Commit pub/sub batches every loop iteration */
topicTree.drain();
});
Loop::get()->addPreHandler(this, [this](Loop *loop) {
Loop::get()->addPreHandler(this, [this](Loop */*loop*/) {
/* Commit pub/sub batches every loop iteration */
topicTree.drain();
});
}
/* Helper for topictree publish, common path from app and ws */
void publish(std::string_view topic, std::string_view message, OpCode opCode, bool compress) {
bool publish(std::string_view topic, std::string_view message, OpCode opCode, bool compress, Subscriber *sender = nullptr) {
bool didMatch = false;
/* We frame the message right here and only pass raw bytes to the pub/subber */
char *dst = (char *) malloc(protocol::messageFrameSize(message.size()));
size_t dst_length = protocol::formatMessage<true>(dst, message.data(), message.length(), opCode, message.length(), false);
topicTree.publish(topic, std::string_view(dst, dst_length));
/* If compression is disabled */
if (compression == DISABLED) {
/* Leave second field empty as nobody will ever read it */
didMatch |= topicTree.publish(topic, {std::string_view(dst, dst_length), {}}, sender);
} else {
/* DEDICATED_COMPRESSOR always takes the same path as must always have MessageMetadata as head */
if (compress || compression != SHARED_COMPRESSOR) {
/* Shared compression mode publishes compressed, framed data */
if (compression == SHARED_COMPRESSOR) {
/* Loop data holds shared compressor */
LoopData *loopData = (LoopData *) us_loop_ext((us_loop_t *) Loop::get());
/* Compress it */
std::string_view compressedMessage = loopData->deflationStream->deflate(loopData->zlibContext, message, true);
/* Frame it */
char *dst_compressed = (char *) malloc(protocol::messageFrameSize(compressedMessage.size()));
size_t dst_compressed_length = protocol::formatMessage<true>(dst_compressed, compressedMessage.data(), compressedMessage.length(), opCode, compressedMessage.length(), true);
/* Always publish the shortest one in any case */
didMatch |= topicTree.publish(topic, {std::string_view(dst, dst_length), dst_compressed_length >= dst_length ? std::string_view(dst, dst_length) : std::string_view(dst_compressed, dst_compressed_length)}, sender);
/* We don't care for allocation here */
::free(dst_compressed);
} else {
/* Dedicated compression mode publishes metadata + unframed uncompressed data */
char *dst_compressed = (char *) malloc(message.length() + sizeof(MessageMetadata));
MessageMetadata mm(
(unsigned int) message.length(),
opCode,
compress
);
memcpy(dst_compressed, (char *) &mm, sizeof(MessageMetadata));
memcpy(dst_compressed + sizeof(MessageMetadata), message.data(), message.length());
/* Interpretation of compressed data depends on what compressor we use */
didMatch |= topicTree.publish(topic, {
std::string_view(dst, dst_length),
std::string_view(dst_compressed, message.length() + sizeof(MessageMetadata))
}, sender);
::free(dst_compressed);
}
} else {
/* If not compressing, put same message on both tracks (only valid for SHARED_COMPRESSOR).
* DEDICATED_COMPRESSOR_xKB must never end up here as we don't put a proper head here. */
didMatch |= topicTree.publish(topic, {std::string_view(dst, dst_length), std::string_view(dst, dst_length)}, sender);
}
}
::free(dst);
return didMatch;
}
};

View File

@ -1,5 +1,5 @@
/*
* Authored by Alex Hultman, 2018-2019.
* Authored by Alex Hultman, 2018-2020.
* Intellectual property of third-party.
* Licensed under the Apache License, Version 2.0 (the "License");
@ -21,18 +21,23 @@
#include "WebSocketProtocol.h"
#include "AsyncSocketData.h"
#include "PerMessageDeflate.h"
#include "TopicTree.h"
#include <string>
namespace uWS {
struct WebSocketData : AsyncSocketData<false>, WebSocketState<true> {
template <bool, bool> friend struct WebSocketContext;
template <bool, bool> friend struct WebSocket;
/* This guy has a lot of friends - why? */
template <bool, bool, typename> friend struct WebSocketContext;
template <bool, typename> friend struct WebSocketContextData;
template <bool, bool, typename> friend struct WebSocket;
template <bool> friend struct HttpContext;
private:
std::string fragmentBuffer;
int controlTipLength = 0;
unsigned int controlTipLength = 0;
bool isShuttingDown = 0;
bool hasTimedOut = false;
enum CompressionStatus : char {
DISABLED,
ENABLED,
@ -45,12 +50,12 @@ private:
/* We could be a subscriber */
Subscriber *subscriber = nullptr;
public:
WebSocketData(bool perMessageDeflate, bool slidingCompression, std::string &&backpressure) : AsyncSocketData<false>(std::move(backpressure)), WebSocketState<true>() {
WebSocketData(bool perMessageDeflate, CompressOptions compressOptions, std::string &&backpressure) : AsyncSocketData<false>(std::move(backpressure)), WebSocketState<true>() {
compressionStatus = perMessageDeflate ? ENABLED : DISABLED;
/* Initialize the dedicated sliding window */
if (perMessageDeflate && slidingCompression) {
deflationStream = new DeflationStream;
if (perMessageDeflate && (compressOptions != CompressOptions::SHARED_COMPRESSOR)) {
deflationStream = new DeflationStream(compressOptions);
}
}

View File

@ -1,5 +1,5 @@
/*
* Authored by Alex Hultman, 2018-2019.
* Authored by Alex Hultman, 2018-2021.
* Intellectual property of third-party.
* Licensed under the Apache License, Version 2.0 (the "License");
@ -18,26 +18,35 @@
#ifndef UWS_WEBSOCKETEXTENSIONS_H
#define UWS_WEBSOCKETEXTENSIONS_H
/* There is a new, huge bug scenario that needs to be fixed:
* pub/sub does not support being in DEDICATED_COMPRESSOR-mode while having
* some clients downgraded to SHARED_COMPRESSOR - we cannot allow the client to
* demand a downgrade to SHARED_COMPRESSOR (yet) until we fix that scenario in pub/sub */
// #define UWS_ALLOW_SHARED_AND_DEDICATED_COMPRESSOR_MIX
/* We forbid negotiating 8 windowBits since Zlib has a bug with this */
// #define UWS_ALLOW_8_WINDOW_BITS
#include <climits>
#include <cctype>
#include <string>
#include <string_view>
#include <tuple>
namespace uWS {
enum Options : unsigned int {
NO_OPTIONS = 0,
PERMESSAGE_DEFLATE = 1,
SERVER_NO_CONTEXT_TAKEOVER = 2, // remove this
CLIENT_NO_CONTEXT_TAKEOVER = 4, // remove this
NO_DELAY = 8,
SLIDING_DEFLATE_WINDOW = 16
};
enum ExtensionTokens {
/* Standard permessage-deflate tokens */
TOK_PERMESSAGE_DEFLATE = 1838,
TOK_SERVER_NO_CONTEXT_TAKEOVER = 2807,
TOK_CLIENT_NO_CONTEXT_TAKEOVER = 2783,
TOK_SERVER_MAX_WINDOW_BITS = 2372,
TOK_CLIENT_MAX_WINDOW_BITS = 2348
TOK_CLIENT_MAX_WINDOW_BITS = 2348,
/* Non-standard alias for Safari */
TOK_X_WEBKIT_DEFLATE_FRAME = 2149,
TOK_NO_CONTEXT_TAKEOVER = 2049,
TOK_MAX_WINDOW_BITS = 1614
};
struct ExtensionsParser {
@ -45,12 +54,18 @@ private:
int *lastInteger = nullptr;
public:
/* Standard */
bool perMessageDeflate = false;
bool serverNoContextTakeover = false;
bool clientNoContextTakeover = false;
int serverMaxWindowBits = 0;
int clientMaxWindowBits = 0;
/* Non-standard Safari */
bool xWebKitDeflateFrame = false;
bool noContextTakeover = false;
int maxWindowBits = 0;
int getToken(const char *&in, const char *stop) {
while (in != stop && !isalnum(*in)) {
in++;
@ -78,12 +93,28 @@ public:
ExtensionsParser(const char *data, size_t length) {
const char *stop = data + length;
int token = 1;
for (; token && token != TOK_PERMESSAGE_DEFLATE; token = getToken(data, stop));
/* Ignore anything before permessage-deflate or x-webkit-deflate-frame */
for (; token && token != TOK_PERMESSAGE_DEFLATE && token != TOK_X_WEBKIT_DEFLATE_FRAME; token = getToken(data, stop));
/* What protocol are we going to use? */
perMessageDeflate = (token == TOK_PERMESSAGE_DEFLATE);
xWebKitDeflateFrame = (token == TOK_X_WEBKIT_DEFLATE_FRAME);
while ((token = getToken(data, stop))) {
switch (token) {
case TOK_X_WEBKIT_DEFLATE_FRAME:
/* Duplicates not allowed/supported */
return;
case TOK_NO_CONTEXT_TAKEOVER:
noContextTakeover = true;
break;
case TOK_MAX_WINDOW_BITS:
maxWindowBits = 1;
lastInteger = &maxWindowBits;
break;
case TOK_PERMESSAGE_DEFLATE:
/* Duplicates not allowed/supported */
return;
case TOK_SERVER_NO_CONTEXT_TAKEOVER:
serverNoContextTakeover = true;
@ -109,60 +140,116 @@ public:
}
};
template <bool isServer>
struct ExtensionsNegotiator {
protected:
int options;
/* Takes what we (the server) wants, returns what we got */
static inline std::tuple<bool, int, int, std::string_view> negotiateCompression(bool wantCompression, int wantedCompressionWindow, int wantedInflationWindow, std::string_view offer) {
public:
ExtensionsNegotiator(int wantedOptions) {
options = wantedOptions;
/* If we don't want compression then we are done here */
if (!wantCompression) {
return {false, 0, 0, ""};
}
std::string generateOffer() {
std::string extensionsOffer;
if (options & Options::PERMESSAGE_DEFLATE) {
extensionsOffer += "permessage-deflate";
ExtensionsParser ep(offer.data(), offer.length());
if (options & Options::CLIENT_NO_CONTEXT_TAKEOVER) {
extensionsOffer += "; client_no_context_takeover";
static thread_local std::string response;
response = "";
int compressionWindow = wantedCompressionWindow;
int inflationWindow = wantedInflationWindow;
bool compression = false;
if (ep.xWebKitDeflateFrame) {
/* We now have compression */
compression = true;
response = "x-webkit-deflate-frame";
/* If the other peer has DEMANDED us no sliding window,
* we cannot compress with anything other than shared compressor */
if (ep.noContextTakeover) {
/* We must fail here right now (fix pub/sub) */
#ifndef UWS_ALLOW_SHARED_AND_DEDICATED_COMPRESSOR_MIX
if (wantedCompressionWindow != 0) {
return {false, 0, 0, ""};
}
#endif
/* It is questionable sending this improves anything */
/*if (options & Options::SERVER_NO_CONTEXT_TAKEOVER) {
extensionsOffer += "; server_no_context_takeover";
}*/
compressionWindow = 0;
}
return extensionsOffer;
}
/* If the other peer has DEMANDED us to use a limited sliding window,
* we have to limit out compression sliding window */
if (ep.maxWindowBits && ep.maxWindowBits < compressionWindow) {
compressionWindow = ep.maxWindowBits;
#ifndef UWS_ALLOW_8_WINDOW_BITS
/* We cannot really deny this, so we have to disable compression in this case */
if (compressionWindow == 8) {
return {false, 0, 0, ""};
}
#endif
}
void readOffer(std::string_view offer) {
if (isServer) {
ExtensionsParser extensionsParser(offer.data(), offer.length());
if ((options & PERMESSAGE_DEFLATE) && extensionsParser.perMessageDeflate) {
if (extensionsParser.clientNoContextTakeover || (options & CLIENT_NO_CONTEXT_TAKEOVER)) {
options |= CLIENT_NO_CONTEXT_TAKEOVER;
}
/* We leave this option for us to read even if the client did not send it */
if (extensionsParser.serverNoContextTakeover) {
options |= SERVER_NO_CONTEXT_TAKEOVER;
}/* else {
options &= ~SERVER_NO_CONTEXT_TAKEOVER;
}*/
/* We decide our own inflation sliding window (and their compression sliding window) */
if (wantedInflationWindow < 15) {
if (!wantedInflationWindow) {
response += "; no_context_takeover";
} else {
options &= ~PERMESSAGE_DEFLATE;
response += "; max_window_bits=" + std::to_string(wantedInflationWindow);
}
}
} else if (ep.perMessageDeflate) {
/* We now have compression */
compression = true;
response = "permessage-deflate";
if (ep.clientNoContextTakeover) {
inflationWindow = 0;
} else if (ep.clientMaxWindowBits && ep.clientMaxWindowBits != 1) {
inflationWindow = std::min<int>(ep.clientMaxWindowBits, inflationWindow);
}
/* Whatever we have now, write */
if (inflationWindow < 15) {
if (!inflationWindow || !ep.clientMaxWindowBits) {
response += "; client_no_context_takeover";
inflationWindow = 0;
} else {
response += "; client_max_window_bits=" + std::to_string(inflationWindow);
}
}
/* This block basically lets the client lower it */
if (ep.serverNoContextTakeover) {
/* This is an important (temporary) fix since we haven't allowed
* these two modes to mix, and pub/sub will not handle this case (yet) */
#ifdef UWS_ALLOW_SHARED_AND_DEDICATED_COMPRESSOR_MIX
compressionWindow = 0;
#endif
} else if (ep.serverMaxWindowBits) {
compressionWindow = std::min<int>(ep.serverMaxWindowBits, compressionWindow);
#ifndef UWS_ALLOW_8_WINDOW_BITS
/* Zlib cannot do windowBits=8, memLevel=1 so we raise it up to 9 minimum */
if (compressionWindow == 8) {
compressionWindow = 9;
}
#endif
}
/* Whatever we have now, write */
if (compressionWindow < 15) {
if (!compressionWindow) {
response += "; server_no_context_takeover";
} else {
response += "; server_max_window_bits=" + std::to_string(compressionWindow);
}
} else {
// todo!
}
}
int getNegotiatedOptions() {
return options;
/* A final sanity check (this check does not actually catch too high values!) */
if ((compressionWindow && compressionWindow < 8) || compressionWindow > 15 || (inflationWindow && inflationWindow < 8) || inflationWindow > 15) {
return {false, 0, 0, ""};
}
};
return {compression, compressionWindow, inflationWindow, response};
}
}

View File

@ -1,5 +1,5 @@
/*
* Authored by Alex Hultman, 2018-2019.
* Authored by Alex Hultman, 2018-2020.
* Intellectual property of third-party.
* Licensed under the Apache License, Version 2.0 (the "License");
@ -34,57 +34,68 @@ struct WebSocketHandshake {
template <typename T>
struct static_for<0, T> {
void operator()(uint32_t *a, uint32_t *hash) {}
void operator()(uint32_t */*a*/, uint32_t */*hash*/) {}
};
template <int state>
struct Sha1Loop {
static inline uint32_t rol(uint32_t value, size_t bits) {return (value << bits) | (value >> (32 - bits));}
static inline uint32_t blk(uint32_t b[16], size_t i) {
return rol(b[(i + 13) & 15] ^ b[(i + 8) & 15] ^ b[(i + 2) & 15] ^ b[i], 1);
}
static inline uint32_t rol(uint32_t value, size_t bits) {return (value << bits) | (value >> (32 - bits));}
static inline uint32_t blk(uint32_t b[16], size_t i) {
return rol(b[(i + 13) & 15] ^ b[(i + 8) & 15] ^ b[(i + 2) & 15] ^ b[i], 1);
}
struct Sha1Loop1 {
template <int i>
static inline void f(uint32_t *a, uint32_t *b) {
switch (state) {
case 1:
a[i % 5] += ((a[(3 + i) % 5] & (a[(2 + i) % 5] ^ a[(1 + i) % 5])) ^ a[(1 + i) % 5]) + b[i] + 0x5a827999 + rol(a[(4 + i) % 5], 5);
a[(3 + i) % 5] = rol(a[(3 + i) % 5], 30);
break;
case 2:
b[i] = blk(b, i);
a[(1 + i) % 5] += ((a[(4 + i) % 5] & (a[(3 + i) % 5] ^ a[(2 + i) % 5])) ^ a[(2 + i) % 5]) + b[i] + 0x5a827999 + rol(a[(5 + i) % 5], 5);
a[(4 + i) % 5] = rol(a[(4 + i) % 5], 30);
break;
case 3:
b[(i + 4) % 16] = blk(b, (i + 4) % 16);
a[i % 5] += (a[(3 + i) % 5] ^ a[(2 + i) % 5] ^ a[(1 + i) % 5]) + b[(i + 4) % 16] + 0x6ed9eba1 + rol(a[(4 + i) % 5], 5);
a[(3 + i) % 5] = rol(a[(3 + i) % 5], 30);
break;
case 4:
b[(i + 8) % 16] = blk(b, (i + 8) % 16);
a[i % 5] += (((a[(3 + i) % 5] | a[(2 + i) % 5]) & a[(1 + i) % 5]) | (a[(3 + i) % 5] & a[(2 + i) % 5])) + b[(i + 8) % 16] + 0x8f1bbcdc + rol(a[(4 + i) % 5], 5);
a[(3 + i) % 5] = rol(a[(3 + i) % 5], 30);
break;
case 5:
b[(i + 12) % 16] = blk(b, (i + 12) % 16);
a[i % 5] += (a[(3 + i) % 5] ^ a[(2 + i) % 5] ^ a[(1 + i) % 5]) + b[(i + 12) % 16] + 0xca62c1d6 + rol(a[(4 + i) % 5], 5);
a[(3 + i) % 5] = rol(a[(3 + i) % 5], 30);
break;
case 6:
b[i] += a[4 - i];
}
a[i % 5] += ((a[(3 + i) % 5] & (a[(2 + i) % 5] ^ a[(1 + i) % 5])) ^ a[(1 + i) % 5]) + b[i] + 0x5a827999 + rol(a[(4 + i) % 5], 5);
a[(3 + i) % 5] = rol(a[(3 + i) % 5], 30);
}
};
struct Sha1Loop2 {
template <int i>
static inline void f(uint32_t *a, uint32_t *b) {
b[i] = blk(b, i);
a[(1 + i) % 5] += ((a[(4 + i) % 5] & (a[(3 + i) % 5] ^ a[(2 + i) % 5])) ^ a[(2 + i) % 5]) + b[i] + 0x5a827999 + rol(a[(5 + i) % 5], 5);
a[(4 + i) % 5] = rol(a[(4 + i) % 5], 30);
}
};
struct Sha1Loop3 {
template <int i>
static inline void f(uint32_t *a, uint32_t *b) {
b[(i + 4) % 16] = blk(b, (i + 4) % 16);
a[i % 5] += (a[(3 + i) % 5] ^ a[(2 + i) % 5] ^ a[(1 + i) % 5]) + b[(i + 4) % 16] + 0x6ed9eba1 + rol(a[(4 + i) % 5], 5);
a[(3 + i) % 5] = rol(a[(3 + i) % 5], 30);
}
};
struct Sha1Loop4 {
template <int i>
static inline void f(uint32_t *a, uint32_t *b) {
b[(i + 8) % 16] = blk(b, (i + 8) % 16);
a[i % 5] += (((a[(3 + i) % 5] | a[(2 + i) % 5]) & a[(1 + i) % 5]) | (a[(3 + i) % 5] & a[(2 + i) % 5])) + b[(i + 8) % 16] + 0x8f1bbcdc + rol(a[(4 + i) % 5], 5);
a[(3 + i) % 5] = rol(a[(3 + i) % 5], 30);
}
};
struct Sha1Loop5 {
template <int i>
static inline void f(uint32_t *a, uint32_t *b) {
b[(i + 12) % 16] = blk(b, (i + 12) % 16);
a[i % 5] += (a[(3 + i) % 5] ^ a[(2 + i) % 5] ^ a[(1 + i) % 5]) + b[(i + 12) % 16] + 0xca62c1d6 + rol(a[(4 + i) % 5], 5);
a[(3 + i) % 5] = rol(a[(3 + i) % 5], 30);
}
};
struct Sha1Loop6 {
template <int i>
static inline void f(uint32_t *a, uint32_t *b) {
b[i] += a[4 - i];
}
};
static inline void sha1(uint32_t hash[5], uint32_t b[16]) {
uint32_t a[5] = {hash[4], hash[3], hash[2], hash[1], hash[0]};
static_for<16, Sha1Loop<1>>()(a, b);
static_for<4, Sha1Loop<2>>()(a, b);
static_for<20, Sha1Loop<3>>()(a, b);
static_for<20, Sha1Loop<4>>()(a, b);
static_for<20, Sha1Loop<5>>()(a, b);
static_for<5, Sha1Loop<6>>()(a, hash);
static_for<16, Sha1Loop1>()(a, b);
static_for<4, Sha1Loop2>()(a, b);
static_for<20, Sha1Loop3>()(a, b);
static_for<20, Sha1Loop4>()(a, b);
static_for<20, Sha1Loop5>()(a, b);
static_for<5, Sha1Loop6>()(a, hash);
}
static inline void base64(unsigned char *src, char *dst) {
@ -112,7 +123,7 @@ public:
};
for (int i = 0; i < 6; i++) {
b_input[i] = (input[4 * i + 3] & 0xff) | (input[4 * i + 2] & 0xff) << 8 | (input[4 * i + 1] & 0xff) << 16 | (input[4 * i + 0] & 0xff) << 24;
b_input[i] = (uint32_t) ((input[4 * i + 3] & 0xff) | (input[4 * i + 2] & 0xff) << 8 | (input[4 * i + 1] & 0xff) << 16 | (input[4 * i + 0] & 0xff) << 24);
}
sha1(b_output, b_input);
uint32_t last_b[16] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 480};

View File

@ -1,5 +1,5 @@
/*
* Authored by Alex Hultman, 2018-2019.
* Authored by Alex Hultman, 2018-2020.
* Intellectual property of third-party.
* Licensed under the Apache License, Version 2.0 (the "License");
@ -21,9 +21,17 @@
#include <cstdint>
#include <cstring>
#include <cstdlib>
#include <string_view>
namespace uWS {
/* We should not overcomplicate these */
const std::string_view ERR_TOO_BIG_MESSAGE("Received too big message");
const std::string_view ERR_WEBSOCKET_TIMEOUT("WebSocket timed out from inactivity");
const std::string_view ERR_INVALID_TEXT("Received invalid UTF-8");
const std::string_view ERR_TOO_BIG_MESSAGE_INFLATION("Received too big message, or other inflation error");
const std::string_view ERR_INVALID_CLOSE_PAYLOAD("Received invalid close payload");
enum OpCode : unsigned char {
TEXT = 1,
BINARY = 2,
@ -49,7 +57,7 @@ public:
struct State {
unsigned int wantsHead : 1;
unsigned int spillLength : 4;
int opStack : 2; // -1, 0, 1
signed int opStack : 2; // -1, 0, 1
unsigned int lastFin : 1;
// 15 bytes
@ -151,20 +159,23 @@ struct CloseFrame {
};
static inline CloseFrame parseClosePayload(char *src, size_t length) {
CloseFrame cf = {};
/* If we get no code or message, default to reporting 1005 no status code present */
CloseFrame cf = {1005, nullptr, 0};
if (length >= 2) {
memcpy(&cf.code, src, 2);
cf = {cond_byte_swap<uint16_t>(cf.code), src + 2, length - 2};
if (cf.code < 1000 || cf.code > 4999 || (cf.code > 1011 && cf.code < 4000) ||
(cf.code >= 1004 && cf.code <= 1006) || !isValidUtf8((unsigned char *) cf.message, cf.length)) {
return {};
/* Even though we got a WebSocket close frame, it in itself is abnormal */
return {1006, nullptr, 0};
}
}
return cf;
}
static inline size_t formatClosePayload(char *dst, uint16_t code, const char *message, size_t length) {
if (code) {
/* We could have more strict checks here, but never append code 0 or 1005 or 1006 */
if (code && code != 1005 && code != 1006) {
code = cond_byte_swap<uint16_t>(code);
memcpy(dst, &code, 2);
/* It is invalid to pass nullptr to memcpy, even though length is 0 */
@ -219,7 +230,7 @@ static inline size_t formatMessage(char *dst, const char *src, size_t length, Op
char mask[4];
if (!isServer) {
dst[1] |= 0x80;
uint32_t random = rand();
uint32_t random = (uint32_t) rand();
memcpy(mask, &random, 4);
memcpy(dst + headerLength, &random, 4);
headerLength += 4;
@ -307,7 +318,7 @@ protected:
wState->state.lastFin = isFin(src);
if (Impl::refusePayloadLength(payLength, wState, user)) {
Impl::forceClose(wState, user);
Impl::forceClose(wState, user, ERR_TOO_BIG_MESSAGE);
return true;
}
@ -351,9 +362,9 @@ protected:
static inline bool consumeContinuation(char *&src, unsigned int &length, WebSocketState<isServer> *wState, void *user) {
if (wState->remainingBytes <= length) {
if (isServer) {
int n = wState->remainingBytes >> 2;
unsigned int n = wState->remainingBytes >> 2;
unmaskInplace(src, src + n * 4, wState->mask);
for (int i = 0, s = wState->remainingBytes % 4; i < s; i++) {
for (unsigned int i = 0, s = wState->remainingBytes % 4; i < s; i++) {
src[n * 4 + i] ^= wState->mask[i];
}
}

View File

@ -1,23 +0,0 @@
Boost Software License - Version 1.0 - August 17th, 2003
Permission is hereby granted, free of charge, to any person or organization
obtaining a copy of the software and accompanying documentation covered by
this license (the "Software") to use, reproduce, display, distribute,
execute, and transmit the Software, and to prepare derivative works of the
Software, and to permit third-parties to whom the Software is furnished to
do so, all subject to the following:
The copyright notices in the Software and this entire statement, including
the above license grant, this restriction and the following disclaimer,
must be included in all copies of the Software, in whole or in part, and
all derivative works of the Software, unless such copies or derivative
works are solely in the form of machine-executable object code generated by
a source language processor.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT
SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE
FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE,
ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.

File diff suppressed because it is too large Load Diff

308
src/Vendor/uwebsockets/src/bsd.c vendored 100644
View File

@ -0,0 +1,308 @@
/*
* Authored by Alex Hultman, 2018-2021.
* Intellectual property of third-party.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/* Todo: this file should lie in networking/bsd.c */
#include "libusockets.h"
#include "internal/internal.h"
#include <stdio.h>
#ifndef _WIN32
//#define _GNU_SOURCE
#include <sys/types.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <netinet/tcp.h>
#include <netdb.h>
#include <string.h>
#include <unistd.h>
#include <fcntl.h>
#include <errno.h>
#endif
LIBUS_SOCKET_DESCRIPTOR apple_no_sigpipe(LIBUS_SOCKET_DESCRIPTOR fd) {
#ifdef __APPLE__
if (fd != LIBUS_SOCKET_ERROR) {
int no_sigpipe = 1;
setsockopt(fd, SOL_SOCKET, SO_NOSIGPIPE, &no_sigpipe, sizeof(int));
}
#endif
return fd;
}
LIBUS_SOCKET_DESCRIPTOR bsd_set_nonblocking(LIBUS_SOCKET_DESCRIPTOR fd) {
#ifdef _WIN32
/* Libuv will set windows sockets as non-blocking */
#else
fcntl(fd, F_SETFL, fcntl(fd, F_GETFL, 0) | O_NONBLOCK);
#endif
return fd;
}
void bsd_socket_nodelay(LIBUS_SOCKET_DESCRIPTOR fd, int enabled) {
setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, (void *) &enabled, sizeof(enabled));
}
void bsd_socket_flush(LIBUS_SOCKET_DESCRIPTOR fd) {
// Linux TCP_CORK has the same underlying corking mechanism as with MSG_MORE
#ifdef TCP_CORK
int enabled = 0;
setsockopt(fd, IPPROTO_TCP, TCP_CORK, &enabled, sizeof(int));
#endif
}
LIBUS_SOCKET_DESCRIPTOR bsd_create_socket(int domain, int type, int protocol) {
// returns INVALID_SOCKET on error
int flags = 0;
#if defined(SOCK_CLOEXEC) && defined(SOCK_NONBLOCK)
flags = SOCK_CLOEXEC | SOCK_NONBLOCK;
#endif
LIBUS_SOCKET_DESCRIPTOR created_fd = socket(domain, type | flags, protocol);
return bsd_set_nonblocking(apple_no_sigpipe(created_fd));
}
void bsd_close_socket(LIBUS_SOCKET_DESCRIPTOR fd) {
#ifdef _WIN32
closesocket(fd);
#else
close(fd);
#endif
}
void bsd_shutdown_socket(LIBUS_SOCKET_DESCRIPTOR fd) {
#ifdef _WIN32
shutdown(fd, SD_SEND);
#else
shutdown(fd, SHUT_WR);
#endif
}
void bsd_shutdown_socket_read(LIBUS_SOCKET_DESCRIPTOR fd) {
#ifdef _WIN32
shutdown(fd, SD_RECEIVE);
#else
shutdown(fd, SHUT_RD);
#endif
}
void internal_finalize_bsd_addr(struct bsd_addr_t *addr) {
// parse, so to speak, the address
if (addr->mem.ss_family == AF_INET6) {
addr->ip = (char *) &((struct sockaddr_in6 *) addr)->sin6_addr;
addr->ip_length = sizeof(struct in6_addr);
addr->port = ntohs(((struct sockaddr_in6 *) addr)->sin6_port);
} else if (addr->mem.ss_family == AF_INET) {
addr->ip = (char *) &((struct sockaddr_in *) addr)->sin_addr;
addr->ip_length = sizeof(struct in_addr);
addr->port = ntohs(((struct sockaddr_in *) addr)->sin_port);
} else {
addr->ip_length = 0;
addr->port = -1;
}
}
int bsd_local_addr(LIBUS_SOCKET_DESCRIPTOR fd, struct bsd_addr_t *addr) {
addr->len = sizeof(addr->mem);
if (getsockname(fd, (struct sockaddr *) &addr->mem, &addr->len)) {
return -1;
}
internal_finalize_bsd_addr(addr);
return 0;
}
int bsd_remote_addr(LIBUS_SOCKET_DESCRIPTOR fd, struct bsd_addr_t *addr) {
addr->len = sizeof(addr->mem);
if (getpeername(fd, (struct sockaddr *) &addr->mem, &addr->len)) {
return -1;
}
internal_finalize_bsd_addr(addr);
return 0;
}
char *bsd_addr_get_ip(struct bsd_addr_t *addr) {
return addr->ip;
}
int bsd_addr_get_ip_length(struct bsd_addr_t *addr) {
return addr->ip_length;
}
int bsd_addr_get_port(struct bsd_addr_t *addr) {
return addr->port;
}
// called by dispatch_ready_poll
LIBUS_SOCKET_DESCRIPTOR bsd_accept_socket(LIBUS_SOCKET_DESCRIPTOR fd, struct bsd_addr_t *addr) {
LIBUS_SOCKET_DESCRIPTOR accepted_fd;
addr->len = sizeof(addr->mem);
#if defined(SOCK_CLOEXEC) && defined(SOCK_NONBLOCK)
// Linux, FreeBSD
accepted_fd = accept4(fd, (struct sockaddr *) addr, &addr->len, SOCK_CLOEXEC | SOCK_NONBLOCK);
#else
// Windows, OS X
accepted_fd = accept(fd, (struct sockaddr *) addr, &addr->len);
#endif
/* We cannot rely on addr since it is not initialized if failed */
if (accepted_fd == LIBUS_SOCKET_ERROR) {
return LIBUS_SOCKET_ERROR;
}
internal_finalize_bsd_addr(addr);
return bsd_set_nonblocking(apple_no_sigpipe(accepted_fd));
}
int bsd_recv(LIBUS_SOCKET_DESCRIPTOR fd, void *buf, int length, int flags) {
return recv(fd, buf, length, flags);
}
int bsd_send(LIBUS_SOCKET_DESCRIPTOR fd, const char *buf, int length, int msg_more) {
// MSG_MORE (Linux), MSG_PARTIAL (Windows), TCP_NOPUSH (BSD)
#ifndef MSG_NOSIGNAL
#define MSG_NOSIGNAL 0
#endif
#ifdef MSG_MORE
// for Linux we do not want signals
return send(fd, buf, length, (msg_more * MSG_MORE) | MSG_NOSIGNAL);
#else
// use TCP_NOPUSH
return send(fd, buf, length, MSG_NOSIGNAL);
#endif
}
int bsd_would_block() {
#ifdef _WIN32
return WSAGetLastError() == WSAEWOULDBLOCK;
#else
return errno == EWOULDBLOCK;// || errno == EAGAIN;
#endif
}
// return LIBUS_SOCKET_ERROR or the fd that represents listen socket
// listen both on ipv6 and ipv4
LIBUS_SOCKET_DESCRIPTOR bsd_create_listen_socket(const char *host, int port, int options) {
struct addrinfo hints, *result;
memset(&hints, 0, sizeof(struct addrinfo));
hints.ai_flags = AI_PASSIVE;
hints.ai_family = AF_UNSPEC;
hints.ai_socktype = SOCK_STREAM;
char port_string[16];
snprintf(port_string, 16, "%d", port);
if (getaddrinfo(host, port_string, &hints, &result)) {
return LIBUS_SOCKET_ERROR;
}
LIBUS_SOCKET_DESCRIPTOR listenFd = LIBUS_SOCKET_ERROR;
struct addrinfo *listenAddr;
for (struct addrinfo *a = result; a && listenFd == LIBUS_SOCKET_ERROR; a = a->ai_next) {
if (a->ai_family == AF_INET6) {
listenFd = bsd_create_socket(a->ai_family, a->ai_socktype, a->ai_protocol);
listenAddr = a;
}
}
for (struct addrinfo *a = result; a && listenFd == LIBUS_SOCKET_ERROR; a = a->ai_next) {
if (a->ai_family == AF_INET) {
listenFd = bsd_create_socket(a->ai_family, a->ai_socktype, a->ai_protocol);
listenAddr = a;
}
}
if (listenFd == LIBUS_SOCKET_ERROR) {
freeaddrinfo(result);
return LIBUS_SOCKET_ERROR;
}
if (port != 0) {
/* Otherwise, always enable SO_REUSEPORT and SO_REUSEADDR _unless_ options specify otherwise */
#if /*defined(__linux) &&*/ defined(SO_REUSEPORT)
if (!(options & LIBUS_LISTEN_EXCLUSIVE_PORT)) {
int optval = 1;
setsockopt(listenFd, SOL_SOCKET, SO_REUSEPORT, &optval, sizeof(optval));
}
#endif
int enabled = 1;
setsockopt(listenFd, SOL_SOCKET, SO_REUSEADDR, (SETSOCKOPT_PTR_TYPE) &enabled, sizeof(enabled));
}
#ifdef IPV6_V6ONLY
int disabled = 0;
setsockopt(listenFd, IPPROTO_IPV6, IPV6_V6ONLY, (SETSOCKOPT_PTR_TYPE) &disabled, sizeof(disabled));
#endif
if (bind(listenFd, listenAddr->ai_addr, (socklen_t) listenAddr->ai_addrlen) || listen(listenFd, 512)) {
bsd_close_socket(listenFd);
freeaddrinfo(result);
return LIBUS_SOCKET_ERROR;
}
freeaddrinfo(result);
return listenFd;
}
LIBUS_SOCKET_DESCRIPTOR bsd_create_connect_socket(const char *host, int port, const char *source_host, int options) {
struct addrinfo hints, *result;
memset(&hints, 0, sizeof(struct addrinfo));
hints.ai_family = AF_UNSPEC;
hints.ai_socktype = SOCK_STREAM;
char port_string[16];
snprintf(port_string, 16, "%d", port);
if (getaddrinfo(host, port_string, &hints, &result) != 0) {
return LIBUS_SOCKET_ERROR;
}
LIBUS_SOCKET_DESCRIPTOR fd = bsd_create_socket(result->ai_family, result->ai_socktype, result->ai_protocol);
if (fd == LIBUS_SOCKET_ERROR) {
freeaddrinfo(result);
return LIBUS_SOCKET_ERROR;
}
if (source_host) {
struct addrinfo *interface_result;
if (!getaddrinfo(source_host, NULL, NULL, &interface_result)) {
int ret = bind(fd, interface_result->ai_addr, (socklen_t) interface_result->ai_addrlen);
freeaddrinfo(interface_result);
if (ret == LIBUS_SOCKET_ERROR) {
return LIBUS_SOCKET_ERROR;
}
}
}
connect(fd, result->ai_addr, (socklen_t) result->ai_addrlen);
freeaddrinfo(result);
return fd;
}

View File

@ -26,6 +26,10 @@ int default_ignore_data_handler(struct us_socket_t *s) {
/* Shared with SSL */
unsigned short us_socket_context_timestamp(int ssl, struct us_socket_context_t *context) {
return context->timestamp;
}
void us_listen_socket_close(int ssl, struct us_listen_socket_t *ls) {
/* us_listen_socket_t extends us_socket_t so we close in similar ways */
if (!us_socket_is_closed(0, &ls->s)) {
@ -80,60 +84,92 @@ struct us_loop_t *us_socket_context_loop(int ssl, struct us_socket_context_t *co
return context->loop;
}
/* Returns the deep copy, to be freed */
const char *deep_str_copy(const char *src) {
if (!src) {
return src;
}
size_t len = strlen(src) + 1;
char *dst = malloc(len);
memcpy(dst, src, len);
return dst;
}
/* Not shared with SSL */
struct us_socket_context_t *us_create_socket_context(int ssl, struct us_loop_t *loop, int context_ext_size, struct us_socket_context_options_t options) {
/* For ease of use we copy all passed strings here */
options.ca_file_name = deep_str_copy(options.ca_file_name);
options.cert_file_name = deep_str_copy(options.cert_file_name);
options.dh_params_file_name = deep_str_copy(options.dh_params_file_name);
options.key_file_name = deep_str_copy(options.key_file_name);
options.passphrase = deep_str_copy(options.passphrase);
/* Add SNI context */
void us_socket_context_add_server_name(int ssl, struct us_socket_context_t *context, const char *hostname_pattern, struct us_socket_context_options_t options) {
#ifndef LIBUS_NO_SSL
if (ssl) {
us_internal_ssl_socket_context_add_server_name((struct us_internal_ssl_socket_context_t *) context, hostname_pattern, options);
}
#endif
}
/* Remove SNI context */
void us_socket_context_remove_server_name(int ssl, struct us_socket_context_t *context, const char *hostname_pattern) {
#ifndef LIBUS_NO_SSL
if (ssl) {
us_internal_ssl_socket_context_remove_server_name((struct us_internal_ssl_socket_context_t *) context, hostname_pattern);
}
#endif
}
/* I don't like this one - maybe rename it to on_missing_server_name? */
/* Called when SNI matching fails - not if a match could be made.
* You may modify the context by adding/removing names in this callback.
* If the correct name is added immediately in the callback, it will be used */
void us_socket_context_on_server_name(int ssl, struct us_socket_context_t *context, void (*cb)(struct us_socket_context_t *, const char *hostname)) {
#ifndef LIBUS_NO_SSL
if (ssl) {
us_internal_ssl_socket_context_on_server_name((struct us_internal_ssl_socket_context_t *) context, (void (*)(struct us_internal_ssl_socket_context_t *, const char *hostname)) cb);
}
#endif
}
/* Todo: get native context from SNI pattern */
void *us_socket_context_get_native_handle(int ssl, struct us_socket_context_t *context) {
#ifndef LIBUS_NO_SSL
if (ssl) {
return us_internal_ssl_socket_context_get_native_handle((struct us_internal_ssl_socket_context_t *) context);
}
#endif
/* There is no native handle for a non-SSL socket context */
return 0;
}
/* Options is currently only applicable for SSL - this will change with time (prefer_low_memory is one example) */
struct us_socket_context_t *us_create_socket_context(int ssl, struct us_loop_t *loop, int context_ext_size, struct us_socket_context_options_t options) {
#ifndef LIBUS_NO_SSL
if (ssl) {
/* This function will call us, again, with SSL = false and a bigger ext_size */
return (struct us_socket_context_t *) us_internal_create_ssl_socket_context(loop, context_ext_size, options);
}
#endif
/* This path is taken once either way - always BEFORE whatever SSL may do LATER.
* context_ext_size will however be modified larger in case of SSL, to hold SSL extensions */
struct us_socket_context_t *context = malloc(sizeof(struct us_socket_context_t) + context_ext_size);
context->loop = loop;
context->head = 0;
context->iterator = 0;
context->next = 0;
context->ignore_data = default_ignore_data_handler;
context->options = options;
/* Begin at 0 */
context->timestamp = 0;
us_internal_loop_link(loop, context);
/* If we are called from within SSL code, SSL code will make further changes to us */
return context;
}
void us_socket_context_free(int ssl, struct us_socket_context_t *context) {
/* We also simply free every copied string here */
free((void *) context->options.ca_file_name);
free((void *) context->options.cert_file_name);
free((void *) context->options.dh_params_file_name);
free((void *) context->options.key_file_name);
free((void *) context->options.passphrase);
#ifndef LIBUS_NO_SSL
if (ssl) {
/* This function will call us again with SSL=false */
us_internal_ssl_socket_context_free((struct us_internal_ssl_socket_context_t *) context);
return;
}
#endif
/* This path is taken once either way - always AFTER whatever SSL may do BEFORE.
* This is the opposite order compared to when creating the context - SSL code is cleaning up before non-SSL */
us_internal_loop_unlink(context->loop, context);
free(context);
}
@ -167,14 +203,14 @@ struct us_listen_socket_t *us_socket_context_listen(int ssl, struct us_socket_co
return ls;
}
struct us_socket_t *us_socket_context_connect(int ssl, struct us_socket_context_t *context, const char *host, int port, int options, int socket_ext_size) {
struct us_socket_t *us_socket_context_connect(int ssl, struct us_socket_context_t *context, const char *host, int port, const char *source_host, int options, int socket_ext_size) {
#ifndef LIBUS_NO_SSL
if (ssl) {
return (struct us_socket_t *) us_internal_ssl_socket_context_connect((struct us_internal_ssl_socket_context_t *) context, host, port, options, socket_ext_size);
return (struct us_socket_t *) us_internal_ssl_socket_context_connect((struct us_internal_ssl_socket_context_t *) context, host, port, source_host, options, socket_ext_size);
}
#endif
LIBUS_SOCKET_DESCRIPTOR connect_socket_fd = bsd_create_connect_socket(host, port, options);
LIBUS_SOCKET_DESCRIPTOR connect_socket_fd = bsd_create_connect_socket(host, port, source_host, options);
if (connect_socket_fd == LIBUS_SOCKET_ERROR) {
return 0;
}
@ -239,10 +275,10 @@ void us_socket_context_on_open(int ssl, struct us_socket_context_t *context, str
context->on_open = on_open;
}
void us_socket_context_on_close(int ssl, struct us_socket_context_t *context, struct us_socket_t *(*on_close)(struct us_socket_t *s)) {
void us_socket_context_on_close(int ssl, struct us_socket_context_t *context, struct us_socket_t *(*on_close)(struct us_socket_t *s, int code, void *reason)) {
#ifndef LIBUS_NO_SSL
if (ssl) {
us_internal_ssl_socket_context_on_close((struct us_internal_ssl_socket_context_t *) context, (struct us_internal_ssl_socket_t * (*)(struct us_internal_ssl_socket_t *)) on_close);
us_internal_ssl_socket_context_on_close((struct us_internal_ssl_socket_context_t *) context, (struct us_internal_ssl_socket_t * (*)(struct us_internal_ssl_socket_t *, int code, void *reason)) on_close);
return;
}
#endif
@ -294,6 +330,17 @@ void us_socket_context_on_end(int ssl, struct us_socket_context_t *context, stru
context->on_end = on_end;
}
void us_socket_context_on_connect_error(int ssl, struct us_socket_context_t *context, struct us_socket_t *(*on_connect_error)(struct us_socket_t *s, int code)) {
#ifndef LIBUS_NO_SSL
if (ssl) {
us_internal_ssl_socket_context_on_connect_error((struct us_internal_ssl_socket_context_t *) context, (struct us_internal_ssl_socket_t * (*)(struct us_internal_ssl_socket_t *, int)) on_connect_error);
return;
}
#endif
context->on_connect_error = on_connect_error;
}
void *us_socket_context_ext(int ssl, struct us_socket_context_t *context) {
#ifndef LIBUS_NO_SSL
if (ssl) {

View File

@ -17,8 +17,16 @@
#ifdef LIBUS_USE_OPENSSL
/* These are in sni_tree.cpp */
void *sni_new();
void sni_free(void *sni, void(*cb)(void *));
int sni_add(void *sni, const char *hostname, void *user);
void *sni_remove(void *sni, const char *hostname);
void *sni_find(void *sni, const char *hostname);
#include "libusockets.h"
#include "internal/internal.h"
#include <string.h>
/* This module contains the entire OpenSSL implementation
* of the SSL socket and socket context interfaces. */
@ -59,10 +67,17 @@ struct us_internal_ssl_socket_context_t {
SSL_CTX *ssl_context;
int is_parent;
// här måste det vara!
/* These decorate the base implementation */
struct us_internal_ssl_socket_t *(*on_open)(struct us_internal_ssl_socket_t *, int is_client, char *ip, int ip_length);
struct us_internal_ssl_socket_t *(*on_data)(struct us_internal_ssl_socket_t *, char *data, int length);
struct us_internal_ssl_socket_t *(*on_close)(struct us_internal_ssl_socket_t *);
struct us_internal_ssl_socket_t *(*on_writable)(struct us_internal_ssl_socket_t *);
struct us_internal_ssl_socket_t *(*on_close)(struct us_internal_ssl_socket_t *, int code, void *reason);
/* Called for missing SNI hostnames, if not NULL */
void (*on_server_name)(struct us_internal_ssl_socket_context_t *, const char *hostname);
/* Pointer to sni tree, created when the context is created and freed likewise when freed */
void *sni;
};
// same here, should or shouldn't it contain s?
@ -70,6 +85,7 @@ struct us_internal_ssl_socket_t {
struct us_socket_t s;
SSL *ssl;
int ssl_write_wants_read; // we use this for now
int ssl_read_wants_write;
};
int passphrase_cb(char *buf, int size, int rwflag, void *u) {
@ -103,7 +119,7 @@ int BIO_s_custom_write(BIO *bio, const char *data, int length) {
int written = us_socket_write(0, loop_ssl_data->ssl_socket, data, length, loop_ssl_data->last_write_was_msg_more);
if (!written) {
BIO_set_flags(bio, BIO_get_flags(bio) | BIO_FLAGS_SHOULD_RETRY | BIO_FLAGS_WRITE);
BIO_set_flags(bio, BIO_FLAGS_SHOULD_RETRY | BIO_FLAGS_WRITE);
return -1;
}
@ -118,7 +134,7 @@ int BIO_s_custom_read(BIO *bio, char *dst, int length) {
//printf("BIO_s_custom_read\n");
if (!loop_ssl_data->ssl_read_input_length) {
BIO_set_flags(bio, BIO_get_flags(bio) | BIO_FLAGS_SHOULD_RETRY | BIO_FLAGS_READ);
BIO_set_flags(bio, BIO_FLAGS_SHOULD_RETRY | BIO_FLAGS_READ);
return -1;
}
@ -141,6 +157,7 @@ struct us_internal_ssl_socket_t *ssl_on_open(struct us_internal_ssl_socket_t *s,
s->ssl = SSL_new(context->ssl_context);
s->ssl_write_wants_read = 0;
s->ssl_read_wants_write = 0;
SSL_set_bio(s->ssl, loop_ssl_data->shared_rbio, loop_ssl_data->shared_wbio);
BIO_up_ref(loop_ssl_data->shared_rbio);
@ -155,19 +172,26 @@ struct us_internal_ssl_socket_t *ssl_on_open(struct us_internal_ssl_socket_t *s,
return (struct us_internal_ssl_socket_t *) context->on_open(s, is_client, ip, ip_length);
}
struct us_internal_ssl_socket_t *ssl_on_close(struct us_internal_ssl_socket_t *s) {
/* This one is a helper; it is entirely shared with non-SSL so can be removed */
struct us_internal_ssl_socket_t *us_internal_ssl_socket_close(struct us_internal_ssl_socket_t *s, int code, void *reason) {
return (struct us_internal_ssl_socket_t *) us_socket_close(0, (struct us_socket_t *) s, code, reason);
}
struct us_internal_ssl_socket_t *ssl_on_close(struct us_internal_ssl_socket_t *s, int code, void *reason) {
struct us_internal_ssl_socket_context_t *context = (struct us_internal_ssl_socket_context_t *) us_socket_context(0, &s->s);
SSL_free(s->ssl);
return context->on_close(s);
return context->on_close(s, code, reason);
}
struct us_internal_ssl_socket_t *ssl_on_end(struct us_internal_ssl_socket_t *s) {
struct us_internal_ssl_socket_context_t *context = (struct us_internal_ssl_socket_context_t *) us_socket_context(0, &s->s);
// struct us_internal_ssl_socket_context_t *context = (struct us_internal_ssl_socket_context_t *) us_socket_context(0, &s->s);
// whatever state we are in, a TCP FIN is always an answered shutdown
return us_internal_ssl_socket_close(s);
/* Todo: this should report CLEANLY SHUTDOWN as reason */
return us_internal_ssl_socket_close(s, 0, NULL);
}
// this whole function needs a complete clean-up
@ -192,7 +216,8 @@ struct us_internal_ssl_socket_t *ssl_on_data(struct us_internal_ssl_socket_t *s,
// two phase shutdown is complete here
//printf("Two step SSL shutdown complete\n");
return us_internal_ssl_socket_close(s);
/* Todo: this should also report some kind of clean shutdown */
return us_internal_ssl_socket_close(s, 0, NULL);
} else if (ret < 0) {
int err = SSL_get_error(s->ssl, ret);
@ -226,13 +251,18 @@ struct us_internal_ssl_socket_t *ssl_on_data(struct us_internal_ssl_socket_t *s,
}
// terminate connection here
return us_internal_ssl_socket_close(s);
return us_internal_ssl_socket_close(s, 0, NULL);
} else {
// emit the data we have and exit
if (err == SSL_ERROR_WANT_WRITE) {
// here we need to trigger writable event next ssl_read!
s->ssl_read_wants_write = 1;
}
// assume we emptied the input buffer fully or error here as well!
if (loop_ssl_data->ssl_read_input_length) {
return us_internal_ssl_socket_close(s);
return us_internal_ssl_socket_close(s, 0, NULL);
}
// cannot emit zero length to app
@ -240,6 +270,8 @@ struct us_internal_ssl_socket_t *ssl_on_data(struct us_internal_ssl_socket_t *s,
break;
}
context = (struct us_internal_ssl_socket_context_t *) us_socket_context(0, &s->s);
s = context->on_data(s, loop_ssl_data->ssl_read_output + LIBUS_RECV_BUFFER_PADDING, read);
if (us_socket_is_closed(0, &s->s)) {
return s;
@ -255,6 +287,8 @@ struct us_internal_ssl_socket_t *ssl_on_data(struct us_internal_ssl_socket_t *s,
// at this point we might be full and need to emit the data to application and start over
if (read == LIBUS_RECV_BUFFER_LENGTH) {
context = (struct us_internal_ssl_socket_context_t *) us_socket_context(0, &s->s);
// emit data and restart
s = context->on_data(s, loop_ssl_data->ssl_read_output + LIBUS_RECV_BUFFER_PADDING, read);
if (us_socket_is_closed(0, &s->s)) {
@ -287,7 +321,7 @@ struct us_internal_ssl_socket_t *ssl_on_data(struct us_internal_ssl_socket_t *s,
//exit(-2);
// not correct anyways!
s = us_internal_ssl_socket_close(s);
s = us_internal_ssl_socket_close(s, 0, NULL);
//us_
}
@ -295,6 +329,28 @@ struct us_internal_ssl_socket_t *ssl_on_data(struct us_internal_ssl_socket_t *s,
return s;
}
struct us_internal_ssl_socket_t *ssl_on_writable(struct us_internal_ssl_socket_t *s) {
struct us_internal_ssl_socket_context_t *context = (struct us_internal_ssl_socket_context_t *) us_socket_context(0, &s->s);
// todo: cork here so that we efficiently output both from reading and from writing?
if (s->ssl_read_wants_write) {
s->ssl_read_wants_write = 0;
// make sure to update context before we call (context can change if the user adopts the socket!)
context = (struct us_internal_ssl_socket_context_t *) us_socket_context(0, &s->s);
// if this one fails to write data, it sets ssl_read_wants_write again
s = (struct us_internal_ssl_socket_t *) context->sc.on_data(&s->s, 0, 0); // cast here!
}
// should this one come before we have read? should it come always? spurious on_writable is okay
s = context->on_writable(s);
return s;
}
/* Lazily inits loop ssl data first time */
void us_internal_init_loop_ssl_data(struct us_loop_t *loop) {
if (!loop->data.ssl_data) {
@ -371,60 +427,77 @@ int ssl_ignore_data(struct us_internal_ssl_socket_t *s) {
}
/* Per-context functions */
struct us_internal_ssl_socket_context_t *us_internal_create_child_ssl_socket_context(struct us_internal_ssl_socket_context_t *context, int context_ext_size) {
struct us_socket_context_options_t options = {0};
void *us_internal_ssl_socket_context_get_native_handle(struct us_internal_ssl_socket_context_t *context) {
return context->ssl_context;
}
struct us_internal_ssl_socket_context_t *us_internal_create_child_ssl_socket_context(struct us_internal_ssl_socket_context_t *context, int context_ext_size) {
/* Create a new non-SSL context */
struct us_socket_context_options_t options = {0};
struct us_internal_ssl_socket_context_t *child_context = (struct us_internal_ssl_socket_context_t *) us_create_socket_context(0, context->sc.loop, sizeof(struct us_internal_ssl_socket_context_t) + context_ext_size, options);
// I think this is the only thing being shared
/* The only thing we share is SSL_CTX */
child_context->ssl_context = context->ssl_context;
child_context->is_parent = 0;
return child_context;
}
struct us_internal_ssl_socket_context_t *us_internal_create_ssl_socket_context(struct us_loop_t *loop, int context_ext_size, struct us_socket_context_options_t options) {
/* Common function for creating a context from options.
* We must NOT free a SSL_CTX with only SSL_CTX_free! Also free any password */
void free_ssl_context(SSL_CTX *ssl_context) {
if (!ssl_context) {
return;
}
us_internal_init_loop_ssl_data(loop);
/* If we have set a password string, free it here */
void *password = SSL_CTX_get_default_passwd_cb_userdata(ssl_context);
/* OpenSSL returns NULL if we have no set password */
free(password);
struct us_socket_context_options_t no_options = {0};
SSL_CTX_free(ssl_context);
}
struct us_internal_ssl_socket_context_t *context = (struct us_internal_ssl_socket_context_t *) us_create_socket_context(0, loop, sizeof(struct us_internal_ssl_socket_context_t) + context_ext_size, no_options);
/* This function should take any options and return SSL_CTX - which has to be free'd with
* our destructor function - free_ssl_context() */
SSL_CTX *create_ssl_context_from_options(struct us_socket_context_options_t options) {
/* Create the context */
SSL_CTX *ssl_context = SSL_CTX_new(TLS_method());
context->ssl_context = SSL_CTX_new(TLS_method());
context->is_parent = 1;
// only parent ssl contexts may need to ignore data
context->sc.ignore_data = (int (*)(struct us_socket_t *)) ssl_ignore_data;
/* Default options we rely on - changing these will break our logic */
SSL_CTX_set_read_ahead(ssl_context, 1);
SSL_CTX_set_mode(ssl_context, SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER);
// options
SSL_CTX_set_read_ahead(context->ssl_context, 1);
SSL_CTX_set_mode(context->ssl_context, SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER);
//SSL_CTX_set_mode(context->ssl_context, SSL_MODE_ENABLE_PARTIAL_WRITE);
/* Anything below TLS 1.2 is disabled */
SSL_CTX_set_min_proto_version(ssl_context, TLS1_2_VERSION);
// this lowers performance a bit in benchmarks
/* The following are helpers. You may easily implement whatever you want by using the native handle directly */
/* Important option for lowering memory usage, but lowers performance slightly */
if (options.ssl_prefer_low_memory_usage) {
SSL_CTX_set_mode(context->ssl_context, SSL_MODE_RELEASE_BUFFERS);
SSL_CTX_set_mode(ssl_context, SSL_MODE_RELEASE_BUFFERS);
}
//SSL_CTX_set_mode(context->ssl_context, SSL_MODE_RELEASE_BUFFERS);
SSL_CTX_set_options(context->ssl_context, SSL_OP_NO_SSLv3);
SSL_CTX_set_options(context->ssl_context, SSL_OP_NO_TLSv1);
// these are going to be extended
if (options.passphrase) {
SSL_CTX_set_default_passwd_cb_userdata(context->ssl_context, (void *) options.passphrase);
SSL_CTX_set_default_passwd_cb(context->ssl_context, passphrase_cb);
/* When freeing the CTX we need to check SSL_CTX_get_default_passwd_cb_userdata and
* free it if set */
SSL_CTX_set_default_passwd_cb_userdata(ssl_context, (void *) strdup(options.passphrase));
SSL_CTX_set_default_passwd_cb(ssl_context, passphrase_cb);
}
/* This one most probably do not need the cert_file_name string to be kept alive */
if (options.cert_file_name) {
if (SSL_CTX_use_certificate_chain_file(context->ssl_context, options.cert_file_name) != 1) {
return 0;
if (SSL_CTX_use_certificate_chain_file(ssl_context, options.cert_file_name) != 1) {
free_ssl_context(ssl_context);
return NULL;
}
}
/* Same as above - we can discard this string afterwards I suppose */
if (options.key_file_name) {
if (SSL_CTX_use_PrivateKey_file(context->ssl_context, options.key_file_name, SSL_FILETYPE_PEM) != 1) {
return 0;
if (SSL_CTX_use_PrivateKey_file(ssl_context, options.key_file_name, SSL_FILETYPE_PEM) != 1) {
free_ssl_context(ssl_context);
return NULL;
}
}
@ -432,13 +505,15 @@ struct us_internal_ssl_socket_context_t *us_internal_create_ssl_socket_context(s
STACK_OF(X509_NAME) *ca_list;
ca_list = SSL_load_client_CA_file(options.ca_file_name);
if(ca_list == NULL) {
return 0;
free_ssl_context(ssl_context);
return NULL;
}
SSL_CTX_set_client_CA_list(context->ssl_context, ca_list);
if (SSL_CTX_load_verify_locations(context->ssl_context, options.ca_file_name, NULL) != 1) {
return 0;
SSL_CTX_set_client_CA_list(ssl_context, ca_list);
if (SSL_CTX_load_verify_locations(ssl_context, options.ca_file_name, NULL) != 1) {
free_ssl_context(ssl_context);
return NULL;
}
SSL_CTX_set_verify(context->ssl_context, SSL_VERIFY_PEER, NULL);
SSL_CTX_set_verify(ssl_context, SSL_VERIFY_PEER, NULL);
}
if (options.dh_params_file_name) {
@ -451,29 +526,155 @@ struct us_internal_ssl_socket_context_t *us_internal_create_ssl_socket_context(s
dh_2048 = PEM_read_DHparams(paramfile, NULL, NULL, NULL);
fclose(paramfile);
} else {
return 0;
free_ssl_context(ssl_context);
return NULL;
}
if (dh_2048 == NULL) {
return 0;
free_ssl_context(ssl_context);
return NULL;
}
if (SSL_CTX_set_tmp_dh(context->ssl_context, dh_2048) != 1) {
return 0;
const long set_tmp_dh = SSL_CTX_set_tmp_dh(ssl_context, dh_2048);
DH_free(dh_2048);
if (set_tmp_dh != 1) {
free_ssl_context(ssl_context);
return NULL;
}
/* OWASP Cipher String 'A+' (https://www.owasp.org/index.php/TLS_Cipher_String_Cheat_Sheet) */
if (SSL_CTX_set_cipher_list(context->ssl_context, "DHE-RSA-AES256-GCM-SHA384:DHE-RSA-AES128-GCM-SHA256:ECDHE-RSA-AES256-GCM-SHA384:ECDHE-RSA-AES128-GCM-SHA256") != 1) {
return 0;
if (SSL_CTX_set_cipher_list(ssl_context, "DHE-RSA-AES256-GCM-SHA384:DHE-RSA-AES128-GCM-SHA256:ECDHE-RSA-AES256-GCM-SHA384:ECDHE-RSA-AES128-GCM-SHA256") != 1) {
free_ssl_context(ssl_context);
return NULL;
}
}
/* This must be free'd with free_ssl_context, not SSL_CTX_free */
return ssl_context;
}
/* Todo: return error on failure? */
void us_internal_ssl_socket_context_add_server_name(struct us_internal_ssl_socket_context_t *context, const char *hostname_pattern, struct us_socket_context_options_t options) {
/* Try and construct an SSL_CTX from options */
SSL_CTX *ssl_context = create_ssl_context_from_options(options);
/* We do not want to hold any nullptr's in our SNI tree */
if (ssl_context) {
if (sni_add(context->sni, hostname_pattern, ssl_context)) {
/* If we already had that name, ignore */
free_ssl_context(ssl_context);
}
}
}
void us_internal_ssl_socket_context_on_server_name(struct us_internal_ssl_socket_context_t *context, void (*cb)(struct us_internal_ssl_socket_context_t *, const char *hostname)) {
context->on_server_name = cb;
}
void us_internal_ssl_socket_context_remove_server_name(struct us_internal_ssl_socket_context_t *context, const char *hostname_pattern) {
/* The same thing must happen for sni_free, that's why we have a callback */
SSL_CTX *sni_node_ssl_context = (SSL_CTX *) sni_remove(context->sni, hostname_pattern);
free_ssl_context(sni_node_ssl_context);
}
/* Returns NULL or SSL_CTX. May call missing server name callback */
SSL_CTX *resolve_context(struct us_internal_ssl_socket_context_t *context, const char *hostname) {
/* Try once first */
void *user = sni_find(context->sni, hostname);
if (!user) {
/* Emit missing hostname then try again */
if (!context->on_server_name) {
/* We have no callback registered, so fail */
return NULL;
}
context->on_server_name(context, hostname);
/* Last try */
user = sni_find(context->sni, hostname);
}
return user;
}
// arg is context
int sni_cb(SSL *ssl, int *al, void *arg) {
if (ssl) {
const char *hostname = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name);
if (hostname && hostname[0]) {
/* Try and resolve (match) required hostname with what we have registered */
SSL_CTX *resolved_ssl_context = resolve_context((struct us_internal_ssl_socket_context_t *) arg, hostname);
if (resolved_ssl_context) {
//printf("Did find matching SNI context for hostname: <%s>!\n", hostname);
SSL_set_SSL_CTX(ssl, resolved_ssl_context);
} else {
/* Call a blocking callback notifying of missing context */
}
}
return SSL_TLSEXT_ERR_OK;
}
/* Can we even come here ever? */
return SSL_TLSEXT_ERR_NOACK;
}
struct us_internal_ssl_socket_context_t *us_internal_create_ssl_socket_context(struct us_loop_t *loop, int context_ext_size, struct us_socket_context_options_t options) {
/* If we haven't initialized the loop data yet, do so .
* This is needed because loop data holds shared OpenSSL data and
* the function is also responsible for initializing OpenSSL */
us_internal_init_loop_ssl_data(loop);
/* First of all we try and create the SSL context from options */
SSL_CTX *ssl_context = create_ssl_context_from_options(options);
if (!ssl_context) {
/* We simply fail early if we cannot even create the OpenSSL context */
return NULL;
}
/* Otherwise ee continue by creating a non-SSL context, but with larger ext to hold our SSL stuff */
struct us_internal_ssl_socket_context_t *context = (struct us_internal_ssl_socket_context_t *) us_create_socket_context(0, loop, sizeof(struct us_internal_ssl_socket_context_t) + context_ext_size, options);
/* I guess this is the only optional callback */
context->on_server_name = NULL;
/* Then we extend its SSL parts */
context->ssl_context = ssl_context;//create_ssl_context_from_options(options);
context->is_parent = 1;
/* We, as parent context, may ignore data */
context->sc.ignore_data = (int (*)(struct us_socket_t *)) ssl_ignore_data;
/* Parent contexts may use SNI */
SSL_CTX_set_tlsext_servername_callback(context->ssl_context, sni_cb);
SSL_CTX_set_tlsext_servername_arg(context->ssl_context, context);
/* Also create the SNI tree */
context->sni = sni_new();
return context;
}
/* Our destructor for hostnames, used below */
void sni_hostname_destructor(void *user) {
/* Some nodes hold null, so this one must ignore this case */
free_ssl_context((SSL_CTX *) user);
}
void us_internal_ssl_socket_context_free(struct us_internal_ssl_socket_context_t *context) {
/* If we are parent then we need to free our OpenSSL context */
if (context->is_parent) {
SSL_CTX_free(context->ssl_context);
free_ssl_context(context->ssl_context);
/* Here we need to register a temporary callback for all still-existing hostnames
* and their contexts. Only parents have an SNI tree */
sni_free(context->sni, sni_hostname_destructor);
}
us_socket_context_free(0, &context->sc);
@ -483,8 +684,8 @@ struct us_listen_socket_t *us_internal_ssl_socket_context_listen(struct us_inter
return us_socket_context_listen(0, &context->sc, host, port, options, sizeof(struct us_internal_ssl_socket_t) - sizeof(struct us_socket_t) + socket_ext_size);
}
struct us_internal_ssl_socket_t *us_internal_ssl_socket_context_connect(struct us_internal_ssl_socket_context_t *context, const char *host, int port, int options, int socket_ext_size) {
return (struct us_internal_ssl_socket_t *) us_socket_context_connect(0, &context->sc, host, port, options, sizeof(struct us_internal_ssl_socket_t) - sizeof(struct us_socket_t) + socket_ext_size);
struct us_internal_ssl_socket_t *us_internal_ssl_socket_context_connect(struct us_internal_ssl_socket_context_t *context, const char *host, int port, const char *source_host, int options, int socket_ext_size) {
return (struct us_internal_ssl_socket_t *) us_socket_context_connect(0, &context->sc, host, port, source_host, options, sizeof(struct us_internal_ssl_socket_t) - sizeof(struct us_socket_t) + socket_ext_size);
}
void us_internal_ssl_socket_context_on_open(struct us_internal_ssl_socket_context_t *context, struct us_internal_ssl_socket_t *(*on_open)(struct us_internal_ssl_socket_t *s, int is_client, char *ip, int ip_length)) {
@ -492,8 +693,8 @@ void us_internal_ssl_socket_context_on_open(struct us_internal_ssl_socket_contex
context->on_open = on_open;
}
void us_internal_ssl_socket_context_on_close(struct us_internal_ssl_socket_context_t *context, struct us_internal_ssl_socket_t *(*on_close)(struct us_internal_ssl_socket_t *s)) {
us_socket_context_on_close(0, (struct us_socket_context_t *) context, (struct us_socket_t *(*)(struct us_socket_t *)) ssl_on_close);
void us_internal_ssl_socket_context_on_close(struct us_internal_ssl_socket_context_t *context, struct us_internal_ssl_socket_t *(*on_close)(struct us_internal_ssl_socket_t *s, int code, void *reason)) {
us_socket_context_on_close(0, (struct us_socket_context_t *) context, (struct us_socket_t *(*)(struct us_socket_t *, int, void *)) ssl_on_close);
context->on_close = on_close;
}
@ -503,22 +704,31 @@ void us_internal_ssl_socket_context_on_data(struct us_internal_ssl_socket_contex
}
void us_internal_ssl_socket_context_on_writable(struct us_internal_ssl_socket_context_t *context, struct us_internal_ssl_socket_t *(*on_writable)(struct us_internal_ssl_socket_t *s)) {
us_socket_context_on_writable(0, (struct us_socket_context_t *) context, (struct us_socket_t *(*)(struct us_socket_t *)) on_writable);
us_socket_context_on_writable(0, (struct us_socket_context_t *) context, (struct us_socket_t *(*)(struct us_socket_t *)) ssl_on_writable);
context->on_writable = on_writable;
}
void us_internal_ssl_socket_context_on_timeout(struct us_internal_ssl_socket_context_t *context, struct us_internal_ssl_socket_t *(*on_timeout)(struct us_internal_ssl_socket_t *s)) {
us_socket_context_on_timeout(0, (struct us_socket_context_t *) context, (struct us_socket_t *(*)(struct us_socket_t *)) on_timeout);
}
/* We do not really listen to passed FIN-handler, we entirely override it with our handler since SSL doesn't really have support for half-closed sockets */
void us_internal_ssl_socket_context_on_end(struct us_internal_ssl_socket_context_t *context, struct us_internal_ssl_socket_t *(*on_end)(struct us_internal_ssl_socket_t *)) {
us_socket_context_on_end(0, (struct us_socket_context_t *) context, (struct us_socket_t *(*)(struct us_socket_t *)) ssl_on_end);
}
void us_internal_ssl_socket_context_on_connect_error(struct us_internal_ssl_socket_context_t *context, struct us_internal_ssl_socket_t *(*on_connect_error)(struct us_internal_ssl_socket_t *, int code)) {
us_socket_context_on_connect_error(0, (struct us_socket_context_t *) context, (struct us_socket_t *(*)(struct us_socket_t *, int)) on_connect_error);
}
void *us_internal_ssl_socket_context_ext(struct us_internal_ssl_socket_context_t *context) {
return context + 1;
}
/* Per socket functions */
void *us_internal_ssl_socket_get_native_handle(struct us_internal_ssl_socket_t *s) {
return s->ssl;
}
int us_internal_ssl_socket_write(struct us_internal_ssl_socket_t *s, const char *data, int length, int msg_more) {
if (us_socket_is_closed(0, &s->s) || us_internal_ssl_socket_is_shut_down(s)) {
@ -617,10 +827,6 @@ void us_internal_ssl_socket_shutdown(struct us_internal_ssl_socket_t *s) {
}
}
struct us_internal_ssl_socket_t *us_internal_ssl_socket_close(struct us_internal_ssl_socket_t *s) {
return (struct us_internal_ssl_socket_t *) us_socket_close(0, (struct us_socket_t *) s);
}
struct us_internal_ssl_socket_t *us_internal_ssl_socket_context_adopt_socket(struct us_internal_ssl_socket_context_t *context, struct us_internal_ssl_socket_t *s, int ext_size) {
// todo: this is completely untested
return (struct us_internal_ssl_socket_t *) us_socket_context_adopt_socket(0, &context->sc, &s->s, sizeof(struct us_internal_ssl_socket_t) - sizeof(struct us_socket_t) + ext_size);

View File

@ -0,0 +1,218 @@
/*
* Authored by Alex Hultman, 2018-2020.
* Intellectual property of third-party.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/* This Server Name Indication hostname tree is written in C++ but could be ported to C.
* Overall it looks like crap, but has no memory allocations in fast path and is O(log n). */
#ifndef SNI_TREE_H
#define SNI_TREE_H
#ifndef LIBUS_NO_SSL
#include <map>
#include <memory>
#include <string_view>
#include <cstring>
#include <cstdlib>
#include <algorithm>
/* We only handle a maximum of 10 labels per hostname */
#define MAX_LABELS 10
/* This cannot be shared */
thread_local void (*sni_free_cb)(void *);
struct sni_node {
/* Empty nodes must always hold null */
void *user = nullptr;
std::map<std::string_view, std::unique_ptr<sni_node>> children;
~sni_node() {
for (auto &p : children) {
/* The data of our string_views are managed by malloc */
free((void *) p.first.data());
/* Call destructor passed to sni_free only if we hold data.
* This is important since sni_remove does not have sni_free_cb set */
if (p.second.get()->user) {
sni_free_cb(p.second.get()->user);
}
}
}
};
// this can only delete ONE single node, but may cull "empty nodes with null as data"
void *removeUser(struct sni_node *root, unsigned int label, std::string_view *labels, unsigned int numLabels) {
/* If we are in the bottom (past bottom by one), there is nothing to remove */
if (label == numLabels) {
void *user = root->user;
/* Mark us for culling on the way up */
root->user = nullptr;
return user;
}
/* Is this label a child of root? */
auto it = root->children.find(labels[label]);
if (it == root->children.end()) {
/* We cannot continue */
return nullptr;
}
void *removedUser = removeUser(it->second.get(), label + 1, labels, numLabels);
/* On the way back up, we may cull empty nodes with no children.
* This ends up being where we remove all nodes */
if (it->second.get()->children.empty() && it->second.get()->user == nullptr) {
/* The data of our string_views are managed by malloc */
free((void *) it->first.data());
/* This can only happen with user set to null, otherwise we use sni_free_cb which is unset by sni_remove */
root->children.erase(it);
}
return removedUser;
}
void *getUser(struct sni_node *root, unsigned int label, std::string_view *labels, unsigned int numLabels) {
/* Do we have labels to match? Otherwise, return where we stand */
if (label == numLabels) {
return root->user;
}
/* Try and match by our label */
auto it = root->children.find(labels[label]);
if (it != root->children.end()) {
void *user = getUser(it->second.get(), label + 1, labels, numLabels);
if (user) {
return user;
}
}
/* Try and match by wildcard */
it = root->children.find("*");
if (it == root->children.end()) {
/* Matching has failed for both label and wildcard */
return nullptr;
}
/* We matched by wildcard */
return getUser(it->second.get(), label + 1, labels, numLabels);
}
extern "C" {
void *sni_new() {
return new sni_node;
}
void sni_free(void *sni, void (*cb)(void *)) {
/* We want to run this callback for every remaining name */
sni_free_cb = cb;
delete (sni_node *) sni;
}
/* Returns non-null if this name already exists */
int sni_add(void *sni, const char *hostname, void *user) {
struct sni_node *root = (struct sni_node *) sni;
/* Traverse all labels in hostname */
for (std::string_view view(hostname, strlen(hostname)), label;
view.length(); view.remove_prefix(std::min(view.length(), label.length() + 1))) {
/* Label is the token separated by dot */
label = view.substr(0, view.find('.', 0));
auto it = root->children.find(label);
if (it == root->children.end()) {
/* Duplicate this label for our kept string_view of it */
void *labelString = malloc(label.length());
memcpy(labelString, label.data(), label.length());
it = root->children.emplace(std::string_view((char *) labelString, label.length()),
std::make_unique<sni_node>()).first;
}
root = it->second.get();
}
/* We must never add multiple contexts for the same name, as that would overwrite and leak */
if (root->user) {
return 1;
}
root->user = user;
return 0;
}
/* Removes the exact match. Wildcards are treated as the verbatim asterisk char, not as an actual wildcard */
void *sni_remove(void *sni, const char *hostname) {
struct sni_node *root = (struct sni_node *) sni;
/* I guess 10 labels is an okay limit */
std::string_view labels[10];
unsigned int numLabels = 0;
/* We traverse all labels first of all */
for (std::string_view view(hostname, strlen(hostname)), label;
view.length(); view.remove_prefix(std::min(view.length(), label.length() + 1))) {
/* Label is the token separated by dot */
label = view.substr(0, view.find('.', 0));
/* Anything longer than 10 labels is forbidden */
if (numLabels == 10) {
return nullptr;
}
labels[numLabels++] = label;
}
return removeUser(root, 0, labels, numLabels);
}
void *sni_find(void *sni, const char *hostname) {
struct sni_node *root = (struct sni_node *) sni;
/* I guess 10 labels is an okay limit */
std::string_view labels[10];
unsigned int numLabels = 0;
/* We traverse all labels first of all */
for (std::string_view view(hostname, strlen(hostname)), label;
view.length(); view.remove_prefix(std::min(view.length(), label.length() + 1))) {
/* Label is the token separated by dot */
label = view.substr(0, view.find('.', 0));
/* Anything longer than 10 labels is forbidden */
if (numLabels == 10) {
return nullptr;
}
labels[numLabels++] = label;
}
return getUser(root, 0, labels, numLabels);
}
}
#endif
#endif

View File

@ -369,7 +369,7 @@ struct us_internal_ssl_socket_context_t *us_internal_create_ssl_socket_context(s
// options
wolfSSL_CTX_set_read_ahead(context->ssl_context, 1);
wolfSSL_CTX_set_mode(context->ssl_context, WOLFSSL_MODE_ACCEPT_MOVING_WRITE_BUFFER);
// this lowers performance a bit in benchmarks
if (options.ssl_prefer_low_memory_usage) {
//SSL_CTX_set_mode(context->ssl_context, SSL_MODE_RELEASE_BUFFERS);
@ -423,8 +423,8 @@ struct us_listen_socket_t *us_internal_ssl_socket_context_listen(struct us_inter
return us_socket_context_listen(0, &context->sc, host, port, options, sizeof(struct us_internal_ssl_socket_t) - sizeof(struct us_socket_t) + socket_ext_size);
}
struct us_internal_ssl_socket_t *us_internal_ssl_socket_context_connect(struct us_internal_ssl_socket_context_t *context, const char *host, int port, int options, int socket_ext_size) {
return (struct us_internal_ssl_socket_t *) us_socket_context_connect(0, &context->sc, host, port, options, sizeof(struct us_internal_ssl_socket_t) - sizeof(struct us_socket_t) + socket_ext_size);
struct us_internal_ssl_socket_t *us_internal_ssl_socket_context_connect(struct us_internal_ssl_socket_context_t *context, const char *host, int port, const char *source_host, int options, int socket_ext_size) {
return (struct us_internal_ssl_socket_t *) us_socket_context_connect(0, &context->sc, host, port, source_host, options, sizeof(struct us_internal_ssl_socket_t) - sizeof(struct us_socket_t) + socket_ext_size);
}
void us_internal_ssl_socket_context_on_open(struct us_internal_ssl_socket_context_t *context, struct us_internal_ssl_socket_t *(*on_open)(struct us_internal_ssl_socket_t *s, int is_client, char *ip, int ip_length)) {

View File

@ -21,6 +21,9 @@
#if defined(LIBUS_USE_EPOLL) || defined(LIBUS_USE_KQUEUE)
/* Cannot include this one on Windows */
#include <unistd.h>
#ifdef LIBUS_USE_EPOLL
#define GET_READY_POLL(loop, index) (struct us_poll_t *) loop->ready_polls[index].data.ptr
#define SET_READY_POLL(loop, index, poll) loop->ready_polls[index].data.ptr = poll
@ -329,8 +332,8 @@ void us_timer_set(struct us_timer_t *t, void (*cb)(struct us_timer_t *t), int ms
internal_cb->cb = (void (*)(struct us_internal_callback_t *)) cb;
struct itimerspec timer_spec = {
{repeat_ms / 1000, ((long)repeat_ms * 1000000) % 1000000000},
{ms / 1000, ((long)ms * 1000000) % 1000000000}
{repeat_ms / 1000, (long) (repeat_ms % 1000) * (long) 1000000},
{ms / 1000, (long) (ms % 1000) * (long) 1000000}
};
timerfd_settime(us_poll_fd((struct us_poll_t *) t), 0, &timer_spec, NULL);

View File

@ -145,7 +145,7 @@ void *us_poll_ext(struct us_poll_t *p) {
}
unsigned int us_internal_accept_poll_event(struct us_poll_t *p) {
printf("us_internal_accept_poll_event\n");
//printf("us_internal_accept_poll_event\n");
return 0;
}

View File

@ -1,5 +1,5 @@
/*
* Authored by Alex Hultman, 2018-2019.
* Authored by Alex Hultman, 2018-2021.
* Intellectual property of third-party.
* Licensed under the Apache License, Version 2.0 (the "License");
@ -21,9 +21,9 @@
#ifdef LIBUS_USE_LIBUV
// poll dispatch
/* uv_poll_t->data always (except for most times after calling us_poll_stop) points to the us_poll_t */
static void poll_cb(uv_poll_t *p, int status, int events) {
us_internal_dispatch_ready_poll((struct us_poll_t *) p, status < 0, events);
us_internal_dispatch_ready_poll((struct us_poll_t *) p->data, status < 0, events);
}
static void prepare_cb(uv_prepare_t *p) {
@ -37,10 +37,21 @@ static void check_cb(uv_check_t *p) {
us_internal_loop_post(loop);
}
/* Not used for polls, since polls need two frees */
static void close_cb_free(uv_handle_t *h) {
free(h->data);
}
/* This one is different for polls, since we need two frees here */
static void close_cb_free_poll(uv_handle_t *h) {
/* It is only in case we called us_poll_stop then quickly us_poll_free that we enter this.
* Most of the time, actual freeing is done by us_poll_free. */
if (h->data) {
free(h->data);
free(h);
}
}
static void timer_cb(uv_timer_t *t) {
struct us_internal_callback_t *cb = t->data;
cb->cb(cb);
@ -59,9 +70,14 @@ void us_poll_init(struct us_poll_t *p, LIBUS_SOCKET_DESCRIPTOR fd, int poll_type
}
void us_poll_free(struct us_poll_t *p, struct us_loop_t *loop) {
if (uv_is_closing((uv_handle_t *) &p->uv_p)) {
p->uv_p.data = p;
/* The idea here is like so; in us_poll_stop we call uv_close after setting data of uv-poll to 0.
* This means that in close_cb_free we call free on 0 with does nothing, since us_poll_stop should
* not really free the poll. HOWEVER, if we then call us_poll_free while still closing the uv-poll,
* we simply change back the data to point to our structure so that we actually do free it like we should. */
if (uv_is_closing((uv_handle_t *) p->uv_p)) {
p->uv_p->data = p;
} else {
free(p->uv_p);
free(p);
}
}
@ -69,24 +85,26 @@ void us_poll_free(struct us_poll_t *p, struct us_loop_t *loop) {
void us_poll_start(struct us_poll_t *p, struct us_loop_t *loop, int events) {
p->poll_type = us_internal_poll_type(p) | ((events & LIBUS_SOCKET_READABLE) ? POLL_TYPE_POLLING_IN : 0) | ((events & LIBUS_SOCKET_WRITABLE) ? POLL_TYPE_POLLING_OUT : 0);
uv_poll_init_socket(loop->uv_loop, &p->uv_p, p->fd);
uv_poll_start(&p->uv_p, events, poll_cb);
uv_poll_init_socket(loop->uv_loop, p->uv_p, p->fd);
uv_poll_start(p->uv_p, events, poll_cb);
}
void us_poll_change(struct us_poll_t *p, struct us_loop_t *loop, int events) {
if (us_poll_events(p) != events) {
p->poll_type = us_internal_poll_type(p) | ((events & LIBUS_SOCKET_READABLE) ? POLL_TYPE_POLLING_IN : 0) | ((events & LIBUS_SOCKET_WRITABLE) ? POLL_TYPE_POLLING_OUT : 0);
uv_poll_start(&p->uv_p, events, poll_cb);
uv_poll_start(p->uv_p, events, poll_cb);
}
}
void us_poll_stop(struct us_poll_t *p, struct us_loop_t *loop) {
uv_poll_stop(&p->uv_p);
uv_poll_stop(p->uv_p);
// close but not free is needed here
p->uv_p.data = 0;
uv_close((uv_handle_t *) &p->uv_p, close_cb_free); // needed here
/* We normally only want to close the poll here, not free it. But if we stop it, then quickly "free" it with
* us_poll_free, we postpone the actual freeing to close_cb_free_poll whenever it triggers.
* That's why we set data to null here, so that us_poll_free can reset it if needed */
p->uv_p->data = 0;
uv_close((uv_handle_t *) p->uv_p, close_cb_free_poll);
}
int us_poll_events(struct us_poll_t *p) {
@ -171,19 +189,17 @@ void us_loop_run(struct us_loop_t *loop) {
}
struct us_poll_t *us_create_poll(struct us_loop_t *loop, int fallthrough, unsigned int ext_size) {
return malloc(sizeof(struct us_poll_t) + ext_size);
struct us_poll_t *p = (struct us_poll_t *) malloc(sizeof(struct us_poll_t) + ext_size);
p->uv_p = malloc(sizeof(uv_poll_t));
p->uv_p->data = p;
return p;
}
// this one is broken, us_poll needs to hold a pointer to uv_poll_t for it to work (bad anyways)
/* If we update our block position we have to updarte the uv_poll data to point to us */
struct us_poll_t *us_poll_resize(struct us_poll_t *p, struct us_loop_t *loop, unsigned int ext_size) {
// do not support it yet
return p;
struct us_poll_t *new_p = realloc(p, sizeof(struct us_poll_t) + ext_size);
if (p != new_p) {
new_p->uv_p.data = new_p;
}
new_p->uv_p->data = new_p;
return new_p;
}
@ -207,7 +223,7 @@ struct us_timer_t *us_create_timer(struct us_loop_t *loop, int fallthrough, unsi
}
void *us_timer_ext(struct us_timer_t *timer) {
return ((struct us_internal_callback_t *) timer) + 1;
return ((char *) timer) + sizeof(struct us_internal_callback_t) + sizeof(uv_timer_t);
}
void us_timer_close(struct us_timer_t *t) {

View File

@ -59,7 +59,7 @@ struct us_loop_t {
struct us_poll_t {
alignas(LIBUS_EXT_ALIGNMENT) struct {
int fd : 28;
signed int fd : 28; // we could have this unsigned if we wanted to, -1 should never be used
unsigned int poll_type : 4;
} state;
};

View File

@ -34,8 +34,10 @@ struct us_loop_t {
uv_check_t *uv_check;
};
// it is no longer valid to cast a pointer to us_poll_t to a pointer of uv_poll_t
struct us_poll_t {
uv_poll_t uv_p;
/* We need to hold a pointer to this uv_poll_t since we need to be able to resize our block */
uv_poll_t *uv_p;
LIBUS_SOCKET_DESCRIPTOR fd;
unsigned char poll_type;
};

View File

@ -18,6 +18,12 @@
#ifndef INTERNAL_H
#define INTERNAL_H
#if defined(_MSC_VER)
#define alignas(x) __declspec(align(x))
#else
#include <stdalign.h>
#endif
/* We only have one networking implementation so far */
#include "internal/networking/bsd.h"
@ -100,7 +106,7 @@ struct us_listen_socket_t {
struct us_socket_context_t {
alignas(LIBUS_EXT_ALIGNMENT) struct us_loop_t *loop;
//unsigned short timeout;
unsigned short timestamp;
struct us_socket_t *head;
struct us_socket_t *iterator;
struct us_socket_context_t *prev, *next;
@ -108,14 +114,12 @@ struct us_socket_context_t {
struct us_socket_t *(*on_open)(struct us_socket_t *, int is_client, char *ip, int ip_length);
struct us_socket_t *(*on_data)(struct us_socket_t *, char *data, int length);
struct us_socket_t *(*on_writable)(struct us_socket_t *);
struct us_socket_t *(*on_close)(struct us_socket_t *);
struct us_socket_t *(*on_close)(struct us_socket_t *, int code, void *reason);
//void (*on_timeout)(struct us_socket_context *);
struct us_socket_t *(*on_socket_timeout)(struct us_socket_t *);
struct us_socket_t *(*on_end)(struct us_socket_t *);
struct us_socket_t *(*on_connect_error)(struct us_socket_t *, int code);
int (*ignore_data)(struct us_socket_t *);
/* All contexts hold references to their own copied options */
struct us_socket_context_options_t options;
};
/* Internal SSL interface */
@ -124,6 +128,14 @@ struct us_socket_context_t {
struct us_internal_ssl_socket_context_t;
struct us_internal_ssl_socket_t;
/* SNI functions */
void us_internal_ssl_socket_context_add_server_name(struct us_internal_ssl_socket_context_t *context, const char *hostname_pattern, struct us_socket_context_options_t options);
void us_internal_ssl_socket_context_remove_server_name(struct us_internal_ssl_socket_context_t *context, const char *hostname_pattern);
void us_internal_ssl_socket_context_on_server_name(struct us_internal_ssl_socket_context_t *context, void (*cb)(struct us_internal_ssl_socket_context_t *, const char *));
void *us_internal_ssl_socket_get_native_handle(struct us_internal_ssl_socket_t *s);
void *us_internal_ssl_socket_context_get_native_handle(struct us_internal_ssl_socket_context_t *context);
struct us_internal_ssl_socket_context_t *us_internal_create_ssl_socket_context(struct us_loop_t *loop,
int context_ext_size, struct us_socket_context_options_t options);
@ -132,7 +144,7 @@ void us_internal_ssl_socket_context_on_open(struct us_internal_ssl_socket_contex
struct us_internal_ssl_socket_t *(*on_open)(struct us_internal_ssl_socket_t *s, int is_client, char *ip, int ip_length));
void us_internal_ssl_socket_context_on_close(struct us_internal_ssl_socket_context_t *context,
struct us_internal_ssl_socket_t *(*on_close)(struct us_internal_ssl_socket_t *s));
struct us_internal_ssl_socket_t *(*on_close)(struct us_internal_ssl_socket_t *s, int code, void *reason));
void us_internal_ssl_socket_context_on_data(struct us_internal_ssl_socket_context_t *context,
struct us_internal_ssl_socket_t *(*on_data)(struct us_internal_ssl_socket_t *s, char *data, int length));
@ -146,11 +158,14 @@ void us_internal_ssl_socket_context_on_timeout(struct us_internal_ssl_socket_con
void us_internal_ssl_socket_context_on_end(struct us_internal_ssl_socket_context_t *context,
struct us_internal_ssl_socket_t *(*on_end)(struct us_internal_ssl_socket_t *s));
void us_internal_ssl_socket_context_on_connect_error(struct us_internal_ssl_socket_context_t *context,
struct us_internal_ssl_socket_t *(*on_connect_error)(struct us_internal_ssl_socket_t *s, int code));
struct us_listen_socket_t *us_internal_ssl_socket_context_listen(struct us_internal_ssl_socket_context_t *context,
const char *host, int port, int options, int socket_ext_size);
struct us_internal_ssl_socket_t *us_internal_ssl_socket_context_connect(struct us_internal_ssl_socket_context_t *context,
const char *host, int port, int options, int socket_ext_size);
const char *host, int port, const char *source_host, int options, int socket_ext_size);
int us_internal_ssl_socket_write(struct us_internal_ssl_socket_t *s, const char *data, int length, int msg_more);
void us_internal_ssl_socket_timeout(struct us_internal_ssl_socket_t *s, unsigned int seconds);
@ -159,7 +174,7 @@ struct us_internal_ssl_socket_context_t *us_internal_ssl_socket_get_context(stru
void *us_internal_ssl_socket_ext(struct us_internal_ssl_socket_t *s);
int us_internal_ssl_socket_is_shut_down(struct us_internal_ssl_socket_t *s);
void us_internal_ssl_socket_shutdown(struct us_internal_ssl_socket_t *s);
struct us_internal_ssl_socket_t *us_internal_ssl_socket_close(struct us_internal_ssl_socket_t *s);
struct us_internal_ssl_socket_t *us_internal_ssl_socket_context_adopt_socket(struct us_internal_ssl_socket_context_t *context,
struct us_internal_ssl_socket_t *s, int ext_size);

View File

@ -24,265 +24,61 @@
// here everything about the syscalls are inline-wrapped and included
#ifdef _WIN32
#ifndef NOMINMAX
#define NOMINMAX
#include <WinSock2.h>
#include <Ws2tcpip.h>
#endif
#include <winsock2.h>
#include <ws2tcpip.h>
#pragma comment(lib, "ws2_32.lib")
#include <stdio.h>
#define SETSOCKOPT_PTR_TYPE const char *
#define LIBUS_SOCKET_ERROR INVALID_SOCKET
#else
#define _GNU_SOURCE
#include <sys/types.h>
/* For socklen_t */
#include <sys/socket.h>
#include <netinet/in.h>
#include <netinet/tcp.h>
#include <netdb.h>
#include <string.h>
#include <unistd.h>
#include <fcntl.h>
#include <stdio.h>
#include <errno.h>
#define SETSOCKOPT_PTR_TYPE int *
#define LIBUS_SOCKET_ERROR -1
#endif
static inline LIBUS_SOCKET_DESCRIPTOR apple_no_sigpipe(LIBUS_SOCKET_DESCRIPTOR fd) {
#ifdef __APPLE__
if (fd != LIBUS_SOCKET_ERROR) {
int no_sigpipe = 1;
setsockopt(fd, SOL_SOCKET, SO_NOSIGPIPE, &no_sigpipe, sizeof(int));
}
#endif
return fd;
}
static inline LIBUS_SOCKET_DESCRIPTOR bsd_set_nonblocking(LIBUS_SOCKET_DESCRIPTOR fd) {
#ifdef _WIN32
/* Libuv will set windows sockets as non-blocking */
#else
fcntl(fd, F_SETFL, fcntl(fd, F_GETFL, 0) | O_NONBLOCK);
#endif
return fd;
}
static inline void bsd_socket_nodelay(LIBUS_SOCKET_DESCRIPTOR fd, int enabled) {
setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, (void *) &enabled, sizeof(enabled));
}
static inline void bsd_socket_flush(LIBUS_SOCKET_DESCRIPTOR fd) {
// Linux TCP_CORK has the same underlying corking mechanism as with MSG_MORE
#ifdef TCP_CORK
int enabled = 0;
setsockopt(fd, IPPROTO_TCP, TCP_CORK, &enabled, sizeof(int));
#endif
}
static inline LIBUS_SOCKET_DESCRIPTOR bsd_create_socket(int domain, int type, int protocol) {
// returns INVALID_SOCKET on error
int flags = 0;
#if defined(SOCK_CLOEXEC) && defined(SOCK_NONBLOCK)
flags = SOCK_CLOEXEC | SOCK_NONBLOCK;
#endif
LIBUS_SOCKET_DESCRIPTOR created_fd = socket(domain, type | flags, protocol);
return bsd_set_nonblocking(apple_no_sigpipe(created_fd));
}
static inline void bsd_close_socket(LIBUS_SOCKET_DESCRIPTOR fd) {
#ifdef _WIN32
closesocket(fd);
#else
close(fd);
#endif
}
static inline void bsd_shutdown_socket(LIBUS_SOCKET_DESCRIPTOR fd) {
#ifdef _WIN32
shutdown(fd, SD_SEND);
#else
shutdown(fd, SHUT_WR);
#endif
}
struct bsd_addr_t {
struct sockaddr_storage mem;
socklen_t len;
char *ip;
int ip_length;
int port;
};
static inline void internal_finalize_bsd_addr(struct bsd_addr_t *addr) {
// parse, so to speak, the address
if (addr->mem.ss_family == AF_INET6) {
addr->ip = (char *) &((struct sockaddr_in6 *) addr)->sin6_addr;
addr->ip_length = sizeof(struct in6_addr);
} else if (addr->mem.ss_family == AF_INET) {
addr->ip = (char *) &((struct sockaddr_in *) addr)->sin_addr;
addr->ip_length = sizeof(struct in_addr);
} else {
addr->ip_length = 0;
}
}
LIBUS_SOCKET_DESCRIPTOR apple_no_sigpipe(LIBUS_SOCKET_DESCRIPTOR fd);
LIBUS_SOCKET_DESCRIPTOR bsd_set_nonblocking(LIBUS_SOCKET_DESCRIPTOR fd);
void bsd_socket_nodelay(LIBUS_SOCKET_DESCRIPTOR fd, int enabled);
void bsd_socket_flush(LIBUS_SOCKET_DESCRIPTOR fd);
LIBUS_SOCKET_DESCRIPTOR bsd_create_socket(int domain, int type, int protocol);
static inline int bsd_socket_addr(LIBUS_SOCKET_DESCRIPTOR fd, struct bsd_addr_t *addr) {
addr->len = sizeof(addr->mem);
if (getpeername(fd, (struct sockaddr *) &addr->mem, &addr->len)) {
return -1;
}
internal_finalize_bsd_addr(addr);
return 0;
}
void bsd_close_socket(LIBUS_SOCKET_DESCRIPTOR fd);
void bsd_shutdown_socket(LIBUS_SOCKET_DESCRIPTOR fd);
void bsd_shutdown_socket_read(LIBUS_SOCKET_DESCRIPTOR fd);
static inline char *bsd_addr_get_ip(struct bsd_addr_t *addr) {
return addr->ip;
}
void internal_finalize_bsd_addr(struct bsd_addr_t *addr);
static inline int bsd_addr_get_ip_length(struct bsd_addr_t *addr) {
return addr->ip_length;
}
int bsd_local_addr(LIBUS_SOCKET_DESCRIPTOR fd, struct bsd_addr_t *addr);
int bsd_remote_addr(LIBUS_SOCKET_DESCRIPTOR fd, struct bsd_addr_t *addr);
char *bsd_addr_get_ip(struct bsd_addr_t *addr);
int bsd_addr_get_ip_length(struct bsd_addr_t *addr);
int bsd_addr_get_port(struct bsd_addr_t *addr);
// called by dispatch_ready_poll
static inline LIBUS_SOCKET_DESCRIPTOR bsd_accept_socket(LIBUS_SOCKET_DESCRIPTOR fd, struct bsd_addr_t *addr) {
LIBUS_SOCKET_DESCRIPTOR accepted_fd;
addr->len = sizeof(addr->mem);
LIBUS_SOCKET_DESCRIPTOR bsd_accept_socket(LIBUS_SOCKET_DESCRIPTOR fd, struct bsd_addr_t *addr);
#if defined(SOCK_CLOEXEC) && defined(SOCK_NONBLOCK)
// Linux, FreeBSD
accepted_fd = accept4(fd, (struct sockaddr *) addr, &addr->len, SOCK_CLOEXEC | SOCK_NONBLOCK);
#else
// Windows, OS X
accepted_fd = accept(fd, (struct sockaddr *) addr, &addr->len);
#endif
internal_finalize_bsd_addr(addr);
return bsd_set_nonblocking(apple_no_sigpipe(accepted_fd));
}
static inline int bsd_recv(LIBUS_SOCKET_DESCRIPTOR fd, void *buf, int length, int flags) {
return recv(fd, buf, length, flags);
}
static inline int bsd_send(LIBUS_SOCKET_DESCRIPTOR fd, const char *buf, int length, int msg_more) {
// MSG_MORE (Linux), MSG_PARTIAL (Windows), TCP_NOPUSH (BSD)
#ifndef MSG_NOSIGNAL
#define MSG_NOSIGNAL 0
#endif
#ifdef MSG_MORE
// for Linux we do not want signals
return send(fd, buf, length, (msg_more * MSG_MORE) | MSG_NOSIGNAL);
#else
// use TCP_NOPUSH
return send(fd, buf, length, MSG_NOSIGNAL);
#endif
}
static inline int bsd_would_block() {
#ifdef _WIN32
return WSAGetLastError() == WSAEWOULDBLOCK;
#else
return errno == EWOULDBLOCK;// || errno == EAGAIN;
#endif
}
int bsd_recv(LIBUS_SOCKET_DESCRIPTOR fd, void *buf, int length, int flags);
int bsd_send(LIBUS_SOCKET_DESCRIPTOR fd, const char *buf, int length, int msg_more);
int bsd_would_block();
// return LIBUS_SOCKET_ERROR or the fd that represents listen socket
// listen both on ipv6 and ipv4
static inline LIBUS_SOCKET_DESCRIPTOR bsd_create_listen_socket(const char *host, int port, int options) {
struct addrinfo hints, *result;
memset(&hints, 0, sizeof(struct addrinfo));
LIBUS_SOCKET_DESCRIPTOR bsd_create_listen_socket(const char *host, int port, int options);
hints.ai_flags = AI_PASSIVE;
hints.ai_family = AF_UNSPEC;
hints.ai_socktype = SOCK_STREAM;
char port_string[16];
snprintf(port_string, 16, "%d", port);
if (getaddrinfo(host, port_string, &hints, &result)) {
return LIBUS_SOCKET_ERROR;
}
LIBUS_SOCKET_DESCRIPTOR listenFd = LIBUS_SOCKET_ERROR;
struct addrinfo *listenAddr;
for (struct addrinfo *a = result; a && listenFd == LIBUS_SOCKET_ERROR; a = a->ai_next) {
if (a->ai_family == AF_INET6) {
listenFd = bsd_create_socket(a->ai_family, a->ai_socktype, a->ai_protocol);
listenAddr = a;
}
}
for (struct addrinfo *a = result; a && listenFd == LIBUS_SOCKET_ERROR; a = a->ai_next) {
if (a->ai_family == AF_INET) {
listenFd = bsd_create_socket(a->ai_family, a->ai_socktype, a->ai_protocol);
listenAddr = a;
}
}
if (listenFd == LIBUS_SOCKET_ERROR) {
freeaddrinfo(result);
return LIBUS_SOCKET_ERROR;
}
/* Always enable SO_REUSEPORT and SO_REUSEADDR _unless_ options specify otherwise */
#if defined(__linux) && defined(SO_REUSEPORT)
if (!(options & LIBUS_LISTEN_EXCLUSIVE_PORT)) {
int optval = 1;
setsockopt(listenFd, SOL_SOCKET, SO_REUSEPORT, &optval, sizeof(optval));
}
#endif
int enabled = 1;
setsockopt(listenFd, SOL_SOCKET, SO_REUSEADDR, (SETSOCKOPT_PTR_TYPE) &enabled, sizeof(enabled));
#ifdef IPV6_V6ONLY
int disabled = 0;
setsockopt(listenFd, IPPROTO_IPV6, IPV6_V6ONLY, (SETSOCKOPT_PTR_TYPE) &disabled, sizeof(disabled));
#endif
if (bind(listenFd, listenAddr->ai_addr, (socklen_t) listenAddr->ai_addrlen) || listen(listenFd, 512)) {
bsd_close_socket(listenFd);
freeaddrinfo(result);
return LIBUS_SOCKET_ERROR;
}
freeaddrinfo(result);
return listenFd;
}
static inline LIBUS_SOCKET_DESCRIPTOR bsd_create_connect_socket(const char *host, int port, int options) {
struct addrinfo hints, *result;
memset(&hints, 0, sizeof(struct addrinfo));
hints.ai_family = AF_UNSPEC;
hints.ai_socktype = SOCK_STREAM;
char port_string[16];
snprintf(port_string, 16, "%d", port);
if (getaddrinfo(host, port_string, &hints, &result) != 0) {
return LIBUS_SOCKET_ERROR;
}
LIBUS_SOCKET_DESCRIPTOR fd = bsd_create_socket(result->ai_family, result->ai_socktype, result->ai_protocol);
if (fd == LIBUS_SOCKET_ERROR) {
freeaddrinfo(result);
return LIBUS_SOCKET_ERROR;
}
connect(fd, result->ai_addr, (socklen_t) result->ai_addrlen);
freeaddrinfo(result);
return fd;
}
LIBUS_SOCKET_DESCRIPTOR bsd_create_connect_socket(const char *host, int port, const char *source_host, int options);
#endif // BSD_H

View File

@ -1,5 +1,5 @@
/*
* Authored by Alex Hultman, 2018-2019.
* Authored by Alex Hultman, 2018-2021.
* Intellectual property of third-party.
* Licensed under the Apache License, Version 2.0 (the "License");
@ -80,24 +80,50 @@ void us_internal_loop_unlink(struct us_loop_t *loop, struct us_socket_context_t
/* This functions should never run recursively */
void us_internal_timer_sweep(struct us_loop_t *loop) {
struct us_internal_loop_data_t *loop_data = &loop->data;
/* For all socket contexts in this loop */
for (loop_data->iterator = loop_data->head; loop_data->iterator; loop_data->iterator = loop_data->iterator->next) {
struct us_socket_context_t *context = loop_data->iterator;
for (context->iterator = context->head; context->iterator; ) {
struct us_socket_t *s = context->iterator;
if (s->timeout && --(s->timeout) == 0) {
/* Update this context's 15-bit timestamp */
context->timestamp = (context->timestamp + 1) & 0x7fff;
context->on_socket_timeout(s);
/* Update our 16-bit full timestamp (the needle in the haystack) */
unsigned short needle = 0x8000 | context->timestamp;
/* Check for unlink / link */
if (s == context->iterator) {
context->iterator = s->next;
/* Begin at head */
struct us_socket_t *s = context->head;
while (s) {
/* Seek until end or timeout found (tightest loop) */
while (1) {
/* We only read from 1 random cache line here */
if (needle == s->timeout) {
break;
}
/* Did we reach the end without a find? */
if ((s = s->next) == 0) {
goto next_context;
}
}
/* Here we have a timeout to emit (slow path) */
s->timeout = 0;
context->iterator = s;
context->on_socket_timeout(s);
/* Check for unlink / link (if the event handler did not modify the chain, we step 1) */
if (s == context->iterator) {
s = s->next;
} else {
context->iterator = s->next;
/* The iterator was changed by event handler */
s = context->iterator;
}
}
/* We always store a 0 to context->iterator here since we are no longer iterating this context */
next_context:
context->iterator = 0;
}
}
@ -136,7 +162,10 @@ void us_internal_loop_post(struct us_loop_t *loop) {
void us_internal_dispatch_ready_poll(struct us_poll_t *p, int error, int events) {
switch (us_internal_poll_type(p)) {
case POLL_TYPE_CALLBACK: {
/* Let's just do this to clear the CodeQL alert */
#ifndef LIBUS_USE_LIBUV
us_internal_accept_poll_event(p);
#endif
struct us_internal_callback_t *cb = (struct us_internal_callback_t *) p;
cb->cb(cb->cb_expects_the_loop ? (struct us_internal_callback_t *) cb->loop : (struct us_internal_callback_t *) &cb->p);
}
@ -147,15 +176,26 @@ void us_internal_dispatch_ready_poll(struct us_poll_t *p, int error, int events)
if (us_poll_events(p) == LIBUS_SOCKET_WRITABLE) {
struct us_socket_t *s = (struct us_socket_t *) p;
us_poll_change(p, s->context->loop, LIBUS_SOCKET_READABLE);
/* It is perfectly possible to come here with an error */
if (error) {
/* Emit error, close without emitting on_close */
s->context->on_connect_error(s, 0);
us_socket_close_connecting(0, s);
} else {
/* All sockets poll for readable */
us_poll_change(p, s->context->loop, LIBUS_SOCKET_READABLE);
/* We always use nodelay */
bsd_socket_nodelay(us_poll_fd(p), 1);
/* We always use nodelay */
bsd_socket_nodelay(us_poll_fd(p), 1);
/* We are now a proper socket */
us_internal_poll_set_type(p, POLL_TYPE_SOCKET);
/* We are now a proper socket */
us_internal_poll_set_type(p, POLL_TYPE_SOCKET);
s->context->on_open(s, 1, 0, 0);
/* If we used a connection timeout we have to reset it here */
us_socket_timeout(0, s, 0);
s->context->on_open(s, 1, 0, 0);
}
} else {
struct us_listen_socket_t *listen_socket = (struct us_listen_socket_t *) p;
struct bsd_addr_t addr;
@ -169,11 +209,11 @@ void us_internal_dispatch_ready_poll(struct us_poll_t *p, int error, int events)
/* Todo: stop timer if any */
do {
struct us_poll_t *p = us_create_poll(us_socket_context(0, &listen_socket->s)->loop, 0, sizeof(struct us_socket_t) - sizeof(struct us_poll_t) + listen_socket->socket_ext_size);
us_poll_init(p, client_fd, POLL_TYPE_SOCKET);
us_poll_start(p, listen_socket->s.context->loop, LIBUS_SOCKET_READABLE);
struct us_poll_t *accepted_p = us_create_poll(us_socket_context(0, &listen_socket->s)->loop, 0, sizeof(struct us_socket_t) - sizeof(struct us_poll_t) + listen_socket->socket_ext_size);
us_poll_init(accepted_p, client_fd, POLL_TYPE_SOCKET);
us_poll_start(accepted_p, listen_socket->s.context->loop, LIBUS_SOCKET_READABLE);
struct us_socket_t *s = (struct us_socket_t *) p;
struct us_socket_t *s = (struct us_socket_t *) accepted_p;
s->context = listen_socket->s.context;
@ -201,7 +241,8 @@ void us_internal_dispatch_ready_poll(struct us_poll_t *p, int error, int events)
/* Such as epollerr epollhup */
if (error) {
s = us_socket_close(0, s);
/* Todo: decide what code we give here */
s = us_socket_close(0, s, 0, NULL);
return;
}
@ -235,14 +276,16 @@ void us_internal_dispatch_ready_poll(struct us_poll_t *p, int error, int events)
} else if (!length) {
if (us_socket_is_shut_down(0, s)) {
/* We got FIN back after sending it */
s = us_socket_close(0, s);
/* Todo: We should give "CLEAN SHUTDOWN" as reason here */
s = us_socket_close(0, s, 0, NULL);
} else {
/* We got FIN, so stop polling for readable */
us_poll_change(&s->p, us_socket_context(0, s)->loop, us_poll_events(&s->p) & LIBUS_SOCKET_WRITABLE);
s = s->context->on_end(s);
}
} else if (length == LIBUS_SOCKET_ERROR && !bsd_would_block()) {
s = us_socket_close(0, s);
/* Todo: decide also here what kind of reason we should give */
s = us_socket_close(0, s, 0, NULL);
}
}
}

View File

@ -1,5 +1,5 @@
/*
* Authored by Alex Hultman, 2018-2019.
* Authored by Alex Hultman, 2018-2021.
* Intellectual property of third-party.
* Licensed under the Apache License, Version 2.0 (the "License");
@ -18,12 +18,27 @@
#include "libusockets.h"
#include "internal/internal.h"
#include <stdlib.h>
#include <string.h>
/* Shared with SSL */
int us_socket_local_port(int ssl, struct us_socket_t *s) {
struct bsd_addr_t addr;
if (bsd_local_addr(us_poll_fd(&s->p), &addr)) {
return -1;
} else {
return bsd_addr_get_port(&addr);
}
}
void us_socket_shutdown_read(int ssl, struct us_socket_t *s) {
/* This syscall is idempotent so no extra check is needed */
bsd_shutdown_socket_read(us_poll_fd((struct us_poll_t *) s));
}
void us_socket_remote_address(int ssl, struct us_socket_t *s, char *buf, int *length) {
struct bsd_addr_t addr;
if (bsd_socket_addr(us_poll_fd(&s->p), &addr) || *length < bsd_addr_get_ip_length(&addr)) {
if (bsd_remote_addr(us_poll_fd(&s->p), &addr) || *length < bsd_addr_get_ip_length(&addr)) {
*length = 0;
} else {
*length = bsd_addr_get_ip_length(&addr);
@ -37,8 +52,7 @@ struct us_socket_context_t *us_socket_context(int ssl, struct us_socket_t *s) {
void us_socket_timeout(int ssl, struct us_socket_t *s, unsigned int seconds) {
if (seconds) {
unsigned short timeout_sweeps = (unsigned short) (0.5f + ((float) seconds) / ((float) LIBUS_TIMEOUT_GRANULARITY));
s->timeout = timeout_sweeps ? timeout_sweeps : 1;
s->timeout = 0x8000 | (s->context->timestamp + (seconds >> 2));
} else {
s->timeout = 0;
}
@ -54,8 +68,61 @@ int us_socket_is_closed(int ssl, struct us_socket_t *s) {
return s->prev == (struct us_socket_t *) s->context;
}
int us_socket_is_established(int ssl, struct us_socket_t *s) {
/* Everything that is not POLL_TYPE_SEMI_SOCKET is established */
return us_internal_poll_type((struct us_poll_t *) s) != POLL_TYPE_SEMI_SOCKET;
}
/* Exactly the same as us_socket_close but does not emit on_close event */
struct us_socket_t *us_socket_close_connecting(int ssl, struct us_socket_t *s) {
if (!us_socket_is_closed(0, s)) {
us_internal_socket_context_unlink(s->context, s);
us_poll_stop((struct us_poll_t *) s, s->context->loop);
bsd_close_socket(us_poll_fd((struct us_poll_t *) s));
/* Link this socket to the close-list and let it be deleted after this iteration */
s->next = s->context->loop->data.closed_head;
s->context->loop->data.closed_head = s;
/* Any socket with prev = context is marked as closed */
s->prev = (struct us_socket_t *) s->context;
//return s->context->on_close(s, code, reason);
}
return s;
}
/* Same as above but emits on_close */
struct us_socket_t *us_socket_close(int ssl, struct us_socket_t *s, int code, void *reason) {
if (!us_socket_is_closed(0, s)) {
us_internal_socket_context_unlink(s->context, s);
us_poll_stop((struct us_poll_t *) s, s->context->loop);
bsd_close_socket(us_poll_fd((struct us_poll_t *) s));
/* Link this socket to the close-list and let it be deleted after this iteration */
s->next = s->context->loop->data.closed_head;
s->context->loop->data.closed_head = s;
/* Any socket with prev = context is marked as closed */
s->prev = (struct us_socket_t *) s->context;
return s->context->on_close(s, code, reason);
}
return s;
}
/* Not shared with SSL */
void *us_socket_get_native_handle(int ssl, struct us_socket_t *s) {
#ifndef LIBUS_NO_SSL
if (ssl) {
return us_internal_ssl_socket_get_native_handle((struct us_internal_ssl_socket_t *) s);
}
#endif
return (void *) (uintptr_t) us_poll_fd((struct us_poll_t *) s);
}
int us_socket_write(int ssl, struct us_socket_t *s, const char *data, int length, int msg_more) {
#ifndef LIBUS_NO_SSL
if (ssl) {
@ -86,30 +153,6 @@ void *us_socket_ext(int ssl, struct us_socket_t *s) {
return s + 1;
}
struct us_socket_t *us_socket_close(int ssl, struct us_socket_t *s) {
#ifndef LIBUS_NO_SSL
if (ssl) {
return (struct us_socket_t *) us_internal_ssl_socket_close((struct us_internal_ssl_socket_t *) s);
}
#endif
if (!us_socket_is_closed(0, s)) {
us_internal_socket_context_unlink(s->context, s);
us_poll_stop((struct us_poll_t *) s, s->context->loop);
bsd_close_socket(us_poll_fd((struct us_poll_t *) s));
/* Link this socket to the close-list and let it be deleted after this iteration */
s->next = s->context->loop->data.closed_head;
s->context->loop->data.closed_head = s;
/* Any socket with prev = context is marked as closed */
s->prev = (struct us_socket_t *) s->context;
return s->context->on_close(s);
}
return s;
}
int us_socket_is_shut_down(int ssl, struct us_socket_t *s) {
#ifndef LIBUS_NO_SSL
if (ssl) {

View File

@ -1,3 +1,5 @@
#pragma warning(disable : 4267 4138)
#include "BrokenithmServer.hpp"
#include <thread>
@ -53,11 +55,12 @@ uint64_t BrokenithmServer::get_controller_state()
struct ConnectionData
{
typedef uWS::WebSocket<false, true, ConnectionData> ConnectionDataSocket;
static int s_connection_counter;
static std::vector<ConnectionData *> s_connections;
int m_uid;
uWS::WebSocket<false, true> *m_websocket;
ConnectionDataSocket *m_websocket;
void static close_all_connections();
@ -75,7 +78,7 @@ struct ConnectionData
s_connections[m_uid] = nullptr;
}
void save_socket(uWS::WebSocket<false, true> *websocket)
void save_socket(ConnectionDataSocket *websocket)
{
m_websocket = websocket;
}
@ -149,12 +152,17 @@ void BrokenithmServer::Impl::start_server()
})
.ws<ConnectionData>(
"/ws",
{uWS::DISABLED,
16 * 1024 * 1024,
10,
16 * 1024 * 1024,
{uWS::DISABLED, // compression
16 * 1024 * 1024, // maxPayloadLength
16, // idleTimeout
16 * 1024 * 1024, // maxBackpressure
false, // closeOnBackpressureLimit
false, // resetIdleTimeoutOnSend
true, // sendPingsAutomatically
0, // maxLifetime
nullptr, // upgrade
// Open handler
[](auto *ws, auto *req) {
[](auto *ws) {
spdlog::info("Controller ID {} connected", ((ConnectionData *)ws->getUserData())->m_uid);
((ConnectionData *)ws->getUserData())->save_socket(ws);
},
@ -181,26 +189,21 @@ void BrokenithmServer::Impl::start_server()
}
}
},
// Drain handler
[](auto *ws) {},
// Ping handler
[](auto *ws) {},
// Pong handler
[](auto *ws) {},
nullptr, // Drain handler
nullptr, // Ping handler
nullptr, // Pong handler
// Close handler
[](auto *ws, int code, std::string_view message) {
spdlog::info("Controller ID {} disconnected", ((ConnectionData *)ws->getUserData())->m_uid);
}})
.listen(
m_port,
[&](auto *token) {
if (token)
{
spdlog::info("Server listening at port {}", m_port);
m_running = true;
m_uws_socket_token = token;
}
})
.listen(m_port, [&](auto *token) {
if (token)
{
spdlog::info("Server listening at port {}", m_port);
m_running = true;
m_uws_socket_token = token;
}
})
.run();
m_running = false;