This commit is contained in:
Kurt Sassenrath 2023-10-26 07:23:51 -07:00
parent 6874da27a3
commit 8f8066c243
6 changed files with 112 additions and 91 deletions

30
.snippets/cpp.lua Normal file
View File

@ -0,0 +1,30 @@
local luasnip = require('luasnip')
local fmt = require('luasnip.extras.fmt').fmt
local header = [[
//-----------------------------------------------------------------------------
// ___ __ _ _
// / _ \__ _ _ __ ___ ___ / /(_)_ __ | | __
// / /_)/ _` | '__/ __|/ _ \/ / | | '_ \| |/ /
// / ___/ (_| | | \__ \ __/ /__| | | | | <
// \/ \__,_|_| |___/\___\____/_|_| |_|_|\_\ .
//
//-----------------------------------------------------------------------------
// Author: Kurt Sassenrath
// Module: {}
//
// {}
//
// Copyright (c) 2023 Kurt Sassenrath.
//
// License TBD.
//-----------------------------------------------------------------------------
]]
return {
s("hdr", fmt(header, {i(1, "<module>"), i(2, "<description>")}))
}

View File

@ -91,7 +91,7 @@ struct fmt::formatter<std::errc> : fmt::formatter<std::error_code> {
template <typename FormatContext> template <typename FormatContext>
auto format(std::errc const& v, FormatContext& ctx) const { auto format(std::errc const& v, FormatContext& ctx) const {
return fmt::formatter<std::error_code>::format( return fmt::formatter<std::error_code>::format(
std::make_error_code(v), ctx); std::make_error_code(v), ctx);
} }
}; };

View File

@ -31,6 +31,7 @@
#include "level.h" #include "level.h"
#include "traits.h" #include "traits.h"
#include <span>
#include <system_error> #include <system_error>
#include <type_traits> #include <type_traits>

View File

@ -42,27 +42,19 @@ struct header_info {
std::uint32_t bytes_read; // How many bytes of the buffer were used. std::uint32_t bytes_read; // How many bytes of the buffer were used.
}; };
class session { struct connect_info {
public: std::uint32_t version;
session() = default; std::string_view user_id;
std::span<std::byte const> session_id;
// Parse the protocol header out of a buffer. This is a member function
// as we may choose to omit data once a session is established.
tl::expected<header_info, error> parse_header(
std::span<std::byte const> buffer) noexcept;
tl::expected<std::monostate, error> handle_connect(
std::span<msgpack::token> tokens) noexcept;
// The maximum size of a single message. Should probably depend on whether
// the session is established or not.
constexpr static std::uint32_t max_message_size = 128 * 1024;
std::string user_id;
private:
}; };
// Parse the protocol header out of a buffer.
tl::expected<header_info, error> parse_header(
std::span<std::byte const> buffer) noexcept;
tl::expected<connect_info, error> parse_connect(
std::span<msgpack::token> tokens) noexcept;
} // namespace proto } // namespace proto
} // namespace parselink } // namespace parselink

View File

