This commit is contained in:
Kurt Sassenrath 2023-10-26 07:23:51 -07:00
parent 6874da27a3
commit 8f8066c243
6 changed files with 112 additions and 91 deletions

30
.snippets/cpp.lua Normal file
View File

@ -0,0 +1,30 @@
local luasnip = require('luasnip')
local fmt = require('luasnip.extras.fmt').fmt
local header = [[
//-----------------------------------------------------------------------------
// ___ __ _ _
// / _ \__ _ _ __ ___ ___ / /(_)_ __ | | __
// / /_)/ _` | '__/ __|/ _ \/ / | | '_ \| |/ /
// / ___/ (_| | | \__ \ __/ /__| | | | | <
// \/ \__,_|_| |___/\___\____/_|_| |_|_|\_\ .
//
//-----------------------------------------------------------------------------
// Author: Kurt Sassenrath
// Module: {}
//
// {}
//
// Copyright (c) 2023 Kurt Sassenrath.
//
// License TBD.
//-----------------------------------------------------------------------------
]]
return {
s("hdr", fmt(header, {i(1, "<module>"), i(2, "<description>")}))
}

View File

@ -91,7 +91,7 @@ struct fmt::formatter<std::errc> : fmt::formatter<std::error_code> {
template <typename FormatContext>
auto format(std::errc const& v, FormatContext& ctx) const {
return fmt::formatter<std::error_code>::format(
std::make_error_code(v), ctx);
std::make_error_code(v), ctx);
}
};

View File

@ -31,6 +31,7 @@
#include "level.h"
#include "traits.h"
#include <span>
#include <system_error>
#include <type_traits>

View File

@ -42,27 +42,19 @@ struct header_info {
std::uint32_t bytes_read; // How many bytes of the buffer were used.
};
class session {
public:
session() = default;
// Parse the protocol header out of a buffer. This is a member function
// as we may choose to omit data once a session is established.
tl::expected<header_info, error> parse_header(
std::span<std::byte const> buffer) noexcept;
tl::expected<std::monostate, error> handle_connect(
std::span<msgpack::token> tokens) noexcept;
// The maximum size of a single message. Should probably depend on whether
// the session is established or not.
constexpr static std::uint32_t max_message_size = 128 * 1024;
std::string user_id;
private:
struct connect_info {
std::uint32_t version;
std::string_view user_id;
std::span<std::byte const> session_id;
};
// Parse the protocol header out of a buffer.
tl::expected<header_info, error> parse_header(
std::span<std::byte const> buffer) noexcept;
tl::expected<connect_info, error> parse_connect(
std::span<msgpack::token> tokens) noexcept;
} // namespace proto
} // namespace parselink

View File

