From 1c7047e3145e8ec88d35d517c028eb880470bff4 Mon Sep 17 00:00:00 2001 From: Kurt Sassenrath Date: Thu, 12 Oct 2023 14:52:06 -0700 Subject: [PATCH] Implement token map_view, additional formatters. msgpack::map_view can be used to iterate, pair-wise, over a range of msgpack::token. It will immediately return if the first token is not a map, and will skip over nested map/arrays. Note for the future: It will be handy to be able to get the subspan corresponding to the nested map/array. Will think about how to solve that later. Begin incorporating map_view into the server. Add formatters for std::byte, dynamic theme for bool, and spans thereof. Maybe switch to range? --- include/parselink/logging.h | 2 +- include/parselink/logging/formatters.h | 14 ++ include/parselink/logging/theme.h | 16 +- include/parselink/msgpack/core/format.h | 4 +- include/parselink/msgpack/token.h | 1 + include/parselink/msgpack/token/views.h | 138 +++++++++++++++ source/server.cpp | 167 +++++++++++++++--- tests/msgpack/BUILD | 10 +- tests/msgpack/test_token_views.cpp | 219 ++++++++++++++++++++++++ tests/msgpack/test_utils.h | 98 +++++++++++ 10 files changed, 644 insertions(+), 25 deletions(-) create mode 100644 include/parselink/msgpack/token/views.h create mode 100644 tests/msgpack/test_token_views.cpp create mode 100644 tests/msgpack/test_utils.h diff --git a/include/parselink/logging.h b/include/parselink/logging.h index a2e3771..cb136cc 100644 --- a/include/parselink/logging.h +++ b/include/parselink/logging.h @@ -36,7 +36,7 @@ namespace logging { // enabled in the library, but the compiler should be able to optimize away // some/most of the calls. constexpr inline auto static_threshold = level::trace; -constexpr inline auto default_threshold = level::info; +constexpr inline auto default_threshold = level::debug; // Structure for holding a message. Note: message is a view over some buffer, // not it's own string. It will either be a static buffer supplied by the diff --git a/include/parselink/logging/formatters.h b/include/parselink/logging/formatters.h index a489eef..c747320 100644 --- a/include/parselink/logging/formatters.h +++ b/include/parselink/logging/formatters.h @@ -96,6 +96,20 @@ struct fmt::formatter : fmt::formatter { } }; +// Support printing bytes as their hex representation. +template <> +struct fmt::formatter { + template + constexpr auto parse(ParseContext& ctx) -> decltype(ctx.begin()) { + return ctx.begin(); + } + + template + auto format(std::byte const v, FormatContext& ctx) const { + return fmt::format_to(ctx.out(), "0x{:0x}", std::uint8_t(v)); + } +}; + // Support printing raw/smart pointers without needing to wrap them in fmt::ptr template struct fmt::formatter : fmt::formatter { diff --git a/include/parselink/logging/theme.h b/include/parselink/logging/theme.h index 6d6adcf..247914e 100644 --- a/include/parselink/logging/theme.h +++ b/include/parselink/logging/theme.h @@ -72,8 +72,12 @@ struct static_theme { }; template +requires (!std::same_as) struct theme : static_theme {}; +template <> +struct theme : static_theme {}; + template T> struct theme : static_theme {}; @@ -84,7 +88,6 @@ template requires std::is_enum_v struct theme : static_theme {}; - // Errors template <> struct theme : static_theme {}; @@ -96,6 +99,13 @@ template <> struct theme : static_theme {}; +template <> +struct theme { + static constexpr auto style(bool l) noexcept { + return fmt::fg(l ? fmt::color::spring_green : fmt::color::indian_red); + } +}; + template <> struct theme> { constexpr static fmt::color colors[] = { @@ -127,7 +137,9 @@ concept has_dynamic_theme = requires (T const& t) { template concept has_theme = has_static_theme || has_dynamic_theme; -static_assert(has_static_theme); +template +requires (has_theme) +struct theme> : theme {}; template constexpr auto get_theme(T const&) { diff --git a/include/parselink/msgpack/core/format.h b/include/parselink/msgpack/core/format.h index fbf0421..5dd9e34 100644 --- a/include/parselink/msgpack/core/format.h +++ b/include/parselink/msgpack/core/format.h @@ -70,7 +70,9 @@ namespace format { nil, boolean, array, - map + map, + array_view, + map_view, }; // Flags that may control the behavior of readers/writers. diff --git a/include/parselink/msgpack/token.h b/include/parselink/msgpack/token.h index a704a2e..01e058f 100644 --- a/include/parselink/msgpack/token.h +++ b/include/parselink/msgpack/token.h @@ -3,5 +3,6 @@ #include "token/type.h" #include "token/reader.h" +#include "token/views.h" #endif // msgpack_object_f1f3a9e5c8be6a11 diff --git a/include/parselink/msgpack/token/views.h b/include/parselink/msgpack/token/views.h new file mode 100644 index 0000000..cefa21b --- /dev/null +++ b/include/parselink/msgpack/token/views.h @@ -0,0 +1,138 @@ +//----------------------------------------------------------------------------- +// ___ __ _ _ +// / _ \__ _ _ __ ___ ___ / /(_)_ __ | | __ +// / /_)/ _` | '__/ __|/ _ \/ / | | '_ \| |/ / +// / ___/ (_| | | \__ \ __/ /__| | | | | < +// \/ \__,_|_| |___/\___\____/_|_| |_|_|\_\ . +// +//----------------------------------------------------------------------------- +// Author: Kurt Sassenrath +// Module: msgpack +// +// Token view utilities. +// +// MessagePack maps and arrays are nested, and the token reader only parses +// out the type. This file provides utilities for iterating over these +// "container" formats without incurring additional overhead on the parser when +// it is not needed. +// +// Copyright (c) 2023 Kurt Sassenrath. +// +// License TBD. +//----------------------------------------------------------------------------- +#ifndef msgpack_token_views_f19c250e782ed51c +#define msgpack_token_views_f19c250e782ed51c + +#include "type.h" + +#include +#include + +namespace msgpack { + + template + requires std::ranges::input_range && std::same_as> + struct map_view : public std::ranges::view_interface> { + public: + class iterator { + friend class sentinel; + + using base_iterator = std::ranges::iterator_t; + using base_sentinel = std::ranges::sentinel_t; + using base_value_type = std::ranges::range_value_t; + using base_reference = std::ranges::range_reference_t; + + base_iterator next(base_iterator current, std::size_t n = 1) { + while (n && current != std::ranges::end(*base_)) { + if (auto m = current->template get(); m) { + n += m->count * 2; + } else if (auto m = current->template get(); m) { + n += m->count; + } + ++current; + --n; + } + return current; + } + + public: + using value_type = std::pair; + using reference = std::pair; + using difference_type = std::ptrdiff_t; + using iterator_category = std::forward_iterator_tag; + + iterator() = default; + iterator(V const& base) + : base_{&base} + , k_{std::ranges::begin(base)} { + // Ensure that k_ points to a map_desc. If not, then we + // effectively treat this as the end. + if (k_->type() == msgpack::format::type::map) { + remaining_ = k_->template get()->count + 1; + // Advance to the first entry in the map. + ++k_; + v_ = next(k_); + } + } + + + [[nodiscard]] reference operator*() const { + return { *k_, *v_ }; + } + iterator& operator++() { + k_ = next(v_); + v_ = next(k_); + --remaining_; + return *this; + } + + [[nodiscard]] iterator operator++(int) { + auto tmp = *this; + ++(*this); + return tmp; + } + + [[nodiscard]] bool operator==(iterator const& rhs) const { + return k_ == rhs.remaining_ && + base_ == rhs.base_; + } + + V const* base_{}; + base_iterator k_{}; + base_iterator v_{}; + std::size_t remaining_{}; + }; + + class sentinel { + public: + [[nodiscard]] bool operator==(sentinel const&) const { + return true; + } + + [[nodiscard]] bool operator==(iterator const& rhs) const { + return rhs.remaining_ == 0 + || rhs.k_ == std::ranges::end(*rhs.base_); + } + }; + + constexpr map_view() noexcept = default; + constexpr map_view(V base) : base_{std::move(base)} {} + + [[nodiscard]] constexpr iterator begin() const { + return { base_ }; + } + [[nodiscard]] constexpr sentinel end() const { + return {}; + } + private: + V base_; + }; + + template + map_view(Range&&) -> map_view>; + +} // namespace msgpack + + +#endif // msgpack_token_views_f19c250e782ed51c diff --git a/source/server.cpp b/source/server.cpp index 3f12098..586ed15 100644 --- a/source/server.cpp +++ b/source/server.cpp @@ -22,6 +22,11 @@ #include "parselink/server.h" #include "parselink/msgpack/token/reader.h" +#include "parselink/msgpack/token/views.h" +#include "parselink/proto/message.h" + +#include + #include #include @@ -63,6 +68,54 @@ struct fmt::formatter } }; +template <> +struct fmt::formatter { + template + constexpr auto parse(ParseContext& ctx) -> decltype(ctx.begin()) { + return ctx.begin(); + } + template + auto format(msgpack::token const& v, FormatContext& ctx) const { + using parselink::logging::themed_arg; + auto out = fmt::format_to(ctx.out(), "()))); + break; + case msgpack::format::type::signed_int: + out = fmt::format_to(out, "{}", themed_arg(*(v.get()))); + break; + case msgpack::format::type::boolean: + out = fmt::format_to(out, "{}", themed_arg(*(v.get()))); + break; + case msgpack::format::type::string: + out = fmt::format_to(out, "{}", themed_arg(*(v.get()))); + break; + case msgpack::format::type::binary: + out = fmt::format_to(out, "{}", + themed_arg(*(v.get>()))); + break; + case msgpack::format::type::map: + out = fmt::format_to(out, "(arity: {})", + themed_arg(v.get()->count)); + break; + case msgpack::format::type::array: + out = fmt::format_to(out, "(arity: {})", + themed_arg(v.get()->count)); + break; + case msgpack::format::type::nil: + out = fmt::format_to(out, "(nil)"); + break; + case msgpack::format::type::invalid: + out = fmt::format_to(out, "(invalid)"); + break; + default: + break; + } + return fmt::format_to(out, ">"); + } +}; + template concept endpoint = requires(T const& t) { {t.address()}; @@ -102,6 +155,9 @@ public: user_session(net::ip::tcp::socket sock) : socket_(std::move(sock)) {} ~user_session() { logger.debug("Closing connection to {}", socket_.remote_endpoint()); + boost::system::error_code ec; + socket_.shutdown(net::ip::tcp::socket::shutdown_both, ec); + socket_.close(); } void start() { @@ -111,40 +167,111 @@ public: }, detached); } - awaitable await_auth() { - std::array buffer; - auto [ec, n] = co_await socket_.async_read_some( - net::buffer(buffer), no_ex_coro); - if (ec) { - logger.error("Reading from user socket failed: {}", ec); - co_return; - } - logger.debug("Read {} bytes from client: {}", n, - std::string_view(reinterpret_cast(buffer.data()), n)); - - // TODO(ksassenrath): Clean this part up. This could be handled in its - // own read_message_header() method. - auto reader = msgpack::token_reader(std::span(buffer.data(), n)); + tl::expected, msgpack::error> parse_header( + std::span data) noexcept { + auto reader = msgpack::token_reader(data); auto magic = reader.read_one().map( [](auto t){ return t == std::string_view{"prs"}; }); if (magic && *magic) { logger.debug("Got magic from client"); } else { - logger.debug("Got error from client: {}", magic); - co_return; + logger.error("Failed to get magic from client: {}", magic); + return tl::unexpected(magic.error()); } auto sz = reader.read_one().and_then( [](auto t){ return t.template get(); }); if (sz && *sz) { logger.debug("Got packet size from client: {}", *sz); } else { - logger.debug("Got error from client: {}", sz); - co_return; + logger.debug("Failed to get packet size from client: {}", sz); + return tl::unexpected(magic.error()); } // Copy the rest of the message to the buffer for parsing. - std::vector msg(*sz); + // TODO(ksassenrath): Replace vector with custom buffer. + std::vector msg; + msg.reserve(*sz); + msg.resize(reader.remaining()); std::copy(reader.current(), reader.end(), msg.begin()); - //auto [ec, n] = co_await socket_.async_read_some(net::buffer()); + return msg; + } + + awaitable> + buffer_message(std::vector& buffer) noexcept { + auto amt = buffer.size(); + auto total = buffer.capacity(); + buffer.resize(total); + + while (amt < total) { + auto subf = std::span(buffer.begin() + amt, buffer.end()); + auto [ec, n] = co_await socket_.async_read_some( + net::buffer(subf), no_ex_coro); + logger.debug("Read {} bytes, total is now {}", n, amt + n); + if (ec || n == 0) { + logger.error("Reading from user socket failed: {}", ec); + co_return tl::make_unexpected(ec); + } + amt += n; + } + co_return std::monostate{}; + } + + tl::expected handle_auth(std::span tokens) { + auto message_type = tokens.begin()->get(); + if (message_type) { + logger.debug("Received '{}' packet. Parsing body", *message_type); + proto::connect_message message; + for (auto const& [k, v] : msgpack::map_view(tokens.subspan(1))) { + logger.debug("Parsing {} -> {}", k, v); + } + } else { + logger.error("Did not get message type: {}", message_type.error()); + } + // The first token should include + return true; + } + + awaitable await_auth() noexcept { + std::array buffer; + auto [ec, n] = co_await socket_.async_read_some( + net::buffer(buffer), no_ex_coro); + + if (ec) { + logger.error("Reading from user socket failed: {}", ec); + co_return; + } + + logger.debug("Read {} bytes from client: {}", n, + std::span(buffer.data(), n)); + + auto hdr_result = parse_header(std::span(buffer.data(), n)); + if (!hdr_result) { + co_return; + } + + auto msg = std::move(*hdr_result); + auto maybe_error = co_await buffer_message(msg); + + if (!maybe_error) { + logger.error("Unable to buffer message: {}", + maybe_error.error()); + co_return; + } + + logger.trace("Message: {}", msg); + + auto reader = msgpack::token_reader(msg); + std::array tokens; + auto parsed = reader.read_some(tokens).and_then( + [this](auto c) { for (auto t : c) logger.trace("{}", t); return handle_auth(c); }) + .or_else([](auto const& error) { + logger.error("Unable to parse msgpack tokens: {}", error); + }); + + if (!parsed) { + co_return; + } + + // Authenticate against database. } enum class state { diff --git a/tests/msgpack/BUILD b/tests/msgpack/BUILD index 713872b..aa295cd 100644 --- a/tests/msgpack/BUILD +++ b/tests/msgpack/BUILD @@ -3,7 +3,7 @@ cc_library( name = "test_deps", srcs = [ "test_main.cpp", - "rng.h" + "rng.h", ], deps = [ "//include/parselink:msgpack", @@ -47,6 +47,14 @@ cc_test( deps = ["test_deps"], ) +cc_test( + name = "token_views", + srcs = [ + "test_token_views.cpp", + ], + deps = ["test_deps"], +) + cc_binary( name = "speed", srcs = [ diff --git a/tests/msgpack/test_token_views.cpp b/tests/msgpack/test_token_views.cpp new file mode 100644 index 0000000..48daff8 --- /dev/null +++ b/tests/msgpack/test_token_views.cpp @@ -0,0 +1,219 @@ +#include + +#include +#include +#include +#include + +using namespace boost::ut; +namespace { +template +constexpr bool operator==(std::span a, std::span b) noexcept { + return std::equal(a.begin(), a.end(), b.begin(), b.end()); +} + +template +constexpr std::array as_bytes(T&& t) { + return std::bit_cast>(std::forward(t)); +} + +template +constexpr std::array make_bytes(Bytes &&...bytes) { + return {std::byte(std::forward(bytes))...}; +} + +template struct oversized_array { + std::array data; + std::size_t size; +}; + + +constexpr auto to_bytes_array_oversized(auto const &container) { + using value_type = std::decay_t; + oversized_array arr; + std::copy(std::begin(container), std::end(container), std::begin(arr.data)); + arr.size = std::distance(std::begin(container), std::end(container)); + return arr; +} + +consteval auto generate_bytes(auto callable) { + constexpr auto oversized = to_bytes_array_oversized(callable()); + using value_type = std::decay_t; + std::array out; + std::copy(std::begin(oversized.data), + std::next(std::begin(oversized.data), oversized.size), + std::begin(out)); + return out; +} + +consteval auto build_string(auto callable) { + constexpr auto string_array = generate_bytes(callable); + return string_array; +} + +template +constexpr auto cat(std::arrayconst&... a) noexcept { + std::array out; + std::size_t index{}; + ((std::copy_n(a.begin(), Sizes, out.begin() + index), index += Sizes), ...); + return out; +} + +constexpr auto repeat(std::span sv, std::size_t count) { + std::vector range; + range.reserve(sv.size() * count); + for (decltype(count) i = 0; i < count; ++i) { + std::copy_n(sv.begin(), sv.size(), std::back_inserter(range)); + } + return range; +} + +constexpr auto repeat(std::string_view sv, std::size_t count) { + std::vector range; + range.reserve(sv.size() * count); + for (decltype(count) i = 0; i < count; ++i) { + std::copy_n(sv.begin(), sv.size(), std::back_inserter(range)); + } + return range; +} + +constexpr auto from_string_view(std::string_view sv) { + std::vector range; + range.resize(sv.size()); + auto itr = range.begin(); + for (auto c : sv) { + *itr = std::byte(c); + ++itr; + } + return range; +} + +template + std::ostream &operator<<(std::ostream &os, tl::expected const &exp) { + if (exp.has_value()) { + os << "Value: '" << *exp << "'"; + } else { + os << "Error"; + } + return os; + } + +bool test_incomplete_message(auto const& payload) { + // Test incomplete message. + for (decltype(payload.size()) i = 1; i < payload.size() - 1; ++i) { + // Test incomplete message. + msgpack::token_reader reader(std::span(payload.data(), i)); + auto token = reader.read_one(); + if (token != tl::make_unexpected(msgpack::error::incomplete_message)) { + fmt::print("Got the wrong response reading subview[0,{}] of payload: {}\n", + i, + token->get().value()); + return false; + } + } + return true; +} + +bool test_end_of_message(auto& reader) { + return reader.read_one() == + tl::make_unexpected(msgpack::error::end_of_message); +} +} + +template <> +struct fmt::formatter { + template + constexpr auto parse(ParseContext& ctx) -> decltype(ctx.begin()) { + return ctx.begin(); + } + template + auto format(msgpack::token const& v, FormatContext& ctx) const { + auto out = fmt::format_to(ctx.out(), "()))); + break; + case msgpack::format::type::signed_int: + out = fmt::format_to(out, "{}", (*(v.get()))); + break; + case msgpack::format::type::boolean: + out = fmt::format_to(out, "{}", (*(v.get()))); + break; + case msgpack::format::type::string: + out = fmt::format_to(out, "{}", (*(v.get()))); + break; + case msgpack::format::type::binary: + out = fmt::format_to(out, "{}", + (*(v.get>()))); + break; + case msgpack::format::type::map: + out = fmt::format_to(out, "(arity: {})", + (v.get()->count)); + break; + case msgpack::format::type::array: + out = fmt::format_to(out, "(arity: {})", + (v.get()->count)); + break; + case msgpack::format::type::nil: + out = fmt::format_to(out, "(nil)"); + break; + case msgpack::format::type::invalid: + out = fmt::format_to(out, "(invalid)"); + break; + default: + break; + } + return fmt::format_to(out, ">"); + } +}; + +suite views_tests = [] { + "read format::fixmap"_test = [] { + // A MessagePack map of 3 strings to 8-bit unsigned integers. + static constexpr auto strings = std::to_array({ + "one", "two", "three", "array", "four", "map", "five", + }); + + constexpr auto payload = cat( + make_bytes(0x87, 0xa3), + generate_bytes([] { return from_string_view(strings[0]); }), + + make_bytes(0x01, 0xa3), + generate_bytes([] { return from_string_view(strings[1]); }), + + make_bytes(0x02, 0xa5), + generate_bytes([] { return from_string_view(strings[2]); }), + + make_bytes(0x03, 0xa5), + generate_bytes([] { return from_string_view(strings[3]); }), + + // Array of size 2, two fixints: + make_bytes(0x92, 0x32, 0x33), + + make_bytes(0xa4), + generate_bytes([] { return from_string_view(strings[4]); }), + + make_bytes(0x04, 0xa3), + generate_bytes([] { return from_string_view(strings[5]); }), + + // Map of size 3, 6 total fixints. + make_bytes(0x83, 0x01, 0x02, 0x03, 0x04, 0x33, 0x11), + + make_bytes(0xa4), + generate_bytes([] { return from_string_view(strings[6]); }) + //make_bytes(0x05) + ); + + std::array tokens; + + msgpack::token_reader reader(payload); + + auto result = reader.read_some(tokens).map([](auto read_tokens) { + for (auto const& [k, v] : msgpack::map_view(read_tokens)) { + fmt::print("map[{}] = {}\n", k, v); + } + }); + + expect(result.has_value()); + }; +}; diff --git a/tests/msgpack/test_utils.h b/tests/msgpack/test_utils.h new file mode 100644 index 0000000..e05afe7 --- /dev/null +++ b/tests/msgpack/test_utils.h @@ -0,0 +1,98 @@ +#ifndef msgpack_test_utils_4573e6627d8efe78 +#define msgpack_test_utils_4573e6627d8efe78 + +#include +#include +#include + +template +constexpr bool operator==(std::span a, std::span b) noexcept { + return std::equal(a.begin(), a.end(), b.begin(), b.end()); +} + +template +constexpr std::array as_bytes(T&& t) { + return std::bit_cast>(std::forward(t)); +} + +template +constexpr std::array make_bytes(Bytes &&...bytes) { + return {std::byte(std::forward(bytes))...}; +} + +template struct oversized_array { + std::array data; + std::size_t size; +}; + +constexpr auto to_bytes_array_oversized(auto const &container) { + using value_type = std::decay_t; + oversized_array arr; + std::copy(std::begin(container), std::end(container), std::begin(arr.data)); + arr.size = std::distance(std::begin(container), std::end(container)); + return arr; +} + +consteval auto generate_bytes(auto callable) { + constexpr auto oversized = to_bytes_array_oversized(callable()); + using value_type = std::decay_t; + std::array out; + std::copy(std::begin(oversized.data), + std::next(std::begin(oversized.data), oversized.size), + std::begin(out)); + return out; +} + +consteval auto build_string(auto callable) { + constexpr auto string_array = generate_bytes(callable); + return string_array; +} + +template +constexpr auto cat(std::arrayconst&... a) noexcept { + std::array out; + std::size_t index{}; + ((std::copy_n(a.begin(), Sizes, out.begin() + index), index += Sizes), ...); + return out; +} + +constexpr auto repeat(std::span sv, std::size_t count) { + std::vector range; + range.reserve(sv.size() * count); + for (decltype(count) i = 0; i < count; ++i) { + std::copy_n(sv.begin(), sv.size(), std::back_inserter(range)); + } + return range; +} + +constexpr auto repeat(std::string_view sv, std::size_t count) { + std::vector range; + range.reserve(sv.size() * count); + for (decltype(count) i = 0; i < count; ++i) { + std::copy_n(sv.begin(), sv.size(), std::back_inserter(range)); + } + return range; +} + +constexpr auto from_string_view(std::string_view sv) { + std::vector range; + range.resize(sv.size()); + auto itr = range.begin(); + for (auto c : sv) { + *itr = std::byte(c); + ++itr; + } + return range; +} + +template +std::ostream &operator<<(std::ostream &os, tl::expected const &exp) { + if (exp.has_value()) { + os << "Value: '" << *exp << "'"; + } else { + os << "Error"; + } + return os; +} + +#endif // msgpack_test_utils_4573e6627d8efe78