@ -17,7 +17,6 @@
// //
// License TBD. // License TBD.
//----------------------------------------------------------------------------- //-----------------------------------------------------------------------------
#include "parselink/proto/session.h" #include "parselink/proto/session.h"
#include "parselink/logging.h" #include "parselink/logging.h"
@ -82,9 +81,19 @@ struct fmt::formatter<msgpack::token> {
namespace { namespace {
logging::logger logger("session"); logging::logger logger("session");
constexpr std::uint32_t max_size = 128 * 1024;
constexpr tl::expected<std::uint32_t, error> check_support(
tl::expected<msgpack::token, error> const& val) {
if (val == 1u) {
return *(*val).get<std::uint32_t>();
}
return tl::make_unexpected(error::unsupported);
} }
tl::expected<header_info, error> session::parse_header( } // namespace
tl::expected<header_info, error> proto::parse_header(
std::span<std::byte const> buffer) noexcept { std::span<std::byte const> buffer) noexcept {
auto reader = msgpack::token_reader(buffer); auto reader = msgpack::token_reader(buffer);
auto magic = reader.read_one(); auto magic = reader.read_one();
@ -101,8 +110,8 @@ tl::expected<header_info, error> session::parse_header(
return tl::unexpected(magic ? error::bad_data : error::incomplete); return tl::unexpected(magic ? error::bad_data : error::incomplete);
} }
if (*size > max_message_size) { if (*size > max_size) {
logger.error("Message {} exceeds max size {}", *size, max_message_size); logger.error("Message size {} exceeds max {}", *size, max_size);
return tl::unexpected(error::too_large); return tl::unexpected(error::too_large);
} }
@ -110,23 +119,30 @@ tl::expected<header_info, error> session::parse_header(
return tl::expected<header_info, error>(tl::in_place, *size, amt); return tl::expected<header_info, error>(tl::in_place, *size, amt);
} }
tl::expected<std::monostate, error> session::handle_connect( constexpr tl::expected<msgpack::token, error> lookup(
auto const& map_view, auto const& key) {
auto find_key = [&key](auto const& kv) { return kv.first == key; };
if (auto field = std::ranges::find_if(map_view, find_key);
field != map_view.end()) {
return (*field).second;
}
return tl::make_unexpected(error::bad_data);
}
tl::expected<connect_info, error> proto::parse_connect(
std::span<msgpack::token> tokens) noexcept { std::span<msgpack::token> tokens) noexcept {
auto message_type = tokens.begin()->get<std::string_view>(); auto message_type = tokens.begin()->get<std::string_view>();
if (message_type && message_type == "connect") { if (message_type && message_type == "connect") {
logger.debug("Received '{}' packet. Parsing body", *message_type); logger.debug("Received '{}' packet. Parsing body", *message_type);
auto map = msgpack::map_view(tokens.subspan(1)); auto map = msgpack::map_view(tokens.subspan(1));
constexpr auto find_version = [](auto const& kv) { connect_info info;
return kv.first == "version"; auto version = lookup(map, "version").and_then(check_support);
};
auto field = std::ranges::find_if(map, find_version); if (version) {
if (field != map.end() && (*field).second == 1u) { info.version = *version;
logger.debug("Version {}", (*field).second.get<std::uint32_t>());
} else { } else {
logger.error("connect failed: missing / unsupported version"); return tl::make_unexpected(version.error());
return tl::make_unexpected(error::unsupported);
} }
for (auto const& [k, v] : msgpack::map_view(tokens.subspan(1))) { for (auto const& [k, v] : msgpack::map_view(tokens.subspan(1))) {
@ -135,19 +151,21 @@ tl::expected<std::monostate, error> session::handle_connect(
logger.error("connect failed: non-string key {}", k); logger.error("connect failed: non-string key {}", k);
return tl::make_unexpected(error::bad_data); return tl::make_unexpected(error::bad_data);
} }
if (k == "user_id") { if (k == "user_id") {
result = v.get<std::string>().map( result = v.get<std::string_view>().map(
[this](auto uid) { user_id = std::move(uid); }); [&info](auto uid) { info.user_id = uid; });
} }
if (!result) { if (!result) {
logger.error( logger.error(
"connect failed: {} -> {}: {}", k, v, result.error()); "connect failed: {} -> {}: {}", k, v, result.error());
return tl::make_unexpected(error::bad_data); return tl::make_unexpected(error::bad_data);
} }
} }
return info;
} else { } else {
logger.error("Did not get message type: {}", message_type.error()); logger.error("Did not get message type: {}", message_type.error());
return tl::make_unexpected(error::bad_data); return tl::make_unexpected(error::bad_data);
} }
return {};
} }

View File