@ -17,7 +17,6 @@
//
// License TBD.
//-----------------------------------------------------------------------------
#include "parselink/proto/session.h"
#include "parselink/logging.h"
@ -82,9 +81,19 @@ struct fmt::formatter<msgpack::token> {
namespace {
logging::logger logger("session");
constexpr std::uint32_t max_size = 128 * 1024;
constexpr tl::expected<std::uint32_t, error> check_support(
tl::expected<msgpack::token, error> const& val) {
if (val == 1u) {
return *(*val).get<std::uint32_t>();
}
return tl::make_unexpected(error::unsupported);
}
tl::expected<header_info, error> session::parse_header(
} // namespace
tl::expected<header_info, error> proto::parse_header(
std::span<std::byte const> buffer) noexcept {
auto reader = msgpack::token_reader(buffer);
auto magic = reader.read_one();
@ -101,8 +110,8 @@ tl::expected<header_info, error> session::parse_header(
return tl::unexpected(magic ? error::bad_data : error::incomplete);
}
if (*size > max_message_size) {
logger.error("Message {} exceeds max size {}", *size, max_message_size);
if (*size > max_size) {
logger.error("Message size {} exceeds max {}", *size, max_size);
return tl::unexpected(error::too_large);
}
@ -110,23 +119,30 @@ tl::expected<header_info, error> session::parse_header(
return tl::expected<header_info, error>(tl::in_place, *size, amt);
}
tl::expected<std::monostate, error> session::handle_connect(
constexpr tl::expected<msgpack::token, error> lookup(
auto const& map_view, auto const& key) {
auto find_key = [&key](auto const& kv) { return kv.first == key; };
if (auto field = std::ranges::find_if(map_view, find_key);
field != map_view.end()) {
return (*field).second;
}
return tl::make_unexpected(error::bad_data);
}
tl::expected<connect_info, error> proto::parse_connect(
std::span<msgpack::token> tokens) noexcept {
auto message_type = tokens.begin()->get<std::string_view>();
if (message_type && message_type == "connect") {
logger.debug("Received '{}' packet. Parsing body", *message_type);
auto map = msgpack::map_view(tokens.subspan(1));
constexpr auto find_version = [](auto const& kv) {
return kv.first == "version";
};
connect_info info;
auto version = lookup(map, "version").and_then(check_support);
auto field = std::ranges::find_if(map, find_version);
if (field != map.end() && (*field).second == 1u) {
logger.debug("Version {}", (*field).second.get<std::uint32_t>());
if (version) {
info.version = *version;
} else {
logger.error("connect failed: missing / unsupported version");
return tl::make_unexpected(error::unsupported);
return tl::make_unexpected(version.error());
}
for (auto const& [k, v] : msgpack::map_view(tokens.subspan(1))) {
@ -135,19 +151,21 @@ tl::expected<std::monostate, error> session::handle_connect(
logger.error("connect failed: non-string key {}", k);
return tl::make_unexpected(error::bad_data);
}
if (k == "user_id") {
result = v.get<std::string>().map(
[this](auto uid) { user_id = std::move(uid); });
result = v.get<std::string_view>().map(
[&info](auto uid) { info.user_id = uid; });
}
if (!result) {
logger.error(
"connect failed: {} -> {}: {}", k, v, result.error());
return tl::make_unexpected(error::bad_data);
}
}
return info;
} else {
logger.error("Did not get message type: {}", message_type.error());
return tl::make_unexpected(error::bad_data);
}
return {};
}

View File

@ -25,6 +25,7 @@
#include "parselink/msgpack/token/views.h"
#include "parselink/proto/session.h"
#include <boost/asio/as_tuple.hpp>
#include <boost/asio/bind_executor.hpp>
#include <boost/asio/co_spawn.hpp>
#include <boost/asio/deferred.hpp>
#include <boost/asio/detached.hpp>
@ -33,12 +34,15 @@
#include <boost/asio/ip/tcp.hpp>
#include <boost/asio/redirect_error.hpp>
#include <boost/asio/signal_set.hpp>
#include <boost/asio/strand.hpp>
#include <boost/asio/write.hpp>
#include <chrono>
#include <map>
#include <fmt/ranges.h>
using namespace parselink;
#include <fmt/ranges.h>
using namespace parselink;
using namespace std::chrono_literals;
@ -148,7 +152,10 @@ constexpr auto no_ex_coro = net::as_tuple(use_awaitable);
constexpr auto no_ex_defer = net::as_tuple(deferred);
} // namespace
class user_connection;
struct user_session {
std::weak_ptr<user_connection> conn;
std::string user_id;
std::array<std::byte, 32> session_id;
std::chrono::system_clock::time_point expires_at;
@ -161,16 +168,19 @@ public:
std::error_code run() noexcept override;
net::awaitable<tl::expected<tl::monostate, proto::error>> create_session(
std::weak_ptr<user_connection> conn,
proto::connect_info const& info);
private:
friend user_session;
awaitable<void> user_listen();
user_session* establish_session(user_session session);
std::map<std::string, user_session, std::less<>> active_user_sessions_;
std::map<std::string, user_session, std::less<>> active_sessions_;
net::io_context io_context_;
net::io_context::strand session_strand_;
net::ip::address addr_;
std::uint16_t user_port_;
std::uint16_t websocket_port_;
@ -197,36 +207,6 @@ public:
detached);
}
tl::expected<std::vector<std::byte>, msgpack::error> parse_header(
std::span<std::byte> data) noexcept {
auto reader = msgpack::token_reader(data);
auto magic = reader.read_one().map([](auto t) { return t == "prs"; });
if (magic && *magic) {
logger.debug("Got magic from client");
} else {
logger.error("Failed to get magic from client: {}", magic);
return tl::unexpected(magic.error());
}
auto sz = reader.read_one().and_then(
[](auto t) { return t.template get<std::size_t>(); });
if (sz && *sz) {
logger.debug("Got packet size from client: {}", *sz);
} else {
logger.debug("Failed to get packet size from client: {}", sz);
return tl::unexpected(magic.error());
}
// Copy the rest of the message to the buffer for parsing.
// TODO(ksassenrath): Replace vector with custom buffer.
std::vector<std::byte> msg;
msg.reserve(*sz);
msg.resize(reader.remaining());
std::copy(reader.current(), reader.end(), msg.begin());
return msg;
}
awaitable<tl::expected<std::monostate, boost::system::error_code>>
buffer_message(std::span<std::byte> buffer) noexcept {
std::size_t amt = 0;
@ -259,9 +239,7 @@ public:
logger.debug("Read {} bytes from client: {}", n,
std::span(buffer.data(), n));
proto::session session;
auto maybe_hdr = session.parse_header(std::span(buffer.data(), n));
auto maybe_hdr = proto::parse_header(std::span(buffer.data(), n));
if (!maybe_hdr) {
logger.error("Unable to parse header: {}", maybe_hdr.error());
@ -286,30 +264,30 @@ public:
auto reader = msgpack::token_reader(msg);
std::array<msgpack::token, 32> tokens;
auto tokens_used = reader.read_some(tokens);
if (tokens_used) {
auto err = session.handle_connect(*tokens_used);
if (!err) co_return;
} else {
logger.error("Unable to parse msgpack tokens: {}", tokens_used);
auto session =
reader.read_some(tokens)
.map_error([](auto) { return proto::error::bad_data; })
.and_then(proto::parse_connect);
if (!session) {
logger.error("Session failed: {}", session.error());
co_return;
}
// Authenticate against database.
logger.debug("User {} established connection", session.user_id);
// session_ = server_.establish_session(std::move(*maybe_user));
co_await server_.create_session(
std::weak_ptr<user_connection>(shared_from_this()), *session);
}
enum class state { init, authenticated, active };
monolithic_server& server_;
user_session* session_{};
net::ip::tcp::socket socket_;
};
monolithic_server::monolithic_server(std::string_view address,
std::uint16_t user_port, std::uint16_t websocket_port)
: io_context_{1}
, session_strand_(io_context_)
, addr_(net::ip::address::from_string(std::string{address}))
, user_port_{user_port}
, websocket_port_{websocket_port} {
@ -344,14 +322,16 @@ std::error_code monolithic_server::run() noexcept {
return {};
}
user_session* monolithic_server::establish_session(user_session session) {
auto existing_session = active_user_sessions_.find(session.user_id);
if (existing_session == active_user_sessions_.end()) {
// No session exists with that user ID yet.
active_user_sessions_.emplace(
session.user_id, std::move(session.user_id));
}
return {};
net::awaitable<tl::expected<tl::monostate, proto::error>>
monolithic_server::create_session(std::weak_ptr<user_connection> conn,
proto::connect_info const& session) {
// Move to session strand.
co_await net::post(session_strand_,
net::bind_executor(session_strand_, use_awaitable));
// Pretend that there's no open
co_return tl::make_unexpected(proto::error::unsupported);
}
std::unique_ptr<server> parselink::make_server(std::string_view address,