This commit is contained in:
Kurt Sassenrath 2023-10-16 19:27:12 -07:00
parent 1c7047e314
commit 9346b5be5d
4 changed files with 143 additions and 109 deletions

View File

@ -240,6 +240,9 @@ public:
}
}
template <std::size_t N>
constexpr bool operator==(char const (&t)[N]) const noexcept;
template <typename T>
constexpr bool operator==(T const& t) const noexcept {
if constexpr (std::equality_comparable<T>) {
@ -286,6 +289,13 @@ constexpr tl::expected<std::string_view, error> token_base<8>::get()
return std::string_view{value_.str, size_and_type_.get_size()};
}
template <std::size_t N>
constexpr bool token_base<8>::operator==(char const (&t)[N]) const noexcept {
auto result = get<std::string_view>().map([&t](auto v) {
return v == std::string_view{t}; });
return result && *result;
}
template<>
inline tl::expected<std::vector<std::byte>, error> token_base<8>::get()
const noexcept

View File

@ -25,6 +25,7 @@
#include "type.h"
#include <bits/iterator_concepts.h>
#include <fmt/format.h>
#include <ranges>
@ -35,6 +36,8 @@ namespace msgpack {
std::ranges::range_value_t<V>>
struct map_view : public std::ranges::view_interface<map_view<V>> {
public:
class sentinel;
class iterator {
friend class sentinel;
@ -57,10 +60,10 @@ namespace msgpack {
}
public:
using value_type = std::pair<V, V>;
using value_type = std::pair<base_value_type, base_value_type>;
using reference = std::pair<base_reference, base_reference>;
using difference_type = std::ptrdiff_t;
using iterator_category = std::forward_iterator_tag;
using iterator_category = std::input_iterator_tag;
iterator() = default;
iterator(V const& base)
@ -76,10 +79,10 @@ namespace msgpack {
}
}
[[nodiscard]] reference operator*() const {
return { *k_, *v_ };
}
iterator& operator++() {
k_ = next(v_);
v_ = next(k_);
@ -98,6 +101,7 @@ namespace msgpack {
base_ == rhs.base_;
}
private:
V const* base_{};
base_iterator k_{};
base_iterator v_{};

View File

@ -70,53 +70,6 @@ struct parser_data_message {
std::span<std::byte> payload;
};
using message = std::variant<
error_message,
connect_message,
challenge_message,
response_message,
session_established_message,
parser_data_message>;
enum class error {
bad_magic, // Did not get the message magic expected
too_large, // The message size was too large.
unknown_type, // The message type is not known
};
// This class is responsible for consuming buffer data and yielding a message
// instance when complete. Will throw an error if data is incorrect.
//
class builder {
public:
// For now, builders don't manage any buffers themselves. Later, that
// may change.
builder() = default;
// Reset the builder to its initial state. This means any partially-decoded
// message data will be lost.
void reset() noexcept;
// How many bytes are needed to perform a meaningful amount of work.
std::size_t bytes_needed() noexcept;
// Process data from a buffer, building messages. Returns the number of
// bytes read from the buffer for the caller's bookkeeping. May yield a
// message in addition.
std::size_t process(std::span<std::byte const> buffer) noexcept;
private:
enum class state {
magic,
size,
payload
};
state state_{state::magic};
std::size_t payload_size_{};
std::size_t payload_remaining_{};
};
} // namespace message
} // namespace parselink

View File

@ -27,7 +27,6 @@
#include <fmt/ranges.h>
#include <boost/asio/io_context.hpp>
#include <boost/asio/signal_set.hpp>
#include <boost/asio/redirect_error.hpp>
@ -41,7 +40,11 @@
#include <boost/asio/detached.hpp>
#include <boost/asio/as_tuple.hpp>
#include <chrono>
#include <map>
using namespace parselink;
using namespace std::chrono_literals;
namespace net = boost::asio;
using net::co_spawn;
@ -50,6 +53,12 @@ using net::use_awaitable;
using net::deferred;
using net::detached;
enum class error {
system,
msgpack,
};
//-----------------------------------------------------------------------------
// TODO(ksassenrath): These are logging formatters for various boost/asio types.
// Not all code is exposed to them, so they cannot be defined inside the
@ -146,15 +155,87 @@ namespace {
constexpr auto no_ex_defer = net::as_tuple(deferred);
}
struct msgbuf {
std::vector<std::byte> payload;
struct user_session {
std::string user_id;
std::array<std::byte, 32> session_id;
std::chrono::system_clock::time_point expires_at;
};
class user_session : public std::enable_shared_from_this<user_session> {
class monolithic_server : public server {
public:
user_session(net::ip::tcp::socket sock) : socket_(std::move(sock)) {}
~user_session() {
logger.debug("Closing connection to {}", socket_.remote_endpoint());
monolithic_server(std::string_view address, std::uint16_t user_port,
std::uint16_t websocket_port);
std::error_code run() noexcept override;
private:
friend user_session;
awaitable<void> user_listen();
user_session* establish_session(user_session session);
std::map<std::string, user_session, std::less<>> active_user_sessions_;
net::io_context io_context_;
net::ip::address addr_;
std::uint16_t user_port_;
std::uint16_t websocket_port_;
};
tl::expected<user_session, msgpack::error> handle_connect(
std::span<msgpack::token> tokens) noexcept {
user_session user;
auto message_type = tokens.begin()->get<std::string_view>();
if (message_type) {
logger.debug("Received '{}' packet. Parsing body", *message_type);
auto map = msgpack::map_view(tokens.subspan(1));
constexpr auto find_version = [](auto const& kv) {
return kv.first == "version";
};
auto field = std::ranges::find_if(map, find_version);
if (field != map.end() && (*field).second == 1u) {
logger.debug("Version {}\n", (*field).second);
} else {
logger.error("connect failed: missing / unsupported version");
return tl::make_unexpected(msgpack::error::unsupported);
}
for (auto const& [k, v] : msgpack::map_view(tokens.subspan(1))) {
tl::expected<void, msgpack::error> result;
if (k.type() != msgpack::format::type::string) {
logger.error("connect failed: non-string key {}", k);
return tl::make_unexpected(msgpack::error::bad_value);
}
if (k == "user_id") {
result = v.get<std::string>().map([&user](auto uid){
user.user_id = std::move(uid);
});
}
if (!result) {
logger.error("connect failed: {} -> {}: {}", k, v,
result.error());
return tl::make_unexpected(msgpack::error::bad_value);
}
}
} else {
logger.error("Did not get message type: {}", message_type.error());
return tl::make_unexpected(msgpack::error::bad_value);
}
return user;
}
class user_connection : public std::enable_shared_from_this<user_connection> {
public:
user_connection(monolithic_server& server, net::ip::tcp::socket sock)
: server_(server)
, socket_(std::move(sock)) {}
~user_connection() {
logger.debug("Connection to {} closed.", socket_.remote_endpoint());
boost::system::error_code ec;
socket_.shutdown(net::ip::tcp::socket::shutdown_both, ec);
socket_.close();
@ -163,21 +244,23 @@ public:
void start() {
logger.debug("New connection from {}", socket_.remote_endpoint());
co_spawn(socket_.get_executor(), [self = shared_from_this()]{
return self->await_auth();
return self->await_connect();
}, detached);
}
tl::expected<std::vector<std::byte>, msgpack::error> parse_header(
std::span<std::byte> data) noexcept {
auto reader = msgpack::token_reader(data);
auto magic = reader.read_one().map(
[](auto t){ return t == std::string_view{"prs"}; });
[](auto t){ return t == "prs"; });
if (magic && *magic) {
logger.debug("Got magic from client");
} else {
logger.error("Failed to get magic from client: {}", magic);
return tl::unexpected(magic.error());
}
auto sz = reader.read_one().and_then(
[](auto t){ return t.template get<std::size_t>(); });
if (sz && *sz) {
@ -186,6 +269,7 @@ public:
logger.debug("Failed to get packet size from client: {}", sz);
return tl::unexpected(magic.error());
}
// Copy the rest of the message to the buffer for parsing.
// TODO(ksassenrath): Replace vector with custom buffer.
std::vector<std::byte> msg;
@ -215,63 +299,53 @@ public:
co_return std::monostate{};
}
tl::expected<bool, msgpack::error> handle_auth(std::span<msgpack::token> tokens) {
auto message_type = tokens.begin()->get<std::string_view>();
if (message_type) {
logger.debug("Received '{}' packet. Parsing body", *message_type);
proto::connect_message message;
for (auto const& [k, v] : msgpack::map_view(tokens.subspan(1))) {
logger.debug("Parsing {} -> {}", k, v);
}
} else {
logger.error("Did not get message type: {}", message_type.error());
}
// The first token should include
return true;
}
awaitable<void> await_auth() noexcept {
std::array<std::byte, 16> buffer;
awaitable<tl::expected<std::vector<std::byte>, boost::system::error_code>>
await_message() noexcept {
// Use a small buffer on the stack to read the initial header.
std::array<std::byte, 8> buffer;
auto [ec, n] = co_await socket_.async_read_some(
net::buffer(buffer), no_ex_coro);
if (ec) {
logger.error("Reading from user socket failed: {}", ec);
co_return;
co_return tl::make_unexpected(ec);
}
logger.debug("Read {} bytes from client: {}", n,
std::span(buffer.data(), n));
auto hdr_result = parse_header(std::span(buffer.data(), n));
if (!hdr_result) {
auto hdr = parse_header(std::span(buffer.data(), n));
if (!hdr) {
logger.error("Unable to parse header: {}", hdr.error());
co_return tl::make_unexpected(boost::system::error_code: );
co_return;
}
auto msg = std::move(*hdr_result);
auto maybe_error = co_await buffer_message(msg);
auto msg = std::move(*hdr);
if (!maybe_error) {
logger.error("Unable to buffer message: {}",
maybe_error.error());
if (auto result = co_await buffer_message(msg); !result) {
logger.error("Unable to parse header: {}", result.error());
co_return;
}
logger.trace("Message: {}", msg);
}
awaitable<bool> await_connect() noexcept {
auto reader = msgpack::token_reader(msg);
std::array<msgpack::token, 32> tokens;
auto parsed = reader.read_some(tokens).and_then(
[this](auto c) { for (auto t : c) logger.trace("{}", t); return handle_auth(c); })
.or_else([](auto const& error) {
auto maybe_user = reader.read_some(tokens)
.and_then(handle_connect)
.map_error([](auto const& error) {
logger.error("Unable to parse msgpack tokens: {}", error);
});
if (!parsed) {
co_return;
if (!maybe_user) {
co_return;
}
// Authenticate against database.
logger.debug("User {} established connection", maybe_user->user_id);
session_ = server_.establish_session(std::move(*maybe_user));
}
enum class state {
@ -280,27 +354,11 @@ public:
active
};
monolithic_server& server_;
user_session* session_{};
net::ip::tcp::socket socket_;
};
class monolithic_server : public server {
public:
monolithic_server(std::string_view address, std::uint16_t user_port,
std::uint16_t websocket_port);
std::error_code run() noexcept override;
private:
awaitable<void> user_listen();
net::io_context io_context_;
net::ip::address addr_;
std::uint16_t user_port_;
std::uint16_t websocket_port_;
};
monolithic_server::monolithic_server(std::string_view address,
std::uint16_t user_port, std::uint16_t websocket_port)
: io_context_{1}
@ -315,7 +373,7 @@ awaitable<void> monolithic_server::user_listen() {
auto exec = co_await net::this_coro::executor;
net::ip::tcp::acceptor acceptor{exec, {addr_, user_port_}};
while (true) {
std::make_shared<user_session>(
std::make_shared<user_connection>(*this,
co_await acceptor.async_accept(use_awaitable))->start();
}
}
@ -336,6 +394,15 @@ std::error_code monolithic_server::run() noexcept {
return {};
}
user_session* monolithic_server::establish_session(user_session session) {
auto existing_session = active_user_sessions_.find(session.user_id);
if (existing_session == active_user_sessions_.end()) {
// No session exists with that user ID yet.
active_user_sessions_.emplace(session.user_id,
std::move(session.user_id));
}
}
std::unique_ptr<server> parselink::make_server(std::string_view address,
std::uint16_t user_port, std::uint16_t websocket_port) {
using impl = monolithic_server;