parselink-old/source/proto/session.cpp

154 lines
5.7 KiB
C++

//-----------------------------------------------------------------------------
// ___ __ _ _
// / _ \__ _ _ __ ___ ___ / /(_)_ __ | | __
// / /_)/ _` | '__/ __|/ _ \/ / | | '_ \| |/ /
// / ___/ (_| | | \__ \ __/ /__| | | | | <
// \/ \__,_|_| |___/\___\____/_|_| |_|_|\_\ .
//
//-----------------------------------------------------------------------------
// Author: Kurt Sassenrath
// Module: Server
//
// Server implementation. Currently, a monolithic server which:
// * Communicates with users via TCP (msgpack).
// * Runs the websocket server for overlays to read.
//
// Copyright (c) 2023 Kurt Sassenrath.
//
// License TBD.
//-----------------------------------------------------------------------------
#include "parselink/proto/session.h"
#include "parselink/logging.h"
#include "parselink/msgpack/token.h"
#include <fmt/ranges.h>
using namespace parselink;
using namespace parselink::proto;
template <>
struct fmt::formatter<msgpack::token> {
template <typename ParseContext>
constexpr auto parse(ParseContext& ctx) -> decltype(ctx.begin()) {
return ctx.begin();
}
template <typename FormatContext>
auto format(msgpack::token const& v, FormatContext& ctx) const {
using parselink::logging::themed_arg;
auto out = fmt::format_to(
ctx.out(), "<msgpack {} = ", themed_arg(v.type()));
switch (v.type()) {
case msgpack::format::type::unsigned_int:
fmt::format_to(
out, "{}", themed_arg(*(v.get<std::uint64_t>())));
break;
case msgpack::format::type::signed_int:
out = fmt::format_to(
out, "{}", themed_arg(*(v.get<std::uint64_t>())));
break;
case msgpack::format::type::boolean:
out = fmt::format_to(out, "{}", themed_arg(*(v.get<bool>())));
break;
case msgpack::format::type::string:
out = fmt::format_to(
out, "{}", themed_arg(*(v.get<std::string_view>())));
break;
case msgpack::format::type::binary:
out = fmt::format_to(out, "{}",
themed_arg(*(v.get<std::span<std::byte const>>())));
break;
case msgpack::format::type::map:
out = fmt::format_to(out, "(arity: {})",
themed_arg(v.get<msgpack::map_desc>()->count));
break;
case msgpack::format::type::array:
out = fmt::format_to(out, "(arity: {})",
themed_arg(v.get<msgpack::array_desc>()->count));
break;
case msgpack::format::type::nil:
out = fmt::format_to(out, "(nil)");
break;
case msgpack::format::type::invalid:
out = fmt::format_to(out, "(invalid)");
break;
default: break;
}
return fmt::format_to(out, ">");
}
};
namespace {
logging::logger logger("session");
}
tl::expected<header_info, error> session::parse_header(
std::span<std::byte const> buffer) noexcept {
auto reader = msgpack::token_reader(buffer);
auto magic = reader.read_one();
if (!magic || *magic != "prs") {
logger.error("Failed to parse magic");
return tl::unexpected(magic ? error::bad_data : error::incomplete);
}
auto size = reader.read_one().and_then(
[](auto t) { return t.template get<std::uint32_t>(); });
if (!size || !*size) {
logger.error("Failed to get valid message size");
return tl::unexpected(magic ? error::bad_data : error::incomplete);
}
if (*size > max_message_size) {
logger.error("Message {} exceeds max size {}", *size, max_message_size);
return tl::unexpected(error::too_large);
}
std::uint32_t amt = std::distance(buffer.begin(), reader.current());
return tl::expected<header_info, error>(tl::in_place, *size, amt);
}
tl::expected<std::monostate, error> session::handle_connect(
std::span<msgpack::token> tokens) noexcept {
auto message_type = tokens.begin()->get<std::string_view>();
if (message_type && message_type == "connect") {
logger.debug("Received '{}' packet. Parsing body", *message_type);
auto map = msgpack::map_view(tokens.subspan(1));
constexpr auto find_version = [](auto const& kv) {
return kv.first == "version";
};
auto field = std::ranges::find_if(map, find_version);
if (field != map.end() && (*field).second == 1u) {
logger.debug("Version {}", (*field).second.get<std::uint32_t>());
} else {
logger.error("connect failed: missing / unsupported version");
return tl::make_unexpected(error::unsupported);
}
for (auto const& [k, v] : msgpack::map_view(tokens.subspan(1))) {
tl::expected<void, msgpack::error> result;
if (k.type() != msgpack::format::type::string) {
logger.error("connect failed: non-string key {}", k);
return tl::make_unexpected(error::bad_data);
}
if (k == "user_id") {
result = v.get<std::string>().map(
[this](auto uid) { user_id = std::move(uid); });
}
if (!result) {
logger.error(
"connect failed: {} -> {}: {}", k, v, result.error());
return tl::make_unexpected(error::bad_data);
}
}
} else {
logger.error("Did not get message type: {}", message_type.error());
return tl::make_unexpected(error::bad_data);
}
return {};
}