parselink-old/source/server.cpp
Kurt Sassenrath b2a6f25306 Start proto::parser, remove msgpack::reader
- proto::parser will likely contain helper functions for parsing
  messages from a buffer, for now it will explicitly define message
  parsers for each available message. It also leverages the new unpacker
  API, which has type safety in mind. These messages should not be
  unstructured, so it doesn't make sense to use the token API.
2024-01-12 19:36:56 -08:00

421 lines
14 KiB
C++

//-----------------------------------------------------------------------------
// ___ __ _ _
// / _ \__ _ _ __ ___ ___ / /(_)_ __ | | __
// / /_)/ _` | '__/ __|/ _ \/ / | | '_ \| |/ /
// / ___/ (_| | | \__ \ __/ /__| | | | | <
// \/ \__,_|_| |___/\___\____/_|_| |_|_|\_\ .
//
//-----------------------------------------------------------------------------
// Author: Kurt Sassenrath
// Module: Server
//
// Server implementation. Currently, a monolithic server which:
// * Communicates with users via TCP (msgpack).
// * Runs the websocket server for overlays to read.
//
// Copyright (c) 2023 Kurt Sassenrath.
//
// License TBD.
//-----------------------------------------------------------------------------
#include "parselink/server.h"
#include "parselink/logging.h"
#include "parselink/msgpack/core/packer.h"
#include "parselink/msgpack/core/unpacker.h"
#include "parselink/proto/parser.h"
#include "parselink/proto/session.h"
#include "parselink/utility/file.h"
#include <fmt/ranges.h>
#include "hydrogen.h"
#include <boost/asio/as_tuple.hpp>
#include <boost/asio/bind_executor.hpp>
#include <boost/asio/co_spawn.hpp>
#include <boost/asio/deferred.hpp>
#include <boost/asio/detached.hpp>
#include <boost/asio/io_context.hpp>
#include <boost/asio/ip/address.hpp>
#include <boost/asio/ip/tcp.hpp>
#include <boost/asio/redirect_error.hpp>
#include <boost/asio/signal_set.hpp>
#include <boost/asio/strand.hpp>
#include <boost/asio/write.hpp>
#include <unordered_set>
using namespace parselink;
#include <fmt/ranges.h>
using namespace parselink;
using namespace std::chrono_literals;
namespace net = boost::asio;
using net::awaitable;
using net::co_spawn;
using net::deferred;
using net::detached;
using net::use_awaitable;
//-----------------------------------------------------------------------------
// TODO(ksassenrath): These are logging formatters for various boost/asio types.
// Not all code is exposed to them, so they cannot be defined inside the
// generic logging/formatters.h header. They should go somewhere else.
//-----------------------------------------------------------------------------
template <>
struct parselink::logging::theme<boost::system::error_code>
: parselink::logging::static_theme<fmt::color::fire_brick> {};
template <>
struct fmt::formatter<boost::system::error_code>
: fmt::formatter<std::string_view> {
template <typename FormatContext>
constexpr auto format(auto const& v, FormatContext& ctx) const {
return fmt::formatter<std::string_view>::format(v.message(), ctx);
}
};
template <>
struct fmt::formatter<tl::monostate> : fmt::formatter<std::string_view> {
constexpr auto format(auto const&, auto& ctx) const {
return fmt::formatter<std::string_view>::format(".", ctx);
}
};
template <>
struct fmt::formatter<msgpack::token> {
template <typename ParseContext>
constexpr auto parse(ParseContext& ctx) -> decltype(ctx.begin()) {
return ctx.begin();
}
template <typename FormatContext>
auto format(msgpack::token const& v, FormatContext& ctx) const {
using parselink::logging::themed_arg;
auto out = fmt::format_to(
ctx.out(), "<msgpack {} = ", themed_arg(v.type()));
switch (v.type()) {
case msgpack::format::type::unsigned_int:
fmt::format_to(
out, "{}", themed_arg(*(v.get<std::uint64_t>())));
break;
case msgpack::format::type::signed_int:
out = fmt::format_to(
out, "{}", themed_arg(*(v.get<std::uint64_t>())));
break;
case msgpack::format::type::boolean:
out = fmt::format_to(out, "{}", themed_arg(*(v.get<bool>())));
break;
case msgpack::format::type::string:
out = fmt::format_to(
out, "{}", themed_arg(*(v.get<std::string_view>())));
break;
case msgpack::format::type::binary:
out = fmt::format_to(out, "{}",
themed_arg(*(v.get<std::span<std::byte const>>())));
break;
case msgpack::format::type::map:
out = fmt::format_to(out, "(arity: {})",
themed_arg(v.get<msgpack::map_desc>()->count));
break;
case msgpack::format::type::array:
out = fmt::format_to(out, "(arity: {})",
themed_arg(v.get<msgpack::array_desc>()->count));
break;
case msgpack::format::type::nil:
out = fmt::format_to(out, "(nil)");
break;
case msgpack::format::type::invalid:
out = fmt::format_to(out, "(invalid)");
break;
default: break;
}
return fmt::format_to(out, ">");
}
};
template <typename T>
concept endpoint = requires(T const& t) {
{ t.address() };
{ t.port() };
};
template <endpoint T>
struct parselink::logging::theme<T>
: parselink::logging::static_theme<fmt::color::coral> {};
template <endpoint T>
struct fmt::formatter<T> : fmt::formatter<std::string_view> {
template <typename FormatContext>
constexpr auto format(auto const& v, FormatContext& ctx) const {
return fmt::format_to(
ctx.out(), "{}:{}", v.address().to_string(), v.port());
}
};
//-----------------------------------------------------------------------------
// End formatters
//-----------------------------------------------------------------------------
namespace {
logging::logger logger("server");
constexpr auto no_ex_coro = net::as_tuple(use_awaitable);
constexpr auto no_ex_defer = net::as_tuple(deferred);
} // namespace
#include <parselink/server/memory_session_manager.h>
class user_connection;
class monolithic_server : public server {
public:
monolithic_server(std::string_view address, std::uint16_t user_port,
std::uint16_t websocket_port);
std::error_code run() noexcept override;
net::awaitable<tl::expected<proto::session*, proto::error>> create_session(
std::shared_ptr<user_connection> const& conn,
proto::connect_message const& info);
private:
awaitable<void> user_listen();
tl::expected<tl::monostate, std::errc> load_keys() noexcept;
hydro_kx_keypair kp_;
net::io_context io_context_;
net::io_context::strand session_strand_;
memory_session_manager session_mgr_;
net::ip::address addr_;
std::uint16_t user_port_;
std::uint16_t websocket_port_;
};
class user_connection : public std::enable_shared_from_this<user_connection> {
public:
user_connection(monolithic_server& server, net::ip::tcp::socket sock)
: server_(server)
, socket_(std::move(sock)) {}
~user_connection() { stop(); }
void stop() {
logger.debug("Connection to {} closed.", socket_.remote_endpoint());
boost::system::error_code ec;
socket_.shutdown(net::ip::tcp::socket::shutdown_both, ec);
socket_.close();
}
void start() {
logger.debug("New connection from {}", socket_.remote_endpoint());
co_spawn(
socket_.get_executor(),
[self = shared_from_this()] { return self->await_connect(); },
detached);
}
awaitable<tl::expected<tl::monostate, proto::error>> read_message(
std::vector<std::byte>& buffer) noexcept {
// Use a small buffer on the stack to read the initial header.
std::array<std::byte, 8> hdrbuff;
auto [ec, n] = co_await socket_.async_read_some(
net::buffer(hdrbuff), no_ex_coro);
if (ec) {
logger.error("Reading hdr from user socket failed: {}", ec);
co_return tl::make_unexpected(proto::error::system_error);
}
logger.debug("Read {} bytes from client: {}", n,
std::span(hdrbuff.data(), n));
auto maybe_hdr = proto::parse_header(std::span(hdrbuff.data(), n));
if (!maybe_hdr) {
logger.error("Unable to parse header: {}", maybe_hdr.error());
co_return tl::make_unexpected(maybe_hdr.error());
}
// TODO(ksassenrath): Replace with specialized allocator.
buffer.resize(maybe_hdr->message_size);
// Copy remaining portion of message in initial read to the message
// buffer.
std::copy(std::next(hdrbuff.begin(), maybe_hdr->bytes_read),
std::next(hdrbuff.begin(), n), buffer.begin());
auto span = std::span(
buffer.begin() + n - maybe_hdr->bytes_read, buffer.end());
// Buffer remaining message.
std::size_t amt = 0;
while (amt < span.size()) {
auto subsp = span.subspan(amt);
auto [ec, n] = co_await socket_.async_read_some(
net::buffer(subsp), no_ex_coro);
logger.debug("Read {} bytes, total is now {}", n, amt + n);
if (ec || n == 0) {
logger.error("Reading msg from user socket failed: {}", ec);
co_return tl::make_unexpected(proto::error::system_error);
}
amt += n;
}
co_return tl::monostate{};
}
awaitable<void> await_connect() noexcept {
std::vector<std::byte> msgbuf;
if (auto maybe_msg = co_await read_message(msgbuf); !maybe_msg) {
logger.debug("returning");
co_return;
}
auto connect_info = proto::parse<proto::connect_message>(msgbuf);
if (!connect_info) {
logger.error("Session failed: {}", connect_info.error());
co_return;
}
auto session = co_await server_.create_session(
shared_from_this(), *connect_info);
if (!session) {
co_return;
}
session_ = *session;
}
enum class state { init, authenticated, active };
monolithic_server& server_;
net::ip::tcp::socket socket_;
proto::session* session_{};
};
monolithic_server::monolithic_server(std::string_view address,
std::uint16_t user_port, std::uint16_t websocket_port)
: io_context_{1}
, session_strand_(io_context_)
, session_mgr_(io_context_)
, addr_(net::ip::address::from_string(std::string{address}))
, user_port_{user_port}
, websocket_port_{websocket_port} {
logger.debug("Loaded keys: {}", load_keys());
logger.debug("Creating monolithic_server(address = {}, user_port = {}, "
"websocket_port = {})",
address, user_port_, websocket_port_);
}
awaitable<void> monolithic_server::user_listen() {
auto exec = co_await net::this_coro::executor;
net::ip::tcp::acceptor acceptor{exec, {addr_, user_port_}};
while (true) {
std::make_shared<user_connection>(
*this, co_await acceptor.async_accept(use_awaitable))
->start();
}
}
std::error_code monolithic_server::run() noexcept {
logger.info("Starting server.");
net::signal_set signals(io_context_, SIGINT, SIGTERM);
signals.async_wait([&](auto, auto) {
logger.info("Received signal... Shutting down.");
io_context_.stop();
});
co_spawn(io_context_, user_listen(), detached);
io_context_.run();
return {};
}
tl::expected<tl::monostate, std::errc> monolithic_server::load_keys() noexcept {
std::string_view filename = "/home/rihya/server_kp.keys";
enum class result { loaded, generated };
auto load_key = [this](auto const& raw) -> tl::expected<result, std::errc> {
msgpack::unpacker unpacker(raw);
auto pk = unpacker.unpack<std::span<std::byte const, sizeof(kp_.pk)>>();
if (!pk) return tl::make_unexpected(std::errc::bad_message);
auto sk = unpacker.unpack<std::span<std::byte const, sizeof(kp_.sk)>>();
if (!sk) return tl::make_unexpected(std::errc::bad_message);
std::ranges::transform(pk->begin(), pk->end(), std::begin(kp_.pk),
[](auto c) { return std::bit_cast<unsigned char>(c); });
std::ranges::transform(sk->begin(), sk->end(), std::begin(kp_.sk),
[](auto c) { return std::bit_cast<unsigned char>(c); });
return result::loaded;
};
auto generate_keys =
[this](auto const& err) -> tl::expected<result, std::errc> {
logger.warning("Could not load server keys, generating a keypair");
hydro_kx_keygen(&kp_);
return result::generated;
};
auto commit = [this](auto const& handle)
-> tl::expected<tl::monostate, std::errc> {
std::vector<std::byte> buff(4 + sizeof(kp_.pk) + sizeof(kp_.sk));
std::span<std::byte> pk(
reinterpret_cast<std::byte*>(kp_.pk), sizeof(kp_.pk));
std::span<std::byte> sk(
reinterpret_cast<std::byte*>(kp_.sk), sizeof(kp_.pk));
msgpack::packer packer(buff);
packer.pack(pk);
packer.pack(sk);
return utility::file::write(handle, packer.subspan()).map([](auto) {
logger.info("Wrote new keys to disk");
return tl::monostate{};
});
};
auto load_or_generate_keys = [&load_key, &generate_keys, &commit](
auto const& handle)
-> tl::expected<tl::monostate, std::errc> {
return utility::file::read<std::byte>(handle)
.and_then(load_key)
.or_else(generate_keys)
.and_then([&handle, &commit](auto r)
-> tl::expected<tl::monostate, std::errc> {
if (r == result::generated) {
return commit(handle);
}
return tl::monostate{};
});
};
return utility::file::open(filename, O_RDWR | O_CREAT, S_IRUSR | S_IWUSR)
.and_then(load_or_generate_keys);
}
net::awaitable<tl::expected<proto::session*, proto::error>>
monolithic_server::create_session(std::shared_ptr<user_connection> const& conn,
proto::connect_message const& info) {
// Move to session strand.
co_await net::post(session_strand_,
net::bind_executor(session_strand_, use_awaitable));
auto session = session_mgr_.create_session(info.user_id);
logger.info("Created session: {}", session);
co_return session;
}
std::unique_ptr<server> parselink::make_server(std::string_view address,
std::uint16_t user_port, std::uint16_t websocket_port) {
using impl = monolithic_server;
return std::make_unique<impl>(address, user_port, websocket_port);
}