parselink-old/source/server.cpp

362 lines
12 KiB
C++

//-----------------------------------------------------------------------------
// ___ __ _ _
// / _ \__ _ _ __ ___ ___ / /(_)_ __ | | __
// / /_)/ _` | '__/ __|/ _ \/ / | | '_ \| |/ /
// / ___/ (_| | | \__ \ __/ /__| | | | | <
// \/ \__,_|_| |___/\___\____/_|_| |_|_|\_\ .
//
//-----------------------------------------------------------------------------
// Author: Kurt Sassenrath
// Module: Server
//
// Server implementation. Currently, a monolithic server which:
// * Communicates with users via TCP (msgpack).
// * Runs the websocket server for overlays to read.
//
// Copyright (c) 2023 Kurt Sassenrath.
//
// License TBD.
//-----------------------------------------------------------------------------
#include "parselink/server.h"
#include "parselink/logging.h"
#include "parselink/msgpack/token/reader.h"
#include "parselink/msgpack/token/views.h"
#include "parselink/proto/session.h"
#include <boost/asio/as_tuple.hpp>
#include <boost/asio/co_spawn.hpp>
#include <boost/asio/deferred.hpp>
#include <boost/asio/detached.hpp>
#include <boost/asio/io_context.hpp>
#include <boost/asio/ip/address.hpp>
#include <boost/asio/ip/tcp.hpp>
#include <boost/asio/redirect_error.hpp>
#include <boost/asio/signal_set.hpp>
#include <boost/asio/write.hpp>
#include <chrono>
#include <map>
#include <fmt/ranges.h>
using namespace parselink;
using namespace std::chrono_literals;
namespace net = boost::asio;
using net::awaitable;
using net::co_spawn;
using net::deferred;
using net::detached;
using net::use_awaitable;
//-----------------------------------------------------------------------------
// TODO(ksassenrath): These are logging formatters for various boost/asio types.
// Not all code is exposed to them, so they cannot be defined inside the
// generic logging/formatters.h header. They should go somewhere else.
//-----------------------------------------------------------------------------
template <>
struct parselink::logging::theme<boost::system::error_code>
: parselink::logging::static_theme<fmt::color::fire_brick> {};
template <>
struct fmt::formatter<boost::system::error_code>
: fmt::formatter<std::string_view> {
template <typename FormatContext>
constexpr auto format(auto const& v, FormatContext& ctx) const {
return fmt::formatter<std::string_view>::format(v.message(), ctx);
}
};
template <>
struct fmt::formatter<msgpack::token> {
template <typename ParseContext>
constexpr auto parse(ParseContext& ctx) -> decltype(ctx.begin()) {
return ctx.begin();
}
template <typename FormatContext>
auto format(msgpack::token const& v, FormatContext& ctx) const {
using parselink::logging::themed_arg;
auto out = fmt::format_to(
ctx.out(), "<msgpack {} = ", themed_arg(v.type()));
switch (v.type()) {
case msgpack::format::type::unsigned_int:
fmt::format_to(
out, "{}", themed_arg(*(v.get<std::uint64_t>())));
break;
case msgpack::format::type::signed_int:
out = fmt::format_to(
out, "{}", themed_arg(*(v.get<std::uint64_t>())));
break;
case msgpack::format::type::boolean:
out = fmt::format_to(out, "{}", themed_arg(*(v.get<bool>())));
break;
case msgpack::format::type::string:
out = fmt::format_to(
out, "{}", themed_arg(*(v.get<std::string_view>())));
break;
case msgpack::format::type::binary:
out = fmt::format_to(out, "{}",
themed_arg(*(v.get<std::span<std::byte const>>())));
break;
case msgpack::format::type::map:
out = fmt::format_to(out, "(arity: {})",
themed_arg(v.get<msgpack::map_desc>()->count));
break;
case msgpack::format::type::array:
out = fmt::format_to(out, "(arity: {})",
themed_arg(v.get<msgpack::array_desc>()->count));
break;
case msgpack::format::type::nil:
out = fmt::format_to(out, "(nil)");
break;
case msgpack::format::type::invalid:
out = fmt::format_to(out, "(invalid)");
break;
default: break;
}
return fmt::format_to(out, ">");
}
};
template <typename T>
concept endpoint = requires(T const& t) {
{ t.address() };
{ t.port() };
};
template <endpoint T>
struct parselink::logging::theme<T>
: parselink::logging::static_theme<fmt::color::coral> {};
template <endpoint T>
struct fmt::formatter<T> : fmt::formatter<std::string_view> {
template <typename FormatContext>
constexpr auto format(auto const& v, FormatContext& ctx) const {
return fmt::format_to(
ctx.out(), "{}:{}", v.address().to_string(), v.port());
}
};
//-----------------------------------------------------------------------------
// End formatters
//-----------------------------------------------------------------------------
namespace {
logging::logger logger("server");
constexpr auto no_ex_coro = net::as_tuple(use_awaitable);
constexpr auto no_ex_defer = net::as_tuple(deferred);
} // namespace
struct user_session {
std::string user_id;
std::array<std::byte, 32> session_id;
std::chrono::system_clock::time_point expires_at;
};
class monolithic_server : public server {
public:
monolithic_server(std::string_view address, std::uint16_t user_port,
std::uint16_t websocket_port);
std::error_code run() noexcept override;
private:
friend user_session;
awaitable<void> user_listen();
user_session* establish_session(user_session session);
std::map<std::string, user_session, std::less<>> active_user_sessions_;
net::io_context io_context_;
net::ip::address addr_;
std::uint16_t user_port_;
std::uint16_t websocket_port_;
};
class user_connection : public std::enable_shared_from_this<user_connection> {
public:
user_connection(monolithic_server& server, net::ip::tcp::socket sock)
: server_(server)
, socket_(std::move(sock)) {}
~user_connection() {
logger.debug("Connection to {} closed.", socket_.remote_endpoint());
boost::system::error_code ec;
socket_.shutdown(net::ip::tcp::socket::shutdown_both, ec);
socket_.close();
}
void start() {
logger.debug("New connection from {}", socket_.remote_endpoint());
co_spawn(
socket_.get_executor(),
[self = shared_from_this()] { return self->await_connect(); },
detached);
}
tl::expected<std::vector<std::byte>, msgpack::error> parse_header(
std::span<std::byte> data) noexcept {
auto reader = msgpack::token_reader(data);
auto magic = reader.read_one().map([](auto t) { return t == "prs"; });
if (magic && *magic) {
logger.debug("Got magic from client");
} else {
logger.error("Failed to get magic from client: {}", magic);
return tl::unexpected(magic.error());
}
auto sz = reader.read_one().and_then(
[](auto t) { return t.template get<std::size_t>(); });
if (sz && *sz) {
logger.debug("Got packet size from client: {}", *sz);
} else {
logger.debug("Failed to get packet size from client: {}", sz);
return tl::unexpected(magic.error());
}
// Copy the rest of the message to the buffer for parsing.
// TODO(ksassenrath): Replace vector with custom buffer.
std::vector<std::byte> msg;
msg.reserve(*sz);
msg.resize(reader.remaining());
std::copy(reader.current(), reader.end(), msg.begin());
return msg;
}
awaitable<tl::expected<std::monostate, boost::system::error_code>>
buffer_message(std::span<std::byte> buffer) noexcept {
std::size_t amt = 0;
while (amt < buffer.size()) {
auto subsp = buffer.subspan(amt);
auto [ec, n] = co_await socket_.async_read_some(
net::buffer(subsp), no_ex_coro);
logger.debug("Read {} bytes, total is now {}", n, amt + n);
if (ec || n == 0) {
logger.error("Reading from user socket failed: {}", ec);
co_return tl::make_unexpected(ec);
}
amt += n;
}
co_return std::monostate{};
}
awaitable<void> await_connect() noexcept {
// Use a small buffer on the stack to read the initial header.
std::array<std::byte, 8> buffer;
auto [ec, n] = co_await socket_.async_read_some(
net::buffer(buffer), no_ex_coro);
if (ec) {
logger.error("Reading from user socket failed: {}", ec);
co_return;
}
logger.debug("Read {} bytes from client: {}", n,
std::span(buffer.data(), n));
proto::session session;
auto maybe_hdr = session.parse_header(std::span(buffer.data(), n));
if (!maybe_hdr) {
logger.error("Unable to parse header: {}", maybe_hdr.error());
co_return;
}
// TODO(ksassenrath): Replace with specialized allocator.
auto msg = std::vector<std::byte>(maybe_hdr->message_size);
// Copy remaining portion of message in initial read to the message
// buffer.
std::copy(std::next(buffer.begin(), maybe_hdr->bytes_read),
std::next(buffer.begin(), n), msg.begin());
auto msg_span =
std::span(msg.begin() + n - maybe_hdr->bytes_read, msg.end());
if (auto result = co_await buffer_message(msg_span); !result) {
logger.error("Unable to parse header: {}", result.error());
co_return;
}
auto reader = msgpack::token_reader(msg);
std::array<msgpack::token, 32> tokens;
auto tokens_used = reader.read_some(tokens);
if (tokens_used) {
auto err = session.handle_connect(*tokens_used);
if (!err) co_return;
} else {
logger.error("Unable to parse msgpack tokens: {}", tokens_used);
co_return;
}
// Authenticate against database.
logger.debug("User {} established connection", session.user_id);
// session_ = server_.establish_session(std::move(*maybe_user));
}
enum class state { init, authenticated, active };
monolithic_server& server_;
user_session* session_{};
net::ip::tcp::socket socket_;
};
monolithic_server::monolithic_server(std::string_view address,
std::uint16_t user_port, std::uint16_t websocket_port)
: io_context_{1}
, addr_(net::ip::address::from_string(std::string{address}))
, user_port_{user_port}
, websocket_port_{websocket_port} {
logger.debug("Creating monolithic_server(address = {}, user_port = {}, "
"websocket_port = {})",
address, user_port_, websocket_port_);
}
awaitable<void> monolithic_server::user_listen() {
auto exec = co_await net::this_coro::executor;
net::ip::tcp::acceptor acceptor{exec, {addr_, user_port_}};
while (true) {
std::make_shared<user_connection>(
*this, co_await acceptor.async_accept(use_awaitable))
->start();
}
}
std::error_code monolithic_server::run() noexcept {
logger.info("Starting server.");
net::signal_set signals(io_context_, SIGINT, SIGTERM);
signals.async_wait([&](auto, auto) {
logger.info("Received signal... Shutting down.");
io_context_.stop();
});
co_spawn(io_context_, user_listen(), detached);
io_context_.run();
return {};
}
user_session* monolithic_server::establish_session(user_session session) {
auto existing_session = active_user_sessions_.find(session.user_id);
if (existing_session == active_user_sessions_.end()) {
// No session exists with that user ID yet.
active_user_sessions_.emplace(
session.user_id, std::move(session.user_id));
}
return {};
}
std::unique_ptr<server> parselink::make_server(std::string_view address,
std::uint16_t user_port, std::uint16_t websocket_port) {
using impl = monolithic_server;
return std::make_unique<impl>(address, user_port, websocket_port);
}