parselink-old/source/proto/session.cpp
2023-10-26 07:23:51 -07:00

172 lines
6.1 KiB
C++

//-----------------------------------------------------------------------------
// ___ __ _ _
// / _ \__ _ _ __ ___ ___ / /(_)_ __ | | __
// / /_)/ _` | '__/ __|/ _ \/ / | | '_ \| |/ /
// / ___/ (_| | | \__ \ __/ /__| | | | | <
// \/ \__,_|_| |___/\___\____/_|_| |_|_|\_\ .
//
//-----------------------------------------------------------------------------
// Author: Kurt Sassenrath
// Module: Server
//
// Server implementation. Currently, a monolithic server which:
// * Communicates with users via TCP (msgpack).
// * Runs the websocket server for overlays to read.
//
// Copyright (c) 2023 Kurt Sassenrath.
//
// License TBD.
//-----------------------------------------------------------------------------
#include "parselink/proto/session.h"
#include "parselink/logging.h"
#include "parselink/msgpack/token.h"
#include <fmt/ranges.h>
using namespace parselink;
using namespace parselink::proto;
template <>
struct fmt::formatter<msgpack::token> {
template <typename ParseContext>
constexpr auto parse(ParseContext& ctx) -> decltype(ctx.begin()) {
return ctx.begin();
}
template <typename FormatContext>
auto format(msgpack::token const& v, FormatContext& ctx) const {
using parselink::logging::themed_arg;
auto out = fmt::format_to(
ctx.out(), "<msgpack {} = ", themed_arg(v.type()));
switch (v.type()) {
case msgpack::format::type::unsigned_int:
fmt::format_to(
out, "{}", themed_arg(*(v.get<std::uint64_t>())));
break;
case msgpack::format::type::signed_int:
out = fmt::format_to(
out, "{}", themed_arg(*(v.get<std::uint64_t>())));
break;
case msgpack::format::type::boolean:
out = fmt::format_to(out, "{}", themed_arg(*(v.get<bool>())));
break;
case msgpack::format::type::string:
out = fmt::format_to(
out, "{}", themed_arg(*(v.get<std::string_view>())));
break;
case msgpack::format::type::binary:
out = fmt::format_to(out, "{}",
themed_arg(*(v.get<std::span<std::byte const>>())));
break;
case msgpack::format::type::map:
out = fmt::format_to(out, "(arity: {})",
themed_arg(v.get<msgpack::map_desc>()->count));
break;
case msgpack::format::type::array:
out = fmt::format_to(out, "(arity: {})",
themed_arg(v.get<msgpack::array_desc>()->count));
break;
case msgpack::format::type::nil:
out = fmt::format_to(out, "(nil)");
break;
case msgpack::format::type::invalid:
out = fmt::format_to(out, "(invalid)");
break;
default: break;
}
return fmt::format_to(out, ">");
}
};
namespace {
logging::logger logger("session");
constexpr std::uint32_t max_size = 128 * 1024;
constexpr tl::expected<std::uint32_t, error> check_support(
tl::expected<msgpack::token, error> const& val) {
if (val == 1u) {
return *(*val).get<std::uint32_t>();
}
return tl::make_unexpected(error::unsupported);
}
} // namespace
tl::expected<header_info, error> proto::parse_header(
std::span<std::byte const> buffer) noexcept {
auto reader = msgpack::token_reader(buffer);
auto magic = reader.read_one();
if (!magic || *magic != "prs") {
logger.error("Failed to parse magic");
return tl::unexpected(magic ? error::bad_data : error::incomplete);
}
auto size = reader.read_one().and_then(
[](auto t) { return t.template get<std::uint32_t>(); });
if (!size || !*size) {
logger.error("Failed to get valid message size");
return tl::unexpected(magic ? error::bad_data : error::incomplete);
}
if (*size > max_size) {
logger.error("Message size {} exceeds max {}", *size, max_size);
return tl::unexpected(error::too_large);
}
std::uint32_t amt = std::distance(buffer.begin(), reader.current());
return tl::expected<header_info, error>(tl::in_place, *size, amt);
}
constexpr tl::expected<msgpack::token, error> lookup(
auto const& map_view, auto const& key) {
auto find_key = [&key](auto const& kv) { return kv.first == key; };
if (auto field = std::ranges::find_if(map_view, find_key);
field != map_view.end()) {
return (*field).second;
}
return tl::make_unexpected(error::bad_data);
}
tl::expected<connect_info, error> proto::parse_connect(
std::span<msgpack::token> tokens) noexcept {
auto message_type = tokens.begin()->get<std::string_view>();
if (message_type && message_type == "connect") {
logger.debug("Received '{}' packet. Parsing body", *message_type);
auto map = msgpack::map_view(tokens.subspan(1));
connect_info info;
auto version = lookup(map, "version").and_then(check_support);
if (version) {
info.version = *version;
} else {
return tl::make_unexpected(version.error());
}
for (auto const& [k, v] : msgpack::map_view(tokens.subspan(1))) {
tl::expected<void, msgpack::error> result;
if (k.type() != msgpack::format::type::string) {
logger.error("connect failed: non-string key {}", k);
return tl::make_unexpected(error::bad_data);
}
if (k == "user_id") {
result = v.get<std::string_view>().map(
[&info](auto uid) { info.user_id = uid; });
}
if (!result) {
logger.error(
"connect failed: {} -> {}: {}", k, v, result.error());
return tl::make_unexpected(error::bad_data);
}
}
return info;
} else {
logger.error("Did not get message type: {}", message_type.error());
return tl::make_unexpected(error::bad_data);
}
}