@ -25,6 +25,7 @@
#include "parselink/msgpack/token/views.h" #include "parselink/msgpack/token/views.h"
#include "parselink/proto/session.h" #include "parselink/proto/session.h"
#include <boost/asio/as_tuple.hpp> #include <boost/asio/as_tuple.hpp>
#include <boost/asio/bind_executor.hpp>
#include <boost/asio/co_spawn.hpp> #include <boost/asio/co_spawn.hpp>
#include <boost/asio/deferred.hpp> #include <boost/asio/deferred.hpp>
#include <boost/asio/detached.hpp> #include <boost/asio/detached.hpp>
@ -33,12 +34,15 @@
#include <boost/asio/ip/tcp.hpp> #include <boost/asio/ip/tcp.hpp>
#include <boost/asio/redirect_error.hpp> #include <boost/asio/redirect_error.hpp>
#include <boost/asio/signal_set.hpp> #include <boost/asio/signal_set.hpp>
#include <boost/asio/strand.hpp>
#include <boost/asio/write.hpp> #include <boost/asio/write.hpp>
#include <chrono>
#include <map> #include <map>
#include <fmt/ranges.h> #include <fmt/ranges.h>
using namespace parselink;
#include <fmt/ranges.h>
using namespace parselink; using namespace parselink;
using namespace std::chrono_literals; using namespace std::chrono_literals;
@ -148,7 +152,10 @@ constexpr auto no_ex_coro = net::as_tuple(use_awaitable);
constexpr auto no_ex_defer = net::as_tuple(deferred); constexpr auto no_ex_defer = net::as_tuple(deferred);
} // namespace } // namespace
class user_connection;
struct user_session { struct user_session {
std::weak_ptr<user_connection> conn;
std::string user_id; std::string user_id;
std::array<std::byte, 32> session_id; std::array<std::byte, 32> session_id;
std::chrono::system_clock::time_point expires_at; std::chrono::system_clock::time_point expires_at;
@ -161,16 +168,19 @@ public:
std::error_code run() noexcept override; std::error_code run() noexcept override;
net::awaitable<tl::expected<tl::monostate, proto::error>> create_session(
std::weak_ptr<user_connection> conn,
proto::connect_info const& info);
private: private:
friend user_session; friend user_session;
awaitable<void> user_listen(); awaitable<void> user_listen();
user_session* establish_session(user_session session); std::map<std::string, user_session, std::less<>> active_sessions_;
std::map<std::string, user_session, std::less<>> active_user_sessions_;
net::io_context io_context_; net::io_context io_context_;
net::io_context::strand session_strand_;
net::ip::address addr_; net::ip::address addr_;
std::uint16_t user_port_; std::uint16_t user_port_;
std::uint16_t websocket_port_; std::uint16_t websocket_port_;
@ -197,36 +207,6 @@ public:
detached); detached);
} }
tl::expected<std::vector<std::byte>, msgpack::error> parse_header(
std::span<std::byte> data) noexcept {
auto reader = msgpack::token_reader(data);
auto magic = reader.read_one().map([](auto t) { return t == "prs"; });
if (magic && *magic) {
logger.debug("Got magic from client");
} else {
logger.error("Failed to get magic from client: {}", magic);
return tl::unexpected(magic.error());
}
auto sz = reader.read_one().and_then(
[](auto t) { return t.template get<std::size_t>(); });
if (sz && *sz) {
logger.debug("Got packet size from client: {}", *sz);
} else {
logger.debug("Failed to get packet size from client: {}", sz);
return tl::unexpected(magic.error());
}
// Copy the rest of the message to the buffer for parsing.
// TODO(ksassenrath): Replace vector with custom buffer.
std::vector<std::byte> msg;
msg.reserve(*sz);
msg.resize(reader.remaining());
std::copy(reader.current(), reader.end(), msg.begin());
return msg;
}
awaitable<tl::expected<std::monostate, boost::system::error_code>> awaitable<tl::expected<std::monostate, boost::system::error_code>>
buffer_message(std::span<std::byte> buffer) noexcept { buffer_message(std::span<std::byte> buffer) noexcept {
std::size_t amt = 0; std::size_t amt = 0;
@ -259,9 +239,7 @@ public:
logger.debug("Read {} bytes from client: {}", n, logger.debug("Read {} bytes from client: {}", n,
std::span(buffer.data(), n)); std::span(buffer.data(), n));
proto::session session; auto maybe_hdr = proto::parse_header(std::span(buffer.data(), n));
auto maybe_hdr = session.parse_header(std::span(buffer.data(), n));
if (!maybe_hdr) { if (!maybe_hdr) {
logger.error("Unable to parse header: {}", maybe_hdr.error()); logger.error("Unable to parse header: {}", maybe_hdr.error());
@ -286,30 +264,30 @@ public:
auto reader = msgpack::token_reader(msg); auto reader = msgpack::token_reader(msg);
std::array<msgpack::token, 32> tokens; std::array<msgpack::token, 32> tokens;
auto tokens_used = reader.read_some(tokens); auto session =
if (tokens_used) { reader.read_some(tokens)
auto err = session.handle_connect(*tokens_used); .map_error([](auto) { return proto::error::bad_data; })
if (!err) co_return; .and_then(proto::parse_connect);
} else {
logger.error("Unable to parse msgpack tokens: {}", tokens_used); if (!session) {
logger.error("Session failed: {}", session.error());
co_return; co_return;
} }
// Authenticate against database. co_await server_.create_session(
logger.debug("User {} established connection", session.user_id); std::weak_ptr<user_connection>(shared_from_this()), *session);
// session_ = server_.establish_session(std::move(*maybe_user));
} }
enum class state { init, authenticated, active }; enum class state { init, authenticated, active };
monolithic_server& server_; monolithic_server& server_;
user_session* session_{};
net::ip::tcp::socket socket_; net::ip::tcp::socket socket_;
}; };
monolithic_server::monolithic_server(std::string_view address, monolithic_server::monolithic_server(std::string_view address,
std::uint16_t user_port, std::uint16_t websocket_port) std::uint16_t user_port, std::uint16_t websocket_port)
: io_context_{1} : io_context_{1}
, session_strand_(io_context_)
, addr_(net::ip::address::from_string(std::string{address})) , addr_(net::ip::address::from_string(std::string{address}))
, user_port_{user_port} , user_port_{user_port}
, websocket_port_{websocket_port} { , websocket_port_{websocket_port} {
@ -344,14 +322,16 @@ std::error_code monolithic_server::run() noexcept {
return {}; return {};
} }
user_session* monolithic_server::establish_session(user_session session) { net::awaitable<tl::expected<tl::monostate, proto::error>>
auto existing_session = active_user_sessions_.find(session.user_id); monolithic_server::create_session(std::weak_ptr<user_connection> conn,
if (existing_session == active_user_sessions_.end()) { proto::connect_info const& session) {
// No session exists with that user ID yet. // Move to session strand.
active_user_sessions_.emplace( co_await net::post(session_strand_,
session.user_id, std::move(session.user_id)); net::bind_executor(session_strand_, use_awaitable));
}
return {}; // Pretend that there's no open
co_return tl::make_unexpected(proto::error::unsupported);
} }
std::unique_ptr<server> parselink::make_server(std::string_view address, std::unique_ptr<server> parselink::make_server(std::string_view address,