From 8f8066c243e908adcfdd2311d093862f2b179edb Mon Sep 17 00:00:00 2001 From: Kurt Sassenrath Date: Thu, 26 Oct 2023 07:23:51 -0700 Subject: [PATCH] WIP --- .snippets/cpp.lua | 30 +++++++++ include/parselink/logging/formatters.h | 2 +- include/parselink/logging/theme.h | 1 + include/parselink/proto/session.h | 30 ++++----- source/proto/session.cpp | 50 +++++++++----- source/server.cpp | 90 ++++++++++---------------- 6 files changed, 112 insertions(+), 91 deletions(-) create mode 100644 .snippets/cpp.lua diff --git a/.snippets/cpp.lua b/.snippets/cpp.lua new file mode 100644 index 0000000..4581905 --- /dev/null +++ b/.snippets/cpp.lua @@ -0,0 +1,30 @@ +local luasnip = require('luasnip') + +local fmt = require('luasnip.extras.fmt').fmt + +local header = [[ +//----------------------------------------------------------------------------- +// ___ __ _ _ +// / _ \__ _ _ __ ___ ___ / /(_)_ __ | | __ +// / /_)/ _` | '__/ __|/ _ \/ / | | '_ \| |/ / +// / ___/ (_| | | \__ \ __/ /__| | | | | < +// \/ \__,_|_| |___/\___\____/_|_| |_|_|\_\ . +// +//----------------------------------------------------------------------------- +// Author: Kurt Sassenrath +// Module: {} +// +// {} +// +// Copyright (c) 2023 Kurt Sassenrath. +// +// License TBD. +//----------------------------------------------------------------------------- + +]] + +return { + s("hdr", fmt(header, {i(1, ""), i(2, "")})) +} + + diff --git a/include/parselink/logging/formatters.h b/include/parselink/logging/formatters.h index 668bed4..353d734 100644 --- a/include/parselink/logging/formatters.h +++ b/include/parselink/logging/formatters.h @@ -91,7 +91,7 @@ struct fmt::formatter : fmt::formatter { template auto format(std::errc const& v, FormatContext& ctx) const { return fmt::formatter::format( - std::make_error_code(v), ctx); + std::make_error_code(v), ctx); } }; diff --git a/include/parselink/logging/theme.h b/include/parselink/logging/theme.h index f12bb8c..7515f6f 100644 --- a/include/parselink/logging/theme.h +++ b/include/parselink/logging/theme.h @@ -31,6 +31,7 @@ #include "level.h" #include "traits.h" +#include #include #include diff --git a/include/parselink/proto/session.h b/include/parselink/proto/session.h index 76a2f82..cc2af7d 100644 --- a/include/parselink/proto/session.h +++ b/include/parselink/proto/session.h @@ -42,27 +42,19 @@ struct header_info { std::uint32_t bytes_read; // How many bytes of the buffer were used. }; -class session { -public: - session() = default; - - // Parse the protocol header out of a buffer. This is a member function - // as we may choose to omit data once a session is established. - tl::expected parse_header( - std::span buffer) noexcept; - - tl::expected handle_connect( - std::span tokens) noexcept; - - // The maximum size of a single message. Should probably depend on whether - // the session is established or not. - constexpr static std::uint32_t max_message_size = 128 * 1024; - - std::string user_id; - -private: +struct connect_info { + std::uint32_t version; + std::string_view user_id; + std::span session_id; }; +// Parse the protocol header out of a buffer. +tl::expected parse_header( + std::span buffer) noexcept; + +tl::expected parse_connect( + std::span tokens) noexcept; + } // namespace proto } // namespace parselink diff --git a/source/proto/session.cpp b/source/proto/session.cpp index 3084f12..9b20db8 100644 --- a/source/proto/session.cpp +++ b/source/proto/session.cpp @@ -17,7 +17,6 @@ // // License TBD. //----------------------------------------------------------------------------- - #include "parselink/proto/session.h" #include "parselink/logging.h" @@ -82,9 +81,19 @@ struct fmt::formatter { namespace { logging::logger logger("session"); +constexpr std::uint32_t max_size = 128 * 1024; + +constexpr tl::expected check_support( + tl::expected const& val) { + if (val == 1u) { + return *(*val).get(); + } + return tl::make_unexpected(error::unsupported); } -tl::expected session::parse_header( +} // namespace + +tl::expected proto::parse_header( std::span buffer) noexcept { auto reader = msgpack::token_reader(buffer); auto magic = reader.read_one(); @@ -101,8 +110,8 @@ tl::expected session::parse_header( return tl::unexpected(magic ? error::bad_data : error::incomplete); } - if (*size > max_message_size) { - logger.error("Message {} exceeds max size {}", *size, max_message_size); + if (*size > max_size) { + logger.error("Message size {} exceeds max {}", *size, max_size); return tl::unexpected(error::too_large); } @@ -110,23 +119,30 @@ tl::expected session::parse_header( return tl::expected(tl::in_place, *size, amt); } -tl::expected session::handle_connect( +constexpr tl::expected lookup( + auto const& map_view, auto const& key) { + auto find_key = [&key](auto const& kv) { return kv.first == key; }; + if (auto field = std::ranges::find_if(map_view, find_key); + field != map_view.end()) { + return (*field).second; + } + return tl::make_unexpected(error::bad_data); +} + +tl::expected proto::parse_connect( std::span tokens) noexcept { auto message_type = tokens.begin()->get(); if (message_type && message_type == "connect") { logger.debug("Received '{}' packet. Parsing body", *message_type); auto map = msgpack::map_view(tokens.subspan(1)); - constexpr auto find_version = [](auto const& kv) { - return kv.first == "version"; - }; + connect_info info; + auto version = lookup(map, "version").and_then(check_support); - auto field = std::ranges::find_if(map, find_version); - if (field != map.end() && (*field).second == 1u) { - logger.debug("Version {}", (*field).second.get()); + if (version) { + info.version = *version; } else { - logger.error("connect failed: missing / unsupported version"); - return tl::make_unexpected(error::unsupported); + return tl::make_unexpected(version.error()); } for (auto const& [k, v] : msgpack::map_view(tokens.subspan(1))) { @@ -135,19 +151,21 @@ tl::expected session::handle_connect( logger.error("connect failed: non-string key {}", k); return tl::make_unexpected(error::bad_data); } + if (k == "user_id") { - result = v.get().map( - [this](auto uid) { user_id = std::move(uid); }); + result = v.get().map( + [&info](auto uid) { info.user_id = uid; }); } + if (!result) { logger.error( "connect failed: {} -> {}: {}", k, v, result.error()); return tl::make_unexpected(error::bad_data); } } + return info; } else { logger.error("Did not get message type: {}", message_type.error()); return tl::make_unexpected(error::bad_data); } - return {}; } diff --git a/source/server.cpp b/source/server.cpp index ee63191..c7fb087 100644 --- a/source/server.cpp +++ b/source/server.cpp @@ -25,6 +25,7 @@ #include "parselink/msgpack/token/views.h" #include "parselink/proto/session.h" #include +#include #include #include #include @@ -33,12 +34,15 @@ #include #include #include +#include #include -#include #include #include +using namespace parselink; +#include + using namespace parselink; using namespace std::chrono_literals; @@ -148,7 +152,10 @@ constexpr auto no_ex_coro = net::as_tuple(use_awaitable); constexpr auto no_ex_defer = net::as_tuple(deferred); } // namespace +class user_connection; + struct user_session { + std::weak_ptr conn; std::string user_id; std::array session_id; std::chrono::system_clock::time_point expires_at; @@ -161,16 +168,19 @@ public: std::error_code run() noexcept override; + net::awaitable> create_session( + std::weak_ptr conn, + proto::connect_info const& info); + private: friend user_session; awaitable user_listen(); - user_session* establish_session(user_session session); - - std::map> active_user_sessions_; + std::map> active_sessions_; net::io_context io_context_; + net::io_context::strand session_strand_; net::ip::address addr_; std::uint16_t user_port_; std::uint16_t websocket_port_; @@ -197,36 +207,6 @@ public: detached); } - tl::expected, msgpack::error> parse_header( - std::span data) noexcept { - auto reader = msgpack::token_reader(data); - - auto magic = reader.read_one().map([](auto t) { return t == "prs"; }); - if (magic && *magic) { - logger.debug("Got magic from client"); - } else { - logger.error("Failed to get magic from client: {}", magic); - return tl::unexpected(magic.error()); - } - - auto sz = reader.read_one().and_then( - [](auto t) { return t.template get(); }); - if (sz && *sz) { - logger.debug("Got packet size from client: {}", *sz); - } else { - logger.debug("Failed to get packet size from client: {}", sz); - return tl::unexpected(magic.error()); - } - - // Copy the rest of the message to the buffer for parsing. - // TODO(ksassenrath): Replace vector with custom buffer. - std::vector msg; - msg.reserve(*sz); - msg.resize(reader.remaining()); - std::copy(reader.current(), reader.end(), msg.begin()); - return msg; - } - awaitable> buffer_message(std::span buffer) noexcept { std::size_t amt = 0; @@ -259,9 +239,7 @@ public: logger.debug("Read {} bytes from client: {}", n, std::span(buffer.data(), n)); - proto::session session; - - auto maybe_hdr = session.parse_header(std::span(buffer.data(), n)); + auto maybe_hdr = proto::parse_header(std::span(buffer.data(), n)); if (!maybe_hdr) { logger.error("Unable to parse header: {}", maybe_hdr.error()); @@ -286,30 +264,30 @@ public: auto reader = msgpack::token_reader(msg); std::array tokens; - auto tokens_used = reader.read_some(tokens); - if (tokens_used) { - auto err = session.handle_connect(*tokens_used); - if (!err) co_return; - } else { - logger.error("Unable to parse msgpack tokens: {}", tokens_used); + auto session = + reader.read_some(tokens) + .map_error([](auto) { return proto::error::bad_data; }) + .and_then(proto::parse_connect); + + if (!session) { + logger.error("Session failed: {}", session.error()); co_return; } - // Authenticate against database. - logger.debug("User {} established connection", session.user_id); - // session_ = server_.establish_session(std::move(*maybe_user)); + co_await server_.create_session( + std::weak_ptr(shared_from_this()), *session); } enum class state { init, authenticated, active }; monolithic_server& server_; - user_session* session_{}; net::ip::tcp::socket socket_; }; monolithic_server::monolithic_server(std::string_view address, std::uint16_t user_port, std::uint16_t websocket_port) : io_context_{1} + , session_strand_(io_context_) , addr_(net::ip::address::from_string(std::string{address})) , user_port_{user_port} , websocket_port_{websocket_port} { @@ -344,14 +322,16 @@ std::error_code monolithic_server::run() noexcept { return {}; } -user_session* monolithic_server::establish_session(user_session session) { - auto existing_session = active_user_sessions_.find(session.user_id); - if (existing_session == active_user_sessions_.end()) { - // No session exists with that user ID yet. - active_user_sessions_.emplace( - session.user_id, std::move(session.user_id)); - } - return {}; +net::awaitable> +monolithic_server::create_session(std::weak_ptr conn, + proto::connect_info const& session) { + // Move to session strand. + co_await net::post(session_strand_, + net::bind_executor(session_strand_, use_awaitable)); + + // Pretend that there's no open + + co_return tl::make_unexpected(proto::error::unsupported); } std::unique_ptr parselink::make_server(std::string_view address,