diff --git a/include/parselink/msgpack/token/type.h b/include/parselink/msgpack/token/type.h index 9a3c3bf..50024f1 100644 --- a/include/parselink/msgpack/token/type.h +++ b/include/parselink/msgpack/token/type.h @@ -240,6 +240,9 @@ public: } } + template + constexpr bool operator==(char const (&t)[N]) const noexcept; + template constexpr bool operator==(T const& t) const noexcept { if constexpr (std::equality_comparable) { @@ -286,6 +289,13 @@ constexpr tl::expected token_base<8>::get() return std::string_view{value_.str, size_and_type_.get_size()}; } + template + constexpr bool token_base<8>::operator==(char const (&t)[N]) const noexcept { + auto result = get().map([&t](auto v) { + return v == std::string_view{t}; }); + return result && *result; + } + template<> inline tl::expected, error> token_base<8>::get() const noexcept diff --git a/include/parselink/msgpack/token/views.h b/include/parselink/msgpack/token/views.h index cefa21b..68267ad 100644 --- a/include/parselink/msgpack/token/views.h +++ b/include/parselink/msgpack/token/views.h @@ -25,6 +25,7 @@ #include "type.h" +#include #include #include @@ -35,6 +36,8 @@ namespace msgpack { std::ranges::range_value_t> struct map_view : public std::ranges::view_interface> { public: + class sentinel; + class iterator { friend class sentinel; @@ -57,10 +60,10 @@ namespace msgpack { } public: - using value_type = std::pair; + using value_type = std::pair; using reference = std::pair; using difference_type = std::ptrdiff_t; - using iterator_category = std::forward_iterator_tag; + using iterator_category = std::input_iterator_tag; iterator() = default; iterator(V const& base) @@ -76,10 +79,10 @@ namespace msgpack { } } - [[nodiscard]] reference operator*() const { return { *k_, *v_ }; } + iterator& operator++() { k_ = next(v_); v_ = next(k_); @@ -98,6 +101,7 @@ namespace msgpack { base_ == rhs.base_; } + private: V const* base_{}; base_iterator k_{}; base_iterator v_{}; diff --git a/include/parselink/proto/message.h b/include/parselink/proto/message.h index 93e4b75..2b5857b 100644 --- a/include/parselink/proto/message.h +++ b/include/parselink/proto/message.h @@ -70,53 +70,6 @@ struct parser_data_message { std::span payload; }; -using message = std::variant< - error_message, - connect_message, - challenge_message, - response_message, - session_established_message, - parser_data_message>; - -enum class error { - bad_magic, // Did not get the message magic expected - too_large, // The message size was too large. - unknown_type, // The message type is not known -}; - -// This class is responsible for consuming buffer data and yielding a message -// instance when complete. Will throw an error if data is incorrect. -// -class builder { -public: - // For now, builders don't manage any buffers themselves. Later, that - // may change. - builder() = default; - - // Reset the builder to its initial state. This means any partially-decoded - // message data will be lost. - void reset() noexcept; - - // How many bytes are needed to perform a meaningful amount of work. - std::size_t bytes_needed() noexcept; - - // Process data from a buffer, building messages. Returns the number of - // bytes read from the buffer for the caller's bookkeeping. May yield a - // message in addition. - std::size_t process(std::span buffer) noexcept; - -private: - enum class state { - magic, - size, - payload - }; - - state state_{state::magic}; - std::size_t payload_size_{}; - std::size_t payload_remaining_{}; -}; - } // namespace message } // namespace parselink diff --git a/source/server.cpp b/source/server.cpp index 586ed15..0810dad 100644 --- a/source/server.cpp +++ b/source/server.cpp @@ -27,7 +27,6 @@ #include - #include #include #include @@ -41,7 +40,11 @@ #include #include +#include +#include + using namespace parselink; +using namespace std::chrono_literals; namespace net = boost::asio; using net::co_spawn; @@ -50,6 +53,12 @@ using net::use_awaitable; using net::deferred; using net::detached; + +enum class error { + system, + msgpack, +}; + //----------------------------------------------------------------------------- // TODO(ksassenrath): These are logging formatters for various boost/asio types. // Not all code is exposed to them, so they cannot be defined inside the @@ -146,15 +155,87 @@ namespace { constexpr auto no_ex_defer = net::as_tuple(deferred); } -struct msgbuf { - std::vector payload; +struct user_session { + std::string user_id; + std::array session_id; + std::chrono::system_clock::time_point expires_at; }; -class user_session : public std::enable_shared_from_this { +class monolithic_server : public server { public: - user_session(net::ip::tcp::socket sock) : socket_(std::move(sock)) {} - ~user_session() { - logger.debug("Closing connection to {}", socket_.remote_endpoint()); + monolithic_server(std::string_view address, std::uint16_t user_port, + std::uint16_t websocket_port); + + std::error_code run() noexcept override; + +private: + + friend user_session; + + awaitable user_listen(); + + user_session* establish_session(user_session session); + + std::map> active_user_sessions_; + + net::io_context io_context_; + net::ip::address addr_; + std::uint16_t user_port_; + std::uint16_t websocket_port_; +}; + +tl::expected handle_connect( + std::span tokens) noexcept { + user_session user; + auto message_type = tokens.begin()->get(); + if (message_type) { + logger.debug("Received '{}' packet. Parsing body", *message_type); + auto map = msgpack::map_view(tokens.subspan(1)); + + constexpr auto find_version = [](auto const& kv) { + return kv.first == "version"; + }; + + auto field = std::ranges::find_if(map, find_version); + if (field != map.end() && (*field).second == 1u) { + logger.debug("Version {}\n", (*field).second); + } else { + logger.error("connect failed: missing / unsupported version"); + return tl::make_unexpected(msgpack::error::unsupported); + } + + for (auto const& [k, v] : msgpack::map_view(tokens.subspan(1))) { + tl::expected result; + if (k.type() != msgpack::format::type::string) { + logger.error("connect failed: non-string key {}", k); + return tl::make_unexpected(msgpack::error::bad_value); + } + if (k == "user_id") { + result = v.get().map([&user](auto uid){ + user.user_id = std::move(uid); + }); + } + if (!result) { + logger.error("connect failed: {} -> {}: {}", k, v, + result.error()); + return tl::make_unexpected(msgpack::error::bad_value); + } + } + } else { + logger.error("Did not get message type: {}", message_type.error()); + return tl::make_unexpected(msgpack::error::bad_value); + } + return user; +} + +class user_connection : public std::enable_shared_from_this { +public: + user_connection(monolithic_server& server, net::ip::tcp::socket sock) + : server_(server) + , socket_(std::move(sock)) {} + + ~user_connection() { + logger.debug("Connection to {} closed.", socket_.remote_endpoint()); boost::system::error_code ec; socket_.shutdown(net::ip::tcp::socket::shutdown_both, ec); socket_.close(); @@ -163,21 +244,23 @@ public: void start() { logger.debug("New connection from {}", socket_.remote_endpoint()); co_spawn(socket_.get_executor(), [self = shared_from_this()]{ - return self->await_auth(); + return self->await_connect(); }, detached); } tl::expected, msgpack::error> parse_header( std::span data) noexcept { auto reader = msgpack::token_reader(data); + auto magic = reader.read_one().map( - [](auto t){ return t == std::string_view{"prs"}; }); + [](auto t){ return t == "prs"; }); if (magic && *magic) { logger.debug("Got magic from client"); } else { logger.error("Failed to get magic from client: {}", magic); return tl::unexpected(magic.error()); } + auto sz = reader.read_one().and_then( [](auto t){ return t.template get(); }); if (sz && *sz) { @@ -186,6 +269,7 @@ public: logger.debug("Failed to get packet size from client: {}", sz); return tl::unexpected(magic.error()); } + // Copy the rest of the message to the buffer for parsing. // TODO(ksassenrath): Replace vector with custom buffer. std::vector msg; @@ -215,63 +299,53 @@ public: co_return std::monostate{}; } - tl::expected handle_auth(std::span tokens) { - auto message_type = tokens.begin()->get(); - if (message_type) { - logger.debug("Received '{}' packet. Parsing body", *message_type); - proto::connect_message message; - for (auto const& [k, v] : msgpack::map_view(tokens.subspan(1))) { - logger.debug("Parsing {} -> {}", k, v); - } - } else { - logger.error("Did not get message type: {}", message_type.error()); - } - // The first token should include - return true; - } - - awaitable await_auth() noexcept { - std::array buffer; + awaitable, boost::system::error_code>> + await_message() noexcept { + // Use a small buffer on the stack to read the initial header. + std::array buffer; auto [ec, n] = co_await socket_.async_read_some( net::buffer(buffer), no_ex_coro); if (ec) { logger.error("Reading from user socket failed: {}", ec); - co_return; + co_return tl::make_unexpected(ec); } logger.debug("Read {} bytes from client: {}", n, std::span(buffer.data(), n)); - auto hdr_result = parse_header(std::span(buffer.data(), n)); - if (!hdr_result) { + auto hdr = parse_header(std::span(buffer.data(), n)); + if (!hdr) { + logger.error("Unable to parse header: {}", hdr.error()); + co_return tl::make_unexpected(boost::system::error_code: ); co_return; } - auto msg = std::move(*hdr_result); - auto maybe_error = co_await buffer_message(msg); + auto msg = std::move(*hdr); - if (!maybe_error) { - logger.error("Unable to buffer message: {}", - maybe_error.error()); + if (auto result = co_await buffer_message(msg); !result) { + logger.error("Unable to parse header: {}", result.error()); co_return; } - logger.trace("Message: {}", msg); + } + awaitable await_connect() noexcept { auto reader = msgpack::token_reader(msg); std::array tokens; - auto parsed = reader.read_some(tokens).and_then( - [this](auto c) { for (auto t : c) logger.trace("{}", t); return handle_auth(c); }) - .or_else([](auto const& error) { + auto maybe_user = reader.read_some(tokens) + .and_then(handle_connect) + .map_error([](auto const& error) { logger.error("Unable to parse msgpack tokens: {}", error); }); - if (!parsed) { - co_return; + if (!maybe_user) { + co_return; } // Authenticate against database. + logger.debug("User {} established connection", maybe_user->user_id); + session_ = server_.establish_session(std::move(*maybe_user)); } enum class state { @@ -280,27 +354,11 @@ public: active }; + monolithic_server& server_; + user_session* session_{}; net::ip::tcp::socket socket_; }; -class monolithic_server : public server { -public: - monolithic_server(std::string_view address, std::uint16_t user_port, - std::uint16_t websocket_port); - - std::error_code run() noexcept override; - -private: - - awaitable user_listen(); - - net::io_context io_context_; - net::ip::address addr_; - std::uint16_t user_port_; - std::uint16_t websocket_port_; -}; - - monolithic_server::monolithic_server(std::string_view address, std::uint16_t user_port, std::uint16_t websocket_port) : io_context_{1} @@ -315,7 +373,7 @@ awaitable monolithic_server::user_listen() { auto exec = co_await net::this_coro::executor; net::ip::tcp::acceptor acceptor{exec, {addr_, user_port_}}; while (true) { - std::make_shared( + std::make_shared(*this, co_await acceptor.async_accept(use_awaitable))->start(); } } @@ -336,6 +394,15 @@ std::error_code monolithic_server::run() noexcept { return {}; } +user_session* monolithic_server::establish_session(user_session session) { + auto existing_session = active_user_sessions_.find(session.user_id); + if (existing_session == active_user_sessions_.end()) { + // No session exists with that user ID yet. + active_user_sessions_.emplace(session.user_id, + std::move(session.user_id)); + } +} + std::unique_ptr parselink::make_server(std::string_view address, std::uint16_t user_port, std::uint16_t websocket_port) { using impl = monolithic_server;