mirror of https://github.com/4yn/brokenithm-kb
update uwebsockets version
parent
5531880481
commit
949f968667
|
@ -1,6 +1,6 @@
|
|||
# Repackaged uWebsockets library
|
||||
|
||||
Source was obtained from [uSockets](https://github.com/uNetworking/uSockets) and [uWebSockets](https://github.com/uNetworking/uWebSockets) and repackaged with a build system.
|
||||
Source was obtained from [uSockets v0.7.1](https://github.com/uNetworking/uSockets/tree/v0.7.1) and [uWebSockets v19.2.0](https://github.com/uNetworking/uWebSockets/tree/v19.2.0) and repackaged with build system configuration.
|
||||
|
||||
# Original uWebsockets README.md
|
||||
|
||||
|
|
|
@ -29,13 +29,13 @@
|
|||
|
||||
/* Define what a socket descriptor is based on platform */
|
||||
#ifdef _WIN32
|
||||
#ifndef NOMINMAX
|
||||
#define NOMINMAX
|
||||
#include <WinSock2.h>
|
||||
#endif
|
||||
#include <winsock2.h>
|
||||
#define LIBUS_SOCKET_DESCRIPTOR SOCKET
|
||||
#define WIN32_EXPORT __declspec(dllexport)
|
||||
#define alignas(x) __declspec(align(x))
|
||||
#else
|
||||
#include <stdalign.h>
|
||||
#define LIBUS_SOCKET_DESCRIPTOR int
|
||||
#define WIN32_EXPORT
|
||||
#endif
|
||||
|
@ -84,9 +84,20 @@ struct us_socket_context_options_t {
|
|||
const char *passphrase;
|
||||
const char *dh_params_file_name;
|
||||
const char *ca_file_name;
|
||||
int ssl_prefer_low_memory_usage;
|
||||
int ssl_prefer_low_memory_usage; /* Todo: rename to prefer_low_memory_usage and apply for TCP as well */
|
||||
};
|
||||
|
||||
/* Return 15-bit timestamp for this context */
|
||||
WIN32_EXPORT unsigned short us_socket_context_timestamp(int ssl, struct us_socket_context_t *context);
|
||||
|
||||
/* Adds SNI domain and cert in asn1 format */
|
||||
WIN32_EXPORT void us_socket_context_add_server_name(int ssl, struct us_socket_context_t *context, const char *hostname_pattern, struct us_socket_context_options_t options);
|
||||
WIN32_EXPORT void us_socket_context_remove_server_name(int ssl, struct us_socket_context_t *context, const char *hostname_pattern);
|
||||
WIN32_EXPORT void us_socket_context_on_server_name(int ssl, struct us_socket_context_t *context, void (*cb)(struct us_socket_context_t *, const char *hostname));
|
||||
|
||||
/* Returns the underlying SSL native handle, such as SSL_CTX or nullptr */
|
||||
WIN32_EXPORT void *us_socket_context_get_native_handle(int ssl, struct us_socket_context_t *context);
|
||||
|
||||
/* A socket context holds shared callbacks and user data extension for associated sockets */
|
||||
WIN32_EXPORT struct us_socket_context_t *us_create_socket_context(int ssl, struct us_loop_t *loop,
|
||||
int ext_size, struct us_socket_context_options_t options);
|
||||
|
@ -98,13 +109,16 @@ WIN32_EXPORT void us_socket_context_free(int ssl, struct us_socket_context_t *co
|
|||
WIN32_EXPORT void us_socket_context_on_open(int ssl, struct us_socket_context_t *context,
|
||||
struct us_socket_t *(*on_open)(struct us_socket_t *s, int is_client, char *ip, int ip_length));
|
||||
WIN32_EXPORT void us_socket_context_on_close(int ssl, struct us_socket_context_t *context,
|
||||
struct us_socket_t *(*on_close)(struct us_socket_t *s));
|
||||
struct us_socket_t *(*on_close)(struct us_socket_t *s, int code, void *reason));
|
||||
WIN32_EXPORT void us_socket_context_on_data(int ssl, struct us_socket_context_t *context,
|
||||
struct us_socket_t *(*on_data)(struct us_socket_t *s, char *data, int length));
|
||||
WIN32_EXPORT void us_socket_context_on_writable(int ssl, struct us_socket_context_t *context,
|
||||
struct us_socket_t *(*on_writable)(struct us_socket_t *s));
|
||||
WIN32_EXPORT void us_socket_context_on_timeout(int ssl, struct us_socket_context_t *context,
|
||||
struct us_socket_t *(*on_timeout)(struct us_socket_t *s));
|
||||
/* This one is only used for when a connecting socket fails in a late stage. */
|
||||
WIN32_EXPORT void us_socket_context_on_connect_error(int ssl, struct us_socket_context_t *context,
|
||||
struct us_socket_t *(*on_connect_error)(struct us_socket_t *s, int code));
|
||||
|
||||
/* Emitted when a socket has been half-closed */
|
||||
WIN32_EXPORT void us_socket_context_on_end(int ssl, struct us_socket_context_t *context, struct us_socket_t *(*on_end)(struct us_socket_t *s));
|
||||
|
@ -119,9 +133,18 @@ WIN32_EXPORT struct us_listen_socket_t *us_socket_context_listen(int ssl, struct
|
|||
/* listen_socket.c/.h */
|
||||
WIN32_EXPORT void us_listen_socket_close(int ssl, struct us_listen_socket_t *ls);
|
||||
|
||||
/* Land in on_open or on_close or return null or return socket */
|
||||
/* Land in on_open or on_connection_error or return null or return socket */
|
||||
WIN32_EXPORT struct us_socket_t *us_socket_context_connect(int ssl, struct us_socket_context_t *context,
|
||||
const char *host, int port, int options, int socket_ext_size);
|
||||
const char *host, int port, const char *source_host, int options, int socket_ext_size);
|
||||
|
||||
/* Is this socket established? Can be used to check if a connecting socket has fired the on_open event yet.
|
||||
* Can also be used to determine if a socket is a listen_socket or not, but you probably know that already. */
|
||||
WIN32_EXPORT int us_socket_is_established(int ssl, struct us_socket_t *s);
|
||||
|
||||
/* Cancel a connecting socket. Can be used together with us_socket_timeout to limit connection times.
|
||||
* Entirely destroys the socket - this function works like us_socket_close but does not trigger on_close event since
|
||||
* you never got the on_open event first. */
|
||||
WIN32_EXPORT struct us_socket_t *us_socket_close_connecting(int ssl, struct us_socket_t *s);
|
||||
|
||||
/* Returns the loop for this socket context. */
|
||||
WIN32_EXPORT struct us_loop_t *us_socket_context_loop(int ssl, struct us_socket_context_t *context);
|
||||
|
@ -189,6 +212,10 @@ WIN32_EXPORT struct us_poll_t *us_poll_resize(struct us_poll_t *p, struct us_loo
|
|||
|
||||
/* Public interfaces for sockets */
|
||||
|
||||
/* Returns the underlying native handle for a socket, such as SSL or file descriptor.
|
||||
* In the case of file descriptor, the value of pointer is fd. */
|
||||
WIN32_EXPORT void *us_socket_get_native_handle(int ssl, struct us_socket_t *s);
|
||||
|
||||
/* Write up to length bytes of data. Returns actual bytes written.
|
||||
* Will call the on_writable callback of active socket context on failure to write everything off in one go.
|
||||
* Set hint msg_more if you have more immediate data to write. */
|
||||
|
@ -210,6 +237,11 @@ WIN32_EXPORT void us_socket_flush(int ssl, struct us_socket_t *s);
|
|||
/* Shuts down the connection by sending FIN and/or close_notify */
|
||||
WIN32_EXPORT void us_socket_shutdown(int ssl, struct us_socket_t *s);
|
||||
|
||||
/* Shuts down the connection in terms of read, meaning next event loop
|
||||
* iteration will catch the socket being closed. Can be used to defer closing
|
||||
* to next event loop iteration. */
|
||||
WIN32_EXPORT void us_socket_shutdown_read(int ssl, struct us_socket_t *s);
|
||||
|
||||
/* Returns whether the socket has been shut down or not */
|
||||
WIN32_EXPORT int us_socket_is_shut_down(int ssl, struct us_socket_t *s);
|
||||
|
||||
|
@ -217,7 +249,10 @@ WIN32_EXPORT int us_socket_is_shut_down(int ssl, struct us_socket_t *s);
|
|||
WIN32_EXPORT int us_socket_is_closed(int ssl, struct us_socket_t *s);
|
||||
|
||||
/* Immediately closes the socket */
|
||||
WIN32_EXPORT struct us_socket_t *us_socket_close(int ssl, struct us_socket_t *s);
|
||||
WIN32_EXPORT struct us_socket_t *us_socket_close(int ssl, struct us_socket_t *s, int code, void *reason);
|
||||
|
||||
/* Returns local port or -1 on failure. */
|
||||
WIN32_EXPORT int us_socket_local_port(int ssl, struct us_socket_t *s);
|
||||
|
||||
/* Copy remote (IP) address of socket, or fail with zero length. */
|
||||
WIN32_EXPORT void us_socket_remote_address(int ssl, struct us_socket_t *s, char *buf, int *length);
|
||||
|
@ -230,7 +265,7 @@ WIN32_EXPORT void us_socket_remote_address(int ssl, struct us_socket_t *s, char
|
|||
#if !defined(LIBUS_USE_EPOLL) && !defined(LIBUS_USE_LIBUV) && !defined(LIBUS_USE_GCD) && !defined(LIBUS_USE_KQUEUE)
|
||||
#if defined(_WIN32)
|
||||
#define LIBUS_USE_LIBUV
|
||||
#elif defined(__APPLE__)
|
||||
#elif defined(__APPLE__) || defined(__FreeBSD__)
|
||||
#define LIBUS_USE_KQUEUE
|
||||
#else
|
||||
#define LIBUS_USE_EPOLL
|
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Authored by Alex Hultman, 2018-2019.
|
||||
* Authored by Alex Hultman, 2018-2020.
|
||||
* Intellectual property of third-party.
|
||||
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
@ -25,42 +25,104 @@
|
|||
#include "HttpResponse.h"
|
||||
#include "WebSocketContext.h"
|
||||
#include "WebSocket.h"
|
||||
#include "WebSocketExtensions.h"
|
||||
#include "WebSocketHandshake.h"
|
||||
#include "PerMessageDeflate.h"
|
||||
|
||||
namespace uWS {
|
||||
|
||||
/* Compress options (really more like PerMessageDeflateOptions) */
|
||||
enum CompressOptions {
|
||||
/* Compression disabled */
|
||||
DISABLED = 0,
|
||||
/* We compress using a shared non-sliding window. No added memory usage, worse compression. */
|
||||
SHARED_COMPRESSOR = 1,
|
||||
/* We compress using a dedicated sliding window. Major memory usage added, better compression of similarly repeated messages. */
|
||||
DEDICATED_COMPRESSOR = 2
|
||||
};
|
||||
/* This one matches us_socket_context_options_t but has default values */
|
||||
struct SocketContextOptions {
|
||||
const char *key_file_name = nullptr;
|
||||
const char *cert_file_name = nullptr;
|
||||
const char *passphrase = nullptr;
|
||||
const char *dh_params_file_name = nullptr;
|
||||
const char *ca_file_name = nullptr;
|
||||
int ssl_prefer_low_memory_usage = 0;
|
||||
|
||||
/* Conversion operator used internally */
|
||||
operator struct us_socket_context_options_t() const {
|
||||
struct us_socket_context_options_t socket_context_options;
|
||||
memcpy(&socket_context_options, this, sizeof(SocketContextOptions));
|
||||
return socket_context_options;
|
||||
}
|
||||
};
|
||||
|
||||
static_assert(sizeof(struct us_socket_context_options_t) == sizeof(SocketContextOptions), "Mismatching uSockets/uWebSockets ABI");
|
||||
|
||||
template <bool SSL>
|
||||
struct TemplatedApp {
|
||||
private:
|
||||
/* The app always owns at least one http context, but creates websocket contexts on demand */
|
||||
HttpContext<SSL> *httpContext;
|
||||
std::vector<WebSocketContext<SSL, true> *> webSocketContexts;
|
||||
std::vector<WebSocketContext<SSL, true, int> *> webSocketContexts;
|
||||
|
||||
public:
|
||||
|
||||
/* Server name */
|
||||
TemplatedApp &&addServerName(std::string hostname_pattern, SocketContextOptions options = {}) {
|
||||
|
||||
us_socket_context_add_server_name(SSL, (struct us_socket_context_t *) httpContext, hostname_pattern.c_str(), options);
|
||||
return std::move(*this);
|
||||
}
|
||||
|
||||
TemplatedApp &&removeServerName(std::string hostname_pattern) {
|
||||
|
||||
us_socket_context_remove_server_name(SSL, (struct us_socket_context_t *) httpContext, hostname_pattern.c_str());
|
||||
return std::move(*this);
|
||||
}
|
||||
|
||||
TemplatedApp &&missingServerName(MoveOnlyFunction<void(const char *hostname)> handler) {
|
||||
|
||||
if (!constructorFailed()) {
|
||||
httpContext->getSocketContextData()->missingServerNameHandler = std::move(handler);
|
||||
|
||||
us_socket_context_on_server_name(SSL, (struct us_socket_context_t *) httpContext, [](struct us_socket_context_t *context, const char *hostname) {
|
||||
|
||||
/* This is the only requirements of being friends with HttpContextData */
|
||||
HttpContext<SSL> *httpContext = (HttpContext<SSL> *) context;
|
||||
httpContext->getSocketContextData()->missingServerNameHandler(hostname);
|
||||
});
|
||||
}
|
||||
|
||||
return std::move(*this);
|
||||
}
|
||||
|
||||
/* Returns the SSL_CTX of this app, or nullptr. */
|
||||
void *getNativeHandle() {
|
||||
return us_socket_context_get_native_handle(SSL, (struct us_socket_context_t *) httpContext);
|
||||
}
|
||||
|
||||
/* Attaches a "filter" function to track socket connections/disconnections */
|
||||
void filter(fu2::unique_function<void(HttpResponse<SSL> *, int)> &&filterHandler) {
|
||||
void filter(MoveOnlyFunction<void(HttpResponse<SSL> *, int)> &&filterHandler) {
|
||||
httpContext->filter(std::move(filterHandler));
|
||||
}
|
||||
|
||||
/* Publishes a message to all websocket contexts */
|
||||
/* Publishes a message to all websocket contexts - conceptually as if publishing to the one single
|
||||
* TopicTree of this app (technically there are many TopicTrees, however the concept is that one
|
||||
* app has one conceptual Topic tree) */
|
||||
void publish(std::string_view topic, std::string_view message, OpCode opCode, bool compress = false) {
|
||||
for (auto *webSocketContext : webSocketContexts) {
|
||||
webSocketContext->getExt()->publish(topic, message, opCode, compress);
|
||||
}
|
||||
}
|
||||
|
||||
/* Returns number of subscribers for this topic, or 0 for failure.
|
||||
* This function should probably be optimized a lot in future releases,
|
||||
* it could be O(1) with a hash map of fullnames and their counts. */
|
||||
unsigned int numSubscribers(std::string_view topic) {
|
||||
unsigned int subscribers = 0;
|
||||
|
||||
for (auto *webSocketContext : webSocketContexts) {
|
||||
auto *webSocketContextData = webSocketContext->getExt();
|
||||
|
||||
Topic *t = webSocketContextData->topicTree.lookupTopic(topic);
|
||||
if (t) {
|
||||
subscribers += t->subs.size();
|
||||
}
|
||||
}
|
||||
|
||||
return subscribers;
|
||||
}
|
||||
|
||||
~TemplatedApp() {
|
||||
/* Let's just put everything here */
|
||||
if (httpContext) {
|
||||
|
@ -84,42 +146,79 @@ public:
|
|||
webSocketContexts = std::move(other.webSocketContexts);
|
||||
}
|
||||
|
||||
TemplatedApp(us_socket_context_options_t options = {}) {
|
||||
httpContext = uWS::HttpContext<SSL>::create(uWS::Loop::get(), options);
|
||||
TemplatedApp(SocketContextOptions options = {}) {
|
||||
httpContext = HttpContext<SSL>::create(Loop::get(), options);
|
||||
}
|
||||
|
||||
bool constructorFailed() {
|
||||
return !httpContext;
|
||||
}
|
||||
|
||||
template <typename UserData>
|
||||
struct WebSocketBehavior {
|
||||
/* Disabled compression by default - probably a bad default */
|
||||
CompressOptions compression = DISABLED;
|
||||
int maxPayloadLength = 16 * 1024;
|
||||
int idleTimeout = 120;
|
||||
int maxBackpressure = 1 * 1024 * 1204;
|
||||
fu2::unique_function<void(uWS::WebSocket<SSL, true> *, HttpRequest *)> open = nullptr;
|
||||
fu2::unique_function<void(uWS::WebSocket<SSL, true> *, std::string_view, uWS::OpCode)> message = nullptr;
|
||||
fu2::unique_function<void(uWS::WebSocket<SSL, true> *)> drain = nullptr;
|
||||
fu2::unique_function<void(uWS::WebSocket<SSL, true> *)> ping = nullptr;
|
||||
fu2::unique_function<void(uWS::WebSocket<SSL, true> *)> pong = nullptr;
|
||||
fu2::unique_function<void(uWS::WebSocket<SSL, true> *, int, std::string_view)> close = nullptr;
|
||||
/* Maximum message size we can receive */
|
||||
unsigned int maxPayloadLength = 16 * 1024;
|
||||
/* 2 minutes timeout is good */
|
||||
unsigned short idleTimeout = 120;
|
||||
/* 64kb backpressure is probably good */
|
||||
unsigned int maxBackpressure = 64 * 1024;
|
||||
bool closeOnBackpressureLimit = false;
|
||||
/* This one depends on kernel timeouts and is a bad default */
|
||||
bool resetIdleTimeoutOnSend = false;
|
||||
/* A good default, esp. for newcomers */
|
||||
bool sendPingsAutomatically = true;
|
||||
/* Maximum socket lifetime in seconds before forced closure (defaults to disabled) */
|
||||
unsigned short maxLifetime = 0;
|
||||
MoveOnlyFunction<void(HttpResponse<SSL> *, HttpRequest *, struct us_socket_context_t *)> upgrade = nullptr;
|
||||
MoveOnlyFunction<void(WebSocket<SSL, true, UserData> *)> open = nullptr;
|
||||
MoveOnlyFunction<void(WebSocket<SSL, true, UserData> *, std::string_view, OpCode)> message = nullptr;
|
||||
MoveOnlyFunction<void(WebSocket<SSL, true, UserData> *)> drain = nullptr;
|
||||
MoveOnlyFunction<void(WebSocket<SSL, true, UserData> *, std::string_view)> ping = nullptr;
|
||||
MoveOnlyFunction<void(WebSocket<SSL, true, UserData> *, std::string_view)> pong = nullptr;
|
||||
MoveOnlyFunction<void(WebSocket<SSL, true, UserData> *, int, std::string_view)> close = nullptr;
|
||||
};
|
||||
|
||||
template <typename UserData>
|
||||
TemplatedApp &&ws(std::string pattern, WebSocketBehavior &&behavior) {
|
||||
TemplatedApp &&ws(std::string pattern, WebSocketBehavior<UserData> &&behavior) {
|
||||
/* Don't compile if alignment rules cannot be satisfied */
|
||||
static_assert(alignof(UserData) <= LIBUS_EXT_ALIGNMENT,
|
||||
"µWebSockets cannot satisfy UserData alignment requirements. You need to recompile µSockets with LIBUS_EXT_ALIGNMENT adjusted accordingly.");
|
||||
|
||||
if (!httpContext) {
|
||||
return std::move(*this);
|
||||
}
|
||||
|
||||
/* Terminate on misleading idleTimeout values */
|
||||
if (behavior.idleTimeout && behavior.idleTimeout < 8) {
|
||||
std::cerr << "Error: idleTimeout must be either 0 or greater than 8!" << std::endl;
|
||||
std::terminate();
|
||||
}
|
||||
|
||||
if (behavior.idleTimeout % 4) {
|
||||
std::cerr << "Warning: idleTimeout should be a multiple of 4!" << std::endl;
|
||||
}
|
||||
|
||||
/* Every route has its own websocket context with its own behavior and user data type */
|
||||
auto *webSocketContext = WebSocketContext<SSL, true>::create(Loop::get(), (us_socket_context_t *) httpContext);
|
||||
auto *webSocketContext = WebSocketContext<SSL, true, UserData>::create(Loop::get(), (us_socket_context_t *) httpContext);
|
||||
|
||||
/* Add all other WebSocketContextData to this new WebSocketContextData */
|
||||
for (WebSocketContext<SSL, true, int> *adjacentWebSocketContext : webSocketContexts) {
|
||||
webSocketContext->getExt()->adjacentWebSocketContextDatas.push_back(adjacentWebSocketContext->getExt());
|
||||
}
|
||||
|
||||
/* Add this WebSocketContextData to all other WebSocketContextData */
|
||||
for (WebSocketContext<SSL, true, int> *adjacentWebSocketContext : webSocketContexts) {
|
||||
adjacentWebSocketContext->getExt()->adjacentWebSocketContextDatas.push_back((WebSocketContextData<SSL, int> *) webSocketContext->getExt());
|
||||
}
|
||||
|
||||
/* We need to clear this later on */
|
||||
webSocketContexts.push_back(webSocketContext);
|
||||
webSocketContexts.push_back((WebSocketContext<SSL, true, int> *) webSocketContext);
|
||||
|
||||
/* Quick fix to disable any compression if set */
|
||||
#ifdef UWS_NO_ZLIB
|
||||
behavior.compression = uWS::DISABLED;
|
||||
behavior.compression = DISABLED;
|
||||
#endif
|
||||
|
||||
/* If we are the first one to use compression, initialize it */
|
||||
|
@ -130,14 +229,15 @@ public:
|
|||
if (!loopData->zlibContext) {
|
||||
loopData->zlibContext = new ZlibContext;
|
||||
loopData->inflationStream = new InflationStream;
|
||||
loopData->deflationStream = new DeflationStream;
|
||||
loopData->deflationStream = new DeflationStream(CompressOptions::DEDICATED_COMPRESSOR);
|
||||
}
|
||||
}
|
||||
|
||||
/* Copy all handlers */
|
||||
webSocketContext->getExt()->openHandler = std::move(behavior.open);
|
||||
webSocketContext->getExt()->messageHandler = std::move(behavior.message);
|
||||
webSocketContext->getExt()->drainHandler = std::move(behavior.drain);
|
||||
webSocketContext->getExt()->closeHandler = std::move([closeHandler = std::move(behavior.close)](WebSocket<SSL, true> *ws, int code, std::string_view message) mutable {
|
||||
webSocketContext->getExt()->closeHandler = std::move([closeHandler = std::move(behavior.close)](WebSocket<SSL, true, UserData> *ws, int code, std::string_view message) mutable {
|
||||
if (closeHandler) {
|
||||
closeHandler(ws, code, message);
|
||||
}
|
||||
|
@ -150,99 +250,30 @@ public:
|
|||
|
||||
/* Copy settings */
|
||||
webSocketContext->getExt()->maxPayloadLength = behavior.maxPayloadLength;
|
||||
webSocketContext->getExt()->idleTimeout = behavior.idleTimeout;
|
||||
webSocketContext->getExt()->maxBackpressure = behavior.maxBackpressure;
|
||||
webSocketContext->getExt()->closeOnBackpressureLimit = behavior.closeOnBackpressureLimit;
|
||||
webSocketContext->getExt()->resetIdleTimeoutOnSend = behavior.resetIdleTimeoutOnSend;
|
||||
webSocketContext->getExt()->sendPingsAutomatically = behavior.sendPingsAutomatically;
|
||||
webSocketContext->getExt()->compression = behavior.compression;
|
||||
|
||||
httpContext->onHttp("get", pattern, [webSocketContext, httpContext = this->httpContext, behavior = std::move(behavior)](auto *res, auto *req) mutable {
|
||||
/* Calculate idleTimeoutCompnents */
|
||||
webSocketContext->getExt()->calculateIdleTimeoutCompnents(behavior.idleTimeout);
|
||||
|
||||
httpContext->onHttp("get", pattern, [webSocketContext, behavior = std::move(behavior)](auto *res, auto *req) mutable {
|
||||
|
||||
/* If we have this header set, it's a websocket */
|
||||
std::string_view secWebSocketKey = req->getHeader("sec-websocket-key");
|
||||
if (secWebSocketKey.length() == 24) {
|
||||
/* Note: OpenSSL can be used here to speed this up somewhat */
|
||||
char secWebSocketAccept[29] = {};
|
||||
WebSocketHandshake::generate(secWebSocketKey.data(), secWebSocketAccept);
|
||||
|
||||
res->writeStatus("101 Switching Protocols")
|
||||
->writeHeader("Upgrade", "websocket")
|
||||
->writeHeader("Connection", "Upgrade")
|
||||
->writeHeader("Sec-WebSocket-Accept", secWebSocketAccept);
|
||||
/* Emit upgrade handler */
|
||||
if (behavior.upgrade) {
|
||||
behavior.upgrade(res, req, (struct us_socket_context_t *) webSocketContext);
|
||||
} else {
|
||||
/* Default handler upgrades to WebSocket */
|
||||
std::string_view secWebSocketProtocol = req->getHeader("sec-websocket-protocol");
|
||||
std::string_view secWebSocketExtensions = req->getHeader("sec-websocket-extensions");
|
||||
|
||||
/* Select first subprotocol if present */
|
||||
std::string_view secWebSocketProtocol = req->getHeader("sec-websocket-protocol");
|
||||
if (secWebSocketProtocol.length()) {
|
||||
res->writeHeader("Sec-WebSocket-Protocol", secWebSocketProtocol.substr(0, secWebSocketProtocol.find(',')));
|
||||
}
|
||||
|
||||
/* Negotiate compression */
|
||||
bool perMessageDeflate = false;
|
||||
bool slidingDeflateWindow = false;
|
||||
if (behavior.compression != DISABLED) {
|
||||
std::string_view extensions = req->getHeader("sec-websocket-extensions");
|
||||
if (extensions.length()) {
|
||||
/* We never support client context takeover (the client cannot compress with a sliding window). */
|
||||
int wantedOptions = PERMESSAGE_DEFLATE | CLIENT_NO_CONTEXT_TAKEOVER;
|
||||
|
||||
/* Shared compressor is the default */
|
||||
if (behavior.compression == SHARED_COMPRESSOR) {
|
||||
/* Disable per-socket compressor */
|
||||
wantedOptions |= SERVER_NO_CONTEXT_TAKEOVER;
|
||||
}
|
||||
|
||||
/* isServer = true */
|
||||
ExtensionsNegotiator<true> extensionsNegotiator(wantedOptions);
|
||||
extensionsNegotiator.readOffer(extensions);
|
||||
|
||||
/* Todo: remove these mid string copies */
|
||||
std::string offer = extensionsNegotiator.generateOffer();
|
||||
if (offer.length()) {
|
||||
res->writeHeader("Sec-WebSocket-Extensions", offer);
|
||||
}
|
||||
|
||||
/* Did we negotiate permessage-deflate? */
|
||||
if (extensionsNegotiator.getNegotiatedOptions() & PERMESSAGE_DEFLATE) {
|
||||
perMessageDeflate = true;
|
||||
}
|
||||
|
||||
/* Is the server allowed to compress with a sliding window? */
|
||||
if (!(extensionsNegotiator.getNegotiatedOptions() & SERVER_NO_CONTEXT_TAKEOVER)) {
|
||||
slidingDeflateWindow = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* This will add our mark */
|
||||
res->upgrade();
|
||||
|
||||
/* Move any backpressure */
|
||||
std::string backpressure(std::move(((AsyncSocketData<SSL> *) res->getHttpResponseData())->buffer));
|
||||
|
||||
/* Keep any fallback buffer alive until we returned from open event, keeping req valid */
|
||||
std::string fallback(std::move(res->getHttpResponseData()->salvageFallbackBuffer()));
|
||||
|
||||
/* Destroy HttpResponseData */
|
||||
res->getHttpResponseData()->~HttpResponseData();
|
||||
|
||||
/* Adopting a socket invalidates it, do not rely on it directly to carry any data */
|
||||
WebSocket<SSL, true> *webSocket = (WebSocket<SSL, true> *) us_socket_context_adopt_socket(SSL,
|
||||
(us_socket_context_t *) webSocketContext, (us_socket_t *) res, sizeof(WebSocketData) + sizeof(UserData));
|
||||
|
||||
/* Update corked socket in case we got a new one (assuming we always are corked in handlers). */
|
||||
webSocket->AsyncSocket<SSL>::cork();
|
||||
|
||||
/* Initialize websocket with any moved backpressure intact */
|
||||
httpContext->upgradeToWebSocket(
|
||||
webSocket->init(perMessageDeflate, slidingDeflateWindow, std::move(backpressure))
|
||||
);
|
||||
|
||||
/* Arm idleTimeout */
|
||||
us_socket_timeout(SSL, (us_socket_t *) webSocket, behavior.idleTimeout);
|
||||
|
||||
/* Default construct the UserData right before calling open handler */
|
||||
new (webSocket->getUserData()) UserData;
|
||||
|
||||
/* Emit open event and start the timeout */
|
||||
if (behavior.open) {
|
||||
behavior.open(webSocket, req);
|
||||
res->template upgrade<UserData>({}, secWebSocketKey, secWebSocketProtocol, secWebSocketExtensions, (struct us_socket_context_t *) webSocketContext);
|
||||
}
|
||||
|
||||
/* We are going to get uncorked by the Http get return */
|
||||
|
@ -257,84 +288,104 @@ public:
|
|||
return std::move(*this);
|
||||
}
|
||||
|
||||
TemplatedApp &&get(std::string pattern, fu2::unique_function<void(HttpResponse<SSL> *, HttpRequest *)> &&handler) {
|
||||
httpContext->onHttp("get", pattern, std::move(handler));
|
||||
TemplatedApp &&get(std::string pattern, MoveOnlyFunction<void(HttpResponse<SSL> *, HttpRequest *)> &&handler) {
|
||||
if (httpContext) {
|
||||
httpContext->onHttp("get", pattern, std::move(handler));
|
||||
}
|
||||
return std::move(*this);
|
||||
}
|
||||
|
||||
TemplatedApp &&post(std::string pattern, fu2::unique_function<void(HttpResponse<SSL> *, HttpRequest *)> &&handler) {
|
||||
httpContext->onHttp("post", pattern, std::move(handler));
|
||||
TemplatedApp &&post(std::string pattern, MoveOnlyFunction<void(HttpResponse<SSL> *, HttpRequest *)> &&handler) {
|
||||
if (httpContext) {
|
||||
httpContext->onHttp("post", pattern, std::move(handler));
|
||||
}
|
||||
return std::move(*this);
|
||||
}
|
||||
|
||||
TemplatedApp &&options(std::string pattern, fu2::unique_function<void(HttpResponse<SSL> *, HttpRequest *)> &&handler) {
|
||||
httpContext->onHttp("options", pattern, std::move(handler));
|
||||
TemplatedApp &&options(std::string pattern, MoveOnlyFunction<void(HttpResponse<SSL> *, HttpRequest *)> &&handler) {
|
||||
if (httpContext) {
|
||||
httpContext->onHttp("options", pattern, std::move(handler));
|
||||
}
|
||||
return std::move(*this);
|
||||
}
|
||||
|
||||
TemplatedApp &&del(std::string pattern, fu2::unique_function<void(HttpResponse<SSL> *, HttpRequest *)> &&handler) {
|
||||
httpContext->onHttp("delete", pattern, std::move(handler));
|
||||
TemplatedApp &&del(std::string pattern, MoveOnlyFunction<void(HttpResponse<SSL> *, HttpRequest *)> &&handler) {
|
||||
if (httpContext) {
|
||||
httpContext->onHttp("delete", pattern, std::move(handler));
|
||||
}
|
||||
return std::move(*this);
|
||||
}
|
||||
|
||||
TemplatedApp &&patch(std::string pattern, fu2::unique_function<void(HttpResponse<SSL> *, HttpRequest *)> &&handler) {
|
||||
httpContext->onHttp("patch", pattern, std::move(handler));
|
||||
TemplatedApp &&patch(std::string pattern, MoveOnlyFunction<void(HttpResponse<SSL> *, HttpRequest *)> &&handler) {
|
||||
if (httpContext) {
|
||||
httpContext->onHttp("patch", pattern, std::move(handler));
|
||||
}
|
||||
return std::move(*this);
|
||||
}
|
||||
|
||||
TemplatedApp &&put(std::string pattern, fu2::unique_function<void(HttpResponse<SSL> *, HttpRequest *)> &&handler) {
|
||||
httpContext->onHttp("put", pattern, std::move(handler));
|
||||
TemplatedApp &&put(std::string pattern, MoveOnlyFunction<void(HttpResponse<SSL> *, HttpRequest *)> &&handler) {
|
||||
if (httpContext) {
|
||||
httpContext->onHttp("put", pattern, std::move(handler));
|
||||
}
|
||||
return std::move(*this);
|
||||
}
|
||||
|
||||
TemplatedApp &&head(std::string pattern, fu2::unique_function<void(HttpResponse<SSL> *, HttpRequest *)> &&handler) {
|
||||
httpContext->onHttp("head", pattern, std::move(handler));
|
||||
TemplatedApp &&head(std::string pattern, MoveOnlyFunction<void(HttpResponse<SSL> *, HttpRequest *)> &&handler) {
|
||||
if (httpContext) {
|
||||
httpContext->onHttp("head", pattern, std::move(handler));
|
||||
}
|
||||
return std::move(*this);
|
||||
}
|
||||
|
||||
TemplatedApp &&connect(std::string pattern, fu2::unique_function<void(HttpResponse<SSL> *, HttpRequest *)> &&handler) {
|
||||
httpContext->onHttp("connect", pattern, std::move(handler));
|
||||
TemplatedApp &&connect(std::string pattern, MoveOnlyFunction<void(HttpResponse<SSL> *, HttpRequest *)> &&handler) {
|
||||
if (httpContext) {
|
||||
httpContext->onHttp("connect", pattern, std::move(handler));
|
||||
}
|
||||
return std::move(*this);
|
||||
}
|
||||
|
||||
TemplatedApp &&trace(std::string pattern, fu2::unique_function<void(HttpResponse<SSL> *, HttpRequest *)> &&handler) {
|
||||
httpContext->onHttp("trace", pattern, std::move(handler));
|
||||
TemplatedApp &&trace(std::string pattern, MoveOnlyFunction<void(HttpResponse<SSL> *, HttpRequest *)> &&handler) {
|
||||
if (httpContext) {
|
||||
httpContext->onHttp("trace", pattern, std::move(handler));
|
||||
}
|
||||
return std::move(*this);
|
||||
}
|
||||
|
||||
/* This one catches any method */
|
||||
TemplatedApp &&any(std::string pattern, fu2::unique_function<void(HttpResponse<SSL> *, HttpRequest *)> &&handler) {
|
||||
httpContext->onHttp("*", pattern, std::move(handler));
|
||||
TemplatedApp &&any(std::string pattern, MoveOnlyFunction<void(HttpResponse<SSL> *, HttpRequest *)> &&handler) {
|
||||
if (httpContext) {
|
||||
httpContext->onHttp("*", pattern, std::move(handler));
|
||||
}
|
||||
return std::move(*this);
|
||||
}
|
||||
|
||||
/* Host, port, callback */
|
||||
TemplatedApp &&listen(std::string host, int port, fu2::unique_function<void(us_listen_socket_t *)> &&handler) {
|
||||
TemplatedApp &&listen(std::string host, int port, MoveOnlyFunction<void(us_listen_socket_t *)> &&handler) {
|
||||
if (!host.length()) {
|
||||
return listen(port, std::move(handler));
|
||||
}
|
||||
handler(httpContext->listen(host.c_str(), port, 0));
|
||||
handler(httpContext ? httpContext->listen(host.c_str(), port, 0) : nullptr);
|
||||
return std::move(*this);
|
||||
}
|
||||
|
||||
/* Host, port, options, callback */
|
||||
TemplatedApp &&listen(std::string host, int port, int options, fu2::unique_function<void(us_listen_socket_t *)> &&handler) {
|
||||
TemplatedApp &&listen(std::string host, int port, int options, MoveOnlyFunction<void(us_listen_socket_t *)> &&handler) {
|
||||
if (!host.length()) {
|
||||
return listen(port, options, std::move(handler));
|
||||
}
|
||||
handler(httpContext->listen(host.c_str(), port, options));
|
||||
handler(httpContext ? httpContext->listen(host.c_str(), port, options) : nullptr);
|
||||
return std::move(*this);
|
||||
}
|
||||
|
||||
/* Port, callback */
|
||||
TemplatedApp &&listen(int port, fu2::unique_function<void(us_listen_socket_t *)> &&handler) {
|
||||
handler(httpContext->listen(nullptr, port, 0));
|
||||
TemplatedApp &&listen(int port, MoveOnlyFunction<void(us_listen_socket_t *)> &&handler) {
|
||||
handler(httpContext ? httpContext->listen(nullptr, port, 0) : nullptr);
|
||||
return std::move(*this);
|
||||
}
|
||||
|
||||
/* Port, options, callback */
|
||||
TemplatedApp &&listen(int port, int options, fu2::unique_function<void(us_listen_socket_t *)> &&handler) {
|
||||
handler(httpContext->listen(nullptr, port, options));
|
||||
TemplatedApp &&listen(int port, int options, MoveOnlyFunction<void(us_listen_socket_t *)> &&handler) {
|
||||
handler(httpContext ? httpContext->listen(nullptr, port, options) : nullptr);
|
||||
return std::move(*this);
|
||||
}
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Authored by Alex Hultman, 2018-2019.
|
||||
* Authored by Alex Hultman, 2018-2020.
|
||||
* Intellectual property of third-party.
|
||||
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
@ -20,21 +20,30 @@
|
|||
|
||||
/* This class implements async socket memory management strategies */
|
||||
|
||||
/* NOTE: Many unsigned/signed conversion warnings could be solved by moving from int length
|
||||
* to unsigned length for everything to/from uSockets - this would however remove the opportunity
|
||||
* to signal error with -1 (which is how the entire UNIX syscalling is built). */
|
||||
|
||||
#include "LoopData.h"
|
||||
#include "AsyncSocketData.h"
|
||||
|
||||
namespace uWS {
|
||||
|
||||
template <bool, bool> struct WebSocketContext;
|
||||
template <bool, bool, typename> struct WebSocketContext;
|
||||
|
||||
template <bool SSL>
|
||||
struct AsyncSocket {
|
||||
template <bool> friend struct HttpContext;
|
||||
template <bool, bool> friend struct WebSocketContext;
|
||||
template <bool> friend struct WebSocketContextData;
|
||||
template <bool, bool, typename> friend struct WebSocketContext;
|
||||
template <bool, typename> friend struct WebSocketContextData;
|
||||
friend struct TopicTree;
|
||||
|
||||
protected:
|
||||
/* Returns SSL pointer or FD as pointer */
|
||||
void *getNativeHandle() {
|
||||
return us_socket_get_native_handle(SSL, (us_socket_t *) this);
|
||||
}
|
||||
|
||||
/* Get loop data for socket */
|
||||
LoopData *getLoopData() {
|
||||
return (LoopData *) us_loop_ext(us_socket_context_loop(SSL, us_socket_context(SSL, (us_socket_t *) this)));
|
||||
|
@ -57,7 +66,7 @@ protected:
|
|||
|
||||
/* Immediately close socket */
|
||||
us_socket_t *close() {
|
||||
return us_socket_close(SSL, (us_socket_t *) this);
|
||||
return us_socket_close(SSL, (us_socket_t *) this, 0, nullptr);
|
||||
}
|
||||
|
||||
/* Cork this socket. Only one socket may ever be corked per-loop at any given time */
|
||||
|
@ -82,7 +91,7 @@ protected:
|
|||
LoopData *loopData = getLoopData();
|
||||
if (loopData->corkedSocket == this && loopData->corkOffset + size < LoopData::CORK_BUFFER_SIZE) {
|
||||
char *sendBuffer = loopData->corkBuffer + loopData->corkOffset;
|
||||
loopData->corkOffset += (int) size;
|
||||
loopData->corkOffset += (unsigned int) size;
|
||||
return {sendBuffer, false};
|
||||
} else {
|
||||
/* Slow path for now, we want to always be corked if possible */
|
||||
|
@ -91,8 +100,30 @@ protected:
|
|||
}
|
||||
|
||||
/* Returns the user space backpressure. */
|
||||
int getBufferedAmount() {
|
||||
return (int) getAsyncSocketData()->buffer.size();
|
||||
unsigned int getBufferedAmount() {
|
||||
return (unsigned int) getAsyncSocketData()->buffer.size();
|
||||
}
|
||||
|
||||
/* Returns the text representation of an IPv4 or IPv6 address */
|
||||
std::string_view addressAsText(std::string_view binary) {
|
||||
static thread_local char buf[64];
|
||||
int ipLength = 0;
|
||||
|
||||
if (!binary.length()) {
|
||||
return {};
|
||||
}
|
||||
|
||||
unsigned char *b = (unsigned char *) binary.data();
|
||||
|
||||
if (binary.length() == 4) {
|
||||
ipLength = sprintf(buf, "%u.%u.%u.%u", b[0], b[1], b[2], b[3]);
|
||||
} else {
|
||||
ipLength = sprintf(buf, "%02x%02x:%02x%02x:%02x%02x:%02x%02x:%02x%02x:%02x%02x:%02x%02x:%02x%02x",
|
||||
b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7], b[8], b[9], b[10], b[11],
|
||||
b[12], b[13], b[14], b[15]);
|
||||
}
|
||||
|
||||
return {buf, (unsigned int) ipLength};
|
||||
}
|
||||
|
||||
/* Returns the remote IP address or empty string on failure */
|
||||
|
@ -100,7 +131,12 @@ protected:
|
|||
static thread_local char buf[16];
|
||||
int ipLength = 16;
|
||||
us_socket_remote_address(SSL, (us_socket_t *) this, buf, &ipLength);
|
||||
return std::string_view(buf, ipLength);
|
||||
return std::string_view(buf, (unsigned int) ipLength);
|
||||
}
|
||||
|
||||
/* Returns the text representation of IP */
|
||||
std::string_view getRemoteAddressAsText() {
|
||||
return addressAsText(getRemoteAddress());
|
||||
}
|
||||
|
||||
/* Write in three levels of prioritization: cork-buffer, syscall, socket-buffer. Always drain if possible.
|
||||
|
@ -124,14 +160,14 @@ protected:
|
|||
if ((unsigned int) written < asyncSocketData->buffer.length()) {
|
||||
|
||||
/* Update buffering (todo: we can do better here if we keep track of what happens to this guy later on) */
|
||||
asyncSocketData->buffer = asyncSocketData->buffer.substr(written);
|
||||
asyncSocketData->buffer = asyncSocketData->buffer.substr((size_t) written);
|
||||
|
||||
if (optionally) {
|
||||
/* Thankfully we can exit early here */
|
||||
return {0, true};
|
||||
} else {
|
||||
/* This path is horrible and points towards erroneous usage */
|
||||
asyncSocketData->buffer.append(src, length);
|
||||
asyncSocketData->buffer.append(src, (unsigned int) length);
|
||||
|
||||
return {length, true};
|
||||
}
|
||||
|
@ -144,21 +180,21 @@ protected:
|
|||
if (length) {
|
||||
if (loopData->corkedSocket == this) {
|
||||
/* We are corked */
|
||||
if (LoopData::CORK_BUFFER_SIZE - loopData->corkOffset >= length) {
|
||||
if (LoopData::CORK_BUFFER_SIZE - loopData->corkOffset >= (unsigned int) length) {
|
||||
/* If the entire chunk fits in cork buffer */
|
||||
memcpy(loopData->corkBuffer + loopData->corkOffset, src, length);
|
||||
loopData->corkOffset += length;
|
||||
memcpy(loopData->corkBuffer + loopData->corkOffset, src, (unsigned int) length);
|
||||
loopData->corkOffset += (unsigned int) length;
|
||||
/* Fall through to default return */
|
||||
} else {
|
||||
/* Strategy differences between SSL and non-SSL regarding syscall minimizing */
|
||||
if constexpr (SSL) {
|
||||
/* Cork up as much as we can */
|
||||
int stripped = LoopData::CORK_BUFFER_SIZE - loopData->corkOffset;
|
||||
unsigned int stripped = LoopData::CORK_BUFFER_SIZE - loopData->corkOffset;
|
||||
memcpy(loopData->corkBuffer + loopData->corkOffset, src, stripped);
|
||||
loopData->corkOffset = LoopData::CORK_BUFFER_SIZE;
|
||||
|
||||
auto [written, failed] = uncork(src + stripped, length - stripped, optionally);
|
||||
return {written + stripped, failed};
|
||||
auto [written, failed] = uncork(src + stripped, length - (int) stripped, optionally);
|
||||
return {written + (int) stripped, failed};
|
||||
}
|
||||
|
||||
/* For non-SSL we take the penalty of two syscalls */
|
||||
|
@ -178,11 +214,11 @@ protected:
|
|||
/* Fall back to worst possible case (should be very rare for HTTP) */
|
||||
/* At least we can reserve room for next chunk if we know it up front */
|
||||
if (nextLength) {
|
||||
asyncSocketData->buffer.reserve(asyncSocketData->buffer.length() + length - written + nextLength);
|
||||
asyncSocketData->buffer.reserve(asyncSocketData->buffer.length() + (size_t) (length - written + nextLength));
|
||||
}
|
||||
|
||||
/* Buffer this chunk */
|
||||
asyncSocketData->buffer.append(src + written, length - written);
|
||||
asyncSocketData->buffer.append(src + written, (size_t) (length - written));
|
||||
|
||||
/* Return the failure */
|
||||
return {length, true};
|
||||
|
@ -205,7 +241,7 @@ protected:
|
|||
|
||||
if (loopData->corkOffset) {
|
||||
/* Corked data is already accounted for via its write call */
|
||||
auto [written, failed] = write(loopData->corkBuffer, loopData->corkOffset, false, length);
|
||||
auto [written, failed] = write(loopData->corkBuffer, (int) loopData->corkOffset, false, length);
|
||||
loopData->corkOffset = 0;
|
||||
|
||||
if (failed) {
|
||||
|
|
|
@ -0,0 +1,65 @@
|
|||
/*
|
||||
* Authored by Alex Hultman, 2018-2019.
|
||||
* Intellectual property of third-party.
|
||||
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef UWS_BLOOMFILTER_H
|
||||
#define UWS_BLOOMFILTER_H
|
||||
|
||||
/* This filter has a decently low amount of false positives for the
|
||||
* standard and non-standard common request headers */
|
||||
|
||||
#include <string_view>
|
||||
#include <bitset>
|
||||
|
||||
namespace uWS {
|
||||
|
||||
struct BloomFilter {
|
||||
private:
|
||||
std::bitset<512> filter;
|
||||
|
||||
unsigned int hash1(std::string_view key) {
|
||||
return ((size_t)key[key.length() - 1] - (key.length() << 3)) & 511;
|
||||
}
|
||||
|
||||
unsigned int hash2(std::string_view key) {
|
||||
return (((size_t)key[0] + (key.length() << 4)) & 511);
|
||||
}
|
||||
|
||||
unsigned int hash3(std::string_view key) {
|
||||
return ((unsigned int)key[key.length() - 2] - 97 - (key.length() << 5)) & 511;
|
||||
}
|
||||
|
||||
public:
|
||||
bool mightHave(std::string_view key) {
|
||||
return filter.test(hash1(key)) && filter.test(hash2(key)) && (key.length() < 2 || filter.test(hash3(key)));
|
||||
}
|
||||
|
||||
void add(std::string_view key) {
|
||||
filter.set(hash1(key));
|
||||
filter.set(hash2(key));
|
||||
if (key.length() >= 2) {
|
||||
filter.set(hash3(key));
|
||||
}
|
||||
}
|
||||
|
||||
void reset() {
|
||||
filter.reset();
|
||||
}
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
#endif // UWS_BLOOMFILTER_H
|
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Authored by Alex Hultman, 2018-2019.
|
||||
* Authored by Alex Hultman, 2018-2020.
|
||||
* Intellectual property of third-party.
|
||||
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
@ -24,10 +24,11 @@
|
|||
#include "HttpContextData.h"
|
||||
#include "HttpResponseData.h"
|
||||
#include "AsyncSocket.h"
|
||||
#include "WebSocketData.h"
|
||||
|
||||
#include <string_view>
|
||||
#include <iostream>
|
||||
#include "f2/function2.hpp"
|
||||
#include "MoveOnlyFunction.h"
|
||||
|
||||
namespace uWS {
|
||||
template<bool> struct HttpResponse;
|
||||
|
@ -35,6 +36,7 @@ template<bool> struct HttpResponse;
|
|||
template <bool SSL>
|
||||
struct HttpContext {
|
||||
template<bool> friend struct TemplatedApp;
|
||||
template<bool> friend struct HttpResponse;
|
||||
private:
|
||||
HttpContext() = delete;
|
||||
|
||||
|
@ -60,7 +62,7 @@ private:
|
|||
/* Init the HttpContext by registering libusockets event handlers */
|
||||
HttpContext<SSL> *init() {
|
||||
/* Handle socket connections */
|
||||
us_socket_context_on_open(SSL, getSocketContext(), [](us_socket_t *s, int is_client, char *ip, int ip_length) {
|
||||
us_socket_context_on_open(SSL, getSocketContext(), [](us_socket_t *s, int /*is_client*/, char */*ip*/, int /*ip_length*/) {
|
||||
/* Any connected socket should timeout until it has a request */
|
||||
us_socket_timeout(SSL, s, HTTP_IDLE_TIMEOUT_S);
|
||||
|
||||
|
@ -77,7 +79,7 @@ private:
|
|||
});
|
||||
|
||||
/* Handle socket disconnections */
|
||||
us_socket_context_on_close(SSL, getSocketContext(), [](us_socket_t *s) {
|
||||
us_socket_context_on_close(SSL, getSocketContext(), [](us_socket_t *s, int /*code*/, void */*reason*/) {
|
||||
/* Get socket ext */
|
||||
HttpResponseData<SSL> *httpResponseData = (HttpResponseData<SSL> *) us_socket_ext(SSL, s);
|
||||
|
||||
|
@ -119,11 +121,19 @@ private:
|
|||
/* Cork this socket */
|
||||
((AsyncSocket<SSL> *) s)->cork();
|
||||
|
||||
/* Mark that we are inside the parser now */
|
||||
httpContextData->isParsingHttp = true;
|
||||
|
||||
// clients need to know the cursor after http parse, not servers!
|
||||
// how far did we read then? we need to know to continue with websocket parsing data? or?
|
||||
|
||||
void *proxyParser = nullptr;
|
||||
#ifdef UWS_WITH_PROXY
|
||||
proxyParser = &httpResponseData->proxyParser;
|
||||
#endif
|
||||
|
||||
/* The return value is entirely up to us to interpret. The HttpParser only care for whether the returned value is DIFFERENT or not from passed user */
|
||||
void *returnedSocket = httpResponseData->consumePostPadded(data, length, s, [httpContextData](void *s, uWS::HttpRequest *httpRequest) -> void * {
|
||||
void *returnedSocket = httpResponseData->consumePostPadded(data, (unsigned int) length, s, proxyParser, [httpContextData](void *s, HttpRequest *httpRequest) -> void * {
|
||||
/* For every request we reset the timeout and hang until user makes action */
|
||||
/* Warning: if we are in shutdown state, resetting the timer is a security issue! */
|
||||
us_socket_timeout(SSL, (us_socket_t *) s, 0);
|
||||
|
@ -134,18 +144,23 @@ private:
|
|||
|
||||
/* Are we not ready for another request yet? Terminate the connection. */
|
||||
if (httpResponseData->state & HttpResponseData<SSL>::HTTP_RESPONSE_PENDING) {
|
||||
us_socket_close(SSL, (us_socket_t *) s);
|
||||
us_socket_close(SSL, (us_socket_t *) s, 0, nullptr);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
/* Mark pending request and emit it */
|
||||
httpResponseData->state = HttpResponseData<SSL>::HTTP_RESPONSE_PENDING;
|
||||
|
||||
/* Mark this response as connectionClose if ancient or connection: close */
|
||||
if (httpRequest->isAncient() || httpRequest->getHeader("connection").length() == 5) {
|
||||
httpResponseData->state |= HttpResponseData<SSL>::HTTP_CONNECTION_CLOSE;
|
||||
}
|
||||
|
||||
/* Route the method and URL */
|
||||
httpContextData->router.getUserData() = {(HttpResponse<SSL> *) s, httpRequest};
|
||||
if (!httpContextData->router.route(httpRequest->getMethod(), httpRequest->getUrl())) {
|
||||
/* We have to force close this socket as we have no handler for it */
|
||||
us_socket_close(SSL, (us_socket_t *) s);
|
||||
us_socket_close(SSL, (us_socket_t *) s, 0, nullptr);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
@ -205,7 +220,7 @@ private:
|
|||
if (us_socket_is_shut_down(SSL, (us_socket_t *) user)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
||||
/* If we were given the last data chunk, reset data handler to ensure following
|
||||
* requests on the same socket won't trigger any previously registered behavior */
|
||||
if (fin) {
|
||||
|
@ -215,10 +230,13 @@ private:
|
|||
return user;
|
||||
}, [](void *user) {
|
||||
/* Close any socket on HTTP errors */
|
||||
us_socket_close(SSL, (us_socket_t *) user);
|
||||
us_socket_close(SSL, (us_socket_t *) user, 0, nullptr);
|
||||
return nullptr;
|
||||
});
|
||||
|
||||
/* Mark that we are no longer parsing Http */
|
||||
httpContextData->isParsingHttp = false;
|
||||
|
||||
/* We need to uncork in all cases, except for nullptr (closed socket, or upgraded socket) */
|
||||
if (returnedSocket != nullptr) {
|
||||
/* Timeout on uncork failure */
|
||||
|
@ -229,6 +247,18 @@ private:
|
|||
((AsyncSocket<SSL> *) s)->timeout(HTTP_IDLE_TIMEOUT_S);
|
||||
}
|
||||
|
||||
/* We need to check if we should close this socket here now */
|
||||
if (httpResponseData->state & HttpResponseData<SSL>::HTTP_CONNECTION_CLOSE) {
|
||||
if ((httpResponseData->state & HttpResponseData<SSL>::HTTP_RESPONSE_PENDING) == 0) {
|
||||
if (((AsyncSocket<SSL> *) s)->getBufferedAmount() == 0) {
|
||||
((AsyncSocket<SSL> *) s)->shutdown();
|
||||
/* We need to force close after sending FIN since we want to hinder
|
||||
* clients from keeping to send their huge data */
|
||||
((AsyncSocket<SSL> *) s)->close();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return (us_socket_t *) returnedSocket;
|
||||
}
|
||||
|
||||
|
@ -238,7 +268,16 @@ private:
|
|||
AsyncSocket<SSL> *asyncSocket = (AsyncSocket<SSL> *) httpContextData->upgradedWebSocket;
|
||||
|
||||
/* Uncork here as well (note: what if we failed to uncork and we then pub/sub before we even upgraded?) */
|
||||
/*auto [written, failed] = */asyncSocket->uncork();
|
||||
auto [written, failed] = asyncSocket->uncork();
|
||||
|
||||
/* If we succeeded in uncorking, check if we have sent WebSocket FIN */
|
||||
if (!failed) {
|
||||
WebSocketData *webSocketData = (WebSocketData *) asyncSocket->getAsyncSocketData();
|
||||
if (webSocketData->isShuttingDown) {
|
||||
/* In that case, also send TCP FIN (this is similar to what we have in ws drain handler) */
|
||||
asyncSocket->shutdown();
|
||||
}
|
||||
}
|
||||
|
||||
/* Reset upgradedWebSocket before we return */
|
||||
httpContextData->upgradedWebSocket = nullptr;
|
||||
|
@ -283,6 +322,18 @@ private:
|
|||
/* Drain any socket buffer, this might empty our backpressure and thus finish the request */
|
||||
/*auto [written, failed] = */asyncSocket->write(nullptr, 0, true, 0);
|
||||
|
||||
/* Should we close this connection after a response - and is this response really done? */
|
||||
if (httpResponseData->state & HttpResponseData<SSL>::HTTP_CONNECTION_CLOSE) {
|
||||
if ((httpResponseData->state & HttpResponseData<SSL>::HTTP_RESPONSE_PENDING) == 0) {
|
||||
if (asyncSocket->getBufferedAmount() == 0) {
|
||||
asyncSocket->shutdown();
|
||||
/* We need to force close after sending FIN since we want to hinder
|
||||
* clients from keeping to send their huge data */
|
||||
asyncSocket->close();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Expect another writable event, or another request within the timeout */
|
||||
asyncSocket->timeout(HTTP_IDLE_TIMEOUT_S);
|
||||
|
||||
|
@ -310,13 +361,6 @@ private:
|
|||
return this;
|
||||
}
|
||||
|
||||
/* Used by App in its WebSocket handler */
|
||||
void upgradeToWebSocket(void *newSocket) {
|
||||
HttpContextData<SSL> *httpContextData = getSocketContextData();
|
||||
|
||||
httpContextData->upgradedWebSocket = newSocket;
|
||||
}
|
||||
|
||||
public:
|
||||
/* Construct a new HttpContext using specified loop */
|
||||
static HttpContext *create(Loop *loop, us_socket_context_options_t options = {}) {
|
||||
|
@ -343,12 +387,12 @@ public:
|
|||
us_socket_context_free(SSL, getSocketContext());
|
||||
}
|
||||
|
||||
void filter(fu2::unique_function<void(HttpResponse<SSL> *, int)> &&filterHandler) {
|
||||
void filter(MoveOnlyFunction<void(HttpResponse<SSL> *, int)> &&filterHandler) {
|
||||
getSocketContextData()->filterHandlers.emplace_back(std::move(filterHandler));
|
||||
}
|
||||
|
||||
/* Register an HTTP route handler acording to URL pattern */
|
||||
void onHttp(std::string method, std::string pattern, fu2::unique_function<void(HttpResponse<SSL> *, HttpRequest *)> &&handler, bool upgrade = false) {
|
||||
void onHttp(std::string method, std::string pattern, MoveOnlyFunction<void(HttpResponse<SSL> *, HttpRequest *)> &&handler, bool upgrade = false) {
|
||||
HttpContextData<SSL> *httpContextData = getSocketContextData();
|
||||
|
||||
/* Todo: This is ugly, fix */
|
||||
|
@ -363,6 +407,13 @@ public:
|
|||
auto user = r->getUserData();
|
||||
user.httpRequest->setYield(false);
|
||||
user.httpRequest->setParameters(r->getParameters());
|
||||
|
||||
/* Middleware? Automatically respond to expectations */
|
||||
std::string_view expect = user.httpRequest->getHeader("expect");
|
||||
if (expect.length() && expect == "100-continue") {
|
||||
user.httpResponse->writeContinue();
|
||||
}
|
||||
|
||||
handler(user.httpResponse, user.httpRequest);
|
||||
|
||||
/* If any handler yielded, the router will keep looking for a suitable handler. */
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Authored by Alex Hultman, 2018-2019.
|
||||
* Authored by Alex Hultman, 2018-2020.
|
||||
* Intellectual property of third-party.
|
||||
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
@ -21,7 +21,7 @@
|
|||
#include "HttpRouter.h"
|
||||
|
||||
#include <vector>
|
||||
#include "f2/function2.hpp"
|
||||
#include "MoveOnlyFunction.h"
|
||||
|
||||
namespace uWS {
|
||||
template<bool> struct HttpResponse;
|
||||
|
@ -31,8 +31,11 @@ template <bool SSL>
|
|||
struct alignas(16) HttpContextData {
|
||||
template <bool> friend struct HttpContext;
|
||||
template <bool> friend struct HttpResponse;
|
||||
template <bool> friend struct TemplatedApp;
|
||||
private:
|
||||
std::vector<fu2::unique_function<void(HttpResponse<SSL> *, int)>> filterHandlers;
|
||||
std::vector<MoveOnlyFunction<void(HttpResponse<SSL> *, int)>> filterHandlers;
|
||||
|
||||
MoveOnlyFunction<void(const char *hostname)> missingServerNameHandler;
|
||||
|
||||
struct RouterData {
|
||||
HttpResponse<SSL> *httpResponse;
|
||||
|
@ -41,6 +44,7 @@ private:
|
|||
|
||||
HttpRouter<RouterData> router;
|
||||
void *upgradedWebSocket = nullptr;
|
||||
bool isParsingHttp = false;
|
||||
};
|
||||
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Authored by Alex Hultman, 2018-2019.
|
||||
* Authored by Alex Hultman, 2018-2020.
|
||||
* Intellectual property of third-party.
|
||||
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
@ -25,12 +25,16 @@
|
|||
#include <string>
|
||||
#include <cstring>
|
||||
#include <algorithm>
|
||||
#include "f2/function2.hpp"
|
||||
#include "MoveOnlyFunction.h"
|
||||
|
||||
#include "BloomFilter.h"
|
||||
#include "ProxyParser.h"
|
||||
#include "QueryParser.h"
|
||||
|
||||
namespace uWS {
|
||||
|
||||
/* We require at least this much post padding */
|
||||
static const int MINIMUM_HTTP_POST_PADDING = 32;
|
||||
static const unsigned int MINIMUM_HTTP_POST_PADDING = 32;
|
||||
|
||||
struct HttpRequest {
|
||||
|
||||
|
@ -41,12 +45,17 @@ private:
|
|||
struct Header {
|
||||
std::string_view key, value;
|
||||
} headers[MAX_HEADERS];
|
||||
int querySeparator;
|
||||
bool ancientHttp;
|
||||
unsigned int querySeparator;
|
||||
bool didYield;
|
||||
|
||||
BloomFilter bf;
|
||||
std::pair<int, std::string_view *> currentParameters;
|
||||
|
||||
public:
|
||||
bool isAncient() {
|
||||
return ancientHttp;
|
||||
}
|
||||
|
||||
bool getYield() {
|
||||
return didYield;
|
||||
}
|
||||
|
@ -87,9 +96,11 @@ public:
|
|||
}
|
||||
|
||||
std::string_view getHeader(std::string_view lowerCasedHeader) {
|
||||
for (Header *h = headers; (++h)->key.length(); ) {
|
||||
if (h->key.length() == lowerCasedHeader.length() && !strncmp(h->key.data(), lowerCasedHeader.data(), lowerCasedHeader.length())) {
|
||||
return h->value;
|
||||
if (bf.mightHave(lowerCasedHeader)) {
|
||||
for (Header *h = headers; (++h)->key.length(); ) {
|
||||
if (h->key.length() == lowerCasedHeader.length() && !strncmp(h->key.data(), lowerCasedHeader.data(), lowerCasedHeader.length())) {
|
||||
return h->value;
|
||||
}
|
||||
}
|
||||
}
|
||||
return std::string_view(nullptr, 0);
|
||||
|
@ -103,8 +114,9 @@ public:
|
|||
return std::string_view(headers->key.data(), headers->key.length());
|
||||
}
|
||||
|
||||
/* Returns the raw querystring as a whole, still encoded */
|
||||
std::string_view getQuery() {
|
||||
if (querySeparator < (int) headers->value.length()) {
|
||||
if (querySeparator < headers->value.length()) {
|
||||
/* Strip the initial ? */
|
||||
return std::string_view(headers->value.data() + querySeparator + 1, headers->value.length() - querySeparator - 1);
|
||||
} else {
|
||||
|
@ -112,11 +124,19 @@ public:
|
|||
}
|
||||
}
|
||||
|
||||
/* Finds and decodes the URI component. */
|
||||
std::string_view getQuery(std::string_view key) {
|
||||
/* Raw querystring including initial '?' sign */
|
||||
std::string_view queryString = std::string_view(headers->value.data() + querySeparator, headers->value.length() - querySeparator);
|
||||
|
||||
return getDecodedQueryValue(key, queryString);
|
||||
}
|
||||
|
||||
void setParameters(std::pair<int, std::string_view *> parameters) {
|
||||
currentParameters = parameters;
|
||||
}
|
||||
|
||||
std::string_view getParameter(unsigned int index) {
|
||||
std::string_view getParameter(unsigned short index) {
|
||||
if (currentParameters.first < (int) index) {
|
||||
return {};
|
||||
} else {
|
||||
|
@ -136,15 +156,40 @@ private:
|
|||
|
||||
static unsigned int toUnsignedInteger(std::string_view str) {
|
||||
unsigned int unsignedIntegerValue = 0;
|
||||
for (unsigned char c : str) {
|
||||
unsignedIntegerValue = unsignedIntegerValue * 10 + (c - '0');
|
||||
for (char c : str) {
|
||||
unsignedIntegerValue = unsignedIntegerValue * 10u + ((unsigned int) c - (unsigned int) '0');
|
||||
}
|
||||
return unsignedIntegerValue;
|
||||
}
|
||||
|
||||
static unsigned int getHeaders(char *postPaddedBuffer, char *end, struct HttpRequest::Header *headers) {
|
||||
static unsigned int getHeaders(char *postPaddedBuffer, char *end, struct HttpRequest::Header *headers, void *reserved) {
|
||||
char *preliminaryKey, *preliminaryValue, *start = postPaddedBuffer;
|
||||
|
||||
#ifdef UWS_WITH_PROXY
|
||||
/* ProxyParser is passed as reserved parameter */
|
||||
ProxyParser *pp = (ProxyParser *) reserved;
|
||||
|
||||
/* Parse PROXY protocol */
|
||||
auto [done, offset] = pp->parse({start, (size_t) (end - postPaddedBuffer)});
|
||||
if (!done) {
|
||||
/* We do not reset the ProxyParser (on filure) since it is tied to this
|
||||
* connection, which is really only supposed to ever get one PROXY frame
|
||||
* anyways. We do however allow multiple PROXY frames to be sent (overwrites former). */
|
||||
return 0;
|
||||
} else {
|
||||
/* We have consumed this data so skip it */
|
||||
start += offset;
|
||||
}
|
||||
#else
|
||||
/* This one is unused */
|
||||
(void) reserved;
|
||||
#endif
|
||||
|
||||
/* It is critical for fallback buffering logic that we only return with success
|
||||
* if we managed to parse a complete HTTP request (minus data). Returning success
|
||||
* for PROXY means we can end up succeeding, yet leaving bytes in the fallback buffer
|
||||
* which is then removed, and our counters to flip due to overflow and we end up with a crash */
|
||||
|
||||
for (unsigned int i = 0; i < HttpRequest::MAX_HEADERS; i++) {
|
||||
for (preliminaryKey = postPaddedBuffer; (*postPaddedBuffer != ':') & (*postPaddedBuffer > 32); *(postPaddedBuffer++) |= 32);
|
||||
if (*postPaddedBuffer == '\r') {
|
||||
|
@ -158,7 +203,7 @@ private:
|
|||
headers->key = std::string_view(preliminaryKey, (size_t) (postPaddedBuffer - preliminaryKey));
|
||||
for (postPaddedBuffer++; (*postPaddedBuffer == ':' || *postPaddedBuffer < 33) && *postPaddedBuffer != '\r'; postPaddedBuffer++);
|
||||
preliminaryValue = postPaddedBuffer;
|
||||
postPaddedBuffer = (char *) memchr(postPaddedBuffer, '\r', end - postPaddedBuffer);
|
||||
postPaddedBuffer = (char *) memchr(postPaddedBuffer, '\r', (size_t) (end - postPaddedBuffer));
|
||||
if (postPaddedBuffer && postPaddedBuffer[1] == '\n') {
|
||||
headers->value = std::string_view(preliminaryValue, (size_t) (postPaddedBuffer - preliminaryValue));
|
||||
postPaddedBuffer += 2;
|
||||
|
@ -173,20 +218,34 @@ private:
|
|||
|
||||
// the only caller of getHeaders
|
||||
template <int CONSUME_MINIMALLY>
|
||||
std::pair<int, void *> fenceAndConsumePostPadded(char *data, int length, void *user, HttpRequest *req, fu2::unique_function<void *(void *, HttpRequest *)> &requestHandler, fu2::unique_function<void *(void *, std::string_view, bool)> &dataHandler) {
|
||||
int consumedTotal = 0;
|
||||
std::pair<unsigned int, void *> fenceAndConsumePostPadded(char *data, unsigned int length, void *user, void *reserved, HttpRequest *req, MoveOnlyFunction<void *(void *, HttpRequest *)> &requestHandler, MoveOnlyFunction<void *(void *, std::string_view, bool)> &dataHandler) {
|
||||
|
||||
/* How much data we CONSUMED (to throw away) */
|
||||
unsigned int consumedTotal = 0;
|
||||
|
||||
/* Fence one byte past end of our buffer (buffer has post padded margins) */
|
||||
data[length] = '\r';
|
||||
|
||||
for (int consumed; length && (consumed = getHeaders(data, data + length, req->headers)); ) {
|
||||
for (unsigned int consumed; length && (consumed = getHeaders(data, data + length, req->headers, reserved)); ) {
|
||||
data += consumed;
|
||||
length -= consumed;
|
||||
consumedTotal += consumed;
|
||||
|
||||
req->headers->value = std::string_view(req->headers->value.data(), std::max<int>(0, (int) req->headers->value.length() - 9));
|
||||
/* Store HTTP version (ancient 1.0 or 1.1) */
|
||||
req->ancientHttp = req->headers->value.length() && (req->headers->value[req->headers->value.length() - 1] == '0');
|
||||
|
||||
/* Strip away tail of first "header value" aka URL */
|
||||
req->headers->value = std::string_view(req->headers->value.data(), (size_t) std::max<int>(0, (int) req->headers->value.length() - 9));
|
||||
|
||||
/* Add all headers to bloom filter */
|
||||
req->bf.reset();
|
||||
for (HttpRequest::Header *h = req->headers; (++h)->key.length(); ) {
|
||||
req->bf.add(h->key);
|
||||
}
|
||||
|
||||
/* Parse query */
|
||||
const char *querySeparatorPtr = (const char *) memchr(req->headers->value.data(), '?', req->headers->value.length());
|
||||
req->querySeparator = (int) ((querySeparatorPtr ? querySeparatorPtr : req->headers->value.data() + req->headers->value.length()) - req->headers->value.data());
|
||||
req->querySeparator = (unsigned int) ((querySeparatorPtr ? querySeparatorPtr : req->headers->value.data() + req->headers->value.length()) - req->headers->value.data());
|
||||
|
||||
/* If returned socket is not what we put in we need
|
||||
* to break here as we either have upgraded to
|
||||
|
@ -225,22 +284,18 @@ private:
|
|||
}
|
||||
|
||||
public:
|
||||
void *consumePostPadded(char *data, unsigned int length, void *user, void *reserved, MoveOnlyFunction<void *(void *, HttpRequest *)> &&requestHandler, MoveOnlyFunction<void *(void *, std::string_view, bool)> &&dataHandler, MoveOnlyFunction<void *(void *)> &&errorHandler) {
|
||||
|
||||
/* We do this to prolong the validity of parsed headers by keeping only the fallback buffer alive */
|
||||
std::string &&salvageFallbackBuffer() {
|
||||
return std::move(fallback);
|
||||
}
|
||||
|
||||
void *consumePostPadded(char *data, int length, void *user, fu2::unique_function<void *(void *, HttpRequest *)> &&requestHandler, fu2::unique_function<void *(void *, std::string_view, bool)> &&dataHandler, fu2::unique_function<void *(void *)> &&errorHandler) {
|
||||
|
||||
/* This resets BloomFilter by construction, but later we also reset it again.
|
||||
* Optimize this to skip resetting twice (req could be made global) */
|
||||
HttpRequest req;
|
||||
|
||||
if (remainingStreamingBytes) {
|
||||
|
||||
// this is exactly the same as below!
|
||||
// todo: refactor this
|
||||
if (remainingStreamingBytes >= (unsigned int) length) {
|
||||
void *returnedUser = dataHandler(user, std::string_view(data, length), remainingStreamingBytes == (unsigned int) length);
|
||||
if (remainingStreamingBytes >= length) {
|
||||
void *returnedUser = dataHandler(user, std::string_view(data, length), remainingStreamingBytes == length);
|
||||
remainingStreamingBytes -= length;
|
||||
return returnedUser;
|
||||
} else {
|
||||
|
@ -257,24 +312,26 @@ public:
|
|||
}
|
||||
|
||||
} else if (fallback.length()) {
|
||||
int had = (int) fallback.length();
|
||||
unsigned int had = (unsigned int) fallback.length();
|
||||
|
||||
int maxCopyDistance = (int) std::min(MAX_FALLBACK_SIZE - fallback.length(), (size_t) length);
|
||||
size_t maxCopyDistance = std::min(MAX_FALLBACK_SIZE - fallback.length(), (size_t) length);
|
||||
|
||||
/* We don't want fallback to be short string optimized, since we want to move it */
|
||||
fallback.reserve(fallback.length() + maxCopyDistance + std::max<int>(MINIMUM_HTTP_POST_PADDING, sizeof(std::string)));
|
||||
fallback.reserve(fallback.length() + maxCopyDistance + std::max<unsigned int>(MINIMUM_HTTP_POST_PADDING, sizeof(std::string)));
|
||||
fallback.append(data, maxCopyDistance);
|
||||
|
||||
// break here on break
|
||||
std::pair<int, void *> consumed = fenceAndConsumePostPadded<true>(fallback.data(), (int) fallback.length(), user, &req, requestHandler, dataHandler);
|
||||
std::pair<unsigned int, void *> consumed = fenceAndConsumePostPadded<true>(fallback.data(), (unsigned int) fallback.length(), user, reserved, &req, requestHandler, dataHandler);
|
||||
if (consumed.second != user) {
|
||||
return consumed.second;
|
||||
}
|
||||
|
||||
if (consumed.first) {
|
||||
|
||||
/* This logic assumes that we consumed everything in fallback buffer.
|
||||
* This is critically important, as we will get an integer overflow in case
|
||||
* of "had" being larger than what we consumed, and that we would drop data */
|
||||
fallback.clear();
|
||||
|
||||
data += consumed.first - had;
|
||||
length -= consumed.first - had;
|
||||
|
||||
|
@ -308,7 +365,7 @@ public:
|
|||
}
|
||||
}
|
||||
|
||||
std::pair<int, void *> consumed = fenceAndConsumePostPadded<false>(data, length, user, &req, requestHandler, dataHandler);
|
||||
std::pair<unsigned int, void *> consumed = fenceAndConsumePostPadded<false>(data, length, user, reserved, &req, requestHandler, dataHandler);
|
||||
if (consumed.second != user) {
|
||||
return consumed.second;
|
||||
}
|
||||
|
@ -317,7 +374,7 @@ public:
|
|||
length -= consumed.first;
|
||||
|
||||
if (length) {
|
||||
if ((unsigned int) length < MAX_FALLBACK_SIZE) {
|
||||
if (length < MAX_FALLBACK_SIZE) {
|
||||
fallback.append(data, length);
|
||||
} else {
|
||||
return errorHandler(user);
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Authored by Alex Hultman, 2018-2019.
|
||||
* Authored by Alex Hultman, 2018-2020.
|
||||
* Intellectual property of third-party.
|
||||
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
@ -25,7 +25,12 @@
|
|||
#include "HttpContextData.h"
|
||||
#include "Utilities.h"
|
||||
|
||||
#include "f2/function2.hpp"
|
||||
#include "WebSocketExtensions.h"
|
||||
#include "WebSocketHandshake.h"
|
||||
#include "WebSocket.h"
|
||||
#include "WebSocketContextData.h"
|
||||
|
||||
#include "MoveOnlyFunction.h"
|
||||
|
||||
/* todo: tryWrite is missing currently, only send smaller segments with write */
|
||||
|
||||
|
@ -56,10 +61,10 @@ private:
|
|||
Super::write(buf, length);
|
||||
}
|
||||
|
||||
/* Write an unsigned 32-bit integer */
|
||||
void writeUnsigned(unsigned int value) {
|
||||
char buf[10];
|
||||
int length = utils::u32toa(value, buf);
|
||||
/* Write an unsigned 64-bit integer */
|
||||
void writeUnsigned64(uint64_t value) {
|
||||
char buf[20];
|
||||
int length = utils::u64toa(value, buf);
|
||||
|
||||
/* For now we do this copy */
|
||||
Super::write(buf, length);
|
||||
|
@ -77,23 +82,43 @@ private:
|
|||
|
||||
/* Called only once per request */
|
||||
void writeMark() {
|
||||
/* You can disable this altogether */
|
||||
#ifndef UWS_HTTPRESPONSE_NO_WRITEMARK
|
||||
writeHeader("uWebSockets", "v0.17");
|
||||
if (!Super::getLoopData()->noMark) {
|
||||
/* We only expose major version */
|
||||
writeHeader("uWebSockets", "19");
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
/* Returns true on success, indicating that it might be feasible to write more data.
|
||||
* Will start timeout if stream reaches totalSize or write failure. */
|
||||
bool internalEnd(std::string_view data, int totalSize, bool optional, bool allowContentLength = true) {
|
||||
bool internalEnd(std::string_view data, size_t totalSize, bool optional, bool allowContentLength = true, bool closeConnection = false) {
|
||||
/* Write status if not already done */
|
||||
writeStatus(HTTP_200_OK);
|
||||
|
||||
/* If no total size given then assume this chunk is everything */
|
||||
if (!totalSize) {
|
||||
totalSize = (int) data.length();
|
||||
totalSize = data.length();
|
||||
}
|
||||
|
||||
HttpResponseData<SSL> *httpResponseData = getHttpResponseData();
|
||||
|
||||
/* In some cases, such as when refusing huge data we want to close the connection when drained */
|
||||
if (closeConnection) {
|
||||
|
||||
/* HTTP 1.1 must send this back unless the client already sent it to us.
|
||||
* It is a connection close when either of the two parties say so but the
|
||||
* one party must tell the other one so.
|
||||
*
|
||||
* This check also serves to limit writing the header only once. */
|
||||
if ((httpResponseData->state & HttpResponseData<SSL>::HTTP_CONNECTION_CLOSE) == 0) {
|
||||
writeHeader("Connection", "close");
|
||||
}
|
||||
|
||||
httpResponseData->state |= HttpResponseData<SSL>::HTTP_CONNECTION_CLOSE;
|
||||
}
|
||||
|
||||
if (httpResponseData->state & HttpResponseData<SSL>::HTTP_WRITE_CALLED) {
|
||||
|
||||
/* We do not have tryWrite-like functionalities, so ignore optional in this path */
|
||||
|
@ -126,7 +151,7 @@ private:
|
|||
if (allowContentLength) {
|
||||
/* Even zero is a valid content-length */
|
||||
Super::write("Content-Length: ", 16);
|
||||
writeUnsigned(totalSize);
|
||||
writeUnsigned64(totalSize);
|
||||
Super::write("\r\n\r\n", 4);
|
||||
} else {
|
||||
Super::write("\r\n", 2);
|
||||
|
@ -140,11 +165,20 @@ private:
|
|||
* if it failed to drain any prior failed header writes */
|
||||
|
||||
/* Write as much as possible without causing backpressure */
|
||||
auto [written, failed] = Super::write(data.data(), (int) data.length(), optional);
|
||||
size_t written = 0;
|
||||
bool failed = false;
|
||||
while (written < data.length() && !failed) {
|
||||
/* uSockets only deals with int sizes, so pass chunks of max signed int size */
|
||||
auto writtenFailed = Super::write(data.data() + written, (int) std::min<size_t>(data.length() - written, INT_MAX), optional);
|
||||
|
||||
written += (size_t) writtenFailed.first;
|
||||
failed = writtenFailed.second;
|
||||
}
|
||||
|
||||
httpResponseData->offset += written;
|
||||
|
||||
/* Success is when we wrote the entire thing without any failures */
|
||||
bool success = (unsigned int) written == data.length() && !failed;
|
||||
bool success = written == data.length() && !failed;
|
||||
|
||||
/* If we are now at the end, start a timeout. Also start a timeout if we failed. */
|
||||
if (!success || httpResponseData->offset == totalSize) {
|
||||
|
@ -160,20 +194,140 @@ private:
|
|||
}
|
||||
}
|
||||
|
||||
/* This call is identical to end, but will never write content-length and is thus suitable for upgrades */
|
||||
void upgrade() {
|
||||
internalEnd({nullptr, 0}, 0, false, false);
|
||||
public:
|
||||
/* If we have proxy support; returns the proxed source address as reported by the proxy. */
|
||||
#ifdef UWS_WITH_PROXY
|
||||
std::string_view getProxiedRemoteAddress() {
|
||||
return getHttpResponseData()->proxyParser.getSourceAddress();
|
||||
}
|
||||
|
||||
std::string_view getProxiedRemoteAddressAsText() {
|
||||
return Super::addressAsText(getProxiedRemoteAddress());
|
||||
}
|
||||
#endif
|
||||
|
||||
/* Manually upgrade to WebSocket. Typically called in upgrade handler. Immediately calls open handler.
|
||||
* NOTE: Will invalidate 'this' as socket might change location in memory. Throw away aftert use. */
|
||||
template <typename UserData>
|
||||
void upgrade(UserData &&userData, std::string_view secWebSocketKey, std::string_view secWebSocketProtocol,
|
||||
std::string_view secWebSocketExtensions,
|
||||
struct us_socket_context_t *webSocketContext) {
|
||||
|
||||
/* Extract needed parameters from WebSocketContextData */
|
||||
WebSocketContextData<SSL, UserData> *webSocketContextData = (WebSocketContextData<SSL, UserData> *) us_socket_context_ext(SSL, webSocketContext);
|
||||
|
||||
/* Note: OpenSSL can be used here to speed this up somewhat */
|
||||
char secWebSocketAccept[29] = {};
|
||||
WebSocketHandshake::generate(secWebSocketKey.data(), secWebSocketAccept);
|
||||
|
||||
writeStatus("101 Switching Protocols")
|
||||
->writeHeader("Upgrade", "websocket")
|
||||
->writeHeader("Connection", "Upgrade")
|
||||
->writeHeader("Sec-WebSocket-Accept", secWebSocketAccept);
|
||||
|
||||
/* Select first subprotocol if present */
|
||||
if (secWebSocketProtocol.length()) {
|
||||
writeHeader("Sec-WebSocket-Protocol", secWebSocketProtocol.substr(0, secWebSocketProtocol.find(',')));
|
||||
}
|
||||
|
||||
/* Negotiate compression */
|
||||
bool perMessageDeflate = false;
|
||||
CompressOptions compressOptions = CompressOptions::DISABLED;
|
||||
if (secWebSocketExtensions.length() && webSocketContextData->compression != DISABLED) {
|
||||
|
||||
/* We always want shared inflation */
|
||||
int wantedInflationWindow = 0;
|
||||
|
||||
/* Map from selected compressor */
|
||||
int wantedCompressionWindow = (webSocketContextData->compression & 0xFF00) >> 8;
|
||||
|
||||
auto [negCompression, negCompressionWindow, negInflationWindow, negResponse] =
|
||||
negotiateCompression(true, wantedCompressionWindow, wantedInflationWindow,
|
||||
secWebSocketExtensions);
|
||||
|
||||
if (negCompression) {
|
||||
perMessageDeflate = true;
|
||||
|
||||
/* Map from windowBits to compressor */
|
||||
if (negCompressionWindow == 0) {
|
||||
compressOptions = CompressOptions::SHARED_COMPRESSOR;
|
||||
} else {
|
||||
compressOptions = (CompressOptions) ((uint32_t) (negCompressionWindow << 8)
|
||||
| (uint32_t) (negCompressionWindow - 7));
|
||||
|
||||
/* If we are dedicated and have the 3kb then correct any 4kb to 3kb,
|
||||
* (they both share the windowBits = 9) */
|
||||
if (webSocketContextData->compression == DEDICATED_COMPRESSOR_3KB) {
|
||||
compressOptions = DEDICATED_COMPRESSOR_3KB;
|
||||
}
|
||||
}
|
||||
|
||||
writeHeader("Sec-WebSocket-Extensions", negResponse);
|
||||
}
|
||||
}
|
||||
|
||||
internalEnd({nullptr, 0}, 0, false, false);
|
||||
|
||||
/* Grab the httpContext from res */
|
||||
HttpContext<SSL> *httpContext = (HttpContext<SSL> *) us_socket_context(SSL, (struct us_socket_t *) this);
|
||||
|
||||
/* Move any backpressure out of HttpResponse */
|
||||
std::string backpressure(std::move(((AsyncSocketData<SSL> *) getHttpResponseData())->buffer));
|
||||
|
||||
/* Destroy HttpResponseData */
|
||||
getHttpResponseData()->~HttpResponseData();
|
||||
|
||||
/* Before we adopt and potentially change socket, check if we are corked */
|
||||
bool wasCorked = Super::isCorked();
|
||||
|
||||
/* Adopting a socket invalidates it, do not rely on it directly to carry any data */
|
||||
WebSocket<SSL, true, UserData> *webSocket = (WebSocket<SSL, true, UserData> *) us_socket_context_adopt_socket(SSL,
|
||||
(us_socket_context_t *) webSocketContext, (us_socket_t *) this, sizeof(WebSocketData) + sizeof(UserData));
|
||||
|
||||
/* For whatever reason we were corked, update cork to the new socket */
|
||||
if (wasCorked) {
|
||||
webSocket->AsyncSocket<SSL>::cork();
|
||||
}
|
||||
|
||||
/* Initialize websocket with any moved backpressure intact */
|
||||
webSocket->init(perMessageDeflate, compressOptions, std::move(backpressure));
|
||||
|
||||
/* We should only mark this if inside the parser; if upgrading "async" we cannot set this */
|
||||
HttpContextData<SSL> *httpContextData = httpContext->getSocketContextData();
|
||||
if (httpContextData->isParsingHttp) {
|
||||
/* We need to tell the Http parser that we changed socket */
|
||||
httpContextData->upgradedWebSocket = webSocket;
|
||||
}
|
||||
|
||||
/* Arm idleTimeout */
|
||||
us_socket_timeout(SSL, (us_socket_t *) webSocket, webSocketContextData->idleTimeoutComponents.first);
|
||||
|
||||
/* Move construct the UserData right before calling open handler */
|
||||
new (webSocket->getUserData()) UserData(std::move(userData));
|
||||
|
||||
/* Emit open event and start the timeout */
|
||||
if (webSocketContextData->openHandler) {
|
||||
webSocketContextData->openHandler(webSocket);
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
/* Immediately terminate this Http response */
|
||||
using Super::close;
|
||||
|
||||
/* See AsyncSocket */
|
||||
using Super::getRemoteAddress;
|
||||
using Super::getRemoteAddressAsText;
|
||||
using Super::getNativeHandle;
|
||||
|
||||
/* Note: Headers are not checked in regards to timeout.
|
||||
* We only check when you actively push data or end the request */
|
||||
|
||||
/* Write 100 Continue, can be done any amount of times */
|
||||
HttpResponse *writeContinue() {
|
||||
Super::write("HTTP/1.1 100 Continue\r\n\r\n", 25);
|
||||
return this;
|
||||
}
|
||||
|
||||
/* Write the HTTP status */
|
||||
HttpResponse *writeStatus(std::string_view status) {
|
||||
HttpResponseData<SSL> *httpResponseData = getHttpResponseData();
|
||||
|
@ -204,22 +358,24 @@ public:
|
|||
}
|
||||
|
||||
/* Write an HTTP header with unsigned int value */
|
||||
HttpResponse *writeHeader(std::string_view key, unsigned int value) {
|
||||
HttpResponse *writeHeader(std::string_view key, uint64_t value) {
|
||||
writeStatus(HTTP_200_OK);
|
||||
|
||||
Super::write(key.data(), (int) key.length());
|
||||
Super::write(": ", 2);
|
||||
writeUnsigned(value);
|
||||
writeUnsigned64(value);
|
||||
Super::write("\r\n", 2);
|
||||
return this;
|
||||
}
|
||||
|
||||
/* End the response with an optional data chunk. Always starts a timeout. */
|
||||
void end(std::string_view data = {}) {
|
||||
internalEnd(data, (int) data.length(), false);
|
||||
void end(std::string_view data = {}, bool closeConnection = false) {
|
||||
internalEnd(data, data.length(), false, true, closeConnection);
|
||||
}
|
||||
|
||||
/* Try and end the response. Returns [true, true] on success.
|
||||
* Starts a timeout in some cases. Returns [ok, hasResponded] */
|
||||
std::pair<bool, bool> tryEnd(std::string_view data, int totalSize = 0) {
|
||||
std::pair<bool, bool> tryEnd(std::string_view data, size_t totalSize = 0) {
|
||||
return {internalEnd(data, totalSize, true), hasResponded()};
|
||||
}
|
||||
|
||||
|
@ -257,7 +413,7 @@ public:
|
|||
}
|
||||
|
||||
/* Get the current byte write offset for this Http response */
|
||||
int getWriteOffset() {
|
||||
size_t getWriteOffset() {
|
||||
HttpResponseData<SSL> *httpResponseData = getHttpResponseData();
|
||||
|
||||
return httpResponseData->offset;
|
||||
|
@ -271,7 +427,7 @@ public:
|
|||
}
|
||||
|
||||
/* Corks the response if possible. Leaves already corked socket be. */
|
||||
HttpResponse *cork(fu2::unique_function<void()> &&handler) {
|
||||
HttpResponse *cork(MoveOnlyFunction<void()> &&handler) {
|
||||
if (!Super::isCorked() && Super::canCork()) {
|
||||
Super::cork();
|
||||
handler();
|
||||
|
@ -292,7 +448,7 @@ public:
|
|||
}
|
||||
|
||||
/* Attach handler for writable HTTP response */
|
||||
HttpResponse *onWritable(fu2::unique_function<bool(int)> &&handler) {
|
||||
HttpResponse *onWritable(MoveOnlyFunction<bool(size_t)> &&handler) {
|
||||
HttpResponseData<SSL> *httpResponseData = getHttpResponseData();
|
||||
|
||||
httpResponseData->onWritable = std::move(handler);
|
||||
|
@ -300,7 +456,7 @@ public:
|
|||
}
|
||||
|
||||
/* Attach handler for aborted HTTP request */
|
||||
HttpResponse *onAborted(fu2::unique_function<void()> &&handler) {
|
||||
HttpResponse *onAborted(MoveOnlyFunction<void()> &&handler) {
|
||||
HttpResponseData<SSL> *httpResponseData = getHttpResponseData();
|
||||
|
||||
httpResponseData->onAborted = std::move(handler);
|
||||
|
@ -308,7 +464,7 @@ public:
|
|||
}
|
||||
|
||||
/* Attach a read handler for data sent. Will be called with FIN set true if last segment. */
|
||||
void onData(fu2::unique_function<void(std::string_view, bool)> &&handler) {
|
||||
void onData(MoveOnlyFunction<void(std::string_view, bool)> &&handler) {
|
||||
HttpResponseData<SSL> *data = getHttpResponseData();
|
||||
data->inStream = std::move(handler);
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Authored by Alex Hultman, 2018-2019.
|
||||
* Authored by Alex Hultman, 2018-2020.
|
||||
* Intellectual property of third-party.
|
||||
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
@ -22,8 +22,9 @@
|
|||
|
||||
#include "HttpParser.h"
|
||||
#include "AsyncSocketData.h"
|
||||
#include "ProxyParser.h"
|
||||
|
||||
#include "f2/function2.hpp"
|
||||
#include "MoveOnlyFunction.h"
|
||||
|
||||
namespace uWS {
|
||||
|
||||
|
@ -38,18 +39,22 @@ private:
|
|||
HTTP_WRITE_CALLED = 2, // used
|
||||
HTTP_END_CALLED = 4, // used
|
||||
HTTP_RESPONSE_PENDING = 8, // used
|
||||
HTTP_ENDED_STREAM_OUT = 16 // not used
|
||||
HTTP_CONNECTION_CLOSE = 16 // used
|
||||
};
|
||||
|
||||
/* Per socket event handlers */
|
||||
fu2::unique_function<bool(int)> onWritable;
|
||||
fu2::unique_function<void()> onAborted;
|
||||
fu2::unique_function<void(std::string_view, bool)> inStream; // onData
|
||||
MoveOnlyFunction<bool(size_t)> onWritable;
|
||||
MoveOnlyFunction<void()> onAborted;
|
||||
MoveOnlyFunction<void(std::string_view, bool)> inStream; // onData
|
||||
/* Outgoing offset */
|
||||
int offset = 0;
|
||||
size_t offset = 0;
|
||||
|
||||
/* Current state (content-length sent, status sent, write called, etc */
|
||||
int state = 0;
|
||||
|
||||
#ifdef UWS_WITH_PROXY
|
||||
ProxyParser proxyParser;
|
||||
#endif
|
||||
};
|
||||
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Authored by Alex Hultman, 2018-2019.
|
||||
* Authored by Alex Hultman, 2018-2020.
|
||||
* Intellectual property of third-party.
|
||||
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
@ -25,8 +25,9 @@
|
|||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
|
||||
#include "f2/function2.hpp"
|
||||
#include "MoveOnlyFunction.h"
|
||||
|
||||
namespace uWS {
|
||||
|
||||
|
@ -47,7 +48,7 @@ private:
|
|||
std::map<std::string, int> priority;
|
||||
|
||||
/* List of handlers */
|
||||
std::vector<fu2::unique_function<bool(HttpRouter *)>> handlers;
|
||||
std::vector<MoveOnlyFunction<bool(HttpRouter *)>> handlers;
|
||||
|
||||
/* Current URL cache */
|
||||
std::string_view currentUrl;
|
||||
|
@ -60,6 +61,8 @@ private:
|
|||
std::vector<std::unique_ptr<Node>> children;
|
||||
std::vector<uint32_t> handlers;
|
||||
bool isHighPriority;
|
||||
|
||||
Node(std::string name) : name(name) {}
|
||||
} root = {"rootNode"};
|
||||
|
||||
/* Advance from parent to child, adding child if necessary */
|
||||
|
@ -71,7 +74,7 @@ private:
|
|||
}
|
||||
|
||||
/* Insert sorted, but keep order if parent is root (we sort methods by priority elsewhere) */
|
||||
std::unique_ptr<Node> newNode(new Node({child}));
|
||||
std::unique_ptr<Node> newNode(new Node(child));
|
||||
newNode->isHighPriority = isHighPriority;
|
||||
return parent->children.emplace(std::upper_bound(parent->children.begin(), parent->children.end(), newNode, [parent, this](auto &a, auto &b) {
|
||||
|
||||
|
@ -107,19 +110,25 @@ private:
|
|||
|
||||
/* Set URL for router. Will reset any URL cache */
|
||||
inline void setUrl(std::string_view url) {
|
||||
/* Remove / from input URL */
|
||||
currentUrl = url.substr(std::min<unsigned int>((unsigned int) url.length(), 1));
|
||||
|
||||
/* Todo: URL may also start with "http://domain/" or "*", not only "/" */
|
||||
|
||||
/* We expect to stand on a slash */
|
||||
currentUrl = url;
|
||||
urlSegmentTop = -1;
|
||||
}
|
||||
|
||||
/* Lazily parse or read from cache */
|
||||
inline std::string_view getUrlSegment(int urlSegment) {
|
||||
inline std::pair<std::string_view, bool> getUrlSegment(int urlSegment) {
|
||||
if (urlSegment > urlSegmentTop) {
|
||||
/* Return empty segment if we are out of URL or stack space, but never for first url segment */
|
||||
/* Signal as STOP when we have no more URL or stack space */
|
||||
if (!currentUrl.length() || urlSegment > 99) {
|
||||
return {};
|
||||
return {{}, true};
|
||||
}
|
||||
|
||||
/* We always stand on a slash here, so step over it */
|
||||
currentUrl.remove_prefix(1);
|
||||
|
||||
auto segmentLength = currentUrl.find('/');
|
||||
if (segmentLength == std::string::npos) {
|
||||
segmentLength = currentUrl.length();
|
||||
|
@ -136,19 +145,22 @@ private:
|
|||
urlSegmentTop++;
|
||||
|
||||
/* Update currentUrl */
|
||||
currentUrl = currentUrl.substr(segmentLength + 1);
|
||||
currentUrl = currentUrl.substr(segmentLength);
|
||||
}
|
||||
}
|
||||
/* In any case we return it */
|
||||
return urlSegmentVector[urlSegment];
|
||||
return {urlSegmentVector[urlSegment], false};
|
||||
}
|
||||
|
||||
/* Executes as many handlers it can */
|
||||
bool executeHandlers(Node *parent, int urlSegment, USERDATA &userData) {
|
||||
/* If we have no more URL and not on first round, return where we may stand */
|
||||
if (urlSegment && !getUrlSegment(urlSegment).length()) {
|
||||
|
||||
auto [segment, isStop] = getUrlSegment(urlSegment);
|
||||
|
||||
/* If we are on STOP, return where we may stand */
|
||||
if (isStop) {
|
||||
/* We have reached accross the entire URL with no stoppage, execute */
|
||||
for (int handler : parent->handlers) {
|
||||
for (uint32_t handler : parent->handlers) {
|
||||
if (handlers[handler & HANDLER_MASK](this)) {
|
||||
return true;
|
||||
}
|
||||
|
@ -160,19 +172,19 @@ private:
|
|||
for (auto &p : parent->children) {
|
||||
if (p->name.length() && p->name[0] == '*') {
|
||||
/* Wildcard match (can be seen as a shortcut) */
|
||||
for (int handler : p->handlers) {
|
||||
for (uint32_t handler : p->handlers) {
|
||||
if (handlers[handler & HANDLER_MASK](this)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
} else if (p->name.length() && p->name[0] == ':' && getUrlSegment(urlSegment).length()) {
|
||||
} else if (p->name.length() && p->name[0] == ':' && segment.length()) {
|
||||
/* Parameter match */
|
||||
routeParameters.push(getUrlSegment(urlSegment));
|
||||
routeParameters.push(segment);
|
||||
if (executeHandlers(p.get(), urlSegment + 1, userData)) {
|
||||
return true;
|
||||
}
|
||||
routeParameters.pop();
|
||||
} else if (p->name == getUrlSegment(urlSegment)) {
|
||||
} else if (p->name == segment) {
|
||||
/* Static match */
|
||||
if (executeHandlers(p.get(), urlSegment + 1, userData)) {
|
||||
return true;
|
||||
|
@ -217,14 +229,14 @@ public:
|
|||
}
|
||||
|
||||
/* Adds the corresponding entires in matching tree and handler list */
|
||||
void add(std::vector<std::string> methods, std::string pattern, fu2::unique_function<bool(HttpRouter *)> &&handler, uint32_t priority = MEDIUM_PRIORITY) {
|
||||
void add(std::vector<std::string> methods, std::string pattern, MoveOnlyFunction<bool(HttpRouter *)> &&handler, uint32_t priority = MEDIUM_PRIORITY) {
|
||||
for (std::string method : methods) {
|
||||
/* Lookup method */
|
||||
Node *node = getNode(&root, method, false);
|
||||
/* Iterate over all segments */
|
||||
setUrl(pattern);
|
||||
for (int i = 0; getUrlSegment(i).length() || i == 0; i++) {
|
||||
node = getNode(node, std::string(getUrlSegment(i)), priority == HIGH_PRIORITY);
|
||||
for (int i = 0; !getUrlSegment(i).second; i++) {
|
||||
node = getNode(node, std::string(getUrlSegment(i).first), priority == HIGH_PRIORITY);
|
||||
}
|
||||
/* Insert handler in order sorted by priority (most significant 1 byte) */
|
||||
node->handlers.insert(std::upper_bound(node->handlers.begin(), node->handlers.end(), (uint32_t) (priority | handlers.size())), (uint32_t) (priority | handlers.size()));
|
||||
|
@ -237,4 +249,4 @@ public:
|
|||
|
||||
}
|
||||
|
||||
#endif // UWS_HTTPROUTER_HPP
|
||||
#endif // UWS_HTTPROUTER_HPP
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Authored by Alex Hultman, 2018-2019.
|
||||
* Authored by Alex Hultman, 2018-2020.
|
||||
* Intellectual property of third-party.
|
||||
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
@ -18,10 +18,10 @@
|
|||
#ifndef UWS_LOOP_H
|
||||
#define UWS_LOOP_H
|
||||
|
||||
/* The loop is lazily created per-thread and run with uWS::run() */
|
||||
/* The loop is lazily created per-thread and run with run() */
|
||||
|
||||
#include "LoopData.h"
|
||||
#include "libusockets.h"
|
||||
#include <libusockets.h>
|
||||
|
||||
namespace uWS {
|
||||
struct Loop {
|
||||
|
@ -80,25 +80,29 @@ private:
|
|||
Loop *loop = nullptr;
|
||||
bool cleanMe = false;
|
||||
};
|
||||
|
||||
|
||||
static LoopCleaner &getLazyLoop() {
|
||||
static thread_local LoopCleaner lazyLoop;
|
||||
return lazyLoop;
|
||||
}
|
||||
|
||||
public:
|
||||
/* Lazily initializes a per-thread loop and returns it.
|
||||
* Will automatically free all initialized loops at exit. */
|
||||
static Loop *get(void *existingNativeLoop = nullptr) {
|
||||
static thread_local LoopCleaner lazyLoop;
|
||||
if (!lazyLoop.loop) {
|
||||
if (!getLazyLoop().loop) {
|
||||
/* If we are given a native loop pointer we pass that to uSockets and let it deal with it */
|
||||
if (existingNativeLoop) {
|
||||
/* Todo: here we want to pass the pointer, not a boolean */
|
||||
lazyLoop.loop = create(existingNativeLoop);
|
||||
getLazyLoop().loop = create(existingNativeLoop);
|
||||
/* We cannot register automatic free here, must be manually done */
|
||||
} else {
|
||||
lazyLoop.loop = create(nullptr);
|
||||
lazyLoop.cleanMe = true;
|
||||
getLazyLoop().loop = create(nullptr);
|
||||
getLazyLoop().cleanMe = true;
|
||||
}
|
||||
}
|
||||
|
||||
return lazyLoop.loop;
|
||||
return getLazyLoop().loop;
|
||||
}
|
||||
|
||||
/* Freeing the default loop should be done once */
|
||||
|
@ -107,9 +111,12 @@ public:
|
|||
loopData->~LoopData();
|
||||
/* uSockets will track whether this loop is owned by us or a borrowed alien loop */
|
||||
us_loop_free((us_loop_t *) this);
|
||||
|
||||
/* Reset lazyLoop */
|
||||
getLazyLoop().loop = nullptr;
|
||||
}
|
||||
|
||||
void addPostHandler(void *key, fu2::unique_function<void(Loop *)> &&handler) {
|
||||
void addPostHandler(void *key, MoveOnlyFunction<void(Loop *)> &&handler) {
|
||||
LoopData *loopData = (LoopData *) us_loop_ext((us_loop_t *) this);
|
||||
|
||||
loopData->postHandlers.emplace(key, std::move(handler));
|
||||
|
@ -122,7 +129,7 @@ public:
|
|||
loopData->postHandlers.erase(key);
|
||||
}
|
||||
|
||||
void addPreHandler(void *key, fu2::unique_function<void(Loop *)> &&handler) {
|
||||
void addPreHandler(void *key, MoveOnlyFunction<void(Loop *)> &&handler) {
|
||||
LoopData *loopData = (LoopData *) us_loop_ext((us_loop_t *) this);
|
||||
|
||||
loopData->preHandlers.emplace(key, std::move(handler));
|
||||
|
@ -136,7 +143,7 @@ public:
|
|||
}
|
||||
|
||||
/* Defer this callback on Loop's thread of execution */
|
||||
void defer(fu2::unique_function<void()> &&cb) {
|
||||
void defer(MoveOnlyFunction<void()> &&cb) {
|
||||
LoopData *loopData = (LoopData *) us_loop_ext((us_loop_t *) this);
|
||||
|
||||
//if (std::thread::get_id() == ) // todo: add fast path for same thread id
|
||||
|
@ -157,6 +164,11 @@ public:
|
|||
void integrate() {
|
||||
us_loop_integrate((us_loop_t *) this);
|
||||
}
|
||||
|
||||
/* Dynamically change this */
|
||||
void setSilent(bool silent) {
|
||||
((LoopData *) us_loop_ext((us_loop_t *) this))->noMark = silent;
|
||||
}
|
||||
};
|
||||
|
||||
/* Can be called from any thread to run the thread local loop */
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Authored by Alex Hultman, 2018-2019.
|
||||
* Authored by Alex Hultman, 2018-2020.
|
||||
* Intellectual property of third-party.
|
||||
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
@ -26,7 +26,7 @@
|
|||
|
||||
#include "PerMessageDeflate.h"
|
||||
|
||||
#include "f2/function2.hpp"
|
||||
#include "MoveOnlyFunction.h"
|
||||
|
||||
namespace uWS {
|
||||
|
||||
|
@ -37,10 +37,10 @@ struct alignas(16) LoopData {
|
|||
private:
|
||||
std::mutex deferMutex;
|
||||
int currentDeferQueue = 0;
|
||||
std::vector<fu2::unique_function<void()>> deferQueues[2];
|
||||
std::vector<MoveOnlyFunction<void()>> deferQueues[2];
|
||||
|
||||
/* Map from void ptr to handler */
|
||||
std::map<void *, fu2::unique_function<void(Loop *)>> postHandlers, preHandlers;
|
||||
std::map<void *, MoveOnlyFunction<void(Loop *)>> postHandlers, preHandlers;
|
||||
|
||||
public:
|
||||
~LoopData() {
|
||||
|
@ -53,12 +53,15 @@ public:
|
|||
delete [] corkBuffer;
|
||||
}
|
||||
|
||||
/* Be silent */
|
||||
bool noMark = false;
|
||||
|
||||
/* Good 16k for SSL perf. */
|
||||
static const int CORK_BUFFER_SIZE = 16 * 1024;
|
||||
static const unsigned int CORK_BUFFER_SIZE = 16 * 1024;
|
||||
|
||||
/* Cork data */
|
||||
char *corkBuffer = new char[CORK_BUFFER_SIZE];
|
||||
int corkOffset = 0;
|
||||
unsigned int corkOffset = 0;
|
||||
void *corkedSocket = nullptr;
|
||||
|
||||
/* Per message deflate data */
|
||||
|
|
|
@ -0,0 +1,64 @@
|
|||
/*
|
||||
* Authored by Alex Hultman, 2018-2020.
|
||||
* Intellectual property of third-party.
|
||||
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
/* Implements the common parser (RFC 822) used in both HTTP and Multipart parsing */
|
||||
|
||||
#ifndef UWS_MESSAGE_PARSER_H
|
||||
#define UWS_MESSAGE_PARSER_H
|
||||
|
||||
#include <string_view>
|
||||
#include <utility>
|
||||
#include <cstring>
|
||||
|
||||
/* For now we have this one here */
|
||||
#define MAX_HEADERS 10
|
||||
|
||||
namespace uWS {
|
||||
|
||||
// should be templated on whether it needs at lest one header (http), or not (multipart)
|
||||
static inline unsigned int getHeaders(char *postPaddedBuffer, char *end, std::pair<std::string_view, std::string_view> *headers) {
|
||||
char *preliminaryKey, *preliminaryValue, *start = postPaddedBuffer;
|
||||
|
||||
for (unsigned int i = 0; i < MAX_HEADERS; i++) {
|
||||
for (preliminaryKey = postPaddedBuffer; (*postPaddedBuffer != ':') & (*postPaddedBuffer > 32); *(postPaddedBuffer++) |= 32);
|
||||
if (*postPaddedBuffer == '\r') {
|
||||
if ((postPaddedBuffer != end) & (postPaddedBuffer[1] == '\n') /* & (i > 0) */) { // multipart does not require any headers like http does
|
||||
headers->first = std::string_view(nullptr, 0);
|
||||
return (unsigned int) ((postPaddedBuffer + 2) - start);
|
||||
} else {
|
||||
return 0;
|
||||
}
|
||||
} else {
|
||||
headers->first = std::string_view(preliminaryKey, (size_t) (postPaddedBuffer - preliminaryKey));
|
||||
for (postPaddedBuffer++; (*postPaddedBuffer == ':' || *postPaddedBuffer < 33) && *postPaddedBuffer != '\r'; postPaddedBuffer++);
|
||||
preliminaryValue = postPaddedBuffer;
|
||||
postPaddedBuffer = (char *) memchr(postPaddedBuffer, '\r', end - postPaddedBuffer);
|
||||
if (postPaddedBuffer && postPaddedBuffer[1] == '\n') {
|
||||
headers->second = std::string_view(preliminaryValue, (size_t) (postPaddedBuffer - preliminaryValue));
|
||||
postPaddedBuffer += 2;
|
||||
headers++;
|
||||
} else {
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
#endif
|
|
@ -0,0 +1,377 @@
|
|||
/*
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2020 Oleg Fatkhiev
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
*/
|
||||
|
||||
/* Sources fetched from https://github.com/ofats/any_invocable on 2021-02-19. */
|
||||
|
||||
#ifndef _ANY_INVOKABLE_H_
|
||||
#define _ANY_INVOKABLE_H_
|
||||
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <type_traits>
|
||||
|
||||
// clang-format off
|
||||
/*
|
||||
namespace std {
|
||||
template<class Sig> class any_invocable; // never defined
|
||||
|
||||
template<class R, class... ArgTypes>
|
||||
class any_invocable<R(ArgTypes...) cv ref noexcept(noex)> {
|
||||
public:
|
||||
using result_type = R;
|
||||
|
||||
// SECTION.3, construct/copy/destroy
|
||||
any_invocable() noexcept;
|
||||
any_invocable(nullptr_t) noexcept;
|
||||
any_invocable(any_invocable&&) noexcept;
|
||||
template<class F> any_invocable(F&&);
|
||||
|
||||
template<class T, class... Args>
|
||||
explicit any_invocable(in_place_type_t<T>, Args&&...);
|
||||
template<class T, class U, class... Args>
|
||||
explicit any_invocable(in_place_type_t<T>, initializer_list<U>, Args&&...);
|
||||
|
||||
any_invocable& operator=(any_invocable&&) noexcept;
|
||||
any_invocable& operator=(nullptr_t) noexcept;
|
||||
template<class F> any_invocable& operator=(F&&);
|
||||
template<class F> any_invocable& operator=(reference_wrapper<F>) noexcept;
|
||||
|
||||
~any_invocable();
|
||||
|
||||
// SECTION.4, any_invocable modifiers
|
||||
void swap(any_invocable&) noexcept;
|
||||
|
||||
// SECTION.5, any_invocable capacity
|
||||
explicit operator bool() const noexcept;
|
||||
|
||||
// SECTION.6, any_invocable invocation
|
||||
R operator()(ArgTypes...) cv ref noexcept(noex);
|
||||
|
||||
// SECTION.7, null pointer comparisons
|
||||
friend bool operator==(const any_invocable&, nullptr_t) noexcept;
|
||||
|
||||
// SECTION.8, specialized algorithms
|
||||
friend void swap(any_invocable&, any_invocable&) noexcept;
|
||||
};
|
||||
}
|
||||
*/
|
||||
// clang-format on
|
||||
|
||||
namespace ofats {
|
||||
|
||||
namespace any_detail {
|
||||
|
||||
using buffer = std::aligned_storage_t<sizeof(void*) * 2, alignof(void*)>;
|
||||
|
||||
template <class T>
|
||||
inline constexpr bool is_small_object_v =
|
||||
sizeof(T) <= sizeof(buffer) && alignof(buffer) % alignof(T) == 0 &&
|
||||
std::is_nothrow_move_constructible_v<T>;
|
||||
|
||||
union storage {
|
||||
void* ptr_ = nullptr;
|
||||
buffer buf_;
|
||||
};
|
||||
|
||||
enum class action { destroy, move };
|
||||
|
||||
template <class R, class... ArgTypes>
|
||||
struct handler_traits {
|
||||
template <class Derived>
|
||||
struct handler_base {
|
||||
static void handle(action act, storage* current, storage* other = nullptr) {
|
||||
switch (act) {
|
||||
case (action::destroy):
|
||||
Derived::destroy(*current);
|
||||
break;
|
||||
case (action::move):
|
||||
Derived::move(*current, *other);
|
||||
break;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <class T>
|
||||
struct small_handler : handler_base<small_handler<T>> {
|
||||
template <class... Args>
|
||||
static void create(storage& s, Args&&... args) {
|
||||
new (static_cast<void*>(&s.buf_)) T(std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
static void destroy(storage& s) noexcept {
|
||||
T& value = *static_cast<T*>(static_cast<void*>(&s.buf_));
|
||||
value.~T();
|
||||
}
|
||||
|
||||
static void move(storage& dst, storage& src) noexcept {
|
||||
create(dst, std::move(*static_cast<T*>(static_cast<void*>(&src.buf_))));
|
||||
destroy(src);
|
||||
}
|
||||
|
||||
static R call(storage& s, ArgTypes... args) {
|
||||
return std::invoke(*static_cast<T*>(static_cast<void*>(&s.buf_)),
|
||||
std::forward<ArgTypes>(args)...);
|
||||
}
|
||||
};
|
||||
|
||||
template <class T>
|
||||
struct large_handler : handler_base<large_handler<T>> {
|
||||
template <class... Args>
|
||||
static void create(storage& s, Args&&... args) {
|
||||
s.ptr_ = new T(std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
static void destroy(storage& s) noexcept { delete static_cast<T*>(s.ptr_); }
|
||||
|
||||
static void move(storage& dst, storage& src) noexcept {
|
||||
dst.ptr_ = src.ptr_;
|
||||
}
|
||||
|
||||
static R call(storage& s, ArgTypes... args) {
|
||||
return std::invoke(*static_cast<T*>(s.ptr_),
|
||||
std::forward<ArgTypes>(args)...);
|
||||
}
|
||||
};
|
||||
|
||||
template <class T>
|
||||
using handler = std::conditional_t<is_small_object_v<T>, small_handler<T>,
|
||||
large_handler<T>>;
|
||||
};
|
||||
|
||||
template <class T>
|
||||
struct is_in_place_type : std::false_type {};
|
||||
|
||||
template <class T>
|
||||
struct is_in_place_type<std::in_place_type_t<T>> : std::true_type {};
|
||||
|
||||
template <class T>
|
||||
inline constexpr auto is_in_place_type_v = is_in_place_type<T>::value;
|
||||
|
||||
template <class R, bool is_noexcept, class... ArgTypes>
|
||||
class any_invocable_impl {
|
||||
template <class T>
|
||||
using handler =
|
||||
typename any_detail::handler_traits<R, ArgTypes...>::template handler<T>;
|
||||
|
||||
using storage = any_detail::storage;
|
||||
using action = any_detail::action;
|
||||
using handle_func = void (*)(any_detail::action, any_detail::storage*,
|
||||
any_detail::storage*);
|
||||
using call_func = R (*)(any_detail::storage&, ArgTypes...);
|
||||
|
||||
public:
|
||||
using result_type = R;
|
||||
|
||||
any_invocable_impl() noexcept = default;
|
||||
any_invocable_impl(std::nullptr_t) noexcept {}
|
||||
any_invocable_impl(any_invocable_impl&& rhs) noexcept {
|
||||
if (rhs.handle_) {
|
||||
handle_ = rhs.handle_;
|
||||
handle_(action::move, &storage_, &rhs.storage_);
|
||||
call_ = rhs.call_;
|
||||
rhs.handle_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
any_invocable_impl& operator=(any_invocable_impl&& rhs) noexcept {
|
||||
any_invocable_impl{std::move(rhs)}.swap(*this);
|
||||
return *this;
|
||||
}
|
||||
any_invocable_impl& operator=(std::nullptr_t) noexcept {
|
||||
destroy();
|
||||
return *this;
|
||||
}
|
||||
|
||||
~any_invocable_impl() { destroy(); }
|
||||
|
||||
void swap(any_invocable_impl& rhs) noexcept {
|
||||
if (handle_) {
|
||||
if (rhs.handle_) {
|
||||
storage tmp;
|
||||
handle_(action::move, &tmp, &storage_);
|
||||
rhs.handle_(action::move, &storage_, &rhs.storage_);
|
||||
handle_(action::move, &rhs.storage_, &tmp);
|
||||
std::swap(handle_, rhs.handle_);
|
||||
std::swap(call_, rhs.call_);
|
||||
} else {
|
||||
rhs.swap(*this);
|
||||
}
|
||||
} else if (rhs.handle_) {
|
||||
rhs.handle_(action::move, &storage_, &rhs.storage_);
|
||||
handle_ = rhs.handle_;
|
||||
call_ = rhs.call_;
|
||||
rhs.handle_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
explicit operator bool() const noexcept { return handle_ != nullptr; }
|
||||
|
||||
protected:
|
||||
template <class F, class... Args>
|
||||
void create(Args&&... args) {
|
||||
using hdl = handler<F>;
|
||||
hdl::create(storage_, std::forward<Args>(args)...);
|
||||
handle_ = &hdl::handle;
|
||||
call_ = &hdl::call;
|
||||
}
|
||||
|
||||
void destroy() noexcept {
|
||||
if (handle_) {
|
||||
handle_(action::destroy, &storage_, nullptr);
|
||||
handle_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
R call(ArgTypes... args) noexcept(is_noexcept) {
|
||||
return call_(storage_, std::forward<ArgTypes>(args)...);
|
||||
}
|
||||
|
||||
friend bool operator==(const any_invocable_impl& f, std::nullptr_t) noexcept {
|
||||
return !f;
|
||||
}
|
||||
friend bool operator==(std::nullptr_t, const any_invocable_impl& f) noexcept {
|
||||
return !f;
|
||||
}
|
||||
friend bool operator!=(const any_invocable_impl& f, std::nullptr_t) noexcept {
|
||||
return static_cast<bool>(f);
|
||||
}
|
||||
friend bool operator!=(std::nullptr_t, const any_invocable_impl& f) noexcept {
|
||||
return static_cast<bool>(f);
|
||||
}
|
||||
|
||||
friend void swap(any_invocable_impl& lhs, any_invocable_impl& rhs) noexcept {
|
||||
lhs.swap(rhs);
|
||||
}
|
||||
|
||||
private:
|
||||
storage storage_;
|
||||
handle_func handle_ = nullptr;
|
||||
call_func call_;
|
||||
};
|
||||
|
||||
template <class T>
|
||||
using remove_cvref_t = std::remove_cv_t<std::remove_reference_t<T>>;
|
||||
|
||||
template <class AI, class F, bool noex, class R, class FCall, class... ArgTypes>
|
||||
using can_convert = std::conjunction<
|
||||
std::negation<std::is_same<remove_cvref_t<F>, AI>>,
|
||||
std::negation<any_detail::is_in_place_type<remove_cvref_t<F>>>,
|
||||
std::is_invocable_r<R, FCall, ArgTypes...>,
|
||||
std::bool_constant<(!noex ||
|
||||
std::is_nothrow_invocable_r_v<R, FCall, ArgTypes...>)>,
|
||||
std::is_constructible<std::decay_t<F>, F>>;
|
||||
|
||||
} // namespace any_detail
|
||||
|
||||
template <class Signature>
|
||||
class any_invocable;
|
||||
|
||||
#define __OFATS_ANY_INVOCABLE(cv, ref, noex, inv_quals) \
|
||||
template <class R, class... ArgTypes> \
|
||||
class any_invocable<R(ArgTypes...) cv ref noexcept(noex)> \
|
||||
: public any_detail::any_invocable_impl<R, noex, ArgTypes...> { \
|
||||
using base_type = any_detail::any_invocable_impl<R, noex, ArgTypes...>; \
|
||||
\
|
||||
public: \
|
||||
using base_type::base_type; \
|
||||
\
|
||||
template < \
|
||||
class F, \
|
||||
class = std::enable_if_t<any_detail::can_convert< \
|
||||
any_invocable, F, noex, R, F inv_quals, ArgTypes...>::value>> \
|
||||
any_invocable(F&& f) { \
|
||||
base_type::template create<std::decay_t<F>>(std::forward<F>(f)); \
|
||||
} \
|
||||
\
|
||||
template <class T, class... Args, class VT = std::decay_t<T>, \
|
||||
class = std::enable_if_t< \
|
||||
std::is_move_constructible_v<VT> && \
|
||||
std::is_constructible_v<VT, Args...> && \
|
||||
std::is_invocable_r_v<R, VT inv_quals, ArgTypes...> && \
|
||||
(!noex || std::is_nothrow_invocable_r_v<R, VT inv_quals, \
|
||||
ArgTypes...>)>> \
|
||||
explicit any_invocable(std::in_place_type_t<T>, Args&&... args) { \
|
||||
base_type::template create<VT>(std::forward<Args>(args)...); \
|
||||
} \
|
||||
\
|
||||
template < \
|
||||
class T, class U, class... Args, class VT = std::decay_t<T>, \
|
||||
class = std::enable_if_t< \
|
||||
std::is_move_constructible_v<VT> && \
|
||||
std::is_constructible_v<VT, std::initializer_list<U>&, Args...> && \
|
||||
std::is_invocable_r_v<R, VT inv_quals, ArgTypes...> && \
|
||||
(!noex || \
|
||||
std::is_nothrow_invocable_r_v<R, VT inv_quals, ArgTypes...>)>> \
|
||||
explicit any_invocable(std::in_place_type_t<T>, \
|
||||
std::initializer_list<U> il, Args&&... args) { \
|
||||
base_type::template create<VT>(il, std::forward<Args>(args)...); \
|
||||
} \
|
||||
\
|
||||
template <class F, class FDec = std::decay_t<F>> \
|
||||
std::enable_if_t<!std::is_same_v<FDec, any_invocable> && \
|
||||
std::is_move_constructible_v<FDec>, \
|
||||
any_invocable&> \
|
||||
operator=(F&& f) { \
|
||||
any_invocable{std::forward<F>(f)}.swap(*this); \
|
||||
return *this; \
|
||||
} \
|
||||
template <class F> \
|
||||
any_invocable& operator=(std::reference_wrapper<F> f) { \
|
||||
any_invocable{f}.swap(*this); \
|
||||
return *this; \
|
||||
} \
|
||||
\
|
||||
R operator()(ArgTypes... args) cv ref noexcept(noex) { \
|
||||
return base_type::call(std::forward<ArgTypes>(args)...); \
|
||||
} \
|
||||
};
|
||||
|
||||
// cv -> {`empty`, const}
|
||||
// ref -> {`empty`, &, &&}
|
||||
// noex -> {true, false}
|
||||
// inv_quals -> (is_empty(ref) ? & : ref)
|
||||
__OFATS_ANY_INVOCABLE(, , false, &) // 000
|
||||
__OFATS_ANY_INVOCABLE(, , true, &) // 001
|
||||
__OFATS_ANY_INVOCABLE(, &, false, &) // 010
|
||||
__OFATS_ANY_INVOCABLE(, &, true, &) // 011
|
||||
__OFATS_ANY_INVOCABLE(, &&, false, &&) // 020
|
||||
__OFATS_ANY_INVOCABLE(, &&, true, &&) // 021
|
||||
__OFATS_ANY_INVOCABLE(const, , false, const&) // 100
|
||||
__OFATS_ANY_INVOCABLE(const, , true, const&) // 101
|
||||
__OFATS_ANY_INVOCABLE(const, &, false, const&) // 110
|
||||
__OFATS_ANY_INVOCABLE(const, &, true, const&) // 111
|
||||
__OFATS_ANY_INVOCABLE(const, &&, false, const&&) // 120
|
||||
__OFATS_ANY_INVOCABLE(const, &&, true, const&&) // 121
|
||||
|
||||
#undef __OFATS_ANY_INVOCABLE
|
||||
|
||||
} // namespace ofats
|
||||
|
||||
/* We, uWebSockets define our own type */
|
||||
namespace uWS {
|
||||
template <class T>
|
||||
using MoveOnlyFunction = ofats::any_invocable<T>;
|
||||
}
|
||||
|
||||
#endif // _ANY_INVOKABLE_H_
|
|
@ -0,0 +1,231 @@
|
|||
/*
|
||||
* Authored by Alex Hultman, 2018-2020.
|
||||
* Intellectual property of third-party.
|
||||
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
/* Implements the multipart protocol. Builds atop parts of our common http parser (not yet refactored that way). */
|
||||
/* https://www.w3.org/Protocols/rfc1341/7_2_Multipart.html */
|
||||
|
||||
#ifndef UWS_MULTIPART_H
|
||||
#define UWS_MULTIPART_H
|
||||
|
||||
#include "MessageParser.h"
|
||||
|
||||
#include <string_view>
|
||||
#include <optional>
|
||||
#include <cstring>
|
||||
#include <utility>
|
||||
#include <cctype>
|
||||
|
||||
namespace uWS {
|
||||
|
||||
/* This one could possibly be shared with ExtensionsParser to some degree */
|
||||
struct ParameterParser {
|
||||
|
||||
/* Takes the line, commonly given as content-disposition header in the multipart */
|
||||
ParameterParser(std::string_view line) {
|
||||
remainingLine = line;
|
||||
}
|
||||
|
||||
/* Returns next key/value where value can simply be empty.
|
||||
* If key (first) is empty then we are at the end */
|
||||
std::pair<std::string_view, std::string_view> getKeyValue() {
|
||||
auto key = getToken();
|
||||
auto op = getToken();
|
||||
|
||||
if (!op.length()) {
|
||||
return {key, ""};
|
||||
}
|
||||
|
||||
if (op[0] != ';') {
|
||||
auto value = getToken();
|
||||
/* Strip ; or if at end, nothing */
|
||||
getToken();
|
||||
return {key, value};
|
||||
}
|
||||
|
||||
return {key, ""};
|
||||
}
|
||||
|
||||
private:
|
||||
std::string_view remainingLine;
|
||||
|
||||
/* Consumes a token from the line. Will "unquote" strings */
|
||||
std::string_view getToken() {
|
||||
/* Strip whitespace */
|
||||
while (remainingLine.length() && isspace(remainingLine[0])) {
|
||||
remainingLine.remove_prefix(1);
|
||||
}
|
||||
|
||||
if (!remainingLine.length()) {
|
||||
/* All we had was space */
|
||||
return {};
|
||||
} else {
|
||||
/* Are we at an operator? */
|
||||
if (remainingLine[0] == ';' || remainingLine[0] == '=') {
|
||||
auto op = remainingLine.substr(0, 1);
|
||||
remainingLine.remove_prefix(1);
|
||||
return op;
|
||||
} else {
|
||||
/* Are we at a quoted string? */
|
||||
if (remainingLine[0] == '\"') {
|
||||
/* Remove first quote and start counting */
|
||||
remainingLine.remove_prefix(1);
|
||||
auto quote = remainingLine;
|
||||
int quoteLength = 0;
|
||||
|
||||
/* Read anything until other double quote appears */
|
||||
while (remainingLine.length() && remainingLine[0] != '\"') {
|
||||
remainingLine.remove_prefix(1);
|
||||
quoteLength++;
|
||||
}
|
||||
|
||||
/* We can't remove_prefix if we have nothing to remove */
|
||||
if (!remainingLine.length()) {
|
||||
return {};
|
||||
}
|
||||
|
||||
remainingLine.remove_prefix(1);
|
||||
return quote.substr(0, quoteLength);
|
||||
} else {
|
||||
/* Read anything until ; = space or end */
|
||||
std::string_view token = remainingLine;
|
||||
|
||||
int tokenLength = 0;
|
||||
while (remainingLine.length() && remainingLine[0] != ';' && remainingLine[0] != '=' && !isspace(remainingLine[0])) {
|
||||
remainingLine.remove_prefix(1);
|
||||
tokenLength++;
|
||||
}
|
||||
|
||||
return token.substr(0, tokenLength);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Nothing */
|
||||
return "";
|
||||
}
|
||||
};
|
||||
|
||||
struct MultipartParser {
|
||||
|
||||
/* 2 chars of hyphen + 1 - 70 chars of boundary */
|
||||
char prependedBoundaryBuffer[72];
|
||||
std::string_view prependedBoundary;
|
||||
std::string_view remainingBody;
|
||||
bool first = true;
|
||||
|
||||
/* I think it is more than sane to limit this to 10 per part */
|
||||
//static const int MAX_HEADERS = 10;
|
||||
|
||||
/* Construct the parser based on contentType (reads boundary) */
|
||||
MultipartParser(std::string_view contentType) {
|
||||
|
||||
/* We expect the form "multipart/something;somethingboundary=something" */
|
||||
if (contentType.length() < 10 || contentType.substr(0, 10) != "multipart/") {
|
||||
return;
|
||||
}
|
||||
|
||||
/* For now we simply guess boundary will lie between = and end. This is not entirely
|
||||
* standards compliant as boundary may be expressed with or without " and spaces */
|
||||
auto equalToken = contentType.find('=', 10);
|
||||
if (equalToken != std::string_view::npos) {
|
||||
|
||||
/* Boundary must be less than or equal to 70 chars yet 1 char or longer */
|
||||
std::string_view boundary = contentType.substr(equalToken + 1);
|
||||
if (!boundary.length() || boundary.length() > 70) {
|
||||
/* Invalid size */
|
||||
return;
|
||||
}
|
||||
|
||||
/* Prepend it with two hyphens */
|
||||
prependedBoundaryBuffer[0] = prependedBoundaryBuffer[1] = '-';
|
||||
memcpy(&prependedBoundaryBuffer[2], boundary.data(), boundary.length());
|
||||
|
||||
prependedBoundary = {prependedBoundaryBuffer, boundary.length() + 2};
|
||||
}
|
||||
}
|
||||
|
||||
/* Is this even a valid multipart request? */
|
||||
bool isValid() {
|
||||
return prependedBoundary.length() != 0;
|
||||
}
|
||||
|
||||
/* Set the body once, before getting any parts */
|
||||
void setBody(std::string_view body) {
|
||||
remainingBody = body;
|
||||
}
|
||||
|
||||
/* Parse out the next part's data, filling the headers. Returns nullopt on end or error. */
|
||||
std::optional<std::string_view> getNextPart(std::pair<std::string_view, std::string_view> *headers) {
|
||||
|
||||
/* The remaining two hyphens should be shorter than the boundary */
|
||||
if (remainingBody.length() < prependedBoundary.length()) {
|
||||
/* We are done now */
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
if (first) {
|
||||
auto nextBoundary = remainingBody.find(prependedBoundary);
|
||||
if (nextBoundary == std::string_view::npos) {
|
||||
/* Cannot parse */
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
/* Toss away boundary and anything before it */
|
||||
remainingBody.remove_prefix(nextBoundary + prependedBoundary.length());
|
||||
first = false;
|
||||
}
|
||||
|
||||
auto nextEndBoundary = remainingBody.find(prependedBoundary);
|
||||
if (nextEndBoundary == std::string_view::npos) {
|
||||
/* Cannot parse (or simply done) */
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
std::string_view part = remainingBody.substr(0, nextEndBoundary);
|
||||
remainingBody.remove_prefix(nextEndBoundary + prependedBoundary.length());
|
||||
|
||||
/* Also strip rn before and rn after the part */
|
||||
if (part.length() < 4) {
|
||||
/* Cannot strip */
|
||||
return std::nullopt;
|
||||
}
|
||||
part.remove_prefix(2);
|
||||
part.remove_suffix(2);
|
||||
|
||||
/* We are allowed to post pad like this because we know the boundary is at least 2 bytes */
|
||||
/* This makes parsing a second pass invalid, so you can only iterate over parts once */
|
||||
memset((char *) part.data() + part.length(), '\r', 1);
|
||||
|
||||
/* For this to be a valid part, we need to consume at least 4 bytes (\r\n\r\n) */
|
||||
int consumed = getHeaders((char *) part.data(), (char *) part.data() + part.length(), headers);
|
||||
|
||||
if (!consumed) {
|
||||
/* This is an invalid part */
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
/* Strip away the headers from the part body data */
|
||||
part.remove_prefix(consumed);
|
||||
|
||||
/* Now pass whatever is remaining of the part */
|
||||
return part;
|
||||
}
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
#endif
|
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Authored by Alex Hultman, 2018-2019.
|
||||
* Authored by Alex Hultman, 2018-2021.
|
||||
* Intellectual property of third-party.
|
||||
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
@ -20,26 +20,53 @@
|
|||
#ifndef UWS_PERMESSAGEDEFLATE_H
|
||||
#define UWS_PERMESSAGEDEFLATE_H
|
||||
|
||||
#ifndef UWS_NO_ZLIB
|
||||
/* We always define these options no matter if ZLIB is enabled or not */
|
||||
namespace uWS {
|
||||
/* Compressor mode is HIGH8(windowBits), LOW8(memLevel) */
|
||||
enum CompressOptions : uint32_t {
|
||||
DISABLED = 0,
|
||||
SHARED_COMPRESSOR = 1,
|
||||
DEDICATED_COMPRESSOR_3KB = 9 << 8 | 1,
|
||||
DEDICATED_COMPRESSOR_4KB = 9 << 8 | 2,
|
||||
DEDICATED_COMPRESSOR_8KB = 10 << 8 | 3,
|
||||
DEDICATED_COMPRESSOR_16KB = 11 << 8 | 4,
|
||||
DEDICATED_COMPRESSOR_32KB = 12 << 8 | 5,
|
||||
DEDICATED_COMPRESSOR_64KB = 13 << 8 | 6,
|
||||
DEDICATED_COMPRESSOR_128KB = 14 << 8 | 7,
|
||||
DEDICATED_COMPRESSOR_256KB = 15 << 8 | 8,
|
||||
/* Same as 256kb */
|
||||
DEDICATED_COMPRESSOR = 15 << 8 | 8
|
||||
};
|
||||
}
|
||||
|
||||
#if !defined(UWS_NO_ZLIB) && !defined(UWS_MOCK_ZLIB)
|
||||
#include <zlib.h>
|
||||
#endif
|
||||
|
||||
#include <string>
|
||||
#include <optional>
|
||||
|
||||
#ifdef UWS_USE_LIBDEFLATE
|
||||
#include "libdeflate.h"
|
||||
#include <cstring>
|
||||
#endif
|
||||
|
||||
namespace uWS {
|
||||
|
||||
/* Do not compile this module if we don't want it */
|
||||
#ifdef UWS_NO_ZLIB
|
||||
#if defined(UWS_NO_ZLIB) || defined(UWS_MOCK_ZLIB)
|
||||
struct ZlibContext {};
|
||||
struct InflationStream {
|
||||
std::string_view inflate(ZlibContext *zlibContext, std::string_view compressed, size_t maxPayloadLength) {
|
||||
return compressed;
|
||||
std::optional<std::string_view> inflate(ZlibContext *zlibContext, std::string_view compressed, size_t maxPayloadLength) {
|
||||
return compressed.substr(0, std::min(maxPayloadLength, compressed.length()));
|
||||
}
|
||||
};
|
||||
struct DeflationStream {
|
||||
std::string_view deflate(ZlibContext *zlibContext, std::string_view raw, bool reset) {
|
||||
return raw;
|
||||
}
|
||||
DeflationStream(int compressOptions) {
|
||||
}
|
||||
};
|
||||
#else
|
||||
|
||||
|
@ -54,26 +81,65 @@ struct ZlibContext {
|
|||
char *deflationBuffer;
|
||||
char *inflationBuffer;
|
||||
|
||||
#ifdef UWS_USE_LIBDEFLATE
|
||||
libdeflate_decompressor *decompressor;
|
||||
libdeflate_compressor *compressor;
|
||||
#endif
|
||||
|
||||
ZlibContext() {
|
||||
deflationBuffer = (char *) malloc(LARGE_BUFFER_SIZE);
|
||||
inflationBuffer = (char *) malloc(LARGE_BUFFER_SIZE);
|
||||
|
||||
#ifdef UWS_USE_LIBDEFLATE
|
||||
decompressor = libdeflate_alloc_decompressor();
|
||||
compressor = libdeflate_alloc_compressor(6);
|
||||
#endif
|
||||
}
|
||||
|
||||
~ZlibContext() {
|
||||
free(deflationBuffer);
|
||||
free(inflationBuffer);
|
||||
|
||||
#ifdef UWS_USE_LIBDEFLATE
|
||||
libdeflate_free_decompressor(decompressor);
|
||||
libdeflate_free_compressor(compressor);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
struct DeflationStream {
|
||||
z_stream deflationStream = {};
|
||||
|
||||
DeflationStream() {
|
||||
deflateInit2(&deflationStream, 1, Z_DEFLATED, -15, 8, Z_DEFAULT_STRATEGY);
|
||||
DeflationStream(CompressOptions compressOptions) {
|
||||
|
||||
/* Sliding inflator should be about 44kb by default, less than compressor */
|
||||
|
||||
/* Memory usage is given by 2 ^ (windowBits + 2) + 2 ^ (memLevel + 9) */
|
||||
int windowBits = -(int) ((compressOptions & 0xFF00) >> 8), memLevel = compressOptions & 0x00FF;
|
||||
|
||||
//printf("windowBits: %d, memLevel: %d\n", windowBits, memLevel);
|
||||
|
||||
deflateInit2(&deflationStream, Z_DEFAULT_COMPRESSION, Z_DEFLATED, windowBits, memLevel, Z_DEFAULT_STRATEGY);
|
||||
}
|
||||
|
||||
/* Deflate and optionally reset */
|
||||
/* Deflate and optionally reset. You must not deflate an empty string. */
|
||||
std::string_view deflate(ZlibContext *zlibContext, std::string_view raw, bool reset) {
|
||||
|
||||
#ifdef UWS_USE_LIBDEFLATE
|
||||
/* Run a fast path in case of shared_compressor */
|
||||
if (reset) {
|
||||
size_t written = 0;
|
||||
static unsigned char buf[1024 + 1];
|
||||
|
||||
written = libdeflate_deflate_compress(zlibContext->compressor, raw.data(), raw.length(), buf, 1024);
|
||||
|
||||
if (written) {
|
||||
memcpy(&buf[written], "\x00", 1);
|
||||
return std::string_view((char *) buf, written + 1);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
/* Odd place to clear this one, fix */
|
||||
zlibContext->dynamicDeflationBuffer.clear();
|
||||
|
||||
|
@ -105,9 +171,11 @@ struct DeflationStream {
|
|||
if (zlibContext->dynamicDeflationBuffer.length()) {
|
||||
zlibContext->dynamicDeflationBuffer.append(zlibContext->deflationBuffer, DEFLATE_OUTPUT_CHUNK - deflationStream.avail_out);
|
||||
|
||||
return {(char *) zlibContext->dynamicDeflationBuffer.data(), zlibContext->dynamicDeflationBuffer.length() - 4};
|
||||
return std::string_view((char *) zlibContext->dynamicDeflationBuffer.data(), zlibContext->dynamicDeflationBuffer.length() - 4);
|
||||
}
|
||||
|
||||
/* Note: We will get an interger overflow resulting in heap buffer overflow if Z_BUF_ERROR is returned
|
||||
* from passing 0 as avail_in. Therefore we must not deflate an empty string */
|
||||
return {
|
||||
zlibContext->deflationBuffer,
|
||||
DEFLATE_OUTPUT_CHUNK - deflationStream.avail_out - 4
|
||||
|
@ -130,7 +198,26 @@ struct InflationStream {
|
|||
inflateEnd(&inflationStream);
|
||||
}
|
||||
|
||||
std::string_view inflate(ZlibContext *zlibContext, std::string_view compressed, size_t maxPayloadLength) {
|
||||
/* Zero length inflates are possible and valid */
|
||||
std::optional<std::string_view> inflate(ZlibContext *zlibContext, std::string_view compressed, size_t maxPayloadLength) {
|
||||
|
||||
#ifdef UWS_USE_LIBDEFLATE
|
||||
/* Try fast path first */
|
||||
size_t written = 0;
|
||||
static char buf[1024];
|
||||
|
||||
/* We have to pad 9 bytes and restore those bytes when done since 9 is more than 6 of next WebSocket message */
|
||||
char tmp[9];
|
||||
memcpy(tmp, (char *) compressed.data() + compressed.length(), 9);
|
||||
memcpy((char *) compressed.data() + compressed.length(), "\x00\x00\xff\xff\x01\x00\x00\xff\xff", 9);
|
||||
libdeflate_result res = libdeflate_deflate_decompress(zlibContext->decompressor, compressed.data(), compressed.length() + 9, buf, 1024, &written);
|
||||
memcpy((char *) compressed.data() + compressed.length(), tmp, 9);
|
||||
|
||||
if (res == 0) {
|
||||
/* Fast path wins */
|
||||
return std::string_view(buf, written);
|
||||
}
|
||||
#endif
|
||||
|
||||
/* We clear this one here, could be done better */
|
||||
zlibContext->dynamicInflationBuffer.clear();
|
||||
|
@ -156,7 +243,7 @@ struct InflationStream {
|
|||
inflateReset(&inflationStream);
|
||||
|
||||
if ((err != Z_BUF_ERROR && err != Z_OK) || zlibContext->dynamicInflationBuffer.length() > maxPayloadLength) {
|
||||
return {nullptr, 0};
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
if (zlibContext->dynamicInflationBuffer.length()) {
|
||||
|
@ -164,18 +251,18 @@ struct InflationStream {
|
|||
|
||||
/* Let's be strict about the max size */
|
||||
if (zlibContext->dynamicInflationBuffer.length() > maxPayloadLength) {
|
||||
return {nullptr, 0};
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
return {zlibContext->dynamicInflationBuffer.data(), zlibContext->dynamicInflationBuffer.length()};
|
||||
return std::string_view(zlibContext->dynamicInflationBuffer.data(), zlibContext->dynamicInflationBuffer.length());
|
||||
}
|
||||
|
||||
/* Let's be strict about the max size */
|
||||
if ((LARGE_BUFFER_SIZE - inflationStream.avail_out) > maxPayloadLength) {
|
||||
return {nullptr, 0};
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
return {zlibContext->inflationBuffer, LARGE_BUFFER_SIZE - inflationStream.avail_out};
|
||||
return std::string_view(zlibContext->inflationBuffer, LARGE_BUFFER_SIZE - inflationStream.avail_out);
|
||||
}
|
||||
|
||||
};
|
||||
|
|
|
@ -0,0 +1,163 @@
|
|||
/*
|
||||
* Authored by Alex Hultman, 2018-2020.
|
||||
* Intellectual property of third-party.
|
||||
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
/* This module implements The PROXY Protocol v2 */
|
||||
|
||||
#ifndef UWS_PROXY_PARSER_H
|
||||
#define UWS_PROXY_PARSER_H
|
||||
|
||||
#ifdef UWS_WITH_PROXY
|
||||
|
||||
namespace uWS {
|
||||
|
||||
struct proxy_hdr_v2 {
|
||||
uint8_t sig[12]; /* hex 0D 0A 0D 0A 00 0D 0A 51 55 49 54 0A */
|
||||
uint8_t ver_cmd; /* protocol version and command */
|
||||
uint8_t fam; /* protocol family and address */
|
||||
uint16_t len; /* number of following bytes part of the header */
|
||||
};
|
||||
|
||||
union proxy_addr {
|
||||
struct { /* for TCP/UDP over IPv4, len = 12 */
|
||||
uint32_t src_addr;
|
||||
uint32_t dst_addr;
|
||||
uint16_t src_port;
|
||||
uint16_t dst_port;
|
||||
} ipv4_addr;
|
||||
struct { /* for TCP/UDP over IPv6, len = 36 */
|
||||
uint8_t src_addr[16];
|
||||
uint8_t dst_addr[16];
|
||||
uint16_t src_port;
|
||||
uint16_t dst_port;
|
||||
} ipv6_addr;
|
||||
};
|
||||
|
||||
/* Byte swap for little-endian systems */
|
||||
/* Todo: This functions should be shared with the one in WebSocketProtocol.h! */
|
||||
template <typename T>
|
||||
T _cond_byte_swap(T value) {
|
||||
uint32_t endian_test = 1;
|
||||
if (*((char *)&endian_test)) {
|
||||
union {
|
||||
T i;
|
||||
uint8_t b[sizeof(T)];
|
||||
} src = { value }, dst;
|
||||
|
||||
for (unsigned int i = 0; i < sizeof(value); i++) {
|
||||
dst.b[i] = src.b[sizeof(value) - 1 - i];
|
||||
}
|
||||
|
||||
return dst.i;
|
||||
}
|
||||
return value;
|
||||
}
|
||||
|
||||
struct ProxyParser {
|
||||
private:
|
||||
union proxy_addr addr;
|
||||
|
||||
/* Default family of 0 signals no proxy address */
|
||||
uint8_t family = 0;
|
||||
|
||||
public:
|
||||
/* Returns 4 or 16 bytes source address */
|
||||
std::string_view getSourceAddress() {
|
||||
|
||||
// UNSPEC family and protocol
|
||||
if (family == 0) {
|
||||
return {};
|
||||
}
|
||||
|
||||
if ((family & 0xf0) >> 4 == 1) {
|
||||
/* Family 1 is INET4 */
|
||||
return {(char *) &addr.ipv4_addr.src_addr, 4};
|
||||
} else {
|
||||
/* Family 2 is INET6 */
|
||||
return {(char *) &addr.ipv6_addr.src_addr, 16};
|
||||
}
|
||||
}
|
||||
|
||||
/* Returns [done, consumed] where done = false on failure */
|
||||
std::pair<bool, unsigned int> parse(std::string_view data) {
|
||||
|
||||
/* We require at least four bytes to determine protocol */
|
||||
if (data.length() < 4) {
|
||||
return {false, 0};
|
||||
}
|
||||
|
||||
/* HTTP can never start with "\r\n\r\n", but PROXY always does */
|
||||
if (memcmp(data.data(), "\r\n\r\n", 4)) {
|
||||
/* This is HTTP, so be done */
|
||||
return {true, 0};
|
||||
}
|
||||
|
||||
/* We assume we are parsing PROXY V2 here */
|
||||
|
||||
/* We require 16 bytes here */
|
||||
if (data.length() < 16) {
|
||||
return {false, 0};
|
||||
}
|
||||
|
||||
/* Header is 16 bytes */
|
||||
struct proxy_hdr_v2 header;
|
||||
memcpy(&header, data.data(), 16);
|
||||
|
||||
if (memcmp(header.sig, "\x0D\x0A\x0D\x0A\x00\x0D\x0A\x51\x55\x49\x54\x0A", 12)) {
|
||||
/* This is not PROXY protocol at all */
|
||||
return {false, 0};
|
||||
}
|
||||
|
||||
/* We only support version 2 */
|
||||
if ((header.ver_cmd & 0xf0) >> 4 != 2) {
|
||||
return {false, 0};
|
||||
}
|
||||
|
||||
//printf("Version: %d\n", (header.ver_cmd & 0xf0) >> 4);
|
||||
//printf("Command: %d\n", (header.ver_cmd & 0x0f));
|
||||
|
||||
/* We get length in network byte order (todo: share this function with the rest) */
|
||||
uint16_t hostLength = _cond_byte_swap<uint16_t>(header.len);
|
||||
|
||||
/* We must have all the data available */
|
||||
if (data.length() < 16u + hostLength) {
|
||||
return {false, 0};
|
||||
}
|
||||
|
||||
/* Payload cannot be more than sizeof proxy_addr */
|
||||
if (sizeof(proxy_addr) < hostLength) {
|
||||
return {false, 0};
|
||||
}
|
||||
|
||||
//printf("Family: %d\n", (header.fam & 0xf0) >> 4);
|
||||
//printf("Transport: %d\n", (header.fam & 0x0f));
|
||||
|
||||
/* We have 0 family by default, and UNSPEC is 0 as well */
|
||||
family = header.fam;
|
||||
|
||||
/* Copy payload */
|
||||
memcpy(&addr, data.data() + 16, hostLength);
|
||||
|
||||
/* We consumed everything */
|
||||
return {true, 16 + hostLength};
|
||||
}
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
#endif // UWS_PROXY_PARSER_H
|
|
@ -0,0 +1,120 @@
|
|||
/*
|
||||
* Authored by Alex Hultman, 2018-2020.
|
||||
* Intellectual property of third-party.
|
||||
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
/* This module implements URI query parsing and retrieval of value given key */
|
||||
|
||||
#ifndef UWS_QUERYPARSER_H
|
||||
#define UWS_QUERYPARSER_H
|
||||
|
||||
#include <string_view>
|
||||
|
||||
namespace uWS {
|
||||
|
||||
/* Takes raw query including initial '?' sign. Will inplace decode, so input will mutate */
|
||||
static inline std::string_view getDecodedQueryValue(std::string_view key, std::string_view rawQuery) {
|
||||
|
||||
/* Can't have a value without a key */
|
||||
if (!key.length()) {
|
||||
return {};
|
||||
}
|
||||
|
||||
/* Start with the whole querystring including initial '?' */
|
||||
std::string_view queryString = rawQuery;
|
||||
|
||||
/* List of key, value could be cached for repeated fetches similar to how headers are, todo! */
|
||||
while (queryString.length()) {
|
||||
/* Find boundaries of this statement */
|
||||
std::string_view statement = queryString.substr(1, queryString.find('&', 1) - 1);
|
||||
|
||||
/* Only bother if first char of key match (early exit) */
|
||||
if (statement.length() && statement[0] == key[0]) {
|
||||
/* Equal sign must be present and not in the end of statement */
|
||||
auto equality = statement.find('=');
|
||||
if (equality != std::string_view::npos && equality != statement.length() - 1) {
|
||||
|
||||
std::string_view statementKey = statement.substr(0, equality);
|
||||
std::string_view statementValue = statement.substr(equality + 1);
|
||||
|
||||
/* String comparison */
|
||||
if (key == statementKey) {
|
||||
|
||||
/* Decode value inplace, put null at end if before length of original */
|
||||
char *in = (char *) statementValue.data();
|
||||
|
||||
/* Write offset */
|
||||
unsigned int out = 0;
|
||||
|
||||
/* Walk over all chars until end or null char, decoding in place */
|
||||
for (unsigned int i = 0; i < statementValue.length() && in[i]; i++) {
|
||||
/* Only bother with '%' */
|
||||
if (in[i] == '%') {
|
||||
/* Do we have enough data for two bytes hex? */
|
||||
if (i + 2 >= statementValue.length()) {
|
||||
return {};
|
||||
}
|
||||
|
||||
/* Two bytes hex */
|
||||
int hex1 = in[i + 1] - '0';
|
||||
if (hex1 > 9) {
|
||||
hex1 &= 223;
|
||||
hex1 -= 7;
|
||||
}
|
||||
|
||||
int hex2 = in[i + 2] - '0';
|
||||
if (hex2 > 9) {
|
||||
hex2 &= 223;
|
||||
hex2 -= 7;
|
||||
}
|
||||
|
||||
*((unsigned char *) &in[out]) = (unsigned char) (hex1 * 16 + hex2);
|
||||
i += 2;
|
||||
} else {
|
||||
/* Is this even a rule? */
|
||||
if (in[i] == '+') {
|
||||
in[out] = ' ';
|
||||
} else {
|
||||
in[out] = in[i];
|
||||
}
|
||||
}
|
||||
|
||||
/* We always only write one char */
|
||||
out++;
|
||||
}
|
||||
|
||||
/* If decoded string is shorter than original, put null char to stop next read */
|
||||
if (out < statementValue.length()) {
|
||||
in[out] = 0;
|
||||
}
|
||||
|
||||
return statementValue.substr(0, out);
|
||||
}
|
||||
} else {
|
||||
/* This querystring is invalid, cannot parse it */
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
queryString.remove_prefix(statement.length() + 1);
|
||||
}
|
||||
|
||||
/* Nothing found */
|
||||
return {};
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
#endif
|
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Authored by Alex Hultman, 2018-2019.
|
||||
* Authored by Alex Hultman, 2018-2021.
|
||||
* Intellectual property of third-party.
|
||||
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
@ -26,6 +26,10 @@
|
|||
#include <set>
|
||||
#include <chrono>
|
||||
#include <list>
|
||||
#include <cstring>
|
||||
|
||||
/* We use std::function here, not MoveOnlyFunction */
|
||||
#include <functional>
|
||||
|
||||
namespace uWS {
|
||||
|
||||
|
@ -57,29 +61,123 @@ struct Topic {
|
|||
/* Terminating wildcard child */
|
||||
Topic *terminatingWildcardChild = nullptr;
|
||||
|
||||
/* What we published */
|
||||
std::map<unsigned int, std::string> messages;
|
||||
/* What we published, {inflated, deflated} */
|
||||
std::map<unsigned int, std::pair<std::string, std::string>> messages;
|
||||
|
||||
std::set<Subscriber *> subs;
|
||||
|
||||
/* Locked or not, used only when iterating over a Subscriber's topics */
|
||||
bool locked = false;
|
||||
|
||||
/* Full name is used when iterating topcis */
|
||||
std::string fullName;
|
||||
};
|
||||
|
||||
struct Hole {
|
||||
std::pair<size_t, size_t> lengths;
|
||||
unsigned int messageId;
|
||||
};
|
||||
|
||||
struct Intersection {
|
||||
std::pair<std::string, std::string> dataChannels;
|
||||
std::vector<Hole> holes;
|
||||
|
||||
void forSubscriber(std::vector<unsigned int> &senderForMessages, std::function<void(std::pair<std::string_view, std::string_view>, bool)> cb) {
|
||||
/* How far we already emitted of the two dataChannels */
|
||||
std::pair<size_t, size_t> emitted = {};
|
||||
|
||||
/* Holes are global to the entire topic tree, so we are not guaranteed to find
|
||||
* holes in this intersection - they are sorted, though */
|
||||
unsigned int examinedHoles = 0;
|
||||
|
||||
/* This is a slow path of sorts, most subscribers will be observers, not active senders */
|
||||
for (unsigned int id : senderForMessages) {
|
||||
std::pair<size_t, size_t> toEmit = {};
|
||||
std::pair<size_t, size_t> toIgnore = {};
|
||||
|
||||
/* This linear search is most probably very small - it could be made log2 if every hole
|
||||
* knows about its previous accumulated length, which is easy to set up. However this
|
||||
* log2 search will most likely never be a warranted perf. gain */
|
||||
for (; examinedHoles < holes.size(); examinedHoles++) {
|
||||
if (holes[examinedHoles].messageId == id) {
|
||||
toIgnore.first += holes[examinedHoles].lengths.first;
|
||||
toIgnore.second += holes[examinedHoles].lengths.second;
|
||||
examinedHoles++;
|
||||
break;
|
||||
}
|
||||
/* We are not the sender of this message so we should emit it in this segment */
|
||||
toEmit.first += holes[examinedHoles].lengths.first;
|
||||
toEmit.second += holes[examinedHoles].lengths.second;
|
||||
}
|
||||
|
||||
/* Emit this segment */
|
||||
if (toEmit.first || toEmit.second) {
|
||||
std::pair<std::string_view, std::string_view> cutDataChannels = {
|
||||
std::string_view(dataChannels.first.data() + emitted.first, toEmit.first),
|
||||
std::string_view(dataChannels.second.data() + emitted.second, toEmit.second),
|
||||
};
|
||||
|
||||
/* We only need to test the first data channel for "FIN" */
|
||||
cb(cutDataChannels, emitted.first + toEmit.first + toIgnore.first == dataChannels.first.length());
|
||||
}
|
||||
|
||||
emitted.first += toEmit.first + toIgnore.first;
|
||||
emitted.second += toEmit.second + toIgnore.second;
|
||||
}
|
||||
|
||||
if (emitted.first == dataChannels.first.length() && emitted.second == dataChannels.second.length()) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::pair<std::string_view, std::string_view> cutDataChannels = {
|
||||
std::string_view(dataChannels.first.data() + emitted.first, dataChannels.first.length() - emitted.first),
|
||||
std::string_view(dataChannels.second.data() + emitted.second, dataChannels.second.length() - emitted.second),
|
||||
};
|
||||
|
||||
cb(cutDataChannels, true);
|
||||
}
|
||||
};
|
||||
|
||||
struct TopicTree {
|
||||
/* Returns Topic, or nullptr. Topic can be root if empty string given. */
|
||||
Topic *lookupTopic(std::string_view topic) {
|
||||
/* Lookup exact Topic ptr from string */
|
||||
Topic *iterator = root;
|
||||
for (size_t start = 0, stop = 0; stop != std::string::npos; start = stop + 1) {
|
||||
stop = topic.find('/', start);
|
||||
std::string_view segment = topic.substr(start, stop - start);
|
||||
|
||||
std::map<std::string_view, Topic *>::iterator it = iterator->children.find(segment);
|
||||
if (it == iterator->children.end()) {
|
||||
/* This topic does not even exist */
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
iterator = it->second;
|
||||
}
|
||||
|
||||
return iterator;
|
||||
}
|
||||
|
||||
private:
|
||||
std::function<int(Subscriber *, std::string_view)> cb;
|
||||
std::function<int(Subscriber *, Intersection &)> cb;
|
||||
|
||||
Topic *root = new Topic;
|
||||
|
||||
/* Global messageId for deduplication of overlapping topics and ordering between topics */
|
||||
unsigned int messageId = 0;
|
||||
|
||||
/* Sender holes */
|
||||
std::map<Subscriber *, std::vector<unsigned int>> senderHoles;
|
||||
|
||||
/* The triggered topics */
|
||||
Topic *triggeredTopics[64];
|
||||
int numTriggeredTopics = 0;
|
||||
Subscriber *min = (Subscriber *) UINTPTR_MAX;
|
||||
|
||||
|
||||
/* Cull or trim unused Topic nodes from leaf to root */
|
||||
void trimTree(Topic *topic) {
|
||||
if (!topic->subs.size() && !topic->children.size() && !topic->terminatingWildcardChild && !topic->wildcardChild) {
|
||||
while (!topic->subs.size() && !topic->children.size() && !topic->terminatingWildcardChild && !topic->wildcardChild) {
|
||||
Topic *parent = topic->parent;
|
||||
|
||||
if (topic->length == 1) {
|
||||
|
@ -112,43 +210,64 @@ private:
|
|||
delete [] topic->name;
|
||||
delete topic;
|
||||
|
||||
if (parent != root) {
|
||||
trimTree(parent);
|
||||
if (parent == root) {
|
||||
break;
|
||||
}
|
||||
|
||||
topic = parent;
|
||||
}
|
||||
}
|
||||
|
||||
/* Should be getData and commit? */
|
||||
void publish(Topic *iterator, size_t start, size_t stop, std::string_view topic, std::string_view message) {
|
||||
/* If we already have 64 triggered topics make sure to drain it here */
|
||||
if (numTriggeredTopics == 64) {
|
||||
drain();
|
||||
}
|
||||
/* Publishes to all matching topics and wildcards. Returns whether at least one topic was a match. */
|
||||
bool publish(Topic *iterator, size_t start, size_t stop, std::string_view topic, std::pair<std::string_view, std::string_view> message) {
|
||||
|
||||
/* Whether we matched with at least one topic */
|
||||
bool didMatch = false;
|
||||
|
||||
/* Iterate over all segments in given topic */
|
||||
for (; stop != std::string::npos; start = stop + 1) {
|
||||
stop = topic.find('/', start);
|
||||
std::string_view segment = topic.substr(start, stop - start);
|
||||
|
||||
/* It is very important to disallow wildcards when publishing.
|
||||
* We will not catch EVERY misuse this lazy way, but enough to hinder
|
||||
* explosive recursion.
|
||||
* Terminating wildcards MAY still get triggered along the way, if for
|
||||
* instace the error is found late while iterating the topic segments. */
|
||||
if (segment.length() == 1) {
|
||||
if (segment[0] == '+' || segment[0] == '#') {
|
||||
/* "Fail" here, but not necessarily for the entire publish */
|
||||
return didMatch;
|
||||
}
|
||||
}
|
||||
|
||||
/* Do we have a terminating wildcard child? */
|
||||
if (iterator->terminatingWildcardChild) {
|
||||
iterator->terminatingWildcardChild->messages[messageId] = message;
|
||||
|
||||
/* Add this topic to triggered */
|
||||
if (!iterator->terminatingWildcardChild->triggered) {
|
||||
/* If we already have 64 triggered topics make sure to drain it here */
|
||||
if (numTriggeredTopics == 64) {
|
||||
drain();
|
||||
}
|
||||
|
||||
triggeredTopics[numTriggeredTopics++] = iterator->terminatingWildcardChild;
|
||||
iterator->terminatingWildcardChild->triggered = true;
|
||||
}
|
||||
|
||||
didMatch = true;
|
||||
}
|
||||
|
||||
/* Do we have a wildcard child? */
|
||||
if (iterator->wildcardChild) {
|
||||
publish(iterator->wildcardChild, stop + 1, stop, topic, message);
|
||||
didMatch |= publish(iterator->wildcardChild, stop + 1, stop, topic, message);
|
||||
}
|
||||
|
||||
std::map<std::string_view, Topic *>::iterator it = iterator->children.find(segment);
|
||||
if (it == iterator->children.end()) {
|
||||
/* Stop trying to match by exact string */
|
||||
return;
|
||||
return didMatch;
|
||||
}
|
||||
|
||||
iterator = it->second;
|
||||
|
@ -159,14 +278,22 @@ private:
|
|||
|
||||
/* Add this topic to triggered */
|
||||
if (!iterator->triggered) {
|
||||
/* If we already have 64 triggered topics make sure to drain it here */
|
||||
if (numTriggeredTopics == 64) {
|
||||
drain();
|
||||
}
|
||||
|
||||
triggeredTopics[numTriggeredTopics++] = iterator;
|
||||
iterator->triggered = true;
|
||||
}
|
||||
|
||||
/* We obviously matches exactly here */
|
||||
return true;
|
||||
}
|
||||
|
||||
public:
|
||||
|
||||
TopicTree(std::function<int(Subscriber *, std::string_view)> cb) {
|
||||
TopicTree(std::function<int(Subscriber *, Intersection &)> cb) {
|
||||
this->cb = cb;
|
||||
}
|
||||
|
||||
|
@ -174,7 +301,20 @@ public:
|
|||
delete root;
|
||||
}
|
||||
|
||||
void subscribe(std::string_view topic, Subscriber *subscriber) {
|
||||
/* This is part of the fast path, so should be optimal */
|
||||
std::vector<unsigned int> &getSenderFor(Subscriber *s) {
|
||||
static thread_local std::vector<unsigned int> emptyVector;
|
||||
|
||||
auto it = senderHoles.find(s);
|
||||
if (it != senderHoles.end()) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
return emptyVector;
|
||||
}
|
||||
|
||||
/* Returns number of subscribers after the call and whether or not we were successful in subscribing */
|
||||
std::pair<unsigned int, bool> subscribe(std::string_view topic, Subscriber *subscriber, bool nonStrict = false) {
|
||||
/* Start iterating from the root */
|
||||
Topic *iterator = root;
|
||||
|
||||
|
@ -196,7 +336,18 @@ public:
|
|||
newTopic->terminatingWildcardChild = nullptr;
|
||||
newTopic->wildcardChild = nullptr;
|
||||
memcpy(newTopic->name, segment.data(), segment.length());
|
||||
|
||||
|
||||
/* Set fullname as parent's name plus our name */
|
||||
newTopic->fullName.reserve(newTopic->parent->fullName.length() + 1 + segment.length());
|
||||
|
||||
/* Only append parent's name if parent is not root */
|
||||
if (newTopic->parent != root) {
|
||||
newTopic->fullName.append(newTopic->parent->fullName);
|
||||
newTopic->fullName.append("/");
|
||||
}
|
||||
|
||||
newTopic->fullName.append(segment);
|
||||
|
||||
/* For simplicity we do insert wildcards with text */
|
||||
iterator->children.insert(lb, {std::string_view(newTopic->name, segment.length()), newTopic});
|
||||
|
||||
|
@ -216,22 +367,38 @@ public:
|
|||
}
|
||||
}
|
||||
|
||||
/* If this topic is triggered, drain the tree before we join */
|
||||
if (iterator->triggered) {
|
||||
if (!nonStrict) {
|
||||
drain();
|
||||
}
|
||||
}
|
||||
|
||||
/* Add socket to Topic's Set */
|
||||
auto [it, inserted] = iterator->subs.insert(subscriber);
|
||||
|
||||
/* Add Topic to list of subscriptions only if we weren't already subscribed */
|
||||
if (inserted) {
|
||||
subscriber->subscriptions.push_back(iterator);
|
||||
return {(unsigned int) iterator->subs.size(), true};
|
||||
}
|
||||
return {(unsigned int) iterator->subs.size(), false};
|
||||
}
|
||||
|
||||
void publish(std::string_view topic, std::string_view message) {
|
||||
publish(root, 0, 0, topic, message);
|
||||
bool publish(std::string_view topic, std::pair<std::string_view, std::string_view> message, Subscriber *sender = nullptr) {
|
||||
/* Add a hole for the sender if one */
|
||||
if (sender) {
|
||||
senderHoles[sender].push_back(messageId);
|
||||
}
|
||||
|
||||
auto ret = publish(root, 0, 0, topic, message);
|
||||
/* MessageIDs are reset on drain - this should be fine since messages itself are cleared on drain */
|
||||
messageId++;
|
||||
return ret;
|
||||
}
|
||||
|
||||
/* Returns whether we were subscribed prior */
|
||||
bool unsubscribe(std::string_view topic, Subscriber *subscriber) {
|
||||
/* Returns a pair of numSubscribers after operation, and whether we were subscribed prior */
|
||||
std::pair<unsigned int, bool> unsubscribe(std::string_view topic, Subscriber *subscriber, bool nonStrict = false) {
|
||||
/* Subscribers are likely to have very few subscriptions (20 or fewer) */
|
||||
if (subscriber) {
|
||||
/* Lookup exact Topic ptr from string */
|
||||
|
@ -243,32 +410,55 @@ public:
|
|||
std::map<std::string_view, Topic *>::iterator it = iterator->children.find(segment);
|
||||
if (it == iterator->children.end()) {
|
||||
/* This topic does not even exist */
|
||||
return false;
|
||||
return {0, false};
|
||||
}
|
||||
|
||||
iterator = it->second;
|
||||
}
|
||||
|
||||
/* Is this topic locked? If so, we cannot unsubscribe from it */
|
||||
if (iterator->locked) {
|
||||
return {iterator->subs.size(), false};
|
||||
}
|
||||
|
||||
/* Try and remove this topic from our list */
|
||||
for (auto it = subscriber->subscriptions.begin(); it != subscriber->subscriptions.end(); it++) {
|
||||
if (*it == iterator) {
|
||||
/* If this topic is triggered, drain the tree before we leave */
|
||||
if (iterator->triggered) {
|
||||
if (!nonStrict) {
|
||||
drain();
|
||||
}
|
||||
}
|
||||
|
||||
/* Remove topic ptr from our list */
|
||||
subscriber->subscriptions.erase(it);
|
||||
|
||||
/* Remove us from Topic's subs */
|
||||
iterator->subs.erase(subscriber);
|
||||
unsigned int numSubscribers = (unsigned int) iterator->subs.size();
|
||||
trimTree(iterator);
|
||||
return true;
|
||||
return {numSubscribers, true};
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
return {0, false};
|
||||
}
|
||||
|
||||
/* Can be called with nullptr, ignore it then */
|
||||
void unsubscribeAll(Subscriber *subscriber) {
|
||||
void unsubscribeAll(Subscriber *subscriber, bool mayFlush = true) {
|
||||
if (subscriber) {
|
||||
for (Topic *topic : subscriber->subscriptions) {
|
||||
|
||||
/* We do not want to flush when closing a socket, it makes no sense to do so */
|
||||
|
||||
/* If this topic is triggered, drain the tree before we leave */
|
||||
if (mayFlush && topic->triggered) {
|
||||
/* Never mind nonStrict here (yet?) */
|
||||
drain();
|
||||
}
|
||||
|
||||
/* Remove us from the topic's set */
|
||||
topic->subs.erase(subscriber);
|
||||
trimTree(topic);
|
||||
}
|
||||
|
@ -290,11 +480,18 @@ public:
|
|||
for (int i = 0; i < numTriggeredTopics; i++) {
|
||||
if (triggeredTopics[i]->subs.size()) {
|
||||
triggeredTopics[numFilteredTriggeredTopics++] = triggeredTopics[i];
|
||||
} else {
|
||||
/* If we no longer have any subscribers, yet still keep this Topic alive (parent),
|
||||
* make sure to clear its potential messages. */
|
||||
triggeredTopics[i]->messages.clear();
|
||||
triggeredTopics[i]->triggered = false;
|
||||
}
|
||||
}
|
||||
numTriggeredTopics = numFilteredTriggeredTopics;
|
||||
|
||||
if (!numTriggeredTopics) {
|
||||
senderHoles.clear();
|
||||
messageId = 0;
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -310,7 +507,7 @@ public:
|
|||
if (min != (Subscriber *)UINTPTR_MAX) {
|
||||
|
||||
/* Up to 64 triggered Topics per batch */
|
||||
std::map<uint64_t, std::string> intersectionCache;
|
||||
std::map<uint64_t, Intersection> intersectionCache;
|
||||
|
||||
/* Loop over these here */
|
||||
std::set<Subscriber *>::iterator it[64];
|
||||
|
@ -319,14 +516,14 @@ public:
|
|||
it[i] = triggeredTopics[i]->subs.begin();
|
||||
end[i] = triggeredTopics[i]->subs.end();
|
||||
}
|
||||
|
||||
|
||||
/* Empty all sets from unique subscribers */
|
||||
for (int nonEmpty = numTriggeredTopics; nonEmpty; ) {
|
||||
|
||||
Subscriber *nextMin = (Subscriber *)UINTPTR_MAX;
|
||||
|
||||
/* The message sets relevant for this intersection */
|
||||
std::map<unsigned int, std::string> *perSubscriberIntersectingTopicMessages[64];
|
||||
std::map<unsigned int, std::pair<std::string, std::string>> *perSubscriberIntersectingTopicMessages[64];
|
||||
int numPerSubscriberIntersectingTopicMessages = 0;
|
||||
|
||||
uint64_t intersection = 0;
|
||||
|
@ -357,18 +554,28 @@ public:
|
|||
}
|
||||
|
||||
/* Generate cache for intersection */
|
||||
if (intersectionCache[intersection].length() == 0) {
|
||||
if (intersectionCache[intersection].dataChannels.first.length() == 0) {
|
||||
|
||||
/* Build the union in order without duplicates */
|
||||
std::map<unsigned int, std::string> complete;
|
||||
std::map<unsigned int, std::pair<std::string, std::string>> complete;
|
||||
for (int i = 0; i < numPerSubscriberIntersectingTopicMessages; i++) {
|
||||
complete.insert(perSubscriberIntersectingTopicMessages[i]->begin(), perSubscriberIntersectingTopicMessages[i]->end());
|
||||
}
|
||||
|
||||
/* Create the linear cache */
|
||||
std::string res;
|
||||
/* Create the linear cache, {inflated, deflated} */
|
||||
Intersection res;
|
||||
for (auto &p : complete) {
|
||||
res.append(p.second);
|
||||
res.dataChannels.first.append(p.second.first);
|
||||
res.dataChannels.second.append(p.second.second);
|
||||
|
||||
/* Appends {id, length, length}
|
||||
* We could possibly append byte offset also,
|
||||
* if we want to use log2 search later. */
|
||||
Hole h;
|
||||
h.lengths.first = p.second.first.length();
|
||||
h.lengths.second = p.second.second.length();
|
||||
h.messageId = p.first;
|
||||
res.holes.push_back(h);
|
||||
}
|
||||
|
||||
cb(min, intersectionCache[intersection] = std::move(res));
|
||||
|
@ -379,7 +586,6 @@ public:
|
|||
|
||||
min = nextMin;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/* Clear messages of triggered Topics */
|
||||
|
@ -388,27 +594,8 @@ public:
|
|||
triggeredTopics[i]->triggered = false;
|
||||
}
|
||||
numTriggeredTopics = 0;
|
||||
}
|
||||
|
||||
void print(Topic *root = nullptr, int indentation = 1) {
|
||||
if (root == nullptr) {
|
||||
std::cout << "Print of tree:" << std::endl;
|
||||
root = this->root;
|
||||
}
|
||||
|
||||
for (auto p : root->children) {
|
||||
for (int i = 0; i < indentation; i++) {
|
||||
std::cout << " ";
|
||||
}
|
||||
std::cout << std::string_view(p.second->name, p.second->length) << " = " << p.second->messages.size() << " publishes, " << p.second->subs.size() << " subscribers {";
|
||||
|
||||
for (auto &p : p.second->subs) {
|
||||
std::cout << p << " referring to socket: " << p->user << ", ";
|
||||
}
|
||||
std::cout << "}" << std::endl;
|
||||
|
||||
print(p.second, indentation + 1);
|
||||
}
|
||||
senderHoles.clear();
|
||||
messageId = 0;
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Authored by Alex Hultman, 2018-2019.
|
||||
* Authored by Alex Hultman, 2018-2020.
|
||||
* Intellectual property of third-party.
|
||||
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
@ -43,8 +43,8 @@ inline int u32toaHex(uint32_t value, char *dst) {
|
|||
return ret;
|
||||
}
|
||||
|
||||
inline int u32toa(uint32_t value, char *dst) {
|
||||
char temp[10];
|
||||
inline int u64toa(uint64_t value, char *dst) {
|
||||
char temp[20];
|
||||
char *p = temp;
|
||||
do {
|
||||
*p++ = (char) ((value % 10) + '0');
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Authored by Alex Hultman, 2018-2019.
|
||||
* Authored by Alex Hultman, 2018-2021.
|
||||
* Intellectual property of third-party.
|
||||
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
@ -27,40 +27,63 @@
|
|||
|
||||
namespace uWS {
|
||||
|
||||
template <bool SSL, bool isServer>
|
||||
template <bool SSL, bool isServer, typename USERDATA>
|
||||
struct WebSocket : AsyncSocket<SSL> {
|
||||
template <bool> friend struct TemplatedApp;
|
||||
template <bool> friend struct HttpResponse;
|
||||
private:
|
||||
typedef AsyncSocket<SSL> Super;
|
||||
|
||||
void *init(bool perMessageDeflate, bool slidingCompression, std::string &&backpressure) {
|
||||
new (us_socket_ext(SSL, (us_socket_t *) this)) WebSocketData(perMessageDeflate, slidingCompression, std::move(backpressure));
|
||||
void *init(bool perMessageDeflate, CompressOptions compressOptions, std::string &&backpressure) {
|
||||
new (us_socket_ext(SSL, (us_socket_t *) this)) WebSocketData(perMessageDeflate, compressOptions, std::move(backpressure));
|
||||
return this;
|
||||
}
|
||||
public:
|
||||
|
||||
/* Returns pointer to the per socket user data */
|
||||
void *getUserData() {
|
||||
USERDATA *getUserData() {
|
||||
WebSocketData *webSocketData = (WebSocketData *) us_socket_ext(SSL, (us_socket_t *) this);
|
||||
/* We just have it overallocated by sizeof type */
|
||||
return (webSocketData + 1);
|
||||
return (USERDATA *) (webSocketData + 1);
|
||||
}
|
||||
|
||||
/* See AsyncSocket */
|
||||
using Super::getBufferedAmount;
|
||||
using Super::getRemoteAddress;
|
||||
using Super::getRemoteAddressAsText;
|
||||
using Super::getNativeHandle;
|
||||
|
||||
/* Simple, immediate close of the socket. Emits close event */
|
||||
using Super::close;
|
||||
|
||||
/* Send or buffer a WebSocket frame, compressed or not. Returns false on increased user space backpressure. */
|
||||
bool send(std::string_view message, uWS::OpCode opCode = uWS::OpCode::BINARY, bool compress = false) {
|
||||
enum SendStatus : int {
|
||||
BACKPRESSURE,
|
||||
SUCCESS,
|
||||
DROPPED
|
||||
};
|
||||
|
||||
/* Send or buffer a WebSocket frame, compressed or not. Returns BACKPRESSURE on increased user space backpressure,
|
||||
* DROPPED on dropped message (due to backpressure) or SUCCCESS if you are free to send even more now. */
|
||||
SendStatus send(std::string_view message, OpCode opCode = OpCode::BINARY, bool compress = false) {
|
||||
WebSocketContextData<SSL, USERDATA> *webSocketContextData = (WebSocketContextData<SSL, USERDATA> *) us_socket_context_ext(SSL,
|
||||
(us_socket_context_t *) us_socket_context(SSL, (us_socket_t *) this)
|
||||
);
|
||||
|
||||
/* Skip sending and report success if we are over the limit of maxBackpressure */
|
||||
if (webSocketContextData->maxBackpressure && webSocketContextData->maxBackpressure < getBufferedAmount()) {
|
||||
/* Also defer a close if we should */
|
||||
if (webSocketContextData->closeOnBackpressureLimit) {
|
||||
us_socket_shutdown_read(SSL, (us_socket_t *) this);
|
||||
}
|
||||
return DROPPED;
|
||||
}
|
||||
|
||||
/* Transform the message to compressed domain if requested */
|
||||
if (compress) {
|
||||
WebSocketData *webSocketData = (WebSocketData *) Super::getAsyncSocketData();
|
||||
|
||||
/* Check and correct the compress hint */
|
||||
if (opCode < 3 && webSocketData->compressionStatus == WebSocketData::ENABLED) {
|
||||
/* Check and correct the compress hint. It is never valid to compress 0 bytes */
|
||||
if (message.length() && opCode < 3 && webSocketData->compressionStatus == WebSocketData::ENABLED) {
|
||||
LoopData *loopData = Super::getLoopData();
|
||||
/* Compress using either shared or dedicated deflationStream */
|
||||
if (webSocketData->deflationStream) {
|
||||
|
@ -82,7 +105,7 @@ public:
|
|||
|
||||
/* Get size, alloate size, write if needed */
|
||||
size_t messageFrameSize = protocol::messageFrameSize(message.length());
|
||||
auto[sendBuffer, requiresWrite] = Super::getSendBuffer(messageFrameSize);
|
||||
auto [sendBuffer, requiresWrite] = Super::getSendBuffer(messageFrameSize);
|
||||
protocol::formatMessage<isServer>(sendBuffer, message.data(), message.length(), opCode, message.length(), compress);
|
||||
/* This is the slow path, when we couldn't cork for the user */
|
||||
if (requiresWrite) {
|
||||
|
@ -93,7 +116,7 @@ public:
|
|||
|
||||
if (failed) {
|
||||
/* Return false for failure, skipping to reset the timeout below */
|
||||
return false;
|
||||
return BACKPRESSURE;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -101,22 +124,24 @@ public:
|
|||
if (automaticallyCorked) {
|
||||
auto [written, failed] = Super::uncork();
|
||||
if (failed) {
|
||||
return false;
|
||||
return BACKPRESSURE;
|
||||
}
|
||||
}
|
||||
|
||||
/* Every successful send resets the timeout */
|
||||
WebSocketContextData<SSL> *webSocketContextData = (WebSocketContextData<SSL> *) us_socket_context_ext(SSL,
|
||||
(us_socket_context_t *) us_socket_context(SSL, (us_socket_t *) this)
|
||||
);
|
||||
AsyncSocket<SSL>::timeout(webSocketContextData->idleTimeout);
|
||||
if (webSocketContextData->resetIdleTimeoutOnSend) {
|
||||
Super::timeout(webSocketContextData->idleTimeoutComponents.first);
|
||||
WebSocketData *webSocketData = (WebSocketData *) Super::getAsyncSocketData();
|
||||
webSocketData->hasTimedOut = false;
|
||||
}
|
||||
|
||||
/* Return success */
|
||||
return true;
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
/* Send websocket close frame, emit close event, send FIN if successful */
|
||||
void end(int code, std::string_view message = {}) {
|
||||
/* Send websocket close frame, emit close event, send FIN if successful.
|
||||
* Will not append a close reason if code is 0 or 1005. */
|
||||
void end(int code = 0, std::string_view message = {}) {
|
||||
/* Check if we already called this one */
|
||||
WebSocketData *webSocketData = (WebSocketData *) us_socket_ext(SSL, (us_socket_t *) this);
|
||||
if (webSocketData->isShuttingDown) {
|
||||
|
@ -128,23 +153,22 @@ public:
|
|||
|
||||
/* Format and send the close frame */
|
||||
static const int MAX_CLOSE_PAYLOAD = 123;
|
||||
int length = (int) std::min<size_t>(MAX_CLOSE_PAYLOAD, message.length());
|
||||
size_t length = std::min<size_t>(MAX_CLOSE_PAYLOAD, message.length());
|
||||
char closePayload[MAX_CLOSE_PAYLOAD + 2];
|
||||
int closePayloadLength = (int) protocol::formatClosePayload(closePayload, (uint16_t) code, message.data(), length);
|
||||
size_t closePayloadLength = protocol::formatClosePayload(closePayload, (uint16_t) code, message.data(), length);
|
||||
bool ok = send(std::string_view(closePayload, closePayloadLength), OpCode::CLOSE);
|
||||
|
||||
/* FIN if we are ok and not corked */
|
||||
WebSocket<SSL, true> *webSocket = (WebSocket<SSL, true> *) this;
|
||||
if (!webSocket->isCorked()) {
|
||||
if (!this->isCorked()) {
|
||||
if (ok) {
|
||||
/* If we are not corked, and we just sent off everything, we need to FIN right here.
|
||||
* In all other cases, we need to fin either if uncork was successful, or when drainage is complete. */
|
||||
webSocket->shutdown();
|
||||
this->shutdown();
|
||||
}
|
||||
}
|
||||
|
||||
/* Emit close event */
|
||||
WebSocketContextData<SSL> *webSocketContextData = (WebSocketContextData<SSL> *) us_socket_context_ext(SSL,
|
||||
WebSocketContextData<SSL, USERDATA> *webSocketContextData = (WebSocketContextData<SSL, USERDATA> *) us_socket_context_ext(SSL,
|
||||
(us_socket_context_t *) us_socket_context(SSL, (us_socket_t *) this)
|
||||
);
|
||||
if (webSocketContextData->closeHandler) {
|
||||
|
@ -152,13 +176,13 @@ public:
|
|||
}
|
||||
|
||||
/* Make sure to unsubscribe from any pub/sub node at exit */
|
||||
webSocketContextData->topicTree.unsubscribeAll(webSocketData->subscriber);
|
||||
webSocketContextData->topicTree.unsubscribeAll(webSocketData->subscriber, false);
|
||||
delete webSocketData->subscriber;
|
||||
webSocketData->subscriber = nullptr;
|
||||
}
|
||||
|
||||
/* Corks the response if possible. Leaves already corked socket be. */
|
||||
void cork(fu2::unique_function<void()> &&handler) {
|
||||
void cork(MoveOnlyFunction<void()> &&handler) {
|
||||
if (!Super::isCorked() && Super::canCork()) {
|
||||
Super::cork();
|
||||
handler();
|
||||
|
@ -172,9 +196,9 @@ public:
|
|||
}
|
||||
}
|
||||
|
||||
/* Subscribe to a topic according to MQTT rules and syntax */
|
||||
void subscribe(std::string_view topic) {
|
||||
WebSocketContextData<SSL> *webSocketContextData = (WebSocketContextData<SSL> *) us_socket_context_ext(SSL,
|
||||
/* Subscribe to a topic according to MQTT rules and syntax. Returns success */
|
||||
/*std::pair<unsigned int, bool>*/ bool subscribe(std::string_view topic, bool nonStrict = false) {
|
||||
WebSocketContextData<SSL, USERDATA> *webSocketContextData = (WebSocketContextData<SSL, USERDATA> *) us_socket_context_ext(SSL,
|
||||
(us_socket_context_t *) us_socket_context(SSL, (us_socket_t *) this)
|
||||
);
|
||||
|
||||
|
@ -184,38 +208,87 @@ public:
|
|||
webSocketData->subscriber = new Subscriber(this);
|
||||
}
|
||||
|
||||
webSocketContextData->topicTree.subscribe(topic, webSocketData->subscriber);
|
||||
/* Cannot return numSubscribers as this is only for this particular websocket context */
|
||||
return webSocketContextData->topicTree.subscribe(topic, webSocketData->subscriber, nonStrict).second;
|
||||
}
|
||||
|
||||
/* Unsubscribe from a topic, returns true if we were subscribed */
|
||||
bool unsubscribe(std::string_view topic) {
|
||||
WebSocketContextData<SSL> *webSocketContextData = (WebSocketContextData<SSL> *) us_socket_context_ext(SSL,
|
||||
/* Unsubscribe from a topic, returns true if we were subscribed. */
|
||||
/*std::pair<unsigned int, bool>*/ bool unsubscribe(std::string_view topic, bool nonStrict = false) {
|
||||
WebSocketContextData<SSL, USERDATA> *webSocketContextData = (WebSocketContextData<SSL, USERDATA> *) us_socket_context_ext(SSL,
|
||||
(us_socket_context_t *) us_socket_context(SSL, (us_socket_t *) this)
|
||||
);
|
||||
|
||||
WebSocketData *webSocketData = (WebSocketData *) us_socket_ext(SSL, (us_socket_t *) this);
|
||||
|
||||
return webSocketContextData->topicTree.unsubscribe(topic, webSocketData->subscriber);
|
||||
/* Cannot return numSubscribers as this is only for this particular websocket context */
|
||||
return webSocketContextData->topicTree.unsubscribe(topic, webSocketData->subscriber, nonStrict).second;
|
||||
}
|
||||
|
||||
/* Unsubscribe from all topics you might be subscribed to */
|
||||
void unsubscribeAll() {
|
||||
WebSocketContextData<SSL> *webSocketContextData = (WebSocketContextData<SSL> *) us_socket_context_ext(SSL,
|
||||
/* Returns whether this socket is subscribed to the specified topic */
|
||||
bool isSubscribed(std::string_view topic) {
|
||||
WebSocketContextData<SSL, USERDATA> *webSocketContextData = (WebSocketContextData<SSL, USERDATA> *) us_socket_context_ext(SSL,
|
||||
(us_socket_context_t *) us_socket_context(SSL, (us_socket_t *) this)
|
||||
);
|
||||
|
||||
WebSocketData *webSocketData = (WebSocketData *) us_socket_ext(SSL, (us_socket_t *) this);
|
||||
if (!webSocketData->subscriber) {
|
||||
return false;
|
||||
}
|
||||
|
||||
webSocketContextData->topicTree.unsubscribeAll(webSocketData->subscriber);
|
||||
Topic *t = webSocketContextData->topicTree.lookupTopic(topic);
|
||||
if (t) {
|
||||
return t->subs.find(webSocketData->subscriber) != t->subs.end();
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
/* Publish a message to a topic according to MQTT rules and syntax */
|
||||
void publish(std::string_view topic, std::string_view message, OpCode opCode = OpCode::TEXT, bool compress = false) {
|
||||
WebSocketContextData<SSL> *webSocketContextData = (WebSocketContextData<SSL> *) us_socket_context_ext(SSL,
|
||||
/* Iterates all topics of this WebSocket. Every topic is represented by its full name.
|
||||
* Can be called in close handler. It is possible to modify the subscription list while
|
||||
* inside the callback ONLY IF not modifying the topic passed to the callback.
|
||||
* Topic names are valid only for the duration of the callback. */
|
||||
void iterateTopics(MoveOnlyFunction<void(std::string_view/*, unsigned int*/)> cb) {
|
||||
WebSocketData *webSocketData = (WebSocketData *) us_socket_ext(SSL, (us_socket_t *) this);
|
||||
|
||||
if (webSocketData->subscriber) {
|
||||
for (Topic *t : webSocketData->subscriber->subscriptions) {
|
||||
/* Lock this topic so that nobody may unsubscribe from it during this callback */
|
||||
t->locked = true;
|
||||
|
||||
cb(t->fullName/*, (unsigned int) t->subs.size()*/);
|
||||
|
||||
t->locked = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Publish a message to a topic according to MQTT rules and syntax. Returns success.
|
||||
* We, the WebSocket, must be subscribed to the topic itself and if so - no message will be sent to ourselves.
|
||||
* Use App::publish for an unconditional publish that simply publishes to whomever might be subscribed. */
|
||||
bool publish(std::string_view topic, std::string_view message, OpCode opCode = OpCode::TEXT, bool compress = false) {
|
||||
WebSocketContextData<SSL, USERDATA> *webSocketContextData = (WebSocketContextData<SSL, USERDATA> *) us_socket_context_ext(SSL,
|
||||
(us_socket_context_t *) us_socket_context(SSL, (us_socket_t *) this)
|
||||
);
|
||||
/* Is the same as publishing per websocket context */
|
||||
webSocketContextData->publish(topic, message, opCode, compress);
|
||||
|
||||
/* We cannot be a subscriber of this topic if we are not a subscriber of anything */
|
||||
WebSocketData *webSocketData = (WebSocketData *) us_socket_ext(SSL, (us_socket_t *) this);
|
||||
if (!webSocketData->subscriber) {
|
||||
/* Failure, but still do return the number of subscribers */
|
||||
return false;
|
||||
}
|
||||
|
||||
/* Publish as sender, does not receive its own messages even if subscribed to relevant topics */
|
||||
bool success = webSocketContextData->publish(topic, message, opCode, compress, webSocketData->subscriber);
|
||||
|
||||
/* Loop over all websocket contexts for this App */
|
||||
if (success) {
|
||||
/* Success is really only determined by the first publish. We must be subscribed to the topic. */
|
||||
for (auto *adjacentWebSocketContextData : webSocketContextData->adjacentWebSocketContextDatas) {
|
||||
adjacentWebSocketContextData->publish(topic, message, opCode, compress);
|
||||
}
|
||||
}
|
||||
|
||||
return success;
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Authored by Alex Hultman, 2018-2019.
|
||||
* Authored by Alex Hultman, 2018-2020.
|
||||
* Intellectual property of third-party.
|
||||
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
@ -25,7 +25,7 @@
|
|||
|
||||
namespace uWS {
|
||||
|
||||
template <bool SSL, bool isServer>
|
||||
template <bool SSL, bool isServer, typename USERDATA>
|
||||
struct WebSocketContext {
|
||||
template <bool> friend struct TemplatedApp;
|
||||
template <bool, typename> friend struct WebSocketProtocol;
|
||||
|
@ -36,12 +36,12 @@ private:
|
|||
return (us_socket_context_t *) this;
|
||||
}
|
||||
|
||||
WebSocketContextData<SSL> *getExt() {
|
||||
return (WebSocketContextData<SSL> *) us_socket_context_ext(SSL, (us_socket_context_t *) this);
|
||||
WebSocketContextData<SSL, USERDATA> *getExt() {
|
||||
return (WebSocketContextData<SSL, USERDATA> *) us_socket_context_ext(SSL, (us_socket_context_t *) this);
|
||||
}
|
||||
|
||||
/* If we have negotiated compression, set this frame compressed */
|
||||
static bool setCompressed(uWS::WebSocketState<isServer> *wState, void *s) {
|
||||
static bool setCompressed(WebSocketState<isServer> */*wState*/, void *s) {
|
||||
WebSocketData *webSocketData = (WebSocketData *) us_socket_ext(SSL, (us_socket_t *) s);
|
||||
|
||||
if (webSocketData->compressionStatus == WebSocketData::CompressionStatus::ENABLED) {
|
||||
|
@ -52,14 +52,14 @@ private:
|
|||
}
|
||||
}
|
||||
|
||||
static void forceClose(uWS::WebSocketState<isServer> *wState, void *s) {
|
||||
us_socket_close(SSL, (us_socket_t *) s);
|
||||
static void forceClose(WebSocketState<isServer> */*wState*/, void *s, std::string_view reason = {}) {
|
||||
us_socket_close(SSL, (us_socket_t *) s, (int) reason.length(), (void *) reason.data());
|
||||
}
|
||||
|
||||
/* Returns true on breakage */
|
||||
static bool handleFragment(char *data, size_t length, unsigned int remainingBytes, int opCode, bool fin, uWS::WebSocketState<isServer> *webSocketState, void *s) {
|
||||
static bool handleFragment(char *data, size_t length, unsigned int remainingBytes, int opCode, bool fin, WebSocketState<isServer> *webSocketState, void *s) {
|
||||
/* WebSocketData and WebSocketContextData */
|
||||
WebSocketContextData<SSL> *webSocketContextData = (WebSocketContextData<SSL> *) us_socket_context_ext(SSL, us_socket_context(SSL, (us_socket_t *) s));
|
||||
WebSocketContextData<SSL, USERDATA> *webSocketContextData = (WebSocketContextData<SSL, USERDATA> *) us_socket_context_ext(SSL, us_socket_context(SSL, (us_socket_t *) s));
|
||||
WebSocketData *webSocketData = (WebSocketData *) us_socket_ext(SSL, (us_socket_t *) s);
|
||||
|
||||
/* Is this a non-control frame? */
|
||||
|
@ -72,25 +72,25 @@ private:
|
|||
webSocketData->compressionStatus = WebSocketData::CompressionStatus::ENABLED;
|
||||
|
||||
LoopData *loopData = (LoopData *) us_loop_ext(us_socket_context_loop(SSL, us_socket_context(SSL, (us_socket_t *) s)));
|
||||
std::string_view inflatedFrame = loopData->inflationStream->inflate(loopData->zlibContext, {data, length}, webSocketContextData->maxPayloadLength);
|
||||
if (!inflatedFrame.length()) {
|
||||
forceClose(webSocketState, s);
|
||||
auto inflatedFrame = loopData->inflationStream->inflate(loopData->zlibContext, {data, length}, webSocketContextData->maxPayloadLength);
|
||||
if (!inflatedFrame.has_value()) {
|
||||
forceClose(webSocketState, s, ERR_TOO_BIG_MESSAGE_INFLATION);
|
||||
return true;
|
||||
} else {
|
||||
data = (char *) inflatedFrame.data();
|
||||
length = inflatedFrame.length();
|
||||
data = (char *) inflatedFrame->data();
|
||||
length = inflatedFrame->length();
|
||||
}
|
||||
}
|
||||
|
||||
/* Check text messages for Utf-8 validity */
|
||||
if (opCode == 1 && !protocol::isValidUtf8((unsigned char *) data, length)) {
|
||||
forceClose(webSocketState, s);
|
||||
forceClose(webSocketState, s, ERR_INVALID_TEXT);
|
||||
return true;
|
||||
}
|
||||
|
||||
/* Emit message event & break if we are closed or shut down when returning */
|
||||
if (webSocketContextData->messageHandler) {
|
||||
webSocketContextData->messageHandler((WebSocket<SSL, isServer> *) s, std::string_view(data, length), (uWS::OpCode) opCode);
|
||||
webSocketContextData->messageHandler((WebSocket<SSL, isServer, USERDATA> *) s, std::string_view(data, length), (OpCode) opCode);
|
||||
if (us_socket_is_closed(SSL, (us_socket_t *) s) || webSocketData->isShuttingDown) {
|
||||
return true;
|
||||
}
|
||||
|
@ -102,7 +102,7 @@ private:
|
|||
}
|
||||
/* Fragments forming a big message are not caught until appending them */
|
||||
if (refusePayloadLength(length + webSocketData->fragmentBuffer.length(), webSocketState, s)) {
|
||||
forceClose(webSocketState, s);
|
||||
forceClose(webSocketState, s, ERR_TOO_BIG_MESSAGE);
|
||||
return true;
|
||||
}
|
||||
webSocketData->fragmentBuffer.append(data, length);
|
||||
|
@ -115,8 +115,8 @@ private:
|
|||
if (webSocketData->compressionStatus == WebSocketData::CompressionStatus::COMPRESSED_FRAME) {
|
||||
webSocketData->compressionStatus = WebSocketData::CompressionStatus::ENABLED;
|
||||
|
||||
// what's really the story here?
|
||||
webSocketData->fragmentBuffer.append("....");
|
||||
/* 9 bytes of padding for libdeflate */
|
||||
webSocketData->fragmentBuffer.append("123456789");
|
||||
|
||||
LoopData *loopData = (LoopData *) us_loop_ext(
|
||||
us_socket_context_loop(SSL,
|
||||
|
@ -124,13 +124,13 @@ private:
|
|||
)
|
||||
);
|
||||
|
||||
std::string_view inflatedFrame = loopData->inflationStream->inflate(loopData->zlibContext, {webSocketData->fragmentBuffer.data(), webSocketData->fragmentBuffer.length() - 4}, webSocketContextData->maxPayloadLength);
|
||||
if (!inflatedFrame.length()) {
|
||||
forceClose(webSocketState, s);
|
||||
auto inflatedFrame = loopData->inflationStream->inflate(loopData->zlibContext, {webSocketData->fragmentBuffer.data(), webSocketData->fragmentBuffer.length() - 9}, webSocketContextData->maxPayloadLength);
|
||||
if (!inflatedFrame.has_value()) {
|
||||
forceClose(webSocketState, s, ERR_TOO_BIG_MESSAGE_INFLATION);
|
||||
return true;
|
||||
} else {
|
||||
data = (char *) inflatedFrame.data();
|
||||
length = inflatedFrame.length();
|
||||
data = (char *) inflatedFrame->data();
|
||||
length = inflatedFrame->length();
|
||||
}
|
||||
|
||||
|
||||
|
@ -142,13 +142,13 @@ private:
|
|||
|
||||
/* Check text messages for Utf-8 validity */
|
||||
if (opCode == 1 && !protocol::isValidUtf8((unsigned char *) data, length)) {
|
||||
forceClose(webSocketState, s);
|
||||
forceClose(webSocketState, s, ERR_INVALID_TEXT);
|
||||
return true;
|
||||
}
|
||||
|
||||
/* Emit message and check for shutdown or close */
|
||||
if (webSocketContextData->messageHandler) {
|
||||
webSocketContextData->messageHandler((WebSocket<SSL, isServer> *) s, std::string_view(data, length), (uWS::OpCode) opCode);
|
||||
webSocketContextData->messageHandler((WebSocket<SSL, isServer, USERDATA> *) s, std::string_view(data, length), (OpCode) opCode);
|
||||
if (us_socket_is_closed(SSL, (us_socket_t *) s) || webSocketData->isShuttingDown) {
|
||||
return true;
|
||||
}
|
||||
|
@ -160,7 +160,7 @@ private:
|
|||
}
|
||||
} else {
|
||||
/* Control frames need the websocket to send pings, pongs and close */
|
||||
WebSocket<SSL, isServer> *webSocket = (WebSocket<SSL, isServer> *) s;
|
||||
WebSocket<SSL, isServer, USERDATA> *webSocket = (WebSocket<SSL, isServer, USERDATA> *) s;
|
||||
|
||||
if (!remainingBytes && fin && !webSocketData->controlTipLength) {
|
||||
if (opCode == CLOSE) {
|
||||
|
@ -171,14 +171,14 @@ private:
|
|||
if (opCode == PING) {
|
||||
webSocket->send(std::string_view(data, length), (OpCode) OpCode::PONG);
|
||||
if (webSocketContextData->pingHandler) {
|
||||
webSocketContextData->pingHandler(webSocket);
|
||||
webSocketContextData->pingHandler(webSocket, {data, length});
|
||||
if (us_socket_is_closed(SSL, (us_socket_t *) s) || webSocketData->isShuttingDown) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
} else if (opCode == PONG) {
|
||||
if (webSocketContextData->pongHandler) {
|
||||
webSocketContextData->pongHandler(webSocket);
|
||||
webSocketContextData->pongHandler(webSocket, {data, length});
|
||||
if (us_socket_is_closed(SSL, (us_socket_t *) s) || webSocketData->isShuttingDown) {
|
||||
return true;
|
||||
}
|
||||
|
@ -188,7 +188,7 @@ private:
|
|||
} else {
|
||||
/* Here we never mind any size optimizations as we are in the worst possible path */
|
||||
webSocketData->fragmentBuffer.append(data, length);
|
||||
webSocketData->controlTipLength += (int) length;
|
||||
webSocketData->controlTipLength += (unsigned int) length;
|
||||
|
||||
if (!remainingBytes && fin) {
|
||||
char *controlBuffer = (char *) webSocketData->fragmentBuffer.data() + webSocketData->fragmentBuffer.length() - webSocketData->controlTipLength;
|
||||
|
@ -200,14 +200,14 @@ private:
|
|||
if (opCode == PING) {
|
||||
webSocket->send(std::string_view(controlBuffer, webSocketData->controlTipLength), (OpCode) OpCode::PONG);
|
||||
if (webSocketContextData->pingHandler) {
|
||||
webSocketContextData->pingHandler(webSocket);
|
||||
webSocketContextData->pingHandler(webSocket, std::string_view(controlBuffer, webSocketData->controlTipLength));
|
||||
if (us_socket_is_closed(SSL, (us_socket_t *) s) || webSocketData->isShuttingDown) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
} else if (opCode == PONG) {
|
||||
if (webSocketContextData->pongHandler) {
|
||||
webSocketContextData->pongHandler(webSocket);
|
||||
webSocketContextData->pongHandler(webSocket, std::string_view(controlBuffer, webSocketData->controlTipLength));
|
||||
if (us_socket_is_closed(SSL, (us_socket_t *) s) || webSocketData->isShuttingDown) {
|
||||
return true;
|
||||
}
|
||||
|
@ -224,32 +224,32 @@ private:
|
|||
return false;
|
||||
}
|
||||
|
||||
static bool refusePayloadLength(uint64_t length, uWS::WebSocketState<isServer> *wState, void *s) {
|
||||
auto *webSocketContextData = (WebSocketContextData<SSL> *) us_socket_context_ext(SSL, us_socket_context(SSL, (us_socket_t *) s));
|
||||
static bool refusePayloadLength(uint64_t length, WebSocketState<isServer> */*wState*/, void *s) {
|
||||
auto *webSocketContextData = (WebSocketContextData<SSL, USERDATA> *) us_socket_context_ext(SSL, us_socket_context(SSL, (us_socket_t *) s));
|
||||
|
||||
/* Return true for refuse, false for accept */
|
||||
return webSocketContextData->maxPayloadLength < length;
|
||||
}
|
||||
|
||||
WebSocketContext<SSL, isServer> *init() {
|
||||
WebSocketContext<SSL, isServer, USERDATA> *init() {
|
||||
/* Adopting a socket does not trigger open event.
|
||||
* We arreive as WebSocket with timeout set and
|
||||
* any backpressure from HTTP state kept. */
|
||||
|
||||
/* Handle socket disconnections */
|
||||
us_socket_context_on_close(SSL, getSocketContext(), [](auto *s) {
|
||||
us_socket_context_on_close(SSL, getSocketContext(), [](auto *s, int code, void *reason) {
|
||||
/* For whatever reason, if we already have emitted close event, do not emit it again */
|
||||
WebSocketData *webSocketData = (WebSocketData *) (us_socket_ext(SSL, s));
|
||||
if (!webSocketData->isShuttingDown) {
|
||||
/* Emit close event */
|
||||
auto *webSocketContextData = (WebSocketContextData<SSL> *) us_socket_context_ext(SSL, us_socket_context(SSL, (us_socket_t *) s));
|
||||
auto *webSocketContextData = (WebSocketContextData<SSL, USERDATA> *) us_socket_context_ext(SSL, us_socket_context(SSL, (us_socket_t *) s));
|
||||
|
||||
if (webSocketContextData->closeHandler) {
|
||||
webSocketContextData->closeHandler((WebSocket<SSL, true> *) s, 1006, {});
|
||||
webSocketContextData->closeHandler((WebSocket<SSL, isServer, USERDATA> *) s, 1006, {(char *) reason, (size_t) code});
|
||||
}
|
||||
|
||||
/* Make sure to unsubscribe from any pub/sub node at exit */
|
||||
webSocketContextData->topicTree.unsubscribeAll(webSocketData->subscriber);
|
||||
webSocketContextData->topicTree.unsubscribeAll(webSocketData->subscriber, false);
|
||||
delete webSocketData->subscriber;
|
||||
webSocketData->subscriber = nullptr;
|
||||
}
|
||||
|
@ -272,17 +272,18 @@ private:
|
|||
return s;
|
||||
}
|
||||
|
||||
auto *webSocketContextData = (WebSocketContextData<SSL> *) us_socket_context_ext(SSL, us_socket_context(SSL, (us_socket_t *) s));
|
||||
auto *webSocketContextData = (WebSocketContextData<SSL, USERDATA> *) us_socket_context_ext(SSL, us_socket_context(SSL, (us_socket_t *) s));
|
||||
auto *asyncSocket = (AsyncSocket<SSL> *) s;
|
||||
|
||||
/* Every time we get data and not in shutdown state we simply reset the timeout */
|
||||
asyncSocket->timeout(webSocketContextData->idleTimeout);
|
||||
asyncSocket->timeout(webSocketContextData->idleTimeoutComponents.first);
|
||||
webSocketData->hasTimedOut = false;
|
||||
|
||||
/* We always cork on data */
|
||||
asyncSocket->cork();
|
||||
|
||||
/* This parser has virtually no overhead */
|
||||
uWS::WebSocketProtocol<isServer, WebSocketContext<SSL, isServer>>::consume(data, length, (WebSocketState<isServer> *) webSocketData, s);
|
||||
WebSocketProtocol<isServer, WebSocketContext<SSL, isServer, USERDATA>>::consume(data, (unsigned int) length, (WebSocketState<isServer> *) webSocketData, s);
|
||||
|
||||
/* Uncorking a closed socekt is fine, in fact it is needed */
|
||||
asyncSocket->uncork();
|
||||
|
@ -302,6 +303,10 @@ private:
|
|||
/* Handle HTTP write out (note: SSL_read may trigger this spuriously, the app need to handle spurious calls) */
|
||||
us_socket_context_on_writable(SSL, getSocketContext(), [](auto *s) {
|
||||
|
||||
/* NOTE: Are we called here corked? If so, the below write code is broken, since
|
||||
* we will have 0 as getBufferedAmount due to writing to cork buffer, then sending TCP FIN before
|
||||
* we actually uncorked and sent off things */
|
||||
|
||||
/* It makes sense to check for us_is_shut_down here and return if so, to avoid shutting down twice */
|
||||
if (us_socket_is_shut_down(SSL, (us_socket_t *) s)) {
|
||||
return s;
|
||||
|
@ -310,16 +315,19 @@ private:
|
|||
AsyncSocket<SSL> *asyncSocket = (AsyncSocket<SSL> *) s;
|
||||
WebSocketData *webSocketData = (WebSocketData *)(us_socket_ext(SSL, s));
|
||||
|
||||
/* We store old backpressure since it is unclear whether write drained anything */
|
||||
int backpressure = asyncSocket->getBufferedAmount();
|
||||
/* We store old backpressure since it is unclear whether write drained anything,
|
||||
* however, in case of coming here with 0 backpressure we still need to emit drain event */
|
||||
unsigned int backpressure = asyncSocket->getBufferedAmount();
|
||||
|
||||
/* Drain as much as possible */
|
||||
asyncSocket->write(nullptr, 0);
|
||||
|
||||
/* Behavior: if we actively drain backpressure, always reset timeout (even if we are in shutdown) */
|
||||
if (backpressure < asyncSocket->getBufferedAmount()) {
|
||||
auto *webSocketContextData = (WebSocketContextData<SSL> *) us_socket_context_ext(SSL, us_socket_context(SSL, (us_socket_t *) s));
|
||||
asyncSocket->timeout(webSocketContextData->idleTimeout);
|
||||
/* Also reset timeout if we came here with 0 backpressure */
|
||||
if (!backpressure || backpressure > asyncSocket->getBufferedAmount()) {
|
||||
auto *webSocketContextData = (WebSocketContextData<SSL, USERDATA> *) us_socket_context_ext(SSL, us_socket_context(SSL, (us_socket_t *) s));
|
||||
asyncSocket->timeout(webSocketContextData->idleTimeoutComponents.first);
|
||||
webSocketData->hasTimedOut = false;
|
||||
}
|
||||
|
||||
/* Are we in (WebSocket) shutdown mode? */
|
||||
|
@ -329,11 +337,11 @@ private:
|
|||
/* Now perform the actual TCP/TLS shutdown which was postponed due to backpressure */
|
||||
asyncSocket->shutdown();
|
||||
}
|
||||
} else if (backpressure > asyncSocket->getBufferedAmount()) {
|
||||
/* Only call drain if we actually drained backpressure */
|
||||
auto *webSocketContextData = (WebSocketContextData<SSL> *) us_socket_context_ext(SSL, us_socket_context(SSL, (us_socket_t *) s));
|
||||
} else if (!backpressure || backpressure > asyncSocket->getBufferedAmount()) {
|
||||
/* Only call drain if we actually drained backpressure or if we came here with 0 backpressure */
|
||||
auto *webSocketContextData = (WebSocketContextData<SSL, USERDATA> *) us_socket_context_ext(SSL, us_socket_context(SSL, (us_socket_t *) s));
|
||||
if (webSocketContextData->drainHandler) {
|
||||
webSocketContextData->drainHandler((WebSocket<SSL, isServer> *) s);
|
||||
webSocketContextData->drainHandler((WebSocket<SSL, isServer, USERDATA> *) s);
|
||||
}
|
||||
/* No need to check for closed here as we leave the handler immediately*/
|
||||
}
|
||||
|
@ -345,7 +353,7 @@ private:
|
|||
us_socket_context_on_end(SSL, getSocketContext(), [](auto *s) {
|
||||
|
||||
/* If we get a fin, we just close I guess */
|
||||
us_socket_close(SSL, (us_socket_t *) s);
|
||||
us_socket_close(SSL, (us_socket_t *) s, 0, nullptr);
|
||||
|
||||
return s;
|
||||
});
|
||||
|
@ -353,8 +361,20 @@ private:
|
|||
/* Handle socket timeouts, simply close them so to not confuse client with FIN */
|
||||
us_socket_context_on_timeout(SSL, getSocketContext(), [](auto *s) {
|
||||
|
||||
auto *webSocketData = (WebSocketData *)(us_socket_ext(SSL, s));
|
||||
auto *webSocketContextData = (WebSocketContextData<SSL, USERDATA> *) us_socket_context_ext(SSL, us_socket_context(SSL, (us_socket_t *) s));
|
||||
|
||||
if (webSocketContextData->sendPingsAutomatically && !webSocketData->hasTimedOut) {
|
||||
webSocketData->hasTimedOut = true;
|
||||
us_socket_timeout(SSL, s, webSocketContextData->idleTimeoutComponents.second);
|
||||
/* Send ping without being corked */
|
||||
((AsyncSocket<SSL> *) s)->write("\x89\x00", 2);
|
||||
return s;
|
||||
}
|
||||
|
||||
/* Timeout is very simple; we just close it */
|
||||
us_socket_close(SSL, (us_socket_t *) s);
|
||||
/* Warning: we happen to know forceClose will not use first parameter so pass nullptr here */
|
||||
forceClose(nullptr, s, ERR_WEBSOCKET_TIMEOUT);
|
||||
|
||||
return s;
|
||||
});
|
||||
|
@ -363,7 +383,7 @@ private:
|
|||
}
|
||||
|
||||
void free() {
|
||||
WebSocketContextData<SSL> *webSocketContextData = (WebSocketContextData<SSL> *) us_socket_context_ext(SSL, (us_socket_context_t *) this);
|
||||
WebSocketContextData<SSL, USERDATA> *webSocketContextData = (WebSocketContextData<SSL, USERDATA> *) us_socket_context_ext(SSL, (us_socket_context_t *) this);
|
||||
webSocketContextData->~WebSocketContextData();
|
||||
|
||||
us_socket_context_free(SSL, (us_socket_context_t *) this);
|
||||
|
@ -371,14 +391,14 @@ private:
|
|||
|
||||
public:
|
||||
/* WebSocket contexts are always child contexts to a HTTP context so no SSL options are needed as they are inherited */
|
||||
static WebSocketContext *create(Loop *loop, us_socket_context_t *parentSocketContext) {
|
||||
WebSocketContext *webSocketContext = (WebSocketContext *) us_create_child_socket_context(SSL, parentSocketContext, sizeof(WebSocketContextData<SSL>));
|
||||
static WebSocketContext *create(Loop */*loop*/, us_socket_context_t *parentSocketContext) {
|
||||
WebSocketContext *webSocketContext = (WebSocketContext *) us_create_child_socket_context(SSL, parentSocketContext, sizeof(WebSocketContextData<SSL, USERDATA>));
|
||||
if (!webSocketContext) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
/* Init socket context data */
|
||||
new ((WebSocketContextData<SSL> *) us_socket_context_ext(SSL, (us_socket_context_t *)webSocketContext)) WebSocketContextData<SSL>;
|
||||
new ((WebSocketContextData<SSL, USERDATA> *) us_socket_context_ext(SSL, (us_socket_context_t *)webSocketContext)) WebSocketContextData<SSL, USERDATA>;
|
||||
return webSocketContext->init();
|
||||
}
|
||||
};
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Authored by Alex Hultman, 2018-2019.
|
||||
* Authored by Alex Hultman, 2018-2020.
|
||||
* Intellectual property of third-party.
|
||||
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
@ -18,83 +18,279 @@
|
|||
#ifndef UWS_WEBSOCKETCONTEXTDATA_H
|
||||
#define UWS_WEBSOCKETCONTEXTDATA_H
|
||||
|
||||
#include "f2/function2.hpp"
|
||||
#include "MoveOnlyFunction.h"
|
||||
#include <string_view>
|
||||
#include <vector>
|
||||
|
||||
#include "WebSocketProtocol.h"
|
||||
#include "TopicTree.h"
|
||||
#include "WebSocketData.h"
|
||||
|
||||
namespace uWS {
|
||||
|
||||
template <bool, bool> struct WebSocket;
|
||||
template <bool, bool, typename> struct WebSocket;
|
||||
|
||||
/* todo: this looks identical to WebSocketBehavior, why not just std::move that entire thing in? */
|
||||
|
||||
template <bool SSL>
|
||||
template <bool SSL, typename USERDATA>
|
||||
struct WebSocketContextData {
|
||||
private:
|
||||
/* Used for prepending unframed messages when using dedicated compressors */
|
||||
struct MessageMetadata {
|
||||
unsigned int length;
|
||||
OpCode opCode;
|
||||
bool compress;
|
||||
/* Undefined init of all members */
|
||||
MessageMetadata() {}
|
||||
MessageMetadata(unsigned int length, OpCode opCode, bool compress)
|
||||
: length(length), opCode(opCode), compress(compress) {}
|
||||
};
|
||||
|
||||
public:
|
||||
/* All WebSocketContextData holds a list to all other WebSocketContextData in this app.
|
||||
* We cannot type it USERDATA since different WebSocketContextData can have different USERDATA. */
|
||||
std::vector<WebSocketContextData<SSL, int> *> adjacentWebSocketContextDatas;
|
||||
|
||||
/* The callbacks for this context */
|
||||
fu2::unique_function<void(WebSocket<SSL, true> *, std::string_view, uWS::OpCode)> messageHandler = nullptr;
|
||||
fu2::unique_function<void(WebSocket<SSL, true> *)> drainHandler = nullptr;
|
||||
fu2::unique_function<void(WebSocket<SSL, true> *, int, std::string_view)> closeHandler = nullptr;
|
||||
MoveOnlyFunction<void(WebSocket<SSL, true, USERDATA> *)> openHandler = nullptr;
|
||||
MoveOnlyFunction<void(WebSocket<SSL, true, USERDATA> *, std::string_view, OpCode)> messageHandler = nullptr;
|
||||
MoveOnlyFunction<void(WebSocket<SSL, true, USERDATA> *)> drainHandler = nullptr;
|
||||
MoveOnlyFunction<void(WebSocket<SSL, true, USERDATA> *, int, std::string_view)> closeHandler = nullptr;
|
||||
/* Todo: these should take message also; breaking change for v0.18 */
|
||||
fu2::unique_function<void(WebSocket<SSL, true> *)> pingHandler = nullptr;
|
||||
fu2::unique_function<void(WebSocket<SSL, true> *)> pongHandler = nullptr;
|
||||
MoveOnlyFunction<void(WebSocket<SSL, true, USERDATA> *, std::string_view)> pingHandler = nullptr;
|
||||
MoveOnlyFunction<void(WebSocket<SSL, true, USERDATA> *, std::string_view)> pongHandler = nullptr;
|
||||
|
||||
/* Settings for this context */
|
||||
size_t maxPayloadLength = 0;
|
||||
int idleTimeout = 0;
|
||||
|
||||
/* We do need these for async upgrade */
|
||||
CompressOptions compression;
|
||||
|
||||
/* There needs to be a maxBackpressure which will force close everything over that limit */
|
||||
size_t maxBackpressure = 0;
|
||||
bool closeOnBackpressureLimit;
|
||||
bool resetIdleTimeoutOnSend;
|
||||
bool sendPingsAutomatically;
|
||||
|
||||
/* These are calculated on creation */
|
||||
std::pair<unsigned short, unsigned short> idleTimeoutComponents;
|
||||
|
||||
/* Each websocket context has a topic tree for pub/sub */
|
||||
TopicTree topicTree;
|
||||
|
||||
/* This is run once on start-up */
|
||||
void calculateIdleTimeoutCompnents(unsigned short idleTimeout) {
|
||||
unsigned short margin = 4;
|
||||
/* 4, 8 or 16 seconds margin based on idleTimeout */
|
||||
while ((int) idleTimeout - margin * 2 >= margin * 2 && margin < 16) {
|
||||
margin = (unsigned short) (margin << 1);
|
||||
}
|
||||
/* We should have no margin if not using sendPingsAutomatically */
|
||||
if (!sendPingsAutomatically) {
|
||||
margin = 0;
|
||||
}
|
||||
idleTimeoutComponents = {
|
||||
idleTimeout - margin,
|
||||
margin
|
||||
};
|
||||
}
|
||||
|
||||
~WebSocketContextData() {
|
||||
/* We must unregister any loop post handler here */
|
||||
Loop::get()->removePostHandler(this);
|
||||
Loop::get()->removePreHandler(this);
|
||||
}
|
||||
|
||||
WebSocketContextData() : topicTree([this](Subscriber *s, std::string_view data) -> int {
|
||||
WebSocketContextData() : topicTree([this](Subscriber *s, Intersection &intersection) -> int {
|
||||
|
||||
/* We could potentially be called here even if we have nothing to send, since we can
|
||||
* be the sender of every single message in this intersection. Also "fin" of a segment is not
|
||||
* guaranteed to be set, in case remaining segments are all from us.
|
||||
* Essentially, we cannot make strict assumptions here. Also, we can even come here corked,
|
||||
* since publish can call drain! */
|
||||
|
||||
/* We rely on writing to regular asyncSockets */
|
||||
auto *asyncSocket = (AsyncSocket<SSL> *) s->user;
|
||||
|
||||
auto [written, failed] = asyncSocket->write(data.data(), (int) data.length());
|
||||
if (!failed) {
|
||||
asyncSocket->timeout(this->idleTimeout);
|
||||
} else {
|
||||
/* Note: this assumes we are not corked, as corking will swallow things and fail later on */
|
||||
/* If we are corked, do not uncork - otherwise if we cork in here, uncork before leaving */
|
||||
bool wasCorked = asyncSocket->isCorked();
|
||||
|
||||
/* Check if we now have too much backpressure (todo: don't buffer up before check) */
|
||||
if ((unsigned int) asyncSocket->getBufferedAmount() > maxBackpressure) {
|
||||
asyncSocket->close();
|
||||
/* Do we even have room for potential data? */
|
||||
if (!maxBackpressure || asyncSocket->getBufferedAmount() < maxBackpressure) {
|
||||
|
||||
/* Roll over all our segments */
|
||||
intersection.forSubscriber(topicTree.getSenderFor(s), [asyncSocket, this](std::pair<std::string_view, std::string_view> data, bool fin) {
|
||||
|
||||
/* We have a segment that is not marked as last ("fin").
|
||||
* Cork if not already so (purely for performance reasons). Does not touch "wasCorked". */
|
||||
if (!fin && !asyncSocket->isCorked() && asyncSocket->canCork()) {
|
||||
asyncSocket->cork();
|
||||
}
|
||||
|
||||
/* Pick uncompressed data track */
|
||||
std::string_view selectedData = data.first;
|
||||
|
||||
/* Are we using compression? Fine, pick the compressed data track */
|
||||
WebSocketData *webSocketData = (WebSocketData *) asyncSocket->getAsyncSocketData();
|
||||
if (webSocketData->compressionStatus != WebSocketData::CompressionStatus::DISABLED) {
|
||||
|
||||
/* This is used for both shared and dedicated paths */
|
||||
selectedData = data.second;
|
||||
|
||||
/* However, dedicated compression has its own path */
|
||||
if (compression != SHARED_COMPRESSOR) {
|
||||
|
||||
WebSocket<SSL, true, int> *ws = (WebSocket<SSL, true, int> *) asyncSocket;
|
||||
|
||||
/* For performance reasons we always cork when in dedicated mode.
|
||||
* Is this really the best? We already kind of cork things in Zlib?
|
||||
* Right, formatting needs a cork buffer, right. Never mind. */
|
||||
if (!ws->isCorked() && ws->canCork()) {
|
||||
asyncSocket->cork();
|
||||
}
|
||||
|
||||
while (selectedData.length()) {
|
||||
/* Interpret the data like so, because this is how we shoved it in */
|
||||
MessageMetadata mm;
|
||||
memcpy((char *) &mm, selectedData.data(), sizeof(MessageMetadata));
|
||||
std::string_view unframedMessage(selectedData.data() + sizeof(MessageMetadata), mm.length);
|
||||
|
||||
/* Skip this message if our backpressure is too high */
|
||||
if (maxBackpressure && ws->getBufferedAmount() > maxBackpressure) {
|
||||
break;
|
||||
}
|
||||
|
||||
/* Here we perform the actual compression and framing */
|
||||
ws->send(unframedMessage, mm.opCode, mm.compress);
|
||||
|
||||
/* Advance until empty */
|
||||
selectedData.remove_prefix(sizeof(MessageMetadata) + mm.length);
|
||||
}
|
||||
|
||||
/* Continue to next segment without executing below path */
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
/* Common path for SHARED and DISABLED. It is an invalid assumption that we always are
|
||||
* uncorked here, however the following (invalid) assumption is not critically wrong either way */
|
||||
|
||||
/* Note: this assumes we are not corked, as corking will swallow things and fail later on */
|
||||
auto [written, failed] = asyncSocket->write(selectedData.data(), (int) selectedData.length());
|
||||
/* If we want strict check for success, we can ignore this check if corked and repeat below
|
||||
* when uncorking - however this is too strict as we really care about PROGRESS rather than
|
||||
* ENTIRE SUCCESS - we need minor API changes to support correct checks */
|
||||
if (!failed) {
|
||||
if (this->resetIdleTimeoutOnSend) {
|
||||
auto *webSocketData = (WebSocketData *) us_socket_ext(SSL, (us_socket_t *) asyncSocket);
|
||||
webSocketData->hasTimedOut = false;
|
||||
asyncSocket->timeout(this->idleTimeoutComponents.first);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/* We are done sending, for whatever reasons we ended up corked while not starting with "wasCorked",
|
||||
* we here need to uncork to restore the state we were called in */
|
||||
if (!wasCorked && asyncSocket->isCorked()) {
|
||||
/* Regarding timeout for writes; */
|
||||
auto [written, failed] = asyncSocket->uncork();
|
||||
/* Again, this check should be more like DID WE PROGRESS rather than DID WE SUCCEED ENTIRELY */
|
||||
if (!failed) {
|
||||
if (this->resetIdleTimeoutOnSend) {
|
||||
auto *webSocketData = (WebSocketData *) us_socket_ext(SSL, (us_socket_t *) asyncSocket);
|
||||
webSocketData->hasTimedOut = false;
|
||||
asyncSocket->timeout(this->idleTimeoutComponents.first);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Defer a close if we now have (or already had) too much backpressure, or simply skip */
|
||||
if (maxBackpressure && closeOnBackpressureLimit && asyncSocket->getBufferedAmount() > maxBackpressure) {
|
||||
/* We must not immediately close the socket, as that could result in stack overflow,
|
||||
* iterator invalidation and other TopicTree::drain bugs. We may shutdown the reading side of the socket,
|
||||
* causing next iteration to error-close the socket from that context instead, if we want to */
|
||||
us_socket_shutdown_read(SSL, (us_socket_t *) asyncSocket);
|
||||
}
|
||||
|
||||
/* Reserved, unused */
|
||||
return 0;
|
||||
}) {
|
||||
/* We empty for both pre and post just to make sure */
|
||||
Loop::get()->addPostHandler(this, [this](Loop *loop) {
|
||||
Loop::get()->addPostHandler(this, [this](Loop */*loop*/) {
|
||||
/* Commit pub/sub batches every loop iteration */
|
||||
topicTree.drain();
|
||||
});
|
||||
|
||||
Loop::get()->addPreHandler(this, [this](Loop *loop) {
|
||||
Loop::get()->addPreHandler(this, [this](Loop */*loop*/) {
|
||||
/* Commit pub/sub batches every loop iteration */
|
||||
topicTree.drain();
|
||||
});
|
||||
}
|
||||
|
||||
/* Helper for topictree publish, common path from app and ws */
|
||||
void publish(std::string_view topic, std::string_view message, OpCode opCode, bool compress) {
|
||||
bool publish(std::string_view topic, std::string_view message, OpCode opCode, bool compress, Subscriber *sender = nullptr) {
|
||||
bool didMatch = false;
|
||||
|
||||
/* We frame the message right here and only pass raw bytes to the pub/subber */
|
||||
char *dst = (char *) malloc(protocol::messageFrameSize(message.size()));
|
||||
size_t dst_length = protocol::formatMessage<true>(dst, message.data(), message.length(), opCode, message.length(), false);
|
||||
|
||||
topicTree.publish(topic, std::string_view(dst, dst_length));
|
||||
/* If compression is disabled */
|
||||
if (compression == DISABLED) {
|
||||
/* Leave second field empty as nobody will ever read it */
|
||||
didMatch |= topicTree.publish(topic, {std::string_view(dst, dst_length), {}}, sender);
|
||||
} else {
|
||||
/* DEDICATED_COMPRESSOR always takes the same path as must always have MessageMetadata as head */
|
||||
if (compress || compression != SHARED_COMPRESSOR) {
|
||||
/* Shared compression mode publishes compressed, framed data */
|
||||
if (compression == SHARED_COMPRESSOR) {
|
||||
/* Loop data holds shared compressor */
|
||||
LoopData *loopData = (LoopData *) us_loop_ext((us_loop_t *) Loop::get());
|
||||
|
||||
/* Compress it */
|
||||
std::string_view compressedMessage = loopData->deflationStream->deflate(loopData->zlibContext, message, true);
|
||||
|
||||
/* Frame it */
|
||||
char *dst_compressed = (char *) malloc(protocol::messageFrameSize(compressedMessage.size()));
|
||||
size_t dst_compressed_length = protocol::formatMessage<true>(dst_compressed, compressedMessage.data(), compressedMessage.length(), opCode, compressedMessage.length(), true);
|
||||
|
||||
/* Always publish the shortest one in any case */
|
||||
didMatch |= topicTree.publish(topic, {std::string_view(dst, dst_length), dst_compressed_length >= dst_length ? std::string_view(dst, dst_length) : std::string_view(dst_compressed, dst_compressed_length)}, sender);
|
||||
|
||||
/* We don't care for allocation here */
|
||||
::free(dst_compressed);
|
||||
} else {
|
||||
/* Dedicated compression mode publishes metadata + unframed uncompressed data */
|
||||
char *dst_compressed = (char *) malloc(message.length() + sizeof(MessageMetadata));
|
||||
|
||||
MessageMetadata mm(
|
||||
(unsigned int) message.length(),
|
||||
opCode,
|
||||
compress
|
||||
);
|
||||
|
||||
memcpy(dst_compressed, (char *) &mm, sizeof(MessageMetadata));
|
||||
memcpy(dst_compressed + sizeof(MessageMetadata), message.data(), message.length());
|
||||
|
||||
/* Interpretation of compressed data depends on what compressor we use */
|
||||
didMatch |= topicTree.publish(topic, {
|
||||
std::string_view(dst, dst_length),
|
||||
std::string_view(dst_compressed, message.length() + sizeof(MessageMetadata))
|
||||
}, sender);
|
||||
|
||||
::free(dst_compressed);
|
||||
}
|
||||
} else {
|
||||
/* If not compressing, put same message on both tracks (only valid for SHARED_COMPRESSOR).
|
||||
* DEDICATED_COMPRESSOR_xKB must never end up here as we don't put a proper head here. */
|
||||
didMatch |= topicTree.publish(topic, {std::string_view(dst, dst_length), std::string_view(dst, dst_length)}, sender);
|
||||
}
|
||||
}
|
||||
|
||||
::free(dst);
|
||||
|
||||
return didMatch;
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Authored by Alex Hultman, 2018-2019.
|
||||
* Authored by Alex Hultman, 2018-2020.
|
||||
* Intellectual property of third-party.
|
||||
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
@ -21,18 +21,23 @@
|
|||
#include "WebSocketProtocol.h"
|
||||
#include "AsyncSocketData.h"
|
||||
#include "PerMessageDeflate.h"
|
||||
#include "TopicTree.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
namespace uWS {
|
||||
|
||||
struct WebSocketData : AsyncSocketData<false>, WebSocketState<true> {
|
||||
template <bool, bool> friend struct WebSocketContext;
|
||||
template <bool, bool> friend struct WebSocket;
|
||||
/* This guy has a lot of friends - why? */
|
||||
template <bool, bool, typename> friend struct WebSocketContext;
|
||||
template <bool, typename> friend struct WebSocketContextData;
|
||||
template <bool, bool, typename> friend struct WebSocket;
|
||||
template <bool> friend struct HttpContext;
|
||||
private:
|
||||
std::string fragmentBuffer;
|
||||
int controlTipLength = 0;
|
||||
unsigned int controlTipLength = 0;
|
||||
bool isShuttingDown = 0;
|
||||
bool hasTimedOut = false;
|
||||
enum CompressionStatus : char {
|
||||
DISABLED,
|
||||
ENABLED,
|
||||
|
@ -45,12 +50,12 @@ private:
|
|||
/* We could be a subscriber */
|
||||
Subscriber *subscriber = nullptr;
|
||||
public:
|
||||
WebSocketData(bool perMessageDeflate, bool slidingCompression, std::string &&backpressure) : AsyncSocketData<false>(std::move(backpressure)), WebSocketState<true>() {
|
||||
WebSocketData(bool perMessageDeflate, CompressOptions compressOptions, std::string &&backpressure) : AsyncSocketData<false>(std::move(backpressure)), WebSocketState<true>() {
|
||||
compressionStatus = perMessageDeflate ? ENABLED : DISABLED;
|
||||
|
||||
/* Initialize the dedicated sliding window */
|
||||
if (perMessageDeflate && slidingCompression) {
|
||||
deflationStream = new DeflationStream;
|
||||
if (perMessageDeflate && (compressOptions != CompressOptions::SHARED_COMPRESSOR)) {
|
||||
deflationStream = new DeflationStream(compressOptions);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Authored by Alex Hultman, 2018-2019.
|
||||
* Authored by Alex Hultman, 2018-2021.
|
||||
* Intellectual property of third-party.
|
||||
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
@ -18,26 +18,35 @@
|
|||
#ifndef UWS_WEBSOCKETEXTENSIONS_H
|
||||
#define UWS_WEBSOCKETEXTENSIONS_H
|
||||
|
||||
/* There is a new, huge bug scenario that needs to be fixed:
|
||||
* pub/sub does not support being in DEDICATED_COMPRESSOR-mode while having
|
||||
* some clients downgraded to SHARED_COMPRESSOR - we cannot allow the client to
|
||||
* demand a downgrade to SHARED_COMPRESSOR (yet) until we fix that scenario in pub/sub */
|
||||
// #define UWS_ALLOW_SHARED_AND_DEDICATED_COMPRESSOR_MIX
|
||||
|
||||
/* We forbid negotiating 8 windowBits since Zlib has a bug with this */
|
||||
// #define UWS_ALLOW_8_WINDOW_BITS
|
||||
|
||||
#include <climits>
|
||||
#include <cctype>
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
#include <tuple>
|
||||
|
||||
namespace uWS {
|
||||
|
||||
enum Options : unsigned int {
|
||||
NO_OPTIONS = 0,
|
||||
PERMESSAGE_DEFLATE = 1,
|
||||
SERVER_NO_CONTEXT_TAKEOVER = 2, // remove this
|
||||
CLIENT_NO_CONTEXT_TAKEOVER = 4, // remove this
|
||||
NO_DELAY = 8,
|
||||
SLIDING_DEFLATE_WINDOW = 16
|
||||
};
|
||||
|
||||
enum ExtensionTokens {
|
||||
/* Standard permessage-deflate tokens */
|
||||
TOK_PERMESSAGE_DEFLATE = 1838,
|
||||
TOK_SERVER_NO_CONTEXT_TAKEOVER = 2807,
|
||||
TOK_CLIENT_NO_CONTEXT_TAKEOVER = 2783,
|
||||
TOK_SERVER_MAX_WINDOW_BITS = 2372,
|
||||
TOK_CLIENT_MAX_WINDOW_BITS = 2348
|
||||
TOK_CLIENT_MAX_WINDOW_BITS = 2348,
|
||||
/* Non-standard alias for Safari */
|
||||
TOK_X_WEBKIT_DEFLATE_FRAME = 2149,
|
||||
TOK_NO_CONTEXT_TAKEOVER = 2049,
|
||||
TOK_MAX_WINDOW_BITS = 1614
|
||||
|
||||
};
|
||||
|
||||
struct ExtensionsParser {
|
||||
|
@ -45,12 +54,18 @@ private:
|
|||
int *lastInteger = nullptr;
|
||||
|
||||
public:
|
||||
/* Standard */
|
||||
bool perMessageDeflate = false;
|
||||
bool serverNoContextTakeover = false;
|
||||
bool clientNoContextTakeover = false;
|
||||
int serverMaxWindowBits = 0;
|
||||
int clientMaxWindowBits = 0;
|
||||
|
||||
/* Non-standard Safari */
|
||||
bool xWebKitDeflateFrame = false;
|
||||
bool noContextTakeover = false;
|
||||
int maxWindowBits = 0;
|
||||
|
||||
int getToken(const char *&in, const char *stop) {
|
||||
while (in != stop && !isalnum(*in)) {
|
||||
in++;
|
||||
|
@ -78,12 +93,28 @@ public:
|
|||
ExtensionsParser(const char *data, size_t length) {
|
||||
const char *stop = data + length;
|
||||
int token = 1;
|
||||
for (; token && token != TOK_PERMESSAGE_DEFLATE; token = getToken(data, stop));
|
||||
|
||||
/* Ignore anything before permessage-deflate or x-webkit-deflate-frame */
|
||||
for (; token && token != TOK_PERMESSAGE_DEFLATE && token != TOK_X_WEBKIT_DEFLATE_FRAME; token = getToken(data, stop));
|
||||
|
||||
/* What protocol are we going to use? */
|
||||
perMessageDeflate = (token == TOK_PERMESSAGE_DEFLATE);
|
||||
xWebKitDeflateFrame = (token == TOK_X_WEBKIT_DEFLATE_FRAME);
|
||||
|
||||
while ((token = getToken(data, stop))) {
|
||||
switch (token) {
|
||||
case TOK_X_WEBKIT_DEFLATE_FRAME:
|
||||
/* Duplicates not allowed/supported */
|
||||
return;
|
||||
case TOK_NO_CONTEXT_TAKEOVER:
|
||||
noContextTakeover = true;
|
||||
break;
|
||||
case TOK_MAX_WINDOW_BITS:
|
||||
maxWindowBits = 1;
|
||||
lastInteger = &maxWindowBits;
|
||||
break;
|
||||
case TOK_PERMESSAGE_DEFLATE:
|
||||
/* Duplicates not allowed/supported */
|
||||
return;
|
||||
case TOK_SERVER_NO_CONTEXT_TAKEOVER:
|
||||
serverNoContextTakeover = true;
|
||||
|
@ -109,60 +140,116 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
template <bool isServer>
|
||||
struct ExtensionsNegotiator {
|
||||
protected:
|
||||
int options;
|
||||
/* Takes what we (the server) wants, returns what we got */
|
||||
static inline std::tuple<bool, int, int, std::string_view> negotiateCompression(bool wantCompression, int wantedCompressionWindow, int wantedInflationWindow, std::string_view offer) {
|
||||
|
||||
public:
|
||||
ExtensionsNegotiator(int wantedOptions) {
|
||||
options = wantedOptions;
|
||||
/* If we don't want compression then we are done here */
|
||||
if (!wantCompression) {
|
||||
return {false, 0, 0, ""};
|
||||
}
|
||||
|
||||
std::string generateOffer() {
|
||||
std::string extensionsOffer;
|
||||
if (options & Options::PERMESSAGE_DEFLATE) {
|
||||
extensionsOffer += "permessage-deflate";
|
||||
ExtensionsParser ep(offer.data(), offer.length());
|
||||
|
||||
if (options & Options::CLIENT_NO_CONTEXT_TAKEOVER) {
|
||||
extensionsOffer += "; client_no_context_takeover";
|
||||
static thread_local std::string response;
|
||||
response = "";
|
||||
|
||||
int compressionWindow = wantedCompressionWindow;
|
||||
int inflationWindow = wantedInflationWindow;
|
||||
bool compression = false;
|
||||
|
||||
if (ep.xWebKitDeflateFrame) {
|
||||
/* We now have compression */
|
||||
compression = true;
|
||||
response = "x-webkit-deflate-frame";
|
||||
|
||||
/* If the other peer has DEMANDED us no sliding window,
|
||||
* we cannot compress with anything other than shared compressor */
|
||||
if (ep.noContextTakeover) {
|
||||
/* We must fail here right now (fix pub/sub) */
|
||||
#ifndef UWS_ALLOW_SHARED_AND_DEDICATED_COMPRESSOR_MIX
|
||||
if (wantedCompressionWindow != 0) {
|
||||
return {false, 0, 0, ""};
|
||||
}
|
||||
#endif
|
||||
|
||||
/* It is questionable sending this improves anything */
|
||||
/*if (options & Options::SERVER_NO_CONTEXT_TAKEOVER) {
|
||||
extensionsOffer += "; server_no_context_takeover";
|
||||
}*/
|
||||
compressionWindow = 0;
|
||||
}
|
||||
|
||||
return extensionsOffer;
|
||||
}
|
||||
/* If the other peer has DEMANDED us to use a limited sliding window,
|
||||
* we have to limit out compression sliding window */
|
||||
if (ep.maxWindowBits && ep.maxWindowBits < compressionWindow) {
|
||||
compressionWindow = ep.maxWindowBits;
|
||||
#ifndef UWS_ALLOW_8_WINDOW_BITS
|
||||
/* We cannot really deny this, so we have to disable compression in this case */
|
||||
if (compressionWindow == 8) {
|
||||
return {false, 0, 0, ""};
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
void readOffer(std::string_view offer) {
|
||||
if (isServer) {
|
||||
ExtensionsParser extensionsParser(offer.data(), offer.length());
|
||||
if ((options & PERMESSAGE_DEFLATE) && extensionsParser.perMessageDeflate) {
|
||||
if (extensionsParser.clientNoContextTakeover || (options & CLIENT_NO_CONTEXT_TAKEOVER)) {
|
||||
options |= CLIENT_NO_CONTEXT_TAKEOVER;
|
||||
}
|
||||
|
||||
/* We leave this option for us to read even if the client did not send it */
|
||||
if (extensionsParser.serverNoContextTakeover) {
|
||||
options |= SERVER_NO_CONTEXT_TAKEOVER;
|
||||
}/* else {
|
||||
options &= ~SERVER_NO_CONTEXT_TAKEOVER;
|
||||
}*/
|
||||
/* We decide our own inflation sliding window (and their compression sliding window) */
|
||||
if (wantedInflationWindow < 15) {
|
||||
if (!wantedInflationWindow) {
|
||||
response += "; no_context_takeover";
|
||||
} else {
|
||||
options &= ~PERMESSAGE_DEFLATE;
|
||||
response += "; max_window_bits=" + std::to_string(wantedInflationWindow);
|
||||
}
|
||||
}
|
||||
} else if (ep.perMessageDeflate) {
|
||||
/* We now have compression */
|
||||
compression = true;
|
||||
response = "permessage-deflate";
|
||||
|
||||
if (ep.clientNoContextTakeover) {
|
||||
inflationWindow = 0;
|
||||
} else if (ep.clientMaxWindowBits && ep.clientMaxWindowBits != 1) {
|
||||
inflationWindow = std::min<int>(ep.clientMaxWindowBits, inflationWindow);
|
||||
}
|
||||
|
||||
/* Whatever we have now, write */
|
||||
if (inflationWindow < 15) {
|
||||
if (!inflationWindow || !ep.clientMaxWindowBits) {
|
||||
response += "; client_no_context_takeover";
|
||||
inflationWindow = 0;
|
||||
} else {
|
||||
response += "; client_max_window_bits=" + std::to_string(inflationWindow);
|
||||
}
|
||||
}
|
||||
|
||||
/* This block basically lets the client lower it */
|
||||
if (ep.serverNoContextTakeover) {
|
||||
/* This is an important (temporary) fix since we haven't allowed
|
||||
* these two modes to mix, and pub/sub will not handle this case (yet) */
|
||||
#ifdef UWS_ALLOW_SHARED_AND_DEDICATED_COMPRESSOR_MIX
|
||||
compressionWindow = 0;
|
||||
#endif
|
||||
} else if (ep.serverMaxWindowBits) {
|
||||
compressionWindow = std::min<int>(ep.serverMaxWindowBits, compressionWindow);
|
||||
#ifndef UWS_ALLOW_8_WINDOW_BITS
|
||||
/* Zlib cannot do windowBits=8, memLevel=1 so we raise it up to 9 minimum */
|
||||
if (compressionWindow == 8) {
|
||||
compressionWindow = 9;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
/* Whatever we have now, write */
|
||||
if (compressionWindow < 15) {
|
||||
if (!compressionWindow) {
|
||||
response += "; server_no_context_takeover";
|
||||
} else {
|
||||
response += "; server_max_window_bits=" + std::to_string(compressionWindow);
|
||||
}
|
||||
} else {
|
||||
// todo!
|
||||
}
|
||||
}
|
||||
|
||||
int getNegotiatedOptions() {
|
||||
return options;
|
||||
/* A final sanity check (this check does not actually catch too high values!) */
|
||||
if ((compressionWindow && compressionWindow < 8) || compressionWindow > 15 || (inflationWindow && inflationWindow < 8) || inflationWindow > 15) {
|
||||
return {false, 0, 0, ""};
|
||||
}
|
||||
};
|
||||
|
||||
return {compression, compressionWindow, inflationWindow, response};
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Authored by Alex Hultman, 2018-2019.
|
||||
* Authored by Alex Hultman, 2018-2020.
|
||||
* Intellectual property of third-party.
|
||||
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
@ -34,57 +34,68 @@ struct WebSocketHandshake {
|
|||
|
||||
template <typename T>
|
||||
struct static_for<0, T> {
|
||||
void operator()(uint32_t *a, uint32_t *hash) {}
|
||||
void operator()(uint32_t */*a*/, uint32_t */*hash*/) {}
|
||||
};
|
||||
|
||||
template <int state>
|
||||
struct Sha1Loop {
|
||||
static inline uint32_t rol(uint32_t value, size_t bits) {return (value << bits) | (value >> (32 - bits));}
|
||||
static inline uint32_t blk(uint32_t b[16], size_t i) {
|
||||
return rol(b[(i + 13) & 15] ^ b[(i + 8) & 15] ^ b[(i + 2) & 15] ^ b[i], 1);
|
||||
}
|
||||
static inline uint32_t rol(uint32_t value, size_t bits) {return (value << bits) | (value >> (32 - bits));}
|
||||
static inline uint32_t blk(uint32_t b[16], size_t i) {
|
||||
return rol(b[(i + 13) & 15] ^ b[(i + 8) & 15] ^ b[(i + 2) & 15] ^ b[i], 1);
|
||||
}
|
||||
|
||||
struct Sha1Loop1 {
|
||||
template <int i>
|
||||
static inline void f(uint32_t *a, uint32_t *b) {
|
||||
switch (state) {
|
||||
case 1:
|
||||
a[i % 5] += ((a[(3 + i) % 5] & (a[(2 + i) % 5] ^ a[(1 + i) % 5])) ^ a[(1 + i) % 5]) + b[i] + 0x5a827999 + rol(a[(4 + i) % 5], 5);
|
||||
a[(3 + i) % 5] = rol(a[(3 + i) % 5], 30);
|
||||
break;
|
||||
case 2:
|
||||
b[i] = blk(b, i);
|
||||
a[(1 + i) % 5] += ((a[(4 + i) % 5] & (a[(3 + i) % 5] ^ a[(2 + i) % 5])) ^ a[(2 + i) % 5]) + b[i] + 0x5a827999 + rol(a[(5 + i) % 5], 5);
|
||||
a[(4 + i) % 5] = rol(a[(4 + i) % 5], 30);
|
||||
break;
|
||||
case 3:
|
||||
b[(i + 4) % 16] = blk(b, (i + 4) % 16);
|
||||
a[i % 5] += (a[(3 + i) % 5] ^ a[(2 + i) % 5] ^ a[(1 + i) % 5]) + b[(i + 4) % 16] + 0x6ed9eba1 + rol(a[(4 + i) % 5], 5);
|
||||
a[(3 + i) % 5] = rol(a[(3 + i) % 5], 30);
|
||||
break;
|
||||
case 4:
|
||||
b[(i + 8) % 16] = blk(b, (i + 8) % 16);
|
||||
a[i % 5] += (((a[(3 + i) % 5] | a[(2 + i) % 5]) & a[(1 + i) % 5]) | (a[(3 + i) % 5] & a[(2 + i) % 5])) + b[(i + 8) % 16] + 0x8f1bbcdc + rol(a[(4 + i) % 5], 5);
|
||||
a[(3 + i) % 5] = rol(a[(3 + i) % 5], 30);
|
||||
break;
|
||||
case 5:
|
||||
b[(i + 12) % 16] = blk(b, (i + 12) % 16);
|
||||
a[i % 5] += (a[(3 + i) % 5] ^ a[(2 + i) % 5] ^ a[(1 + i) % 5]) + b[(i + 12) % 16] + 0xca62c1d6 + rol(a[(4 + i) % 5], 5);
|
||||
a[(3 + i) % 5] = rol(a[(3 + i) % 5], 30);
|
||||
break;
|
||||
case 6:
|
||||
b[i] += a[4 - i];
|
||||
}
|
||||
a[i % 5] += ((a[(3 + i) % 5] & (a[(2 + i) % 5] ^ a[(1 + i) % 5])) ^ a[(1 + i) % 5]) + b[i] + 0x5a827999 + rol(a[(4 + i) % 5], 5);
|
||||
a[(3 + i) % 5] = rol(a[(3 + i) % 5], 30);
|
||||
}
|
||||
};
|
||||
struct Sha1Loop2 {
|
||||
template <int i>
|
||||
static inline void f(uint32_t *a, uint32_t *b) {
|
||||
b[i] = blk(b, i);
|
||||
a[(1 + i) % 5] += ((a[(4 + i) % 5] & (a[(3 + i) % 5] ^ a[(2 + i) % 5])) ^ a[(2 + i) % 5]) + b[i] + 0x5a827999 + rol(a[(5 + i) % 5], 5);
|
||||
a[(4 + i) % 5] = rol(a[(4 + i) % 5], 30);
|
||||
}
|
||||
};
|
||||
struct Sha1Loop3 {
|
||||
template <int i>
|
||||
static inline void f(uint32_t *a, uint32_t *b) {
|
||||
b[(i + 4) % 16] = blk(b, (i + 4) % 16);
|
||||
a[i % 5] += (a[(3 + i) % 5] ^ a[(2 + i) % 5] ^ a[(1 + i) % 5]) + b[(i + 4) % 16] + 0x6ed9eba1 + rol(a[(4 + i) % 5], 5);
|
||||
a[(3 + i) % 5] = rol(a[(3 + i) % 5], 30);
|
||||
}
|
||||
};
|
||||
struct Sha1Loop4 {
|
||||
template <int i>
|
||||
static inline void f(uint32_t *a, uint32_t *b) {
|
||||
b[(i + 8) % 16] = blk(b, (i + 8) % 16);
|
||||
a[i % 5] += (((a[(3 + i) % 5] | a[(2 + i) % 5]) & a[(1 + i) % 5]) | (a[(3 + i) % 5] & a[(2 + i) % 5])) + b[(i + 8) % 16] + 0x8f1bbcdc + rol(a[(4 + i) % 5], 5);
|
||||
a[(3 + i) % 5] = rol(a[(3 + i) % 5], 30);
|
||||
}
|
||||
};
|
||||
struct Sha1Loop5 {
|
||||
template <int i>
|
||||
static inline void f(uint32_t *a, uint32_t *b) {
|
||||
b[(i + 12) % 16] = blk(b, (i + 12) % 16);
|
||||
a[i % 5] += (a[(3 + i) % 5] ^ a[(2 + i) % 5] ^ a[(1 + i) % 5]) + b[(i + 12) % 16] + 0xca62c1d6 + rol(a[(4 + i) % 5], 5);
|
||||
a[(3 + i) % 5] = rol(a[(3 + i) % 5], 30);
|
||||
}
|
||||
};
|
||||
struct Sha1Loop6 {
|
||||
template <int i>
|
||||
static inline void f(uint32_t *a, uint32_t *b) {
|
||||
b[i] += a[4 - i];
|
||||
}
|
||||
};
|
||||
|
||||
static inline void sha1(uint32_t hash[5], uint32_t b[16]) {
|
||||
uint32_t a[5] = {hash[4], hash[3], hash[2], hash[1], hash[0]};
|
||||
static_for<16, Sha1Loop<1>>()(a, b);
|
||||
static_for<4, Sha1Loop<2>>()(a, b);
|
||||
static_for<20, Sha1Loop<3>>()(a, b);
|
||||
static_for<20, Sha1Loop<4>>()(a, b);
|
||||
static_for<20, Sha1Loop<5>>()(a, b);
|
||||
static_for<5, Sha1Loop<6>>()(a, hash);
|
||||
static_for<16, Sha1Loop1>()(a, b);
|
||||
static_for<4, Sha1Loop2>()(a, b);
|
||||
static_for<20, Sha1Loop3>()(a, b);
|
||||
static_for<20, Sha1Loop4>()(a, b);
|
||||
static_for<20, Sha1Loop5>()(a, b);
|
||||
static_for<5, Sha1Loop6>()(a, hash);
|
||||
}
|
||||
|
||||
static inline void base64(unsigned char *src, char *dst) {
|
||||
|
@ -112,7 +123,7 @@ public:
|
|||
};
|
||||
|
||||
for (int i = 0; i < 6; i++) {
|
||||
b_input[i] = (input[4 * i + 3] & 0xff) | (input[4 * i + 2] & 0xff) << 8 | (input[4 * i + 1] & 0xff) << 16 | (input[4 * i + 0] & 0xff) << 24;
|
||||
b_input[i] = (uint32_t) ((input[4 * i + 3] & 0xff) | (input[4 * i + 2] & 0xff) << 8 | (input[4 * i + 1] & 0xff) << 16 | (input[4 * i + 0] & 0xff) << 24);
|
||||
}
|
||||
sha1(b_output, b_input);
|
||||
uint32_t last_b[16] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 480};
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Authored by Alex Hultman, 2018-2019.
|
||||
* Authored by Alex Hultman, 2018-2020.
|
||||
* Intellectual property of third-party.
|
||||
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
@ -21,9 +21,17 @@
|
|||
#include <cstdint>
|
||||
#include <cstring>
|
||||
#include <cstdlib>
|
||||
#include <string_view>
|
||||
|
||||
namespace uWS {
|
||||
|
||||
/* We should not overcomplicate these */
|
||||
const std::string_view ERR_TOO_BIG_MESSAGE("Received too big message");
|
||||
const std::string_view ERR_WEBSOCKET_TIMEOUT("WebSocket timed out from inactivity");
|
||||
const std::string_view ERR_INVALID_TEXT("Received invalid UTF-8");
|
||||
const std::string_view ERR_TOO_BIG_MESSAGE_INFLATION("Received too big message, or other inflation error");
|
||||
const std::string_view ERR_INVALID_CLOSE_PAYLOAD("Received invalid close payload");
|
||||
|
||||
enum OpCode : unsigned char {
|
||||
TEXT = 1,
|
||||
BINARY = 2,
|
||||
|
@ -49,7 +57,7 @@ public:
|
|||
struct State {
|
||||
unsigned int wantsHead : 1;
|
||||
unsigned int spillLength : 4;
|
||||
int opStack : 2; // -1, 0, 1
|
||||
signed int opStack : 2; // -1, 0, 1
|
||||
unsigned int lastFin : 1;
|
||||
|
||||
// 15 bytes
|
||||
|
@ -151,20 +159,23 @@ struct CloseFrame {
|
|||
};
|
||||
|
||||
static inline CloseFrame parseClosePayload(char *src, size_t length) {
|
||||
CloseFrame cf = {};
|
||||
/* If we get no code or message, default to reporting 1005 no status code present */
|
||||
CloseFrame cf = {1005, nullptr, 0};
|
||||
if (length >= 2) {
|
||||
memcpy(&cf.code, src, 2);
|
||||
cf = {cond_byte_swap<uint16_t>(cf.code), src + 2, length - 2};
|
||||
if (cf.code < 1000 || cf.code > 4999 || (cf.code > 1011 && cf.code < 4000) ||
|
||||
(cf.code >= 1004 && cf.code <= 1006) || !isValidUtf8((unsigned char *) cf.message, cf.length)) {
|
||||
return {};
|
||||
/* Even though we got a WebSocket close frame, it in itself is abnormal */
|
||||
return {1006, nullptr, 0};
|
||||
}
|
||||
}
|
||||
return cf;
|
||||
}
|
||||
|
||||
static inline size_t formatClosePayload(char *dst, uint16_t code, const char *message, size_t length) {
|
||||
if (code) {
|
||||
/* We could have more strict checks here, but never append code 0 or 1005 or 1006 */
|
||||
if (code && code != 1005 && code != 1006) {
|
||||
code = cond_byte_swap<uint16_t>(code);
|
||||
memcpy(dst, &code, 2);
|
||||
/* It is invalid to pass nullptr to memcpy, even though length is 0 */
|
||||
|
@ -219,7 +230,7 @@ static inline size_t formatMessage(char *dst, const char *src, size_t length, Op
|
|||
char mask[4];
|
||||
if (!isServer) {
|
||||
dst[1] |= 0x80;
|
||||
uint32_t random = rand();
|
||||
uint32_t random = (uint32_t) rand();
|
||||
memcpy(mask, &random, 4);
|
||||
memcpy(dst + headerLength, &random, 4);
|
||||
headerLength += 4;
|
||||
|
@ -307,7 +318,7 @@ protected:
|
|||
wState->state.lastFin = isFin(src);
|
||||
|
||||
if (Impl::refusePayloadLength(payLength, wState, user)) {
|
||||
Impl::forceClose(wState, user);
|
||||
Impl::forceClose(wState, user, ERR_TOO_BIG_MESSAGE);
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -351,9 +362,9 @@ protected:
|
|||
static inline bool consumeContinuation(char *&src, unsigned int &length, WebSocketState<isServer> *wState, void *user) {
|
||||
if (wState->remainingBytes <= length) {
|
||||
if (isServer) {
|
||||
int n = wState->remainingBytes >> 2;
|
||||
unsigned int n = wState->remainingBytes >> 2;
|
||||
unmaskInplace(src, src + n * 4, wState->mask);
|
||||
for (int i = 0, s = wState->remainingBytes % 4; i < s; i++) {
|
||||
for (unsigned int i = 0, s = wState->remainingBytes % 4; i < s; i++) {
|
||||
src[n * 4 + i] ^= wState->mask[i];
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,23 +0,0 @@
|
|||
Boost Software License - Version 1.0 - August 17th, 2003
|
||||
|
||||
Permission is hereby granted, free of charge, to any person or organization
|
||||
obtaining a copy of the software and accompanying documentation covered by
|
||||
this license (the "Software") to use, reproduce, display, distribute,
|
||||
execute, and transmit the Software, and to prepare derivative works of the
|
||||
Software, and to permit third-parties to whom the Software is furnished to
|
||||
do so, all subject to the following:
|
||||
|
||||
The copyright notices in the Software and this entire statement, including
|
||||
the above license grant, this restriction and the following disclaimer,
|
||||
must be included in all copies of the Software, in whole or in part, and
|
||||
all derivative works of the Software, unless such copies or derivative
|
||||
works are solely in the form of machine-executable object code generated by
|
||||
a source language processor.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT
|
||||
SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE
|
||||
FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE,
|
||||
ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
DEALINGS IN THE SOFTWARE.
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,308 @@
|
|||
/*
|
||||
* Authored by Alex Hultman, 2018-2021.
|
||||
* Intellectual property of third-party.
|
||||
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
/* Todo: this file should lie in networking/bsd.c */
|
||||
|
||||
#include "libusockets.h"
|
||||
#include "internal/internal.h"
|
||||
|
||||
#include <stdio.h>
|
||||
|
||||
#ifndef _WIN32
|
||||
//#define _GNU_SOURCE
|
||||
#include <sys/types.h>
|
||||
#include <sys/socket.h>
|
||||
#include <netinet/in.h>
|
||||
#include <netinet/tcp.h>
|
||||
#include <netdb.h>
|
||||
#include <string.h>
|
||||
#include <unistd.h>
|
||||
#include <fcntl.h>
|
||||
#include <errno.h>
|
||||
#endif
|
||||
|
||||
LIBUS_SOCKET_DESCRIPTOR apple_no_sigpipe(LIBUS_SOCKET_DESCRIPTOR fd) {
|
||||
#ifdef __APPLE__
|
||||
if (fd != LIBUS_SOCKET_ERROR) {
|
||||
int no_sigpipe = 1;
|
||||
setsockopt(fd, SOL_SOCKET, SO_NOSIGPIPE, &no_sigpipe, sizeof(int));
|
||||
}
|
||||
#endif
|
||||
return fd;
|
||||
}
|
||||
|
||||
LIBUS_SOCKET_DESCRIPTOR bsd_set_nonblocking(LIBUS_SOCKET_DESCRIPTOR fd) {
|
||||
#ifdef _WIN32
|
||||
/* Libuv will set windows sockets as non-blocking */
|
||||
#else
|
||||
fcntl(fd, F_SETFL, fcntl(fd, F_GETFL, 0) | O_NONBLOCK);
|
||||
#endif
|
||||
return fd;
|
||||
}
|
||||
|
||||
void bsd_socket_nodelay(LIBUS_SOCKET_DESCRIPTOR fd, int enabled) {
|
||||
setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, (void *) &enabled, sizeof(enabled));
|
||||
}
|
||||
|
||||
void bsd_socket_flush(LIBUS_SOCKET_DESCRIPTOR fd) {
|
||||
// Linux TCP_CORK has the same underlying corking mechanism as with MSG_MORE
|
||||
#ifdef TCP_CORK
|
||||
int enabled = 0;
|
||||
setsockopt(fd, IPPROTO_TCP, TCP_CORK, &enabled, sizeof(int));
|
||||
#endif
|
||||
}
|
||||
|
||||
LIBUS_SOCKET_DESCRIPTOR bsd_create_socket(int domain, int type, int protocol) {
|
||||
// returns INVALID_SOCKET on error
|
||||
int flags = 0;
|
||||
#if defined(SOCK_CLOEXEC) && defined(SOCK_NONBLOCK)
|
||||
flags = SOCK_CLOEXEC | SOCK_NONBLOCK;
|
||||
#endif
|
||||
|
||||
LIBUS_SOCKET_DESCRIPTOR created_fd = socket(domain, type | flags, protocol);
|
||||
|
||||
return bsd_set_nonblocking(apple_no_sigpipe(created_fd));
|
||||
}
|
||||
|
||||
void bsd_close_socket(LIBUS_SOCKET_DESCRIPTOR fd) {
|
||||
#ifdef _WIN32
|
||||
closesocket(fd);
|
||||
#else
|
||||
close(fd);
|
||||
#endif
|
||||
}
|
||||
|
||||
void bsd_shutdown_socket(LIBUS_SOCKET_DESCRIPTOR fd) {
|
||||
#ifdef _WIN32
|
||||
shutdown(fd, SD_SEND);
|
||||
#else
|
||||
shutdown(fd, SHUT_WR);
|
||||
#endif
|
||||
}
|
||||
|
||||
void bsd_shutdown_socket_read(LIBUS_SOCKET_DESCRIPTOR fd) {
|
||||
#ifdef _WIN32
|
||||
shutdown(fd, SD_RECEIVE);
|
||||
#else
|
||||
shutdown(fd, SHUT_RD);
|
||||
#endif
|
||||
}
|
||||
|
||||
void internal_finalize_bsd_addr(struct bsd_addr_t *addr) {
|
||||
// parse, so to speak, the address
|
||||
if (addr->mem.ss_family == AF_INET6) {
|
||||
addr->ip = (char *) &((struct sockaddr_in6 *) addr)->sin6_addr;
|
||||
addr->ip_length = sizeof(struct in6_addr);
|
||||
addr->port = ntohs(((struct sockaddr_in6 *) addr)->sin6_port);
|
||||
} else if (addr->mem.ss_family == AF_INET) {
|
||||
addr->ip = (char *) &((struct sockaddr_in *) addr)->sin_addr;
|
||||
addr->ip_length = sizeof(struct in_addr);
|
||||
addr->port = ntohs(((struct sockaddr_in *) addr)->sin_port);
|
||||
} else {
|
||||
addr->ip_length = 0;
|
||||
addr->port = -1;
|
||||
}
|
||||
}
|
||||
|
||||
int bsd_local_addr(LIBUS_SOCKET_DESCRIPTOR fd, struct bsd_addr_t *addr) {
|
||||
addr->len = sizeof(addr->mem);
|
||||
if (getsockname(fd, (struct sockaddr *) &addr->mem, &addr->len)) {
|
||||
return -1;
|
||||
}
|
||||
internal_finalize_bsd_addr(addr);
|
||||
return 0;
|
||||
}
|
||||
|
||||
int bsd_remote_addr(LIBUS_SOCKET_DESCRIPTOR fd, struct bsd_addr_t *addr) {
|
||||
addr->len = sizeof(addr->mem);
|
||||
if (getpeername(fd, (struct sockaddr *) &addr->mem, &addr->len)) {
|
||||
return -1;
|
||||
}
|
||||
internal_finalize_bsd_addr(addr);
|
||||
return 0;
|
||||
}
|
||||
|
||||
char *bsd_addr_get_ip(struct bsd_addr_t *addr) {
|
||||
return addr->ip;
|
||||
}
|
||||
|
||||
int bsd_addr_get_ip_length(struct bsd_addr_t *addr) {
|
||||
return addr->ip_length;
|
||||
}
|
||||
|
||||
int bsd_addr_get_port(struct bsd_addr_t *addr) {
|
||||
return addr->port;
|
||||
}
|
||||
|
||||
// called by dispatch_ready_poll
|
||||
LIBUS_SOCKET_DESCRIPTOR bsd_accept_socket(LIBUS_SOCKET_DESCRIPTOR fd, struct bsd_addr_t *addr) {
|
||||
LIBUS_SOCKET_DESCRIPTOR accepted_fd;
|
||||
addr->len = sizeof(addr->mem);
|
||||
|
||||
#if defined(SOCK_CLOEXEC) && defined(SOCK_NONBLOCK)
|
||||
// Linux, FreeBSD
|
||||
accepted_fd = accept4(fd, (struct sockaddr *) addr, &addr->len, SOCK_CLOEXEC | SOCK_NONBLOCK);
|
||||
#else
|
||||
// Windows, OS X
|
||||
accepted_fd = accept(fd, (struct sockaddr *) addr, &addr->len);
|
||||
|
||||
#endif
|
||||
|
||||
/* We cannot rely on addr since it is not initialized if failed */
|
||||
if (accepted_fd == LIBUS_SOCKET_ERROR) {
|
||||
return LIBUS_SOCKET_ERROR;
|
||||
}
|
||||
|
||||
internal_finalize_bsd_addr(addr);
|
||||
|
||||
return bsd_set_nonblocking(apple_no_sigpipe(accepted_fd));
|
||||
}
|
||||
|
||||
int bsd_recv(LIBUS_SOCKET_DESCRIPTOR fd, void *buf, int length, int flags) {
|
||||
return recv(fd, buf, length, flags);
|
||||
}
|
||||
|
||||
int bsd_send(LIBUS_SOCKET_DESCRIPTOR fd, const char *buf, int length, int msg_more) {
|
||||
|
||||
// MSG_MORE (Linux), MSG_PARTIAL (Windows), TCP_NOPUSH (BSD)
|
||||
|
||||
#ifndef MSG_NOSIGNAL
|
||||
#define MSG_NOSIGNAL 0
|
||||
#endif
|
||||
|
||||
#ifdef MSG_MORE
|
||||
|
||||
// for Linux we do not want signals
|
||||
return send(fd, buf, length, (msg_more * MSG_MORE) | MSG_NOSIGNAL);
|
||||
|
||||
#else
|
||||
|
||||
// use TCP_NOPUSH
|
||||
|
||||
return send(fd, buf, length, MSG_NOSIGNAL);
|
||||
|
||||
#endif
|
||||
}
|
||||
|
||||
int bsd_would_block() {
|
||||
#ifdef _WIN32
|
||||
return WSAGetLastError() == WSAEWOULDBLOCK;
|
||||
#else
|
||||
return errno == EWOULDBLOCK;// || errno == EAGAIN;
|
||||
#endif
|
||||
}
|
||||
|
||||
// return LIBUS_SOCKET_ERROR or the fd that represents listen socket
|
||||
// listen both on ipv6 and ipv4
|
||||
LIBUS_SOCKET_DESCRIPTOR bsd_create_listen_socket(const char *host, int port, int options) {
|
||||
struct addrinfo hints, *result;
|
||||
memset(&hints, 0, sizeof(struct addrinfo));
|
||||
|
||||
hints.ai_flags = AI_PASSIVE;
|
||||
hints.ai_family = AF_UNSPEC;
|
||||
hints.ai_socktype = SOCK_STREAM;
|
||||
|
||||
char port_string[16];
|
||||
snprintf(port_string, 16, "%d", port);
|
||||
|
||||
if (getaddrinfo(host, port_string, &hints, &result)) {
|
||||
return LIBUS_SOCKET_ERROR;
|
||||
}
|
||||
|
||||
LIBUS_SOCKET_DESCRIPTOR listenFd = LIBUS_SOCKET_ERROR;
|
||||
struct addrinfo *listenAddr;
|
||||
for (struct addrinfo *a = result; a && listenFd == LIBUS_SOCKET_ERROR; a = a->ai_next) {
|
||||
if (a->ai_family == AF_INET6) {
|
||||
listenFd = bsd_create_socket(a->ai_family, a->ai_socktype, a->ai_protocol);
|
||||
listenAddr = a;
|
||||
}
|
||||
}
|
||||
|
||||
for (struct addrinfo *a = result; a && listenFd == LIBUS_SOCKET_ERROR; a = a->ai_next) {
|
||||
if (a->ai_family == AF_INET) {
|
||||
listenFd = bsd_create_socket(a->ai_family, a->ai_socktype, a->ai_protocol);
|
||||
listenAddr = a;
|
||||
}
|
||||
}
|
||||
|
||||
if (listenFd == LIBUS_SOCKET_ERROR) {
|
||||
freeaddrinfo(result);
|
||||
return LIBUS_SOCKET_ERROR;
|
||||
}
|
||||
|
||||
if (port != 0) {
|
||||
/* Otherwise, always enable SO_REUSEPORT and SO_REUSEADDR _unless_ options specify otherwise */
|
||||
#if /*defined(__linux) &&*/ defined(SO_REUSEPORT)
|
||||
if (!(options & LIBUS_LISTEN_EXCLUSIVE_PORT)) {
|
||||
int optval = 1;
|
||||
setsockopt(listenFd, SOL_SOCKET, SO_REUSEPORT, &optval, sizeof(optval));
|
||||
}
|
||||
#endif
|
||||
int enabled = 1;
|
||||
setsockopt(listenFd, SOL_SOCKET, SO_REUSEADDR, (SETSOCKOPT_PTR_TYPE) &enabled, sizeof(enabled));
|
||||
}
|
||||
|
||||
#ifdef IPV6_V6ONLY
|
||||
int disabled = 0;
|
||||
setsockopt(listenFd, IPPROTO_IPV6, IPV6_V6ONLY, (SETSOCKOPT_PTR_TYPE) &disabled, sizeof(disabled));
|
||||
#endif
|
||||
|
||||
if (bind(listenFd, listenAddr->ai_addr, (socklen_t) listenAddr->ai_addrlen) || listen(listenFd, 512)) {
|
||||
bsd_close_socket(listenFd);
|
||||
freeaddrinfo(result);
|
||||
return LIBUS_SOCKET_ERROR;
|
||||
}
|
||||
|
||||
freeaddrinfo(result);
|
||||
return listenFd;
|
||||
}
|
||||
|
||||
LIBUS_SOCKET_DESCRIPTOR bsd_create_connect_socket(const char *host, int port, const char *source_host, int options) {
|
||||
struct addrinfo hints, *result;
|
||||
memset(&hints, 0, sizeof(struct addrinfo));
|
||||
hints.ai_family = AF_UNSPEC;
|
||||
hints.ai_socktype = SOCK_STREAM;
|
||||
|
||||
char port_string[16];
|
||||
snprintf(port_string, 16, "%d", port);
|
||||
|
||||
if (getaddrinfo(host, port_string, &hints, &result) != 0) {
|
||||
return LIBUS_SOCKET_ERROR;
|
||||
}
|
||||
|
||||
LIBUS_SOCKET_DESCRIPTOR fd = bsd_create_socket(result->ai_family, result->ai_socktype, result->ai_protocol);
|
||||
if (fd == LIBUS_SOCKET_ERROR) {
|
||||
freeaddrinfo(result);
|
||||
return LIBUS_SOCKET_ERROR;
|
||||
}
|
||||
|
||||
if (source_host) {
|
||||
struct addrinfo *interface_result;
|
||||
if (!getaddrinfo(source_host, NULL, NULL, &interface_result)) {
|
||||
int ret = bind(fd, interface_result->ai_addr, (socklen_t) interface_result->ai_addrlen);
|
||||
freeaddrinfo(interface_result);
|
||||
if (ret == LIBUS_SOCKET_ERROR) {
|
||||
return LIBUS_SOCKET_ERROR;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
connect(fd, result->ai_addr, (socklen_t) result->ai_addrlen);
|
||||
freeaddrinfo(result);
|
||||
|
||||
return fd;
|
||||
}
|
|
@ -26,6 +26,10 @@ int default_ignore_data_handler(struct us_socket_t *s) {
|
|||
|
||||
/* Shared with SSL */
|
||||
|
||||
unsigned short us_socket_context_timestamp(int ssl, struct us_socket_context_t *context) {
|
||||
return context->timestamp;
|
||||
}
|
||||
|
||||
void us_listen_socket_close(int ssl, struct us_listen_socket_t *ls) {
|
||||
/* us_listen_socket_t extends us_socket_t so we close in similar ways */
|
||||
if (!us_socket_is_closed(0, &ls->s)) {
|
||||
|
@ -80,60 +84,92 @@ struct us_loop_t *us_socket_context_loop(int ssl, struct us_socket_context_t *co
|
|||
return context->loop;
|
||||
}
|
||||
|
||||
/* Returns the deep copy, to be freed */
|
||||
const char *deep_str_copy(const char *src) {
|
||||
if (!src) {
|
||||
return src;
|
||||
}
|
||||
size_t len = strlen(src) + 1;
|
||||
char *dst = malloc(len);
|
||||
memcpy(dst, src, len);
|
||||
return dst;
|
||||
}
|
||||
|
||||
/* Not shared with SSL */
|
||||
|
||||
struct us_socket_context_t *us_create_socket_context(int ssl, struct us_loop_t *loop, int context_ext_size, struct us_socket_context_options_t options) {
|
||||
/* For ease of use we copy all passed strings here */
|
||||
options.ca_file_name = deep_str_copy(options.ca_file_name);
|
||||
options.cert_file_name = deep_str_copy(options.cert_file_name);
|
||||
options.dh_params_file_name = deep_str_copy(options.dh_params_file_name);
|
||||
options.key_file_name = deep_str_copy(options.key_file_name);
|
||||
options.passphrase = deep_str_copy(options.passphrase);
|
||||
|
||||
/* Add SNI context */
|
||||
void us_socket_context_add_server_name(int ssl, struct us_socket_context_t *context, const char *hostname_pattern, struct us_socket_context_options_t options) {
|
||||
#ifndef LIBUS_NO_SSL
|
||||
if (ssl) {
|
||||
us_internal_ssl_socket_context_add_server_name((struct us_internal_ssl_socket_context_t *) context, hostname_pattern, options);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
/* Remove SNI context */
|
||||
void us_socket_context_remove_server_name(int ssl, struct us_socket_context_t *context, const char *hostname_pattern) {
|
||||
#ifndef LIBUS_NO_SSL
|
||||
if (ssl) {
|
||||
us_internal_ssl_socket_context_remove_server_name((struct us_internal_ssl_socket_context_t *) context, hostname_pattern);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
/* I don't like this one - maybe rename it to on_missing_server_name? */
|
||||
|
||||
/* Called when SNI matching fails - not if a match could be made.
|
||||
* You may modify the context by adding/removing names in this callback.
|
||||
* If the correct name is added immediately in the callback, it will be used */
|
||||
void us_socket_context_on_server_name(int ssl, struct us_socket_context_t *context, void (*cb)(struct us_socket_context_t *, const char *hostname)) {
|
||||
#ifndef LIBUS_NO_SSL
|
||||
if (ssl) {
|
||||
us_internal_ssl_socket_context_on_server_name((struct us_internal_ssl_socket_context_t *) context, (void (*)(struct us_internal_ssl_socket_context_t *, const char *hostname)) cb);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
/* Todo: get native context from SNI pattern */
|
||||
|
||||
void *us_socket_context_get_native_handle(int ssl, struct us_socket_context_t *context) {
|
||||
#ifndef LIBUS_NO_SSL
|
||||
if (ssl) {
|
||||
return us_internal_ssl_socket_context_get_native_handle((struct us_internal_ssl_socket_context_t *) context);
|
||||
}
|
||||
#endif
|
||||
|
||||
/* There is no native handle for a non-SSL socket context */
|
||||
return 0;
|
||||
}
|
||||
|
||||
/* Options is currently only applicable for SSL - this will change with time (prefer_low_memory is one example) */
|
||||
struct us_socket_context_t *us_create_socket_context(int ssl, struct us_loop_t *loop, int context_ext_size, struct us_socket_context_options_t options) {
|
||||
#ifndef LIBUS_NO_SSL
|
||||
if (ssl) {
|
||||
/* This function will call us, again, with SSL = false and a bigger ext_size */
|
||||
return (struct us_socket_context_t *) us_internal_create_ssl_socket_context(loop, context_ext_size, options);
|
||||
}
|
||||
#endif
|
||||
|
||||
/* This path is taken once either way - always BEFORE whatever SSL may do LATER.
|
||||
* context_ext_size will however be modified larger in case of SSL, to hold SSL extensions */
|
||||
|
||||
struct us_socket_context_t *context = malloc(sizeof(struct us_socket_context_t) + context_ext_size);
|
||||
context->loop = loop;
|
||||
context->head = 0;
|
||||
context->iterator = 0;
|
||||
context->next = 0;
|
||||
context->ignore_data = default_ignore_data_handler;
|
||||
context->options = options;
|
||||
|
||||
/* Begin at 0 */
|
||||
context->timestamp = 0;
|
||||
|
||||
us_internal_loop_link(loop, context);
|
||||
|
||||
/* If we are called from within SSL code, SSL code will make further changes to us */
|
||||
return context;
|
||||
}
|
||||
|
||||
void us_socket_context_free(int ssl, struct us_socket_context_t *context) {
|
||||
/* We also simply free every copied string here */
|
||||
free((void *) context->options.ca_file_name);
|
||||
free((void *) context->options.cert_file_name);
|
||||
free((void *) context->options.dh_params_file_name);
|
||||
free((void *) context->options.key_file_name);
|
||||
free((void *) context->options.passphrase);
|
||||
|
||||
#ifndef LIBUS_NO_SSL
|
||||
if (ssl) {
|
||||
/* This function will call us again with SSL=false */
|
||||
us_internal_ssl_socket_context_free((struct us_internal_ssl_socket_context_t *) context);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
/* This path is taken once either way - always AFTER whatever SSL may do BEFORE.
|
||||
* This is the opposite order compared to when creating the context - SSL code is cleaning up before non-SSL */
|
||||
|
||||
us_internal_loop_unlink(context->loop, context);
|
||||
free(context);
|
||||
}
|
||||
|
@ -167,14 +203,14 @@ struct us_listen_socket_t *us_socket_context_listen(int ssl, struct us_socket_co
|
|||
return ls;
|
||||
}
|
||||
|
||||
struct us_socket_t *us_socket_context_connect(int ssl, struct us_socket_context_t *context, const char *host, int port, int options, int socket_ext_size) {
|
||||
struct us_socket_t *us_socket_context_connect(int ssl, struct us_socket_context_t *context, const char *host, int port, const char *source_host, int options, int socket_ext_size) {
|
||||
#ifndef LIBUS_NO_SSL
|
||||
if (ssl) {
|
||||
return (struct us_socket_t *) us_internal_ssl_socket_context_connect((struct us_internal_ssl_socket_context_t *) context, host, port, options, socket_ext_size);
|
||||
return (struct us_socket_t *) us_internal_ssl_socket_context_connect((struct us_internal_ssl_socket_context_t *) context, host, port, source_host, options, socket_ext_size);
|
||||
}
|
||||
#endif
|
||||
|
||||
LIBUS_SOCKET_DESCRIPTOR connect_socket_fd = bsd_create_connect_socket(host, port, options);
|
||||
LIBUS_SOCKET_DESCRIPTOR connect_socket_fd = bsd_create_connect_socket(host, port, source_host, options);
|
||||
if (connect_socket_fd == LIBUS_SOCKET_ERROR) {
|
||||
return 0;
|
||||
}
|
||||
|
@ -239,10 +275,10 @@ void us_socket_context_on_open(int ssl, struct us_socket_context_t *context, str
|
|||
context->on_open = on_open;
|
||||
}
|
||||
|
||||
void us_socket_context_on_close(int ssl, struct us_socket_context_t *context, struct us_socket_t *(*on_close)(struct us_socket_t *s)) {
|
||||
void us_socket_context_on_close(int ssl, struct us_socket_context_t *context, struct us_socket_t *(*on_close)(struct us_socket_t *s, int code, void *reason)) {
|
||||
#ifndef LIBUS_NO_SSL
|
||||
if (ssl) {
|
||||
us_internal_ssl_socket_context_on_close((struct us_internal_ssl_socket_context_t *) context, (struct us_internal_ssl_socket_t * (*)(struct us_internal_ssl_socket_t *)) on_close);
|
||||
us_internal_ssl_socket_context_on_close((struct us_internal_ssl_socket_context_t *) context, (struct us_internal_ssl_socket_t * (*)(struct us_internal_ssl_socket_t *, int code, void *reason)) on_close);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
@ -294,6 +330,17 @@ void us_socket_context_on_end(int ssl, struct us_socket_context_t *context, stru
|
|||
context->on_end = on_end;
|
||||
}
|
||||
|
||||
void us_socket_context_on_connect_error(int ssl, struct us_socket_context_t *context, struct us_socket_t *(*on_connect_error)(struct us_socket_t *s, int code)) {
|
||||
#ifndef LIBUS_NO_SSL
|
||||
if (ssl) {
|
||||
us_internal_ssl_socket_context_on_connect_error((struct us_internal_ssl_socket_context_t *) context, (struct us_internal_ssl_socket_t * (*)(struct us_internal_ssl_socket_t *, int)) on_connect_error);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
context->on_connect_error = on_connect_error;
|
||||
}
|
||||
|
||||
void *us_socket_context_ext(int ssl, struct us_socket_context_t *context) {
|
||||
#ifndef LIBUS_NO_SSL
|
||||
if (ssl) {
|
||||
|
|
|
@ -17,8 +17,16 @@
|
|||
|
||||
#ifdef LIBUS_USE_OPENSSL
|
||||
|
||||
/* These are in sni_tree.cpp */
|
||||
void *sni_new();
|
||||
void sni_free(void *sni, void(*cb)(void *));
|
||||
int sni_add(void *sni, const char *hostname, void *user);
|
||||
void *sni_remove(void *sni, const char *hostname);
|
||||
void *sni_find(void *sni, const char *hostname);
|
||||
|
||||
#include "libusockets.h"
|
||||
#include "internal/internal.h"
|
||||
#include <string.h>
|
||||
|
||||
/* This module contains the entire OpenSSL implementation
|
||||
* of the SSL socket and socket context interfaces. */
|
||||
|
@ -59,10 +67,17 @@ struct us_internal_ssl_socket_context_t {
|
|||
SSL_CTX *ssl_context;
|
||||
int is_parent;
|
||||
|
||||
// här måste det vara!
|
||||
/* These decorate the base implementation */
|
||||
struct us_internal_ssl_socket_t *(*on_open)(struct us_internal_ssl_socket_t *, int is_client, char *ip, int ip_length);
|
||||
struct us_internal_ssl_socket_t *(*on_data)(struct us_internal_ssl_socket_t *, char *data, int length);
|
||||
struct us_internal_ssl_socket_t *(*on_close)(struct us_internal_ssl_socket_t *);
|
||||
struct us_internal_ssl_socket_t *(*on_writable)(struct us_internal_ssl_socket_t *);
|
||||
struct us_internal_ssl_socket_t *(*on_close)(struct us_internal_ssl_socket_t *, int code, void *reason);
|
||||
|
||||
/* Called for missing SNI hostnames, if not NULL */
|
||||
void (*on_server_name)(struct us_internal_ssl_socket_context_t *, const char *hostname);
|
||||
|
||||
/* Pointer to sni tree, created when the context is created and freed likewise when freed */
|
||||
void *sni;
|
||||
};
|
||||
|
||||
// same here, should or shouldn't it contain s?
|
||||
|
@ -70,6 +85,7 @@ struct us_internal_ssl_socket_t {
|
|||
struct us_socket_t s;
|
||||
SSL *ssl;
|
||||
int ssl_write_wants_read; // we use this for now
|
||||
int ssl_read_wants_write;
|
||||
};
|
||||
|
||||
int passphrase_cb(char *buf, int size, int rwflag, void *u) {
|
||||
|
@ -103,7 +119,7 @@ int BIO_s_custom_write(BIO *bio, const char *data, int length) {
|
|||
int written = us_socket_write(0, loop_ssl_data->ssl_socket, data, length, loop_ssl_data->last_write_was_msg_more);
|
||||
|
||||
if (!written) {
|
||||
BIO_set_flags(bio, BIO_get_flags(bio) | BIO_FLAGS_SHOULD_RETRY | BIO_FLAGS_WRITE);
|
||||
BIO_set_flags(bio, BIO_FLAGS_SHOULD_RETRY | BIO_FLAGS_WRITE);
|
||||
return -1;
|
||||
}
|
||||
|
||||
|
@ -118,7 +134,7 @@ int BIO_s_custom_read(BIO *bio, char *dst, int length) {
|
|||
//printf("BIO_s_custom_read\n");
|
||||
|
||||
if (!loop_ssl_data->ssl_read_input_length) {
|
||||
BIO_set_flags(bio, BIO_get_flags(bio) | BIO_FLAGS_SHOULD_RETRY | BIO_FLAGS_READ);
|
||||
BIO_set_flags(bio, BIO_FLAGS_SHOULD_RETRY | BIO_FLAGS_READ);
|
||||
return -1;
|
||||
}
|
||||
|
||||
|
@ -141,6 +157,7 @@ struct us_internal_ssl_socket_t *ssl_on_open(struct us_internal_ssl_socket_t *s,
|
|||
|
||||
s->ssl = SSL_new(context->ssl_context);
|
||||
s->ssl_write_wants_read = 0;
|
||||
s->ssl_read_wants_write = 0;
|
||||
SSL_set_bio(s->ssl, loop_ssl_data->shared_rbio, loop_ssl_data->shared_wbio);
|
||||
|
||||
BIO_up_ref(loop_ssl_data->shared_rbio);
|
||||
|
@ -155,19 +172,26 @@ struct us_internal_ssl_socket_t *ssl_on_open(struct us_internal_ssl_socket_t *s,
|
|||
return (struct us_internal_ssl_socket_t *) context->on_open(s, is_client, ip, ip_length);
|
||||
}
|
||||
|
||||
struct us_internal_ssl_socket_t *ssl_on_close(struct us_internal_ssl_socket_t *s) {
|
||||
/* This one is a helper; it is entirely shared with non-SSL so can be removed */
|
||||
struct us_internal_ssl_socket_t *us_internal_ssl_socket_close(struct us_internal_ssl_socket_t *s, int code, void *reason) {
|
||||
return (struct us_internal_ssl_socket_t *) us_socket_close(0, (struct us_socket_t *) s, code, reason);
|
||||
}
|
||||
|
||||
struct us_internal_ssl_socket_t *ssl_on_close(struct us_internal_ssl_socket_t *s, int code, void *reason) {
|
||||
struct us_internal_ssl_socket_context_t *context = (struct us_internal_ssl_socket_context_t *) us_socket_context(0, &s->s);
|
||||
|
||||
SSL_free(s->ssl);
|
||||
|
||||
return context->on_close(s);
|
||||
return context->on_close(s, code, reason);
|
||||
}
|
||||
|
||||
struct us_internal_ssl_socket_t *ssl_on_end(struct us_internal_ssl_socket_t *s) {
|
||||
struct us_internal_ssl_socket_context_t *context = (struct us_internal_ssl_socket_context_t *) us_socket_context(0, &s->s);
|
||||
// struct us_internal_ssl_socket_context_t *context = (struct us_internal_ssl_socket_context_t *) us_socket_context(0, &s->s);
|
||||
|
||||
// whatever state we are in, a TCP FIN is always an answered shutdown
|
||||
return us_internal_ssl_socket_close(s);
|
||||
|
||||
/* Todo: this should report CLEANLY SHUTDOWN as reason */
|
||||
return us_internal_ssl_socket_close(s, 0, NULL);
|
||||
}
|
||||
|
||||
// this whole function needs a complete clean-up
|
||||
|
@ -192,7 +216,8 @@ struct us_internal_ssl_socket_t *ssl_on_data(struct us_internal_ssl_socket_t *s,
|
|||
// two phase shutdown is complete here
|
||||
//printf("Two step SSL shutdown complete\n");
|
||||
|
||||
return us_internal_ssl_socket_close(s);
|
||||
/* Todo: this should also report some kind of clean shutdown */
|
||||
return us_internal_ssl_socket_close(s, 0, NULL);
|
||||
} else if (ret < 0) {
|
||||
|
||||
int err = SSL_get_error(s->ssl, ret);
|
||||
|
@ -226,13 +251,18 @@ struct us_internal_ssl_socket_t *ssl_on_data(struct us_internal_ssl_socket_t *s,
|
|||
}
|
||||
|
||||
// terminate connection here
|
||||
return us_internal_ssl_socket_close(s);
|
||||
return us_internal_ssl_socket_close(s, 0, NULL);
|
||||
} else {
|
||||
// emit the data we have and exit
|
||||
|
||||
if (err == SSL_ERROR_WANT_WRITE) {
|
||||
// here we need to trigger writable event next ssl_read!
|
||||
s->ssl_read_wants_write = 1;
|
||||
}
|
||||
|
||||
// assume we emptied the input buffer fully or error here as well!
|
||||
if (loop_ssl_data->ssl_read_input_length) {
|
||||
return us_internal_ssl_socket_close(s);
|
||||
return us_internal_ssl_socket_close(s, 0, NULL);
|
||||
}
|
||||
|
||||
// cannot emit zero length to app
|
||||
|
@ -240,6 +270,8 @@ struct us_internal_ssl_socket_t *ssl_on_data(struct us_internal_ssl_socket_t *s,
|
|||
break;
|
||||
}
|
||||
|
||||
context = (struct us_internal_ssl_socket_context_t *) us_socket_context(0, &s->s);
|
||||
|
||||
s = context->on_data(s, loop_ssl_data->ssl_read_output + LIBUS_RECV_BUFFER_PADDING, read);
|
||||
if (us_socket_is_closed(0, &s->s)) {
|
||||
return s;
|
||||
|
@ -255,6 +287,8 @@ struct us_internal_ssl_socket_t *ssl_on_data(struct us_internal_ssl_socket_t *s,
|
|||
// at this point we might be full and need to emit the data to application and start over
|
||||
if (read == LIBUS_RECV_BUFFER_LENGTH) {
|
||||
|
||||
context = (struct us_internal_ssl_socket_context_t *) us_socket_context(0, &s->s);
|
||||
|
||||
// emit data and restart
|
||||
s = context->on_data(s, loop_ssl_data->ssl_read_output + LIBUS_RECV_BUFFER_PADDING, read);
|
||||
if (us_socket_is_closed(0, &s->s)) {
|
||||
|
@ -287,7 +321,7 @@ struct us_internal_ssl_socket_t *ssl_on_data(struct us_internal_ssl_socket_t *s,
|
|||
//exit(-2);
|
||||
|
||||
// not correct anyways!
|
||||
s = us_internal_ssl_socket_close(s);
|
||||
s = us_internal_ssl_socket_close(s, 0, NULL);
|
||||
|
||||
//us_
|
||||
}
|
||||
|
@ -295,6 +329,28 @@ struct us_internal_ssl_socket_t *ssl_on_data(struct us_internal_ssl_socket_t *s,
|
|||
return s;
|
||||
}
|
||||
|
||||
struct us_internal_ssl_socket_t *ssl_on_writable(struct us_internal_ssl_socket_t *s) {
|
||||
|
||||
struct us_internal_ssl_socket_context_t *context = (struct us_internal_ssl_socket_context_t *) us_socket_context(0, &s->s);
|
||||
|
||||
// todo: cork here so that we efficiently output both from reading and from writing?
|
||||
|
||||
if (s->ssl_read_wants_write) {
|
||||
s->ssl_read_wants_write = 0;
|
||||
|
||||
// make sure to update context before we call (context can change if the user adopts the socket!)
|
||||
context = (struct us_internal_ssl_socket_context_t *) us_socket_context(0, &s->s);
|
||||
|
||||
// if this one fails to write data, it sets ssl_read_wants_write again
|
||||
s = (struct us_internal_ssl_socket_t *) context->sc.on_data(&s->s, 0, 0); // cast here!
|
||||
}
|
||||
|
||||
// should this one come before we have read? should it come always? spurious on_writable is okay
|
||||
s = context->on_writable(s);
|
||||
|
||||
return s;
|
||||
}
|
||||
|
||||
/* Lazily inits loop ssl data first time */
|
||||
void us_internal_init_loop_ssl_data(struct us_loop_t *loop) {
|
||||
if (!loop->data.ssl_data) {
|
||||
|
@ -371,60 +427,77 @@ int ssl_ignore_data(struct us_internal_ssl_socket_t *s) {
|
|||
}
|
||||
|
||||
/* Per-context functions */
|
||||
struct us_internal_ssl_socket_context_t *us_internal_create_child_ssl_socket_context(struct us_internal_ssl_socket_context_t *context, int context_ext_size) {
|
||||
struct us_socket_context_options_t options = {0};
|
||||
void *us_internal_ssl_socket_context_get_native_handle(struct us_internal_ssl_socket_context_t *context) {
|
||||
return context->ssl_context;
|
||||
}
|
||||
|
||||
struct us_internal_ssl_socket_context_t *us_internal_create_child_ssl_socket_context(struct us_internal_ssl_socket_context_t *context, int context_ext_size) {
|
||||
/* Create a new non-SSL context */
|
||||
struct us_socket_context_options_t options = {0};
|
||||
struct us_internal_ssl_socket_context_t *child_context = (struct us_internal_ssl_socket_context_t *) us_create_socket_context(0, context->sc.loop, sizeof(struct us_internal_ssl_socket_context_t) + context_ext_size, options);
|
||||
|
||||
// I think this is the only thing being shared
|
||||
/* The only thing we share is SSL_CTX */
|
||||
child_context->ssl_context = context->ssl_context;
|
||||
child_context->is_parent = 0;
|
||||
|
||||
return child_context;
|
||||
}
|
||||
|
||||
struct us_internal_ssl_socket_context_t *us_internal_create_ssl_socket_context(struct us_loop_t *loop, int context_ext_size, struct us_socket_context_options_t options) {
|
||||
/* Common function for creating a context from options.
|
||||
* We must NOT free a SSL_CTX with only SSL_CTX_free! Also free any password */
|
||||
void free_ssl_context(SSL_CTX *ssl_context) {
|
||||
if (!ssl_context) {
|
||||
return;
|
||||
}
|
||||
|
||||
us_internal_init_loop_ssl_data(loop);
|
||||
/* If we have set a password string, free it here */
|
||||
void *password = SSL_CTX_get_default_passwd_cb_userdata(ssl_context);
|
||||
/* OpenSSL returns NULL if we have no set password */
|
||||
free(password);
|
||||
|
||||
struct us_socket_context_options_t no_options = {0};
|
||||
SSL_CTX_free(ssl_context);
|
||||
}
|
||||
|
||||
struct us_internal_ssl_socket_context_t *context = (struct us_internal_ssl_socket_context_t *) us_create_socket_context(0, loop, sizeof(struct us_internal_ssl_socket_context_t) + context_ext_size, no_options);
|
||||
/* This function should take any options and return SSL_CTX - which has to be free'd with
|
||||
* our destructor function - free_ssl_context() */
|
||||
SSL_CTX *create_ssl_context_from_options(struct us_socket_context_options_t options) {
|
||||
/* Create the context */
|
||||
SSL_CTX *ssl_context = SSL_CTX_new(TLS_method());
|
||||
|
||||
context->ssl_context = SSL_CTX_new(TLS_method());
|
||||
context->is_parent = 1;
|
||||
// only parent ssl contexts may need to ignore data
|
||||
context->sc.ignore_data = (int (*)(struct us_socket_t *)) ssl_ignore_data;
|
||||
/* Default options we rely on - changing these will break our logic */
|
||||
SSL_CTX_set_read_ahead(ssl_context, 1);
|
||||
SSL_CTX_set_mode(ssl_context, SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER);
|
||||
|
||||
// options
|
||||
SSL_CTX_set_read_ahead(context->ssl_context, 1);
|
||||
SSL_CTX_set_mode(context->ssl_context, SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER);
|
||||
//SSL_CTX_set_mode(context->ssl_context, SSL_MODE_ENABLE_PARTIAL_WRITE);
|
||||
/* Anything below TLS 1.2 is disabled */
|
||||
SSL_CTX_set_min_proto_version(ssl_context, TLS1_2_VERSION);
|
||||
|
||||
// this lowers performance a bit in benchmarks
|
||||
/* The following are helpers. You may easily implement whatever you want by using the native handle directly */
|
||||
|
||||
/* Important option for lowering memory usage, but lowers performance slightly */
|
||||
if (options.ssl_prefer_low_memory_usage) {
|
||||
SSL_CTX_set_mode(context->ssl_context, SSL_MODE_RELEASE_BUFFERS);
|
||||
SSL_CTX_set_mode(ssl_context, SSL_MODE_RELEASE_BUFFERS);
|
||||
}
|
||||
|
||||
//SSL_CTX_set_mode(context->ssl_context, SSL_MODE_RELEASE_BUFFERS);
|
||||
SSL_CTX_set_options(context->ssl_context, SSL_OP_NO_SSLv3);
|
||||
SSL_CTX_set_options(context->ssl_context, SSL_OP_NO_TLSv1);
|
||||
|
||||
// these are going to be extended
|
||||
if (options.passphrase) {
|
||||
SSL_CTX_set_default_passwd_cb_userdata(context->ssl_context, (void *) options.passphrase);
|
||||
SSL_CTX_set_default_passwd_cb(context->ssl_context, passphrase_cb);
|
||||
/* When freeing the CTX we need to check SSL_CTX_get_default_passwd_cb_userdata and
|
||||
* free it if set */
|
||||
SSL_CTX_set_default_passwd_cb_userdata(ssl_context, (void *) strdup(options.passphrase));
|
||||
SSL_CTX_set_default_passwd_cb(ssl_context, passphrase_cb);
|
||||
}
|
||||
|
||||
/* This one most probably do not need the cert_file_name string to be kept alive */
|
||||
if (options.cert_file_name) {
|
||||
if (SSL_CTX_use_certificate_chain_file(context->ssl_context, options.cert_file_name) != 1) {
|
||||
return 0;
|
||||
if (SSL_CTX_use_certificate_chain_file(ssl_context, options.cert_file_name) != 1) {
|
||||
free_ssl_context(ssl_context);
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
|
||||
/* Same as above - we can discard this string afterwards I suppose */
|
||||
if (options.key_file_name) {
|
||||
if (SSL_CTX_use_PrivateKey_file(context->ssl_context, options.key_file_name, SSL_FILETYPE_PEM) != 1) {
|
||||
return 0;
|
||||
if (SSL_CTX_use_PrivateKey_file(ssl_context, options.key_file_name, SSL_FILETYPE_PEM) != 1) {
|
||||
free_ssl_context(ssl_context);
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -432,13 +505,15 @@ struct us_internal_ssl_socket_context_t *us_internal_create_ssl_socket_context(s
|
|||
STACK_OF(X509_NAME) *ca_list;
|
||||
ca_list = SSL_load_client_CA_file(options.ca_file_name);
|
||||
if(ca_list == NULL) {
|
||||
return 0;
|
||||
free_ssl_context(ssl_context);
|
||||
return NULL;
|
||||
}
|
||||
SSL_CTX_set_client_CA_list(context->ssl_context, ca_list);
|
||||
if (SSL_CTX_load_verify_locations(context->ssl_context, options.ca_file_name, NULL) != 1) {
|
||||
return 0;
|
||||
SSL_CTX_set_client_CA_list(ssl_context, ca_list);
|
||||
if (SSL_CTX_load_verify_locations(ssl_context, options.ca_file_name, NULL) != 1) {
|
||||
free_ssl_context(ssl_context);
|
||||
return NULL;
|
||||
}
|
||||
SSL_CTX_set_verify(context->ssl_context, SSL_VERIFY_PEER, NULL);
|
||||
SSL_CTX_set_verify(ssl_context, SSL_VERIFY_PEER, NULL);
|
||||
}
|
||||
|
||||
if (options.dh_params_file_name) {
|
||||
|
@ -451,29 +526,155 @@ struct us_internal_ssl_socket_context_t *us_internal_create_ssl_socket_context(s
|
|||
dh_2048 = PEM_read_DHparams(paramfile, NULL, NULL, NULL);
|
||||
fclose(paramfile);
|
||||
} else {
|
||||
return 0;
|
||||
free_ssl_context(ssl_context);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
if (dh_2048 == NULL) {
|
||||
return 0;
|
||||
free_ssl_context(ssl_context);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
if (SSL_CTX_set_tmp_dh(context->ssl_context, dh_2048) != 1) {
|
||||
return 0;
|
||||
const long set_tmp_dh = SSL_CTX_set_tmp_dh(ssl_context, dh_2048);
|
||||
DH_free(dh_2048);
|
||||
|
||||
if (set_tmp_dh != 1) {
|
||||
free_ssl_context(ssl_context);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
/* OWASP Cipher String 'A+' (https://www.owasp.org/index.php/TLS_Cipher_String_Cheat_Sheet) */
|
||||
if (SSL_CTX_set_cipher_list(context->ssl_context, "DHE-RSA-AES256-GCM-SHA384:DHE-RSA-AES128-GCM-SHA256:ECDHE-RSA-AES256-GCM-SHA384:ECDHE-RSA-AES128-GCM-SHA256") != 1) {
|
||||
return 0;
|
||||
if (SSL_CTX_set_cipher_list(ssl_context, "DHE-RSA-AES256-GCM-SHA384:DHE-RSA-AES128-GCM-SHA256:ECDHE-RSA-AES256-GCM-SHA384:ECDHE-RSA-AES128-GCM-SHA256") != 1) {
|
||||
free_ssl_context(ssl_context);
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
|
||||
/* This must be free'd with free_ssl_context, not SSL_CTX_free */
|
||||
return ssl_context;
|
||||
}
|
||||
|
||||
/* Todo: return error on failure? */
|
||||
void us_internal_ssl_socket_context_add_server_name(struct us_internal_ssl_socket_context_t *context, const char *hostname_pattern, struct us_socket_context_options_t options) {
|
||||
|
||||
/* Try and construct an SSL_CTX from options */
|
||||
SSL_CTX *ssl_context = create_ssl_context_from_options(options);
|
||||
|
||||
/* We do not want to hold any nullptr's in our SNI tree */
|
||||
if (ssl_context) {
|
||||
if (sni_add(context->sni, hostname_pattern, ssl_context)) {
|
||||
/* If we already had that name, ignore */
|
||||
free_ssl_context(ssl_context);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void us_internal_ssl_socket_context_on_server_name(struct us_internal_ssl_socket_context_t *context, void (*cb)(struct us_internal_ssl_socket_context_t *, const char *hostname)) {
|
||||
context->on_server_name = cb;
|
||||
}
|
||||
|
||||
void us_internal_ssl_socket_context_remove_server_name(struct us_internal_ssl_socket_context_t *context, const char *hostname_pattern) {
|
||||
|
||||
/* The same thing must happen for sni_free, that's why we have a callback */
|
||||
SSL_CTX *sni_node_ssl_context = (SSL_CTX *) sni_remove(context->sni, hostname_pattern);
|
||||
free_ssl_context(sni_node_ssl_context);
|
||||
}
|
||||
|
||||
/* Returns NULL or SSL_CTX. May call missing server name callback */
|
||||
SSL_CTX *resolve_context(struct us_internal_ssl_socket_context_t *context, const char *hostname) {
|
||||
|
||||
/* Try once first */
|
||||
void *user = sni_find(context->sni, hostname);
|
||||
if (!user) {
|
||||
/* Emit missing hostname then try again */
|
||||
if (!context->on_server_name) {
|
||||
/* We have no callback registered, so fail */
|
||||
return NULL;
|
||||
}
|
||||
|
||||
context->on_server_name(context, hostname);
|
||||
|
||||
/* Last try */
|
||||
user = sni_find(context->sni, hostname);
|
||||
}
|
||||
|
||||
return user;
|
||||
}
|
||||
|
||||
// arg is context
|
||||
int sni_cb(SSL *ssl, int *al, void *arg) {
|
||||
|
||||
if (ssl) {
|
||||
const char *hostname = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name);
|
||||
if (hostname && hostname[0]) {
|
||||
/* Try and resolve (match) required hostname with what we have registered */
|
||||
SSL_CTX *resolved_ssl_context = resolve_context((struct us_internal_ssl_socket_context_t *) arg, hostname);
|
||||
if (resolved_ssl_context) {
|
||||
//printf("Did find matching SNI context for hostname: <%s>!\n", hostname);
|
||||
SSL_set_SSL_CTX(ssl, resolved_ssl_context);
|
||||
} else {
|
||||
/* Call a blocking callback notifying of missing context */
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return SSL_TLSEXT_ERR_OK;
|
||||
}
|
||||
|
||||
/* Can we even come here ever? */
|
||||
return SSL_TLSEXT_ERR_NOACK;
|
||||
}
|
||||
|
||||
struct us_internal_ssl_socket_context_t *us_internal_create_ssl_socket_context(struct us_loop_t *loop, int context_ext_size, struct us_socket_context_options_t options) {
|
||||
/* If we haven't initialized the loop data yet, do so .
|
||||
* This is needed because loop data holds shared OpenSSL data and
|
||||
* the function is also responsible for initializing OpenSSL */
|
||||
us_internal_init_loop_ssl_data(loop);
|
||||
|
||||
/* First of all we try and create the SSL context from options */
|
||||
SSL_CTX *ssl_context = create_ssl_context_from_options(options);
|
||||
if (!ssl_context) {
|
||||
/* We simply fail early if we cannot even create the OpenSSL context */
|
||||
return NULL;
|
||||
}
|
||||
|
||||
/* Otherwise ee continue by creating a non-SSL context, but with larger ext to hold our SSL stuff */
|
||||
struct us_internal_ssl_socket_context_t *context = (struct us_internal_ssl_socket_context_t *) us_create_socket_context(0, loop, sizeof(struct us_internal_ssl_socket_context_t) + context_ext_size, options);
|
||||
|
||||
/* I guess this is the only optional callback */
|
||||
context->on_server_name = NULL;
|
||||
|
||||
/* Then we extend its SSL parts */
|
||||
context->ssl_context = ssl_context;//create_ssl_context_from_options(options);
|
||||
context->is_parent = 1;
|
||||
|
||||
/* We, as parent context, may ignore data */
|
||||
context->sc.ignore_data = (int (*)(struct us_socket_t *)) ssl_ignore_data;
|
||||
|
||||
/* Parent contexts may use SNI */
|
||||
SSL_CTX_set_tlsext_servername_callback(context->ssl_context, sni_cb);
|
||||
SSL_CTX_set_tlsext_servername_arg(context->ssl_context, context);
|
||||
|
||||
/* Also create the SNI tree */
|
||||
context->sni = sni_new();
|
||||
|
||||
return context;
|
||||
}
|
||||
|
||||
/* Our destructor for hostnames, used below */
|
||||
void sni_hostname_destructor(void *user) {
|
||||
/* Some nodes hold null, so this one must ignore this case */
|
||||
free_ssl_context((SSL_CTX *) user);
|
||||
}
|
||||
|
||||
void us_internal_ssl_socket_context_free(struct us_internal_ssl_socket_context_t *context) {
|
||||
/* If we are parent then we need to free our OpenSSL context */
|
||||
if (context->is_parent) {
|
||||
SSL_CTX_free(context->ssl_context);
|
||||
free_ssl_context(context->ssl_context);
|
||||
|
||||
/* Here we need to register a temporary callback for all still-existing hostnames
|
||||
* and their contexts. Only parents have an SNI tree */
|
||||
sni_free(context->sni, sni_hostname_destructor);
|
||||
}
|
||||
|
||||
us_socket_context_free(0, &context->sc);
|
||||
|
@ -483,8 +684,8 @@ struct us_listen_socket_t *us_internal_ssl_socket_context_listen(struct us_inter
|
|||
return us_socket_context_listen(0, &context->sc, host, port, options, sizeof(struct us_internal_ssl_socket_t) - sizeof(struct us_socket_t) + socket_ext_size);
|
||||
}
|
||||
|
||||
struct us_internal_ssl_socket_t *us_internal_ssl_socket_context_connect(struct us_internal_ssl_socket_context_t *context, const char *host, int port, int options, int socket_ext_size) {
|
||||
return (struct us_internal_ssl_socket_t *) us_socket_context_connect(0, &context->sc, host, port, options, sizeof(struct us_internal_ssl_socket_t) - sizeof(struct us_socket_t) + socket_ext_size);
|
||||
struct us_internal_ssl_socket_t *us_internal_ssl_socket_context_connect(struct us_internal_ssl_socket_context_t *context, const char *host, int port, const char *source_host, int options, int socket_ext_size) {
|
||||
return (struct us_internal_ssl_socket_t *) us_socket_context_connect(0, &context->sc, host, port, source_host, options, sizeof(struct us_internal_ssl_socket_t) - sizeof(struct us_socket_t) + socket_ext_size);
|
||||
}
|
||||
|
||||
void us_internal_ssl_socket_context_on_open(struct us_internal_ssl_socket_context_t *context, struct us_internal_ssl_socket_t *(*on_open)(struct us_internal_ssl_socket_t *s, int is_client, char *ip, int ip_length)) {
|
||||
|
@ -492,8 +693,8 @@ void us_internal_ssl_socket_context_on_open(struct us_internal_ssl_socket_contex
|
|||
context->on_open = on_open;
|
||||
}
|
||||
|
||||
void us_internal_ssl_socket_context_on_close(struct us_internal_ssl_socket_context_t *context, struct us_internal_ssl_socket_t *(*on_close)(struct us_internal_ssl_socket_t *s)) {
|
||||
us_socket_context_on_close(0, (struct us_socket_context_t *) context, (struct us_socket_t *(*)(struct us_socket_t *)) ssl_on_close);
|
||||
void us_internal_ssl_socket_context_on_close(struct us_internal_ssl_socket_context_t *context, struct us_internal_ssl_socket_t *(*on_close)(struct us_internal_ssl_socket_t *s, int code, void *reason)) {
|
||||
us_socket_context_on_close(0, (struct us_socket_context_t *) context, (struct us_socket_t *(*)(struct us_socket_t *, int, void *)) ssl_on_close);
|
||||
context->on_close = on_close;
|
||||
}
|
||||
|
||||
|
@ -503,22 +704,31 @@ void us_internal_ssl_socket_context_on_data(struct us_internal_ssl_socket_contex
|
|||
}
|
||||
|
||||
void us_internal_ssl_socket_context_on_writable(struct us_internal_ssl_socket_context_t *context, struct us_internal_ssl_socket_t *(*on_writable)(struct us_internal_ssl_socket_t *s)) {
|
||||
us_socket_context_on_writable(0, (struct us_socket_context_t *) context, (struct us_socket_t *(*)(struct us_socket_t *)) on_writable);
|
||||
us_socket_context_on_writable(0, (struct us_socket_context_t *) context, (struct us_socket_t *(*)(struct us_socket_t *)) ssl_on_writable);
|
||||
context->on_writable = on_writable;
|
||||
}
|
||||
|
||||
void us_internal_ssl_socket_context_on_timeout(struct us_internal_ssl_socket_context_t *context, struct us_internal_ssl_socket_t *(*on_timeout)(struct us_internal_ssl_socket_t *s)) {
|
||||
us_socket_context_on_timeout(0, (struct us_socket_context_t *) context, (struct us_socket_t *(*)(struct us_socket_t *)) on_timeout);
|
||||
}
|
||||
|
||||
/* We do not really listen to passed FIN-handler, we entirely override it with our handler since SSL doesn't really have support for half-closed sockets */
|
||||
void us_internal_ssl_socket_context_on_end(struct us_internal_ssl_socket_context_t *context, struct us_internal_ssl_socket_t *(*on_end)(struct us_internal_ssl_socket_t *)) {
|
||||
us_socket_context_on_end(0, (struct us_socket_context_t *) context, (struct us_socket_t *(*)(struct us_socket_t *)) ssl_on_end);
|
||||
}
|
||||
|
||||
void us_internal_ssl_socket_context_on_connect_error(struct us_internal_ssl_socket_context_t *context, struct us_internal_ssl_socket_t *(*on_connect_error)(struct us_internal_ssl_socket_t *, int code)) {
|
||||
us_socket_context_on_connect_error(0, (struct us_socket_context_t *) context, (struct us_socket_t *(*)(struct us_socket_t *, int)) on_connect_error);
|
||||
}
|
||||
|
||||
void *us_internal_ssl_socket_context_ext(struct us_internal_ssl_socket_context_t *context) {
|
||||
return context + 1;
|
||||
}
|
||||
|
||||
/* Per socket functions */
|
||||
void *us_internal_ssl_socket_get_native_handle(struct us_internal_ssl_socket_t *s) {
|
||||
return s->ssl;
|
||||
}
|
||||
|
||||
int us_internal_ssl_socket_write(struct us_internal_ssl_socket_t *s, const char *data, int length, int msg_more) {
|
||||
if (us_socket_is_closed(0, &s->s) || us_internal_ssl_socket_is_shut_down(s)) {
|
||||
|
@ -617,10 +827,6 @@ void us_internal_ssl_socket_shutdown(struct us_internal_ssl_socket_t *s) {
|
|||
}
|
||||
}
|
||||
|
||||
struct us_internal_ssl_socket_t *us_internal_ssl_socket_close(struct us_internal_ssl_socket_t *s) {
|
||||
return (struct us_internal_ssl_socket_t *) us_socket_close(0, (struct us_socket_t *) s);
|
||||
}
|
||||
|
||||
struct us_internal_ssl_socket_t *us_internal_ssl_socket_context_adopt_socket(struct us_internal_ssl_socket_context_t *context, struct us_internal_ssl_socket_t *s, int ext_size) {
|
||||
// todo: this is completely untested
|
||||
return (struct us_internal_ssl_socket_t *) us_socket_context_adopt_socket(0, &context->sc, &s->s, sizeof(struct us_internal_ssl_socket_t) - sizeof(struct us_socket_t) + ext_size);
|
||||
|
|
|
@ -0,0 +1,218 @@
|
|||
/*
|
||||
* Authored by Alex Hultman, 2018-2020.
|
||||
* Intellectual property of third-party.
|
||||
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
/* This Server Name Indication hostname tree is written in C++ but could be ported to C.
|
||||
* Overall it looks like crap, but has no memory allocations in fast path and is O(log n). */
|
||||
|
||||
#ifndef SNI_TREE_H
|
||||
#define SNI_TREE_H
|
||||
|
||||
#ifndef LIBUS_NO_SSL
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string_view>
|
||||
#include <cstring>
|
||||
#include <cstdlib>
|
||||
#include <algorithm>
|
||||
|
||||
/* We only handle a maximum of 10 labels per hostname */
|
||||
#define MAX_LABELS 10
|
||||
|
||||
/* This cannot be shared */
|
||||
thread_local void (*sni_free_cb)(void *);
|
||||
|
||||
struct sni_node {
|
||||
/* Empty nodes must always hold null */
|
||||
void *user = nullptr;
|
||||
std::map<std::string_view, std::unique_ptr<sni_node>> children;
|
||||
|
||||
~sni_node() {
|
||||
for (auto &p : children) {
|
||||
/* The data of our string_views are managed by malloc */
|
||||
free((void *) p.first.data());
|
||||
|
||||
/* Call destructor passed to sni_free only if we hold data.
|
||||
* This is important since sni_remove does not have sni_free_cb set */
|
||||
if (p.second.get()->user) {
|
||||
sni_free_cb(p.second.get()->user);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// this can only delete ONE single node, but may cull "empty nodes with null as data"
|
||||
void *removeUser(struct sni_node *root, unsigned int label, std::string_view *labels, unsigned int numLabels) {
|
||||
|
||||
/* If we are in the bottom (past bottom by one), there is nothing to remove */
|
||||
if (label == numLabels) {
|
||||
void *user = root->user;
|
||||
/* Mark us for culling on the way up */
|
||||
root->user = nullptr;
|
||||
return user;
|
||||
}
|
||||
|
||||
/* Is this label a child of root? */
|
||||
auto it = root->children.find(labels[label]);
|
||||
if (it == root->children.end()) {
|
||||
/* We cannot continue */
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void *removedUser = removeUser(it->second.get(), label + 1, labels, numLabels);
|
||||
|
||||
/* On the way back up, we may cull empty nodes with no children.
|
||||
* This ends up being where we remove all nodes */
|
||||
if (it->second.get()->children.empty() && it->second.get()->user == nullptr) {
|
||||
|
||||
/* The data of our string_views are managed by malloc */
|
||||
free((void *) it->first.data());
|
||||
|
||||
/* This can only happen with user set to null, otherwise we use sni_free_cb which is unset by sni_remove */
|
||||
root->children.erase(it);
|
||||
}
|
||||
|
||||
return removedUser;
|
||||
}
|
||||
|
||||
void *getUser(struct sni_node *root, unsigned int label, std::string_view *labels, unsigned int numLabels) {
|
||||
|
||||
/* Do we have labels to match? Otherwise, return where we stand */
|
||||
if (label == numLabels) {
|
||||
return root->user;
|
||||
}
|
||||
|
||||
/* Try and match by our label */
|
||||
auto it = root->children.find(labels[label]);
|
||||
if (it != root->children.end()) {
|
||||
void *user = getUser(it->second.get(), label + 1, labels, numLabels);
|
||||
if (user) {
|
||||
return user;
|
||||
}
|
||||
}
|
||||
|
||||
/* Try and match by wildcard */
|
||||
it = root->children.find("*");
|
||||
if (it == root->children.end()) {
|
||||
/* Matching has failed for both label and wildcard */
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
/* We matched by wildcard */
|
||||
return getUser(it->second.get(), label + 1, labels, numLabels);
|
||||
}
|
||||
|
||||
extern "C" {
|
||||
|
||||
void *sni_new() {
|
||||
return new sni_node;
|
||||
}
|
||||
|
||||
void sni_free(void *sni, void (*cb)(void *)) {
|
||||
/* We want to run this callback for every remaining name */
|
||||
sni_free_cb = cb;
|
||||
|
||||
delete (sni_node *) sni;
|
||||
}
|
||||
|
||||
/* Returns non-null if this name already exists */
|
||||
int sni_add(void *sni, const char *hostname, void *user) {
|
||||
struct sni_node *root = (struct sni_node *) sni;
|
||||
|
||||
/* Traverse all labels in hostname */
|
||||
for (std::string_view view(hostname, strlen(hostname)), label;
|
||||
view.length(); view.remove_prefix(std::min(view.length(), label.length() + 1))) {
|
||||
/* Label is the token separated by dot */
|
||||
label = view.substr(0, view.find('.', 0));
|
||||
|
||||
auto it = root->children.find(label);
|
||||
if (it == root->children.end()) {
|
||||
/* Duplicate this label for our kept string_view of it */
|
||||
void *labelString = malloc(label.length());
|
||||
memcpy(labelString, label.data(), label.length());
|
||||
|
||||
it = root->children.emplace(std::string_view((char *) labelString, label.length()),
|
||||
std::make_unique<sni_node>()).first;
|
||||
}
|
||||
|
||||
root = it->second.get();
|
||||
}
|
||||
|
||||
/* We must never add multiple contexts for the same name, as that would overwrite and leak */
|
||||
if (root->user) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
root->user = user;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/* Removes the exact match. Wildcards are treated as the verbatim asterisk char, not as an actual wildcard */
|
||||
void *sni_remove(void *sni, const char *hostname) {
|
||||
struct sni_node *root = (struct sni_node *) sni;
|
||||
|
||||
/* I guess 10 labels is an okay limit */
|
||||
std::string_view labels[10];
|
||||
unsigned int numLabels = 0;
|
||||
|
||||
/* We traverse all labels first of all */
|
||||
for (std::string_view view(hostname, strlen(hostname)), label;
|
||||
view.length(); view.remove_prefix(std::min(view.length(), label.length() + 1))) {
|
||||
/* Label is the token separated by dot */
|
||||
label = view.substr(0, view.find('.', 0));
|
||||
|
||||
/* Anything longer than 10 labels is forbidden */
|
||||
if (numLabels == 10) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
labels[numLabels++] = label;
|
||||
}
|
||||
|
||||
return removeUser(root, 0, labels, numLabels);
|
||||
}
|
||||
|
||||
void *sni_find(void *sni, const char *hostname) {
|
||||
struct sni_node *root = (struct sni_node *) sni;
|
||||
|
||||
/* I guess 10 labels is an okay limit */
|
||||
std::string_view labels[10];
|
||||
unsigned int numLabels = 0;
|
||||
|
||||
/* We traverse all labels first of all */
|
||||
for (std::string_view view(hostname, strlen(hostname)), label;
|
||||
view.length(); view.remove_prefix(std::min(view.length(), label.length() + 1))) {
|
||||
/* Label is the token separated by dot */
|
||||
label = view.substr(0, view.find('.', 0));
|
||||
|
||||
/* Anything longer than 10 labels is forbidden */
|
||||
if (numLabels == 10) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
labels[numLabels++] = label;
|
||||
}
|
||||
|
||||
return getUser(root, 0, labels, numLabels);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
#endif
|
|
@ -369,7 +369,7 @@ struct us_internal_ssl_socket_context_t *us_internal_create_ssl_socket_context(s
|
|||
// options
|
||||
wolfSSL_CTX_set_read_ahead(context->ssl_context, 1);
|
||||
wolfSSL_CTX_set_mode(context->ssl_context, WOLFSSL_MODE_ACCEPT_MOVING_WRITE_BUFFER);
|
||||
|
||||
|
||||
// this lowers performance a bit in benchmarks
|
||||
if (options.ssl_prefer_low_memory_usage) {
|
||||
//SSL_CTX_set_mode(context->ssl_context, SSL_MODE_RELEASE_BUFFERS);
|
||||
|
@ -423,8 +423,8 @@ struct us_listen_socket_t *us_internal_ssl_socket_context_listen(struct us_inter
|
|||
return us_socket_context_listen(0, &context->sc, host, port, options, sizeof(struct us_internal_ssl_socket_t) - sizeof(struct us_socket_t) + socket_ext_size);
|
||||
}
|
||||
|
||||
struct us_internal_ssl_socket_t *us_internal_ssl_socket_context_connect(struct us_internal_ssl_socket_context_t *context, const char *host, int port, int options, int socket_ext_size) {
|
||||
return (struct us_internal_ssl_socket_t *) us_socket_context_connect(0, &context->sc, host, port, options, sizeof(struct us_internal_ssl_socket_t) - sizeof(struct us_socket_t) + socket_ext_size);
|
||||
struct us_internal_ssl_socket_t *us_internal_ssl_socket_context_connect(struct us_internal_ssl_socket_context_t *context, const char *host, int port, const char *source_host, int options, int socket_ext_size) {
|
||||
return (struct us_internal_ssl_socket_t *) us_socket_context_connect(0, &context->sc, host, port, source_host, options, sizeof(struct us_internal_ssl_socket_t) - sizeof(struct us_socket_t) + socket_ext_size);
|
||||
}
|
||||
|
||||
void us_internal_ssl_socket_context_on_open(struct us_internal_ssl_socket_context_t *context, struct us_internal_ssl_socket_t *(*on_open)(struct us_internal_ssl_socket_t *s, int is_client, char *ip, int ip_length)) {
|
||||
|
|
|
@ -21,6 +21,9 @@
|
|||
|
||||
#if defined(LIBUS_USE_EPOLL) || defined(LIBUS_USE_KQUEUE)
|
||||
|
||||
/* Cannot include this one on Windows */
|
||||
#include <unistd.h>
|
||||
|
||||
#ifdef LIBUS_USE_EPOLL
|
||||
#define GET_READY_POLL(loop, index) (struct us_poll_t *) loop->ready_polls[index].data.ptr
|
||||
#define SET_READY_POLL(loop, index, poll) loop->ready_polls[index].data.ptr = poll
|
||||
|
@ -329,8 +332,8 @@ void us_timer_set(struct us_timer_t *t, void (*cb)(struct us_timer_t *t), int ms
|
|||
internal_cb->cb = (void (*)(struct us_internal_callback_t *)) cb;
|
||||
|
||||
struct itimerspec timer_spec = {
|
||||
{repeat_ms / 1000, ((long)repeat_ms * 1000000) % 1000000000},
|
||||
{ms / 1000, ((long)ms * 1000000) % 1000000000}
|
||||
{repeat_ms / 1000, (long) (repeat_ms % 1000) * (long) 1000000},
|
||||
{ms / 1000, (long) (ms % 1000) * (long) 1000000}
|
||||
};
|
||||
|
||||
timerfd_settime(us_poll_fd((struct us_poll_t *) t), 0, &timer_spec, NULL);
|
||||
|
|
|
@ -145,7 +145,7 @@ void *us_poll_ext(struct us_poll_t *p) {
|
|||
}
|
||||
|
||||
unsigned int us_internal_accept_poll_event(struct us_poll_t *p) {
|
||||
printf("us_internal_accept_poll_event\n");
|
||||
//printf("us_internal_accept_poll_event\n");
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Authored by Alex Hultman, 2018-2019.
|
||||
* Authored by Alex Hultman, 2018-2021.
|
||||
* Intellectual property of third-party.
|
||||
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
@ -21,9 +21,9 @@
|
|||
|
||||
#ifdef LIBUS_USE_LIBUV
|
||||
|
||||
// poll dispatch
|
||||
/* uv_poll_t->data always (except for most times after calling us_poll_stop) points to the us_poll_t */
|
||||
static void poll_cb(uv_poll_t *p, int status, int events) {
|
||||
us_internal_dispatch_ready_poll((struct us_poll_t *) p, status < 0, events);
|
||||
us_internal_dispatch_ready_poll((struct us_poll_t *) p->data, status < 0, events);
|
||||
}
|
||||
|
||||
static void prepare_cb(uv_prepare_t *p) {
|
||||
|
@ -37,10 +37,21 @@ static void check_cb(uv_check_t *p) {
|
|||
us_internal_loop_post(loop);
|
||||
}
|
||||
|
||||
/* Not used for polls, since polls need two frees */
|
||||
static void close_cb_free(uv_handle_t *h) {
|
||||
free(h->data);
|
||||
}
|
||||
|
||||
/* This one is different for polls, since we need two frees here */
|
||||
static void close_cb_free_poll(uv_handle_t *h) {
|
||||
/* It is only in case we called us_poll_stop then quickly us_poll_free that we enter this.
|
||||
* Most of the time, actual freeing is done by us_poll_free. */
|
||||
if (h->data) {
|
||||
free(h->data);
|
||||
free(h);
|
||||
}
|
||||
}
|
||||
|
||||
static void timer_cb(uv_timer_t *t) {
|
||||
struct us_internal_callback_t *cb = t->data;
|
||||
cb->cb(cb);
|
||||
|
@ -59,9 +70,14 @@ void us_poll_init(struct us_poll_t *p, LIBUS_SOCKET_DESCRIPTOR fd, int poll_type
|
|||
}
|
||||
|
||||
void us_poll_free(struct us_poll_t *p, struct us_loop_t *loop) {
|
||||
if (uv_is_closing((uv_handle_t *) &p->uv_p)) {
|
||||
p->uv_p.data = p;
|
||||
/* The idea here is like so; in us_poll_stop we call uv_close after setting data of uv-poll to 0.
|
||||
* This means that in close_cb_free we call free on 0 with does nothing, since us_poll_stop should
|
||||
* not really free the poll. HOWEVER, if we then call us_poll_free while still closing the uv-poll,
|
||||
* we simply change back the data to point to our structure so that we actually do free it like we should. */
|
||||
if (uv_is_closing((uv_handle_t *) p->uv_p)) {
|
||||
p->uv_p->data = p;
|
||||
} else {
|
||||
free(p->uv_p);
|
||||
free(p);
|
||||
}
|
||||
}
|
||||
|
@ -69,24 +85,26 @@ void us_poll_free(struct us_poll_t *p, struct us_loop_t *loop) {
|
|||
void us_poll_start(struct us_poll_t *p, struct us_loop_t *loop, int events) {
|
||||
p->poll_type = us_internal_poll_type(p) | ((events & LIBUS_SOCKET_READABLE) ? POLL_TYPE_POLLING_IN : 0) | ((events & LIBUS_SOCKET_WRITABLE) ? POLL_TYPE_POLLING_OUT : 0);
|
||||
|
||||
uv_poll_init_socket(loop->uv_loop, &p->uv_p, p->fd);
|
||||
uv_poll_start(&p->uv_p, events, poll_cb);
|
||||
uv_poll_init_socket(loop->uv_loop, p->uv_p, p->fd);
|
||||
uv_poll_start(p->uv_p, events, poll_cb);
|
||||
}
|
||||
|
||||
void us_poll_change(struct us_poll_t *p, struct us_loop_t *loop, int events) {
|
||||
if (us_poll_events(p) != events) {
|
||||
p->poll_type = us_internal_poll_type(p) | ((events & LIBUS_SOCKET_READABLE) ? POLL_TYPE_POLLING_IN : 0) | ((events & LIBUS_SOCKET_WRITABLE) ? POLL_TYPE_POLLING_OUT : 0);
|
||||
|
||||
uv_poll_start(&p->uv_p, events, poll_cb);
|
||||
uv_poll_start(p->uv_p, events, poll_cb);
|
||||
}
|
||||
}
|
||||
|
||||
void us_poll_stop(struct us_poll_t *p, struct us_loop_t *loop) {
|
||||
uv_poll_stop(&p->uv_p);
|
||||
uv_poll_stop(p->uv_p);
|
||||
|
||||
// close but not free is needed here
|
||||
p->uv_p.data = 0;
|
||||
uv_close((uv_handle_t *) &p->uv_p, close_cb_free); // needed here
|
||||
/* We normally only want to close the poll here, not free it. But if we stop it, then quickly "free" it with
|
||||
* us_poll_free, we postpone the actual freeing to close_cb_free_poll whenever it triggers.
|
||||
* That's why we set data to null here, so that us_poll_free can reset it if needed */
|
||||
p->uv_p->data = 0;
|
||||
uv_close((uv_handle_t *) p->uv_p, close_cb_free_poll);
|
||||
}
|
||||
|
||||
int us_poll_events(struct us_poll_t *p) {
|
||||
|
@ -171,19 +189,17 @@ void us_loop_run(struct us_loop_t *loop) {
|
|||
}
|
||||
|
||||
struct us_poll_t *us_create_poll(struct us_loop_t *loop, int fallthrough, unsigned int ext_size) {
|
||||
return malloc(sizeof(struct us_poll_t) + ext_size);
|
||||
struct us_poll_t *p = (struct us_poll_t *) malloc(sizeof(struct us_poll_t) + ext_size);
|
||||
p->uv_p = malloc(sizeof(uv_poll_t));
|
||||
p->uv_p->data = p;
|
||||
return p;
|
||||
}
|
||||
|
||||
// this one is broken, us_poll needs to hold a pointer to uv_poll_t for it to work (bad anyways)
|
||||
/* If we update our block position we have to updarte the uv_poll data to point to us */
|
||||
struct us_poll_t *us_poll_resize(struct us_poll_t *p, struct us_loop_t *loop, unsigned int ext_size) {
|
||||
|
||||
// do not support it yet
|
||||
return p;
|
||||
|
||||
struct us_poll_t *new_p = realloc(p, sizeof(struct us_poll_t) + ext_size);
|
||||
if (p != new_p) {
|
||||
new_p->uv_p.data = new_p;
|
||||
}
|
||||
new_p->uv_p->data = new_p;
|
||||
|
||||
return new_p;
|
||||
}
|
||||
|
@ -207,7 +223,7 @@ struct us_timer_t *us_create_timer(struct us_loop_t *loop, int fallthrough, unsi
|
|||
}
|
||||
|
||||
void *us_timer_ext(struct us_timer_t *timer) {
|
||||
return ((struct us_internal_callback_t *) timer) + 1;
|
||||
return ((char *) timer) + sizeof(struct us_internal_callback_t) + sizeof(uv_timer_t);
|
||||
}
|
||||
|
||||
void us_timer_close(struct us_timer_t *t) {
|
||||
|
|
|
@ -59,7 +59,7 @@ struct us_loop_t {
|
|||
|
||||
struct us_poll_t {
|
||||
alignas(LIBUS_EXT_ALIGNMENT) struct {
|
||||
int fd : 28;
|
||||
signed int fd : 28; // we could have this unsigned if we wanted to, -1 should never be used
|
||||
unsigned int poll_type : 4;
|
||||
} state;
|
||||
};
|
||||
|
|
|
@ -34,8 +34,10 @@ struct us_loop_t {
|
|||
uv_check_t *uv_check;
|
||||
};
|
||||
|
||||
// it is no longer valid to cast a pointer to us_poll_t to a pointer of uv_poll_t
|
||||
struct us_poll_t {
|
||||
uv_poll_t uv_p;
|
||||
/* We need to hold a pointer to this uv_poll_t since we need to be able to resize our block */
|
||||
uv_poll_t *uv_p;
|
||||
LIBUS_SOCKET_DESCRIPTOR fd;
|
||||
unsigned char poll_type;
|
||||
};
|
||||
|
|
|
@ -18,6 +18,12 @@
|
|||
#ifndef INTERNAL_H
|
||||
#define INTERNAL_H
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
#define alignas(x) __declspec(align(x))
|
||||
#else
|
||||
#include <stdalign.h>
|
||||
#endif
|
||||
|
||||
/* We only have one networking implementation so far */
|
||||
#include "internal/networking/bsd.h"
|
||||
|
||||
|
@ -100,7 +106,7 @@ struct us_listen_socket_t {
|
|||
|
||||
struct us_socket_context_t {
|
||||
alignas(LIBUS_EXT_ALIGNMENT) struct us_loop_t *loop;
|
||||
//unsigned short timeout;
|
||||
unsigned short timestamp;
|
||||
struct us_socket_t *head;
|
||||
struct us_socket_t *iterator;
|
||||
struct us_socket_context_t *prev, *next;
|
||||
|
@ -108,14 +114,12 @@ struct us_socket_context_t {
|
|||
struct us_socket_t *(*on_open)(struct us_socket_t *, int is_client, char *ip, int ip_length);
|
||||
struct us_socket_t *(*on_data)(struct us_socket_t *, char *data, int length);
|
||||
struct us_socket_t *(*on_writable)(struct us_socket_t *);
|
||||
struct us_socket_t *(*on_close)(struct us_socket_t *);
|
||||
struct us_socket_t *(*on_close)(struct us_socket_t *, int code, void *reason);
|
||||
//void (*on_timeout)(struct us_socket_context *);
|
||||
struct us_socket_t *(*on_socket_timeout)(struct us_socket_t *);
|
||||
struct us_socket_t *(*on_end)(struct us_socket_t *);
|
||||
struct us_socket_t *(*on_connect_error)(struct us_socket_t *, int code);
|
||||
int (*ignore_data)(struct us_socket_t *);
|
||||
|
||||
/* All contexts hold references to their own copied options */
|
||||
struct us_socket_context_options_t options;
|
||||
};
|
||||
|
||||
/* Internal SSL interface */
|
||||
|
@ -124,6 +128,14 @@ struct us_socket_context_t {
|
|||
struct us_internal_ssl_socket_context_t;
|
||||
struct us_internal_ssl_socket_t;
|
||||
|
||||
/* SNI functions */
|
||||
void us_internal_ssl_socket_context_add_server_name(struct us_internal_ssl_socket_context_t *context, const char *hostname_pattern, struct us_socket_context_options_t options);
|
||||
void us_internal_ssl_socket_context_remove_server_name(struct us_internal_ssl_socket_context_t *context, const char *hostname_pattern);
|
||||
void us_internal_ssl_socket_context_on_server_name(struct us_internal_ssl_socket_context_t *context, void (*cb)(struct us_internal_ssl_socket_context_t *, const char *));
|
||||
|
||||
void *us_internal_ssl_socket_get_native_handle(struct us_internal_ssl_socket_t *s);
|
||||
void *us_internal_ssl_socket_context_get_native_handle(struct us_internal_ssl_socket_context_t *context);
|
||||
|
||||
struct us_internal_ssl_socket_context_t *us_internal_create_ssl_socket_context(struct us_loop_t *loop,
|
||||
int context_ext_size, struct us_socket_context_options_t options);
|
||||
|
||||
|
@ -132,7 +144,7 @@ void us_internal_ssl_socket_context_on_open(struct us_internal_ssl_socket_contex
|
|||
struct us_internal_ssl_socket_t *(*on_open)(struct us_internal_ssl_socket_t *s, int is_client, char *ip, int ip_length));
|
||||
|
||||
void us_internal_ssl_socket_context_on_close(struct us_internal_ssl_socket_context_t *context,
|
||||
struct us_internal_ssl_socket_t *(*on_close)(struct us_internal_ssl_socket_t *s));
|
||||
struct us_internal_ssl_socket_t *(*on_close)(struct us_internal_ssl_socket_t *s, int code, void *reason));
|
||||
|
||||
void us_internal_ssl_socket_context_on_data(struct us_internal_ssl_socket_context_t *context,
|
||||
struct us_internal_ssl_socket_t *(*on_data)(struct us_internal_ssl_socket_t *s, char *data, int length));
|
||||
|
@ -146,11 +158,14 @@ void us_internal_ssl_socket_context_on_timeout(struct us_internal_ssl_socket_con
|
|||
void us_internal_ssl_socket_context_on_end(struct us_internal_ssl_socket_context_t *context,
|
||||
struct us_internal_ssl_socket_t *(*on_end)(struct us_internal_ssl_socket_t *s));
|
||||
|
||||
void us_internal_ssl_socket_context_on_connect_error(struct us_internal_ssl_socket_context_t *context,
|
||||
struct us_internal_ssl_socket_t *(*on_connect_error)(struct us_internal_ssl_socket_t *s, int code));
|
||||
|
||||
struct us_listen_socket_t *us_internal_ssl_socket_context_listen(struct us_internal_ssl_socket_context_t *context,
|
||||
const char *host, int port, int options, int socket_ext_size);
|
||||
|
||||
struct us_internal_ssl_socket_t *us_internal_ssl_socket_context_connect(struct us_internal_ssl_socket_context_t *context,
|
||||
const char *host, int port, int options, int socket_ext_size);
|
||||
const char *host, int port, const char *source_host, int options, int socket_ext_size);
|
||||
|
||||
int us_internal_ssl_socket_write(struct us_internal_ssl_socket_t *s, const char *data, int length, int msg_more);
|
||||
void us_internal_ssl_socket_timeout(struct us_internal_ssl_socket_t *s, unsigned int seconds);
|
||||
|
@ -159,7 +174,7 @@ struct us_internal_ssl_socket_context_t *us_internal_ssl_socket_get_context(stru
|
|||
void *us_internal_ssl_socket_ext(struct us_internal_ssl_socket_t *s);
|
||||
int us_internal_ssl_socket_is_shut_down(struct us_internal_ssl_socket_t *s);
|
||||
void us_internal_ssl_socket_shutdown(struct us_internal_ssl_socket_t *s);
|
||||
struct us_internal_ssl_socket_t *us_internal_ssl_socket_close(struct us_internal_ssl_socket_t *s);
|
||||
|
||||
struct us_internal_ssl_socket_t *us_internal_ssl_socket_context_adopt_socket(struct us_internal_ssl_socket_context_t *context,
|
||||
struct us_internal_ssl_socket_t *s, int ext_size);
|
||||
|
||||
|
|
|
@ -24,265 +24,61 @@
|
|||
// here everything about the syscalls are inline-wrapped and included
|
||||
|
||||
#ifdef _WIN32
|
||||
#ifndef NOMINMAX
|
||||
#define NOMINMAX
|
||||
#include <WinSock2.h>
|
||||
#include <Ws2tcpip.h>
|
||||
#endif
|
||||
#include <winsock2.h>
|
||||
#include <ws2tcpip.h>
|
||||
#pragma comment(lib, "ws2_32.lib")
|
||||
#include <stdio.h>
|
||||
#define SETSOCKOPT_PTR_TYPE const char *
|
||||
#define LIBUS_SOCKET_ERROR INVALID_SOCKET
|
||||
#else
|
||||
#define _GNU_SOURCE
|
||||
#include <sys/types.h>
|
||||
/* For socklen_t */
|
||||
#include <sys/socket.h>
|
||||
#include <netinet/in.h>
|
||||
#include <netinet/tcp.h>
|
||||
#include <netdb.h>
|
||||
#include <string.h>
|
||||
#include <unistd.h>
|
||||
#include <fcntl.h>
|
||||
#include <stdio.h>
|
||||
#include <errno.h>
|
||||
#define SETSOCKOPT_PTR_TYPE int *
|
||||
#define LIBUS_SOCKET_ERROR -1
|
||||
#endif
|
||||
|
||||
static inline LIBUS_SOCKET_DESCRIPTOR apple_no_sigpipe(LIBUS_SOCKET_DESCRIPTOR fd) {
|
||||
#ifdef __APPLE__
|
||||
if (fd != LIBUS_SOCKET_ERROR) {
|
||||
int no_sigpipe = 1;
|
||||
setsockopt(fd, SOL_SOCKET, SO_NOSIGPIPE, &no_sigpipe, sizeof(int));
|
||||
}
|
||||
#endif
|
||||
return fd;
|
||||
}
|
||||
|
||||
static inline LIBUS_SOCKET_DESCRIPTOR bsd_set_nonblocking(LIBUS_SOCKET_DESCRIPTOR fd) {
|
||||
#ifdef _WIN32
|
||||
/* Libuv will set windows sockets as non-blocking */
|
||||
#else
|
||||
fcntl(fd, F_SETFL, fcntl(fd, F_GETFL, 0) | O_NONBLOCK);
|
||||
#endif
|
||||
return fd;
|
||||
}
|
||||
|
||||
static inline void bsd_socket_nodelay(LIBUS_SOCKET_DESCRIPTOR fd, int enabled) {
|
||||
setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, (void *) &enabled, sizeof(enabled));
|
||||
}
|
||||
|
||||
static inline void bsd_socket_flush(LIBUS_SOCKET_DESCRIPTOR fd) {
|
||||
// Linux TCP_CORK has the same underlying corking mechanism as with MSG_MORE
|
||||
#ifdef TCP_CORK
|
||||
int enabled = 0;
|
||||
setsockopt(fd, IPPROTO_TCP, TCP_CORK, &enabled, sizeof(int));
|
||||
#endif
|
||||
}
|
||||
|
||||
static inline LIBUS_SOCKET_DESCRIPTOR bsd_create_socket(int domain, int type, int protocol) {
|
||||
// returns INVALID_SOCKET on error
|
||||
int flags = 0;
|
||||
#if defined(SOCK_CLOEXEC) && defined(SOCK_NONBLOCK)
|
||||
flags = SOCK_CLOEXEC | SOCK_NONBLOCK;
|
||||
#endif
|
||||
|
||||
LIBUS_SOCKET_DESCRIPTOR created_fd = socket(domain, type | flags, protocol);
|
||||
|
||||
return bsd_set_nonblocking(apple_no_sigpipe(created_fd));
|
||||
}
|
||||
|
||||
static inline void bsd_close_socket(LIBUS_SOCKET_DESCRIPTOR fd) {
|
||||
#ifdef _WIN32
|
||||
closesocket(fd);
|
||||
#else
|
||||
close(fd);
|
||||
#endif
|
||||
}
|
||||
|
||||
static inline void bsd_shutdown_socket(LIBUS_SOCKET_DESCRIPTOR fd) {
|
||||
#ifdef _WIN32
|
||||
shutdown(fd, SD_SEND);
|
||||
#else
|
||||
shutdown(fd, SHUT_WR);
|
||||
#endif
|
||||
}
|
||||
|
||||
struct bsd_addr_t {
|
||||
struct sockaddr_storage mem;
|
||||
socklen_t len;
|
||||
char *ip;
|
||||
int ip_length;
|
||||
int port;
|
||||
};
|
||||
|
||||
static inline void internal_finalize_bsd_addr(struct bsd_addr_t *addr) {
|
||||
// parse, so to speak, the address
|
||||
if (addr->mem.ss_family == AF_INET6) {
|
||||
addr->ip = (char *) &((struct sockaddr_in6 *) addr)->sin6_addr;
|
||||
addr->ip_length = sizeof(struct in6_addr);
|
||||
} else if (addr->mem.ss_family == AF_INET) {
|
||||
addr->ip = (char *) &((struct sockaddr_in *) addr)->sin_addr;
|
||||
addr->ip_length = sizeof(struct in_addr);
|
||||
} else {
|
||||
addr->ip_length = 0;
|
||||
}
|
||||
}
|
||||
LIBUS_SOCKET_DESCRIPTOR apple_no_sigpipe(LIBUS_SOCKET_DESCRIPTOR fd);
|
||||
LIBUS_SOCKET_DESCRIPTOR bsd_set_nonblocking(LIBUS_SOCKET_DESCRIPTOR fd);
|
||||
void bsd_socket_nodelay(LIBUS_SOCKET_DESCRIPTOR fd, int enabled);
|
||||
void bsd_socket_flush(LIBUS_SOCKET_DESCRIPTOR fd);
|
||||
LIBUS_SOCKET_DESCRIPTOR bsd_create_socket(int domain, int type, int protocol);
|
||||
|
||||
static inline int bsd_socket_addr(LIBUS_SOCKET_DESCRIPTOR fd, struct bsd_addr_t *addr) {
|
||||
addr->len = sizeof(addr->mem);
|
||||
if (getpeername(fd, (struct sockaddr *) &addr->mem, &addr->len)) {
|
||||
return -1;
|
||||
}
|
||||
internal_finalize_bsd_addr(addr);
|
||||
return 0;
|
||||
}
|
||||
void bsd_close_socket(LIBUS_SOCKET_DESCRIPTOR fd);
|
||||
void bsd_shutdown_socket(LIBUS_SOCKET_DESCRIPTOR fd);
|
||||
void bsd_shutdown_socket_read(LIBUS_SOCKET_DESCRIPTOR fd);
|
||||
|
||||
static inline char *bsd_addr_get_ip(struct bsd_addr_t *addr) {
|
||||
return addr->ip;
|
||||
}
|
||||
void internal_finalize_bsd_addr(struct bsd_addr_t *addr);
|
||||
|
||||
static inline int bsd_addr_get_ip_length(struct bsd_addr_t *addr) {
|
||||
return addr->ip_length;
|
||||
}
|
||||
int bsd_local_addr(LIBUS_SOCKET_DESCRIPTOR fd, struct bsd_addr_t *addr);
|
||||
int bsd_remote_addr(LIBUS_SOCKET_DESCRIPTOR fd, struct bsd_addr_t *addr);
|
||||
|
||||
char *bsd_addr_get_ip(struct bsd_addr_t *addr);
|
||||
int bsd_addr_get_ip_length(struct bsd_addr_t *addr);
|
||||
|
||||
int bsd_addr_get_port(struct bsd_addr_t *addr);
|
||||
|
||||
// called by dispatch_ready_poll
|
||||
static inline LIBUS_SOCKET_DESCRIPTOR bsd_accept_socket(LIBUS_SOCKET_DESCRIPTOR fd, struct bsd_addr_t *addr) {
|
||||
LIBUS_SOCKET_DESCRIPTOR accepted_fd;
|
||||
addr->len = sizeof(addr->mem);
|
||||
LIBUS_SOCKET_DESCRIPTOR bsd_accept_socket(LIBUS_SOCKET_DESCRIPTOR fd, struct bsd_addr_t *addr);
|
||||
|
||||
#if defined(SOCK_CLOEXEC) && defined(SOCK_NONBLOCK)
|
||||
// Linux, FreeBSD
|
||||
accepted_fd = accept4(fd, (struct sockaddr *) addr, &addr->len, SOCK_CLOEXEC | SOCK_NONBLOCK);
|
||||
#else
|
||||
// Windows, OS X
|
||||
accepted_fd = accept(fd, (struct sockaddr *) addr, &addr->len);
|
||||
|
||||
#endif
|
||||
|
||||
internal_finalize_bsd_addr(addr);
|
||||
|
||||
return bsd_set_nonblocking(apple_no_sigpipe(accepted_fd));
|
||||
}
|
||||
|
||||
static inline int bsd_recv(LIBUS_SOCKET_DESCRIPTOR fd, void *buf, int length, int flags) {
|
||||
return recv(fd, buf, length, flags);
|
||||
}
|
||||
|
||||
static inline int bsd_send(LIBUS_SOCKET_DESCRIPTOR fd, const char *buf, int length, int msg_more) {
|
||||
|
||||
// MSG_MORE (Linux), MSG_PARTIAL (Windows), TCP_NOPUSH (BSD)
|
||||
|
||||
#ifndef MSG_NOSIGNAL
|
||||
#define MSG_NOSIGNAL 0
|
||||
#endif
|
||||
|
||||
#ifdef MSG_MORE
|
||||
|
||||
// for Linux we do not want signals
|
||||
return send(fd, buf, length, (msg_more * MSG_MORE) | MSG_NOSIGNAL);
|
||||
|
||||
#else
|
||||
|
||||
// use TCP_NOPUSH
|
||||
|
||||
return send(fd, buf, length, MSG_NOSIGNAL);
|
||||
|
||||
#endif
|
||||
}
|
||||
|
||||
static inline int bsd_would_block() {
|
||||
#ifdef _WIN32
|
||||
return WSAGetLastError() == WSAEWOULDBLOCK;
|
||||
#else
|
||||
return errno == EWOULDBLOCK;// || errno == EAGAIN;
|
||||
#endif
|
||||
}
|
||||
int bsd_recv(LIBUS_SOCKET_DESCRIPTOR fd, void *buf, int length, int flags);
|
||||
int bsd_send(LIBUS_SOCKET_DESCRIPTOR fd, const char *buf, int length, int msg_more);
|
||||
int bsd_would_block();
|
||||
|
||||
// return LIBUS_SOCKET_ERROR or the fd that represents listen socket
|
||||
// listen both on ipv6 and ipv4
|
||||
static inline LIBUS_SOCKET_DESCRIPTOR bsd_create_listen_socket(const char *host, int port, int options) {
|
||||
struct addrinfo hints, *result;
|
||||
memset(&hints, 0, sizeof(struct addrinfo));
|
||||
LIBUS_SOCKET_DESCRIPTOR bsd_create_listen_socket(const char *host, int port, int options);
|
||||
|
||||
hints.ai_flags = AI_PASSIVE;
|
||||
hints.ai_family = AF_UNSPEC;
|
||||
hints.ai_socktype = SOCK_STREAM;
|
||||
|
||||
char port_string[16];
|
||||
snprintf(port_string, 16, "%d", port);
|
||||
|
||||
if (getaddrinfo(host, port_string, &hints, &result)) {
|
||||
return LIBUS_SOCKET_ERROR;
|
||||
}
|
||||
|
||||
LIBUS_SOCKET_DESCRIPTOR listenFd = LIBUS_SOCKET_ERROR;
|
||||
struct addrinfo *listenAddr;
|
||||
for (struct addrinfo *a = result; a && listenFd == LIBUS_SOCKET_ERROR; a = a->ai_next) {
|
||||
if (a->ai_family == AF_INET6) {
|
||||
listenFd = bsd_create_socket(a->ai_family, a->ai_socktype, a->ai_protocol);
|
||||
listenAddr = a;
|
||||
}
|
||||
}
|
||||
|
||||
for (struct addrinfo *a = result; a && listenFd == LIBUS_SOCKET_ERROR; a = a->ai_next) {
|
||||
if (a->ai_family == AF_INET) {
|
||||
listenFd = bsd_create_socket(a->ai_family, a->ai_socktype, a->ai_protocol);
|
||||
listenAddr = a;
|
||||
}
|
||||
}
|
||||
|
||||
if (listenFd == LIBUS_SOCKET_ERROR) {
|
||||
freeaddrinfo(result);
|
||||
return LIBUS_SOCKET_ERROR;
|
||||
}
|
||||
|
||||
/* Always enable SO_REUSEPORT and SO_REUSEADDR _unless_ options specify otherwise */
|
||||
#if defined(__linux) && defined(SO_REUSEPORT)
|
||||
if (!(options & LIBUS_LISTEN_EXCLUSIVE_PORT)) {
|
||||
int optval = 1;
|
||||
setsockopt(listenFd, SOL_SOCKET, SO_REUSEPORT, &optval, sizeof(optval));
|
||||
}
|
||||
#endif
|
||||
|
||||
int enabled = 1;
|
||||
setsockopt(listenFd, SOL_SOCKET, SO_REUSEADDR, (SETSOCKOPT_PTR_TYPE) &enabled, sizeof(enabled));
|
||||
|
||||
#ifdef IPV6_V6ONLY
|
||||
int disabled = 0;
|
||||
setsockopt(listenFd, IPPROTO_IPV6, IPV6_V6ONLY, (SETSOCKOPT_PTR_TYPE) &disabled, sizeof(disabled));
|
||||
#endif
|
||||
|
||||
if (bind(listenFd, listenAddr->ai_addr, (socklen_t) listenAddr->ai_addrlen) || listen(listenFd, 512)) {
|
||||
bsd_close_socket(listenFd);
|
||||
freeaddrinfo(result);
|
||||
return LIBUS_SOCKET_ERROR;
|
||||
}
|
||||
|
||||
freeaddrinfo(result);
|
||||
return listenFd;
|
||||
}
|
||||
|
||||
static inline LIBUS_SOCKET_DESCRIPTOR bsd_create_connect_socket(const char *host, int port, int options) {
|
||||
struct addrinfo hints, *result;
|
||||
memset(&hints, 0, sizeof(struct addrinfo));
|
||||
hints.ai_family = AF_UNSPEC;
|
||||
hints.ai_socktype = SOCK_STREAM;
|
||||
|
||||
char port_string[16];
|
||||
snprintf(port_string, 16, "%d", port);
|
||||
|
||||
if (getaddrinfo(host, port_string, &hints, &result) != 0) {
|
||||
return LIBUS_SOCKET_ERROR;
|
||||
}
|
||||
|
||||
LIBUS_SOCKET_DESCRIPTOR fd = bsd_create_socket(result->ai_family, result->ai_socktype, result->ai_protocol);
|
||||
if (fd == LIBUS_SOCKET_ERROR) {
|
||||
freeaddrinfo(result);
|
||||
return LIBUS_SOCKET_ERROR;
|
||||
}
|
||||
|
||||
connect(fd, result->ai_addr, (socklen_t) result->ai_addrlen);
|
||||
freeaddrinfo(result);
|
||||
|
||||
return fd;
|
||||
}
|
||||
LIBUS_SOCKET_DESCRIPTOR bsd_create_connect_socket(const char *host, int port, const char *source_host, int options);
|
||||
|
||||
#endif // BSD_H
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Authored by Alex Hultman, 2018-2019.
|
||||
* Authored by Alex Hultman, 2018-2021.
|
||||
* Intellectual property of third-party.
|
||||
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
@ -80,24 +80,50 @@ void us_internal_loop_unlink(struct us_loop_t *loop, struct us_socket_context_t
|
|||
/* This functions should never run recursively */
|
||||
void us_internal_timer_sweep(struct us_loop_t *loop) {
|
||||
struct us_internal_loop_data_t *loop_data = &loop->data;
|
||||
/* For all socket contexts in this loop */
|
||||
for (loop_data->iterator = loop_data->head; loop_data->iterator; loop_data->iterator = loop_data->iterator->next) {
|
||||
|
||||
struct us_socket_context_t *context = loop_data->iterator;
|
||||
for (context->iterator = context->head; context->iterator; ) {
|
||||
|
||||
struct us_socket_t *s = context->iterator;
|
||||
if (s->timeout && --(s->timeout) == 0) {
|
||||
/* Update this context's 15-bit timestamp */
|
||||
context->timestamp = (context->timestamp + 1) & 0x7fff;
|
||||
|
||||
context->on_socket_timeout(s);
|
||||
/* Update our 16-bit full timestamp (the needle in the haystack) */
|
||||
unsigned short needle = 0x8000 | context->timestamp;
|
||||
|
||||
/* Check for unlink / link */
|
||||
if (s == context->iterator) {
|
||||
context->iterator = s->next;
|
||||
/* Begin at head */
|
||||
struct us_socket_t *s = context->head;
|
||||
while (s) {
|
||||
/* Seek until end or timeout found (tightest loop) */
|
||||
while (1) {
|
||||
/* We only read from 1 random cache line here */
|
||||
if (needle == s->timeout) {
|
||||
break;
|
||||
}
|
||||
|
||||
/* Did we reach the end without a find? */
|
||||
if ((s = s->next) == 0) {
|
||||
goto next_context;
|
||||
}
|
||||
}
|
||||
|
||||
/* Here we have a timeout to emit (slow path) */
|
||||
s->timeout = 0;
|
||||
context->iterator = s;
|
||||
|
||||
context->on_socket_timeout(s);
|
||||
|
||||
/* Check for unlink / link (if the event handler did not modify the chain, we step 1) */
|
||||
if (s == context->iterator) {
|
||||
s = s->next;
|
||||
} else {
|
||||
context->iterator = s->next;
|
||||
/* The iterator was changed by event handler */
|
||||
s = context->iterator;
|
||||
}
|
||||
}
|
||||
/* We always store a 0 to context->iterator here since we are no longer iterating this context */
|
||||
next_context:
|
||||
context->iterator = 0;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -136,7 +162,10 @@ void us_internal_loop_post(struct us_loop_t *loop) {
|
|||
void us_internal_dispatch_ready_poll(struct us_poll_t *p, int error, int events) {
|
||||
switch (us_internal_poll_type(p)) {
|
||||
case POLL_TYPE_CALLBACK: {
|
||||
/* Let's just do this to clear the CodeQL alert */
|
||||
#ifndef LIBUS_USE_LIBUV
|
||||
us_internal_accept_poll_event(p);
|
||||
#endif
|
||||
struct us_internal_callback_t *cb = (struct us_internal_callback_t *) p;
|
||||
cb->cb(cb->cb_expects_the_loop ? (struct us_internal_callback_t *) cb->loop : (struct us_internal_callback_t *) &cb->p);
|
||||
}
|
||||
|
@ -147,15 +176,26 @@ void us_internal_dispatch_ready_poll(struct us_poll_t *p, int error, int events)
|
|||
if (us_poll_events(p) == LIBUS_SOCKET_WRITABLE) {
|
||||
struct us_socket_t *s = (struct us_socket_t *) p;
|
||||
|
||||
us_poll_change(p, s->context->loop, LIBUS_SOCKET_READABLE);
|
||||
/* It is perfectly possible to come here with an error */
|
||||
if (error) {
|
||||
/* Emit error, close without emitting on_close */
|
||||
s->context->on_connect_error(s, 0);
|
||||
us_socket_close_connecting(0, s);
|
||||
} else {
|
||||
/* All sockets poll for readable */
|
||||
us_poll_change(p, s->context->loop, LIBUS_SOCKET_READABLE);
|
||||
|
||||
/* We always use nodelay */
|
||||
bsd_socket_nodelay(us_poll_fd(p), 1);
|
||||
/* We always use nodelay */
|
||||
bsd_socket_nodelay(us_poll_fd(p), 1);
|
||||
|
||||
/* We are now a proper socket */
|
||||
us_internal_poll_set_type(p, POLL_TYPE_SOCKET);
|
||||
/* We are now a proper socket */
|
||||
us_internal_poll_set_type(p, POLL_TYPE_SOCKET);
|
||||
|
||||
s->context->on_open(s, 1, 0, 0);
|
||||
/* If we used a connection timeout we have to reset it here */
|
||||
us_socket_timeout(0, s, 0);
|
||||
|
||||
s->context->on_open(s, 1, 0, 0);
|
||||
}
|
||||
} else {
|
||||
struct us_listen_socket_t *listen_socket = (struct us_listen_socket_t *) p;
|
||||
struct bsd_addr_t addr;
|
||||
|
@ -169,11 +209,11 @@ void us_internal_dispatch_ready_poll(struct us_poll_t *p, int error, int events)
|
|||
/* Todo: stop timer if any */
|
||||
|
||||
do {
|
||||
struct us_poll_t *p = us_create_poll(us_socket_context(0, &listen_socket->s)->loop, 0, sizeof(struct us_socket_t) - sizeof(struct us_poll_t) + listen_socket->socket_ext_size);
|
||||
us_poll_init(p, client_fd, POLL_TYPE_SOCKET);
|
||||
us_poll_start(p, listen_socket->s.context->loop, LIBUS_SOCKET_READABLE);
|
||||
struct us_poll_t *accepted_p = us_create_poll(us_socket_context(0, &listen_socket->s)->loop, 0, sizeof(struct us_socket_t) - sizeof(struct us_poll_t) + listen_socket->socket_ext_size);
|
||||
us_poll_init(accepted_p, client_fd, POLL_TYPE_SOCKET);
|
||||
us_poll_start(accepted_p, listen_socket->s.context->loop, LIBUS_SOCKET_READABLE);
|
||||
|
||||
struct us_socket_t *s = (struct us_socket_t *) p;
|
||||
struct us_socket_t *s = (struct us_socket_t *) accepted_p;
|
||||
|
||||
s->context = listen_socket->s.context;
|
||||
|
||||
|
@ -201,7 +241,8 @@ void us_internal_dispatch_ready_poll(struct us_poll_t *p, int error, int events)
|
|||
|
||||
/* Such as epollerr epollhup */
|
||||
if (error) {
|
||||
s = us_socket_close(0, s);
|
||||
/* Todo: decide what code we give here */
|
||||
s = us_socket_close(0, s, 0, NULL);
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -235,14 +276,16 @@ void us_internal_dispatch_ready_poll(struct us_poll_t *p, int error, int events)
|
|||
} else if (!length) {
|
||||
if (us_socket_is_shut_down(0, s)) {
|
||||
/* We got FIN back after sending it */
|
||||
s = us_socket_close(0, s);
|
||||
/* Todo: We should give "CLEAN SHUTDOWN" as reason here */
|
||||
s = us_socket_close(0, s, 0, NULL);
|
||||
} else {
|
||||
/* We got FIN, so stop polling for readable */
|
||||
us_poll_change(&s->p, us_socket_context(0, s)->loop, us_poll_events(&s->p) & LIBUS_SOCKET_WRITABLE);
|
||||
s = s->context->on_end(s);
|
||||
}
|
||||
} else if (length == LIBUS_SOCKET_ERROR && !bsd_would_block()) {
|
||||
s = us_socket_close(0, s);
|
||||
/* Todo: decide also here what kind of reason we should give */
|
||||
s = us_socket_close(0, s, 0, NULL);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Authored by Alex Hultman, 2018-2019.
|
||||
* Authored by Alex Hultman, 2018-2021.
|
||||
* Intellectual property of third-party.
|
||||
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
@ -18,12 +18,27 @@
|
|||
#include "libusockets.h"
|
||||
#include "internal/internal.h"
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
|
||||
/* Shared with SSL */
|
||||
|
||||
int us_socket_local_port(int ssl, struct us_socket_t *s) {
|
||||
struct bsd_addr_t addr;
|
||||
if (bsd_local_addr(us_poll_fd(&s->p), &addr)) {
|
||||
return -1;
|
||||
} else {
|
||||
return bsd_addr_get_port(&addr);
|
||||
}
|
||||
}
|
||||
|
||||
void us_socket_shutdown_read(int ssl, struct us_socket_t *s) {
|
||||
/* This syscall is idempotent so no extra check is needed */
|
||||
bsd_shutdown_socket_read(us_poll_fd((struct us_poll_t *) s));
|
||||
}
|
||||
|
||||
void us_socket_remote_address(int ssl, struct us_socket_t *s, char *buf, int *length) {
|
||||
struct bsd_addr_t addr;
|
||||
if (bsd_socket_addr(us_poll_fd(&s->p), &addr) || *length < bsd_addr_get_ip_length(&addr)) {
|
||||
if (bsd_remote_addr(us_poll_fd(&s->p), &addr) || *length < bsd_addr_get_ip_length(&addr)) {
|
||||
*length = 0;
|
||||
} else {
|
||||
*length = bsd_addr_get_ip_length(&addr);
|
||||
|
@ -37,8 +52,7 @@ struct us_socket_context_t *us_socket_context(int ssl, struct us_socket_t *s) {
|
|||
|
||||
void us_socket_timeout(int ssl, struct us_socket_t *s, unsigned int seconds) {
|
||||
if (seconds) {
|
||||
unsigned short timeout_sweeps = (unsigned short) (0.5f + ((float) seconds) / ((float) LIBUS_TIMEOUT_GRANULARITY));
|
||||
s->timeout = timeout_sweeps ? timeout_sweeps : 1;
|
||||
s->timeout = 0x8000 | (s->context->timestamp + (seconds >> 2));
|
||||
} else {
|
||||
s->timeout = 0;
|
||||
}
|
||||
|
@ -54,8 +68,61 @@ int us_socket_is_closed(int ssl, struct us_socket_t *s) {
|
|||
return s->prev == (struct us_socket_t *) s->context;
|
||||
}
|
||||
|
||||
int us_socket_is_established(int ssl, struct us_socket_t *s) {
|
||||
/* Everything that is not POLL_TYPE_SEMI_SOCKET is established */
|
||||
return us_internal_poll_type((struct us_poll_t *) s) != POLL_TYPE_SEMI_SOCKET;
|
||||
}
|
||||
|
||||
/* Exactly the same as us_socket_close but does not emit on_close event */
|
||||
struct us_socket_t *us_socket_close_connecting(int ssl, struct us_socket_t *s) {
|
||||
if (!us_socket_is_closed(0, s)) {
|
||||
us_internal_socket_context_unlink(s->context, s);
|
||||
us_poll_stop((struct us_poll_t *) s, s->context->loop);
|
||||
bsd_close_socket(us_poll_fd((struct us_poll_t *) s));
|
||||
|
||||
/* Link this socket to the close-list and let it be deleted after this iteration */
|
||||
s->next = s->context->loop->data.closed_head;
|
||||
s->context->loop->data.closed_head = s;
|
||||
|
||||
/* Any socket with prev = context is marked as closed */
|
||||
s->prev = (struct us_socket_t *) s->context;
|
||||
|
||||
//return s->context->on_close(s, code, reason);
|
||||
}
|
||||
return s;
|
||||
}
|
||||
|
||||
/* Same as above but emits on_close */
|
||||
struct us_socket_t *us_socket_close(int ssl, struct us_socket_t *s, int code, void *reason) {
|
||||
if (!us_socket_is_closed(0, s)) {
|
||||
us_internal_socket_context_unlink(s->context, s);
|
||||
us_poll_stop((struct us_poll_t *) s, s->context->loop);
|
||||
bsd_close_socket(us_poll_fd((struct us_poll_t *) s));
|
||||
|
||||
/* Link this socket to the close-list and let it be deleted after this iteration */
|
||||
s->next = s->context->loop->data.closed_head;
|
||||
s->context->loop->data.closed_head = s;
|
||||
|
||||
/* Any socket with prev = context is marked as closed */
|
||||
s->prev = (struct us_socket_t *) s->context;
|
||||
|
||||
return s->context->on_close(s, code, reason);
|
||||
}
|
||||
return s;
|
||||
}
|
||||
|
||||
/* Not shared with SSL */
|
||||
|
||||
void *us_socket_get_native_handle(int ssl, struct us_socket_t *s) {
|
||||
#ifndef LIBUS_NO_SSL
|
||||
if (ssl) {
|
||||
return us_internal_ssl_socket_get_native_handle((struct us_internal_ssl_socket_t *) s);
|
||||
}
|
||||
#endif
|
||||
|
||||
return (void *) (uintptr_t) us_poll_fd((struct us_poll_t *) s);
|
||||
}
|
||||
|
||||
int us_socket_write(int ssl, struct us_socket_t *s, const char *data, int length, int msg_more) {
|
||||
#ifndef LIBUS_NO_SSL
|
||||
if (ssl) {
|
||||
|
@ -86,30 +153,6 @@ void *us_socket_ext(int ssl, struct us_socket_t *s) {
|
|||
return s + 1;
|
||||
}
|
||||
|
||||
struct us_socket_t *us_socket_close(int ssl, struct us_socket_t *s) {
|
||||
#ifndef LIBUS_NO_SSL
|
||||
if (ssl) {
|
||||
return (struct us_socket_t *) us_internal_ssl_socket_close((struct us_internal_ssl_socket_t *) s);
|
||||
}
|
||||
#endif
|
||||
|
||||
if (!us_socket_is_closed(0, s)) {
|
||||
us_internal_socket_context_unlink(s->context, s);
|
||||
us_poll_stop((struct us_poll_t *) s, s->context->loop);
|
||||
bsd_close_socket(us_poll_fd((struct us_poll_t *) s));
|
||||
|
||||
/* Link this socket to the close-list and let it be deleted after this iteration */
|
||||
s->next = s->context->loop->data.closed_head;
|
||||
s->context->loop->data.closed_head = s;
|
||||
|
||||
/* Any socket with prev = context is marked as closed */
|
||||
s->prev = (struct us_socket_t *) s->context;
|
||||
|
||||
return s->context->on_close(s);
|
||||
}
|
||||
return s;
|
||||
}
|
||||
|
||||
int us_socket_is_shut_down(int ssl, struct us_socket_t *s) {
|
||||
#ifndef LIBUS_NO_SSL
|
||||
if (ssl) {
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
#pragma warning(disable : 4267 4138)
|
||||
|
||||
#include "BrokenithmServer.hpp"
|
||||
|
||||
#include <thread>
|
||||
|
@ -53,11 +55,12 @@ uint64_t BrokenithmServer::get_controller_state()
|
|||
|
||||
struct ConnectionData
|
||||
{
|
||||
typedef uWS::WebSocket<false, true, ConnectionData> ConnectionDataSocket;
|
||||
static int s_connection_counter;
|
||||
static std::vector<ConnectionData *> s_connections;
|
||||
|
||||
int m_uid;
|
||||
uWS::WebSocket<false, true> *m_websocket;
|
||||
ConnectionDataSocket *m_websocket;
|
||||
|
||||
void static close_all_connections();
|
||||
|
||||
|
@ -75,7 +78,7 @@ struct ConnectionData
|
|||
s_connections[m_uid] = nullptr;
|
||||
}
|
||||
|
||||
void save_socket(uWS::WebSocket<false, true> *websocket)
|
||||
void save_socket(ConnectionDataSocket *websocket)
|
||||
{
|
||||
m_websocket = websocket;
|
||||
}
|
||||
|
@ -149,12 +152,17 @@ void BrokenithmServer::Impl::start_server()
|
|||
})
|
||||
.ws<ConnectionData>(
|
||||
"/ws",
|
||||
{uWS::DISABLED,
|
||||
16 * 1024 * 1024,
|
||||
10,
|
||||
16 * 1024 * 1024,
|
||||
{uWS::DISABLED, // compression
|
||||
16 * 1024 * 1024, // maxPayloadLength
|
||||
16, // idleTimeout
|
||||
16 * 1024 * 1024, // maxBackpressure
|
||||
false, // closeOnBackpressureLimit
|
||||
false, // resetIdleTimeoutOnSend
|
||||
true, // sendPingsAutomatically
|
||||
0, // maxLifetime
|
||||
nullptr, // upgrade
|
||||
// Open handler
|
||||
[](auto *ws, auto *req) {
|
||||
[](auto *ws) {
|
||||
spdlog::info("Controller ID {} connected", ((ConnectionData *)ws->getUserData())->m_uid);
|
||||
((ConnectionData *)ws->getUserData())->save_socket(ws);
|
||||
},
|
||||
|
@ -181,26 +189,21 @@ void BrokenithmServer::Impl::start_server()
|
|||
}
|
||||
}
|
||||
},
|
||||
// Drain handler
|
||||
[](auto *ws) {},
|
||||
// Ping handler
|
||||
[](auto *ws) {},
|
||||
// Pong handler
|
||||
[](auto *ws) {},
|
||||
nullptr, // Drain handler
|
||||
nullptr, // Ping handler
|
||||
nullptr, // Pong handler
|
||||
// Close handler
|
||||
[](auto *ws, int code, std::string_view message) {
|
||||
spdlog::info("Controller ID {} disconnected", ((ConnectionData *)ws->getUserData())->m_uid);
|
||||
}})
|
||||
.listen(
|
||||
m_port,
|
||||
[&](auto *token) {
|
||||
if (token)
|
||||
{
|
||||
spdlog::info("Server listening at port {}", m_port);
|
||||
m_running = true;
|
||||
m_uws_socket_token = token;
|
||||
}
|
||||
})
|
||||
.listen(m_port, [&](auto *token) {
|
||||
if (token)
|
||||
{
|
||||
spdlog::info("Server listening at port {}", m_port);
|
||||
m_running = true;
|
||||
m_uws_socket_token = token;
|
||||
}
|
||||
})
|
||||
.run();
|
||||
|
||||
m_running = false;
|
||||
|
|
Loading…
Reference in New Issue