diff --git a/include/parselink/BUILD b/include/parselink/BUILD index 6dc630a..6c43e38 100644 --- a/include/parselink/BUILD +++ b/include/parselink/BUILD @@ -19,7 +19,7 @@ cc_library( name = "proto", hdrs = glob(["proto/**/*.h"]), includes = ["."], - deps = ["//include:path"], + deps = ["@expected", "//include:path"], visibility = ["//visibility:public"], ) diff --git a/include/parselink/proto/session.h b/include/parselink/proto/session.h new file mode 100644 index 0000000..1a69c6e --- /dev/null +++ b/include/parselink/proto/session.h @@ -0,0 +1,71 @@ +//----------------------------------------------------------------------------- +// ___ __ _ _ +// / _ \__ _ _ __ ___ ___ / /(_)_ __ | | __ +// / /_)/ _` | '__/ __|/ _ \/ / | | '_ \| |/ / +// / ___/ (_| | | \__ \ __/ /__| | | | | < +// \/ \__,_|_| |___/\___\____/_|_| |_|_|\_\ . +// +//----------------------------------------------------------------------------- +// Author: Kurt Sassenrath +// Module: Proto +// +// Session management for the "user" protocol. +// +// Copyright (c) 2023 Kurt Sassenrath. +// +// License TBD. +//----------------------------------------------------------------------------- +#ifndef session_07eae057feface79 +#define session_07eae057feface79 + +#include + +#include "parselink/msgpack/token.h" + +#include +#include +#include + +namespace parselink { +namespace proto { + +enum class error { + system_error, + incomplete, + unsupported, + bad_data, + too_large, +}; + +// Structure containing header information parsed from a buffer. +struct header_info { + std::uint32_t message_size; // Size of the message, minus the header. + std::uint32_t bytes_read; // How many bytes of the buffer were used. +}; + +class session { +public: + session() = default; + + // Parse the protocol header out of a buffer. This is a member function + // as we may choose to omit data once a session is established. + tl::expected parse_header( + std::span buffer) noexcept; + + tl::expected handle_connect( + std::span tokens) noexcept; + + // The maximum size of a single message. Should probably depend on whether + // the session is established or not. + constexpr static std::uint32_t max_message_size = 128 * 1024; + + + std::string user_id; +private: +}; + +} // namespace proto +} // namespace parselink + +#endif // session_0c61530748b9f966 + diff --git a/source/BUILD b/source/BUILD index 3540a31..1ef99f1 100644 --- a/source/BUILD +++ b/source/BUILD @@ -15,9 +15,9 @@ cc_binary( deps = [ "headers", "//include/parselink:msgpack", - "//include/parselink:proto", "//include/parselink:utility", "//source/logging", + "//source/proto", "@boost//:beast", ], ) diff --git a/source/proto/BUILD b/source/proto/BUILD new file mode 100644 index 0000000..41225b4 --- /dev/null +++ b/source/proto/BUILD @@ -0,0 +1,17 @@ +# parselink + +cc_library( + name = "proto", + srcs = [ + "session.cpp", + ], + deps = [ + "//include/parselink:proto", + "//include/parselink:msgpack", + "//source/logging", + ], + visibility = [ + # TODO: Fix visibility + "//visibility:public", + ], +) diff --git a/source/proto/session.cpp b/source/proto/session.cpp new file mode 100644 index 0000000..822618a --- /dev/null +++ b/source/proto/session.cpp @@ -0,0 +1,151 @@ +//----------------------------------------------------------------------------- +// ___ __ _ _ +// / _ \__ _ _ __ ___ ___ / /(_)_ __ | | __ +// / /_)/ _` | '__/ __|/ _ \/ / | | '_ \| |/ / +// / ___/ (_| | | \__ \ __/ /__| | | | | < +// \/ \__,_|_| |___/\___\____/_|_| |_|_|\_\ . +// +//----------------------------------------------------------------------------- +// Author: Kurt Sassenrath +// Module: Server +// +// Server implementation. Currently, a monolithic server which: +// * Communicates with users via TCP (msgpack). +// * Runs the websocket server for overlays to read. +// +// Copyright (c) 2023 Kurt Sassenrath. +// +// License TBD. +//----------------------------------------------------------------------------- + +#include "parselink/proto/session.h" + +#include + +#include "parselink/logging.h" +#include "parselink/msgpack/token.h" + +using namespace parselink; +using namespace parselink::proto; + +template <> +struct fmt::formatter { + template + constexpr auto parse(ParseContext& ctx) -> decltype(ctx.begin()) { + return ctx.begin(); + } + template + auto format(msgpack::token const& v, FormatContext& ctx) const { + using parselink::logging::themed_arg; + auto out = fmt::format_to(ctx.out(), "()))); + break; + case msgpack::format::type::signed_int: + out = fmt::format_to(out, "{}", themed_arg(*(v.get()))); + break; + case msgpack::format::type::boolean: + out = fmt::format_to(out, "{}", themed_arg(*(v.get()))); + break; + case msgpack::format::type::string: + out = fmt::format_to(out, "{}", themed_arg(*(v.get()))); + break; + case msgpack::format::type::binary: + out = fmt::format_to(out, "{}", + themed_arg(*(v.get>()))); + break; + case msgpack::format::type::map: + out = fmt::format_to(out, "(arity: {})", + themed_arg(v.get()->count)); + break; + case msgpack::format::type::array: + out = fmt::format_to(out, "(arity: {})", + themed_arg(v.get()->count)); + break; + case msgpack::format::type::nil: + out = fmt::format_to(out, "(nil)"); + break; + case msgpack::format::type::invalid: + out = fmt::format_to(out, "(invalid)"); + break; + default: + break; + } + return fmt::format_to(out, ">"); + } +}; + + +namespace { +logging::logger logger("session"); +} + +tl::expected session::parse_header( + std::span buffer) noexcept { + auto reader = msgpack::token_reader(buffer); + auto magic = reader.read_one(); + if (!magic || *magic != "prs") { + logger.error("Failed to parse magic"); + return tl::unexpected(magic ? error::bad_data : error::incomplete); + } + + auto size = reader.read_one().and_then( + [](auto t){ return t.template get(); }); + + if (!size || !*size) { + logger.error("Failed to get valid message size"); + return tl::unexpected(magic ? error::bad_data : error::incomplete); + } + + if (*size > max_message_size) { + logger.error("Message {} exceeds max size {}", *size, max_message_size); + return tl::unexpected(error::too_large); + } + + std::uint32_t amt = std::distance(buffer.begin(), reader.current()); + return tl::expected(tl::in_place, *size, amt); +} + +tl::expected session::handle_connect( + std::span tokens) noexcept { + auto message_type = tokens.begin()->get(); + if (message_type && message_type == "connect") { + logger.debug("Received '{}' packet. Parsing body", *message_type); + auto map = msgpack::map_view(tokens.subspan(1)); + + constexpr auto find_version = [](auto const& kv) { + return kv.first == "version"; + }; + + auto field = std::ranges::find_if(map, find_version); + if (field != map.end() && (*field).second == 1u) { + logger.debug("Version {}", (*field).second.get()); + } else { + logger.error("connect failed: missing / unsupported version"); + return tl::make_unexpected(error::unsupported); + } + + for (auto const& [k, v] : msgpack::map_view(tokens.subspan(1))) { + tl::expected result; + if (k.type() != msgpack::format::type::string) { + logger.error("connect failed: non-string key {}", k); + return tl::make_unexpected(error::bad_data); + } + if (k == "user_id") { + result = v.get().map([this](auto uid){ + user_id = std::move(uid); + }); + } + if (!result) { + logger.error("connect failed: {} -> {}: {}", k, v, + result.error()); + return tl::make_unexpected(error::bad_data); + } + } + } else { + logger.error("Did not get message type: {}", message_type.error()); + return tl::make_unexpected(error::bad_data); + } + return {}; +} diff --git a/source/server.cpp b/source/server.cpp index 0810dad..69360fe 100644 --- a/source/server.cpp +++ b/source/server.cpp @@ -23,7 +23,7 @@ #include "parselink/msgpack/token/reader.h" #include "parselink/msgpack/token/views.h" -#include "parselink/proto/message.h" +#include "parselink/proto/session.h" #include @@ -53,12 +53,6 @@ using net::use_awaitable; using net::deferred; using net::detached; - -enum class error { - system, - msgpack, -}; - //----------------------------------------------------------------------------- // TODO(ksassenrath): These are logging formatters for various boost/asio types. // Not all code is exposed to them, so they cannot be defined inside the @@ -184,50 +178,6 @@ private: std::uint16_t websocket_port_; }; -tl::expected handle_connect( - std::span tokens) noexcept { - user_session user; - auto message_type = tokens.begin()->get(); - if (message_type) { - logger.debug("Received '{}' packet. Parsing body", *message_type); - auto map = msgpack::map_view(tokens.subspan(1)); - - constexpr auto find_version = [](auto const& kv) { - return kv.first == "version"; - }; - - auto field = std::ranges::find_if(map, find_version); - if (field != map.end() && (*field).second == 1u) { - logger.debug("Version {}\n", (*field).second); - } else { - logger.error("connect failed: missing / unsupported version"); - return tl::make_unexpected(msgpack::error::unsupported); - } - - for (auto const& [k, v] : msgpack::map_view(tokens.subspan(1))) { - tl::expected result; - if (k.type() != msgpack::format::type::string) { - logger.error("connect failed: non-string key {}", k); - return tl::make_unexpected(msgpack::error::bad_value); - } - if (k == "user_id") { - result = v.get().map([&user](auto uid){ - user.user_id = std::move(uid); - }); - } - if (!result) { - logger.error("connect failed: {} -> {}: {}", k, v, - result.error()); - return tl::make_unexpected(msgpack::error::bad_value); - } - } - } else { - logger.error("Did not get message type: {}", message_type.error()); - return tl::make_unexpected(msgpack::error::bad_value); - } - return user; -} - class user_connection : public std::enable_shared_from_this { public: user_connection(monolithic_server& server, net::ip::tcp::socket sock) @@ -280,15 +230,14 @@ public: } awaitable> - buffer_message(std::vector& buffer) noexcept { - auto amt = buffer.size(); - auto total = buffer.capacity(); - buffer.resize(total); + buffer_message(std::span buffer) noexcept { - while (amt < total) { - auto subf = std::span(buffer.begin() + amt, buffer.end()); + std::size_t amt = 0; + + while (amt < buffer.size()) { + auto subsp = buffer.subspan(amt); auto [ec, n] = co_await socket_.async_read_some( - net::buffer(subf), no_ex_coro); + net::buffer(subsp), no_ex_coro); logger.debug("Read {} bytes, total is now {}", n, amt + n); if (ec || n == 0) { logger.error("Reading from user socket failed: {}", ec); @@ -299,8 +248,7 @@ public: co_return std::monostate{}; } - awaitable, boost::system::error_code>> - await_message() noexcept { + awaitable await_connect() noexcept { // Use a small buffer on the stack to read the initial header. std::array buffer; auto [ec, n] = co_await socket_.async_read_some( @@ -308,44 +256,52 @@ public: if (ec) { logger.error("Reading from user socket failed: {}", ec); - co_return tl::make_unexpected(ec); + co_return; } logger.debug("Read {} bytes from client: {}", n, std::span(buffer.data(), n)); - auto hdr = parse_header(std::span(buffer.data(), n)); - if (!hdr) { - logger.error("Unable to parse header: {}", hdr.error()); - co_return tl::make_unexpected(boost::system::error_code: ); + proto::session session; + + auto maybe_hdr = session.parse_header(std::span(buffer.data(), n)); + + if (!maybe_hdr) { + logger.error("Unable to parse header: {}", maybe_hdr.error()); co_return; } - auto msg = std::move(*hdr); + // TODO(ksassenrath): Replace with specialized allocator. + auto msg = std::vector(maybe_hdr->message_size); - if (auto result = co_await buffer_message(msg); !result) { + // Copy remaining portion of message in initial read to the message + // buffer. + std::copy( + std::next(buffer.begin(), maybe_hdr->bytes_read), + std::next(buffer.begin(), n), + msg.begin()); + + auto msg_span = std::span(msg.begin() + n - maybe_hdr->bytes_read, msg.end()); + + if (auto result = co_await buffer_message(msg_span); !result) { logger.error("Unable to parse header: {}", result.error()); co_return; } - } - - awaitable await_connect() noexcept { auto reader = msgpack::token_reader(msg); std::array tokens; - auto maybe_user = reader.read_some(tokens) - .and_then(handle_connect) - .map_error([](auto const& error) { - logger.error("Unable to parse msgpack tokens: {}", error); - }); - - if (!maybe_user) { + auto tokens_used = reader.read_some(tokens); + if (tokens_used) { + auto err = session.handle_connect(*tokens_used); + if (!err) co_return; + } else { + logger.error("Unable to parse msgpack tokens: {}", tokens_used); co_return; } // Authenticate against database. - logger.debug("User {} established connection", maybe_user->user_id); - session_ = server_.establish_session(std::move(*maybe_user)); + logger.debug("User {} established connection", session.user_id); + //session_ = server_.establish_session(std::move(*maybe_user)); } enum class state { @@ -401,6 +357,7 @@ user_session* monolithic_server::establish_session(user_session session) { active_user_sessions_.emplace(session.user_id, std::move(session.user_id)); } + return {}; } std::unique_ptr parselink::make_server(std::string_view address,