This commit is contained in:
Kurt Sassenrath 2023-10-16 19:27:12 -07:00
parent 1c7047e314
commit 9346b5be5d
4 changed files with 143 additions and 109 deletions

View File

@ -240,6 +240,9 @@ public:
} }
} }
template <std::size_t N>
constexpr bool operator==(char const (&t)[N]) const noexcept;
template <typename T> template <typename T>
constexpr bool operator==(T const& t) const noexcept { constexpr bool operator==(T const& t) const noexcept {
if constexpr (std::equality_comparable<T>) { if constexpr (std::equality_comparable<T>) {
@ -286,6 +289,13 @@ constexpr tl::expected<std::string_view, error> token_base<8>::get()
return std::string_view{value_.str, size_and_type_.get_size()}; return std::string_view{value_.str, size_and_type_.get_size()};
} }
template <std::size_t N>
constexpr bool token_base<8>::operator==(char const (&t)[N]) const noexcept {
auto result = get<std::string_view>().map([&t](auto v) {
return v == std::string_view{t}; });
return result && *result;
}
template<> template<>
inline tl::expected<std::vector<std::byte>, error> token_base<8>::get() inline tl::expected<std::vector<std::byte>, error> token_base<8>::get()
const noexcept const noexcept

View File

@ -25,6 +25,7 @@
#include "type.h" #include "type.h"
#include <bits/iterator_concepts.h>
#include <fmt/format.h> #include <fmt/format.h>
#include <ranges> #include <ranges>
@ -35,6 +36,8 @@ namespace msgpack {
std::ranges::range_value_t<V>> std::ranges::range_value_t<V>>
struct map_view : public std::ranges::view_interface<map_view<V>> { struct map_view : public std::ranges::view_interface<map_view<V>> {
public: public:
class sentinel;
class iterator { class iterator {
friend class sentinel; friend class sentinel;
@ -57,10 +60,10 @@ namespace msgpack {
} }
public: public:
using value_type = std::pair<V, V>; using value_type = std::pair<base_value_type, base_value_type>;
using reference = std::pair<base_reference, base_reference>; using reference = std::pair<base_reference, base_reference>;
using difference_type = std::ptrdiff_t; using difference_type = std::ptrdiff_t;
using iterator_category = std::forward_iterator_tag; using iterator_category = std::input_iterator_tag;
iterator() = default; iterator() = default;
iterator(V const& base) iterator(V const& base)
@ -76,10 +79,10 @@ namespace msgpack {
} }
} }
[[nodiscard]] reference operator*() const { [[nodiscard]] reference operator*() const {
return { *k_, *v_ }; return { *k_, *v_ };
} }
iterator& operator++() { iterator& operator++() {
k_ = next(v_); k_ = next(v_);
v_ = next(k_); v_ = next(k_);
@ -98,6 +101,7 @@ namespace msgpack {
base_ == rhs.base_; base_ == rhs.base_;
} }
private:
V const* base_{}; V const* base_{};
base_iterator k_{}; base_iterator k_{};
base_iterator v_{}; base_iterator v_{};

View File

@ -70,53 +70,6 @@ struct parser_data_message {
std::span<std::byte> payload; std::span<std::byte> payload;
}; };
using message = std::variant<
error_message,
connect_message,
challenge_message,
response_message,
session_established_message,
parser_data_message>;
enum class error {
bad_magic, // Did not get the message magic expected
too_large, // The message size was too large.
unknown_type, // The message type is not known
};
// This class is responsible for consuming buffer data and yielding a message
// instance when complete. Will throw an error if data is incorrect.
//
class builder {
public:
// For now, builders don't manage any buffers themselves. Later, that
// may change.
builder() = default;
// Reset the builder to its initial state. This means any partially-decoded
// message data will be lost.
void reset() noexcept;
// How many bytes are needed to perform a meaningful amount of work.
std::size_t bytes_needed() noexcept;
// Process data from a buffer, building messages. Returns the number of
// bytes read from the buffer for the caller's bookkeeping. May yield a
// message in addition.
std::size_t process(std::span<std::byte const> buffer) noexcept;
private:
enum class state {
magic,
size,
payload
};
state state_{state::magic};
std::size_t payload_size_{};
std::size_t payload_remaining_{};
};
} // namespace message } // namespace message
} // namespace parselink } // namespace parselink

View File

@ -27,7 +27,6 @@
#include <fmt/ranges.h> #include <fmt/ranges.h>
#include <boost/asio/io_context.hpp> #include <boost/asio/io_context.hpp>
#include <boost/asio/signal_set.hpp> #include <boost/asio/signal_set.hpp>
#include <boost/asio/redirect_error.hpp> #include <boost/asio/redirect_error.hpp>
@ -41,7 +40,11 @@
#include <boost/asio/detached.hpp> #include <boost/asio/detached.hpp>
#include <boost/asio/as_tuple.hpp> #include <boost/asio/as_tuple.hpp>
#include <chrono>
#include <map>
using namespace parselink; using namespace parselink;
using namespace std::chrono_literals;
namespace net = boost::asio; namespace net = boost::asio;
using net::co_spawn; using net::co_spawn;
@ -50,6 +53,12 @@ using net::use_awaitable;
using net::deferred; using net::deferred;
using net::detached; using net::detached;
enum class error {
system,
msgpack,
};
//----------------------------------------------------------------------------- //-----------------------------------------------------------------------------
// TODO(ksassenrath): These are logging formatters for various boost/asio types. // TODO(ksassenrath): These are logging formatters for various boost/asio types.
// Not all code is exposed to them, so they cannot be defined inside the // Not all code is exposed to them, so they cannot be defined inside the
@ -146,15 +155,87 @@ namespace {
constexpr auto no_ex_defer = net::as_tuple(deferred); constexpr auto no_ex_defer = net::as_tuple(deferred);
} }
struct msgbuf { struct user_session {
std::vector<std::byte> payload; std::string user_id;
std::array<std::byte, 32> session_id;
std::chrono::system_clock::time_point expires_at;
}; };
class user_session : public std::enable_shared_from_this<user_session> { class monolithic_server : public server {
public: public:
user_session(net::ip::tcp::socket sock) : socket_(std::move(sock)) {} monolithic_server(std::string_view address, std::uint16_t user_port,
~user_session() { std::uint16_t websocket_port);
logger.debug("Closing connection to {}", socket_.remote_endpoint());
std::error_code run() noexcept override;
private:
friend user_session;
awaitable<void> user_listen();
user_session* establish_session(user_session session);
std::map<std::string, user_session, std::less<>> active_user_sessions_;
net::io_context io_context_;
net::ip::address addr_;
std::uint16_t user_port_;
std::uint16_t websocket_port_;
};
tl::expected<user_session, msgpack::error> handle_connect(
std::span<msgpack::token> tokens) noexcept {
user_session user;
auto message_type = tokens.begin()->get<std::string_view>();
if (message_type) {
logger.debug("Received '{}' packet. Parsing body", *message_type);
auto map = msgpack::map_view(tokens.subspan(1));
constexpr auto find_version = [](auto const& kv) {
return kv.first == "version";
};
auto field = std::ranges::find_if(map, find_version);
if (field != map.end() && (*field).second == 1u) {
logger.debug("Version {}\n", (*field).second);
} else {
logger.error("connect failed: missing / unsupported version");
return tl::make_unexpected(msgpack::error::unsupported);
}
for (auto const& [k, v] : msgpack::map_view(tokens.subspan(1))) {
tl::expected<void, msgpack::error> result;
if (k.type() != msgpack::format::type::string) {
logger.error("connect failed: non-string key {}", k);
return tl::make_unexpected(msgpack::error::bad_value);
}
if (k == "user_id") {
result = v.get<std::string>().map([&user](auto uid){
user.user_id = std::move(uid);
});
}
if (!result) {
logger.error("connect failed: {} -> {}: {}", k, v,
result.error());
return tl::make_unexpected(msgpack::error::bad_value);
}
}
} else {
logger.error("Did not get message type: {}", message_type.error());
return tl::make_unexpected(msgpack::error::bad_value);
}
return user;
}
class user_connection : public std::enable_shared_from_this<user_connection> {
public:
user_connection(monolithic_server& server, net::ip::tcp::socket sock)
: server_(server)
, socket_(std::move(sock)) {}
~user_connection() {
logger.debug("Connection to {} closed.", socket_.remote_endpoint());
boost::system::error_code ec; boost::system::error_code ec;
socket_.shutdown(net::ip::tcp::socket::shutdown_both, ec); socket_.shutdown(net::ip::tcp::socket::shutdown_both, ec);
socket_.close(); socket_.close();
@ -163,21 +244,23 @@ public:
void start() { void start() {
logger.debug("New connection from {}", socket_.remote_endpoint()); logger.debug("New connection from {}", socket_.remote_endpoint());
co_spawn(socket_.get_executor(), [self = shared_from_this()]{ co_spawn(socket_.get_executor(), [self = shared_from_this()]{
return self->await_auth(); return self->await_connect();
}, detached); }, detached);
} }
tl::expected<std::vector<std::byte>, msgpack::error> parse_header( tl::expected<std::vector<std::byte>, msgpack::error> parse_header(
std::span<std::byte> data) noexcept { std::span<std::byte> data) noexcept {
auto reader = msgpack::token_reader(data); auto reader = msgpack::token_reader(data);
auto magic = reader.read_one().map( auto magic = reader.read_one().map(
[](auto t){ return t == std::string_view{"prs"}; }); [](auto t){ return t == "prs"; });
if (magic && *magic) { if (magic && *magic) {
logger.debug("Got magic from client"); logger.debug("Got magic from client");
} else { } else {
logger.error("Failed to get magic from client: {}", magic); logger.error("Failed to get magic from client: {}", magic);
return tl::unexpected(magic.error()); return tl::unexpected(magic.error());
} }
auto sz = reader.read_one().and_then( auto sz = reader.read_one().and_then(
[](auto t){ return t.template get<std::size_t>(); }); [](auto t){ return t.template get<std::size_t>(); });
if (sz && *sz) { if (sz && *sz) {
@ -186,6 +269,7 @@ public:
logger.debug("Failed to get packet size from client: {}", sz); logger.debug("Failed to get packet size from client: {}", sz);
return tl::unexpected(magic.error()); return tl::unexpected(magic.error());
} }
// Copy the rest of the message to the buffer for parsing. // Copy the rest of the message to the buffer for parsing.
// TODO(ksassenrath): Replace vector with custom buffer. // TODO(ksassenrath): Replace vector with custom buffer.
std::vector<std::byte> msg; std::vector<std::byte> msg;
@ -215,63 +299,53 @@ public:
co_return std::monostate{}; co_return std::monostate{};
} }
tl::expected<bool, msgpack::error> handle_auth(std::span<msgpack::token> tokens) { awaitable<tl::expected<std::vector<std::byte>, boost::system::error_code>>
auto message_type = tokens.begin()->get<std::string_view>(); await_message() noexcept {
if (message_type) { // Use a small buffer on the stack to read the initial header.
logger.debug("Received '{}' packet. Parsing body", *message_type); std::array<std::byte, 8> buffer;
proto::connect_message message;
for (auto const& [k, v] : msgpack::map_view(tokens.subspan(1))) {
logger.debug("Parsing {} -> {}", k, v);
}
} else {
logger.error("Did not get message type: {}", message_type.error());
}
// The first token should include
return true;
}
awaitable<void> await_auth() noexcept {
std::array<std::byte, 16> buffer;
auto [ec, n] = co_await socket_.async_read_some( auto [ec, n] = co_await socket_.async_read_some(
net::buffer(buffer), no_ex_coro); net::buffer(buffer), no_ex_coro);
if (ec) { if (ec) {
logger.error("Reading from user socket failed: {}", ec); logger.error("Reading from user socket failed: {}", ec);
co_return; co_return tl::make_unexpected(ec);
} }
logger.debug("Read {} bytes from client: {}", n, logger.debug("Read {} bytes from client: {}", n,
std::span(buffer.data(), n)); std::span(buffer.data(), n));
auto hdr_result = parse_header(std::span(buffer.data(), n)); auto hdr = parse_header(std::span(buffer.data(), n));
if (!hdr_result) { if (!hdr) {
logger.error("Unable to parse header: {}", hdr.error());
co_return tl::make_unexpected(boost::system::error_code: );
co_return; co_return;
} }
auto msg = std::move(*hdr_result); auto msg = std::move(*hdr);
auto maybe_error = co_await buffer_message(msg);
if (!maybe_error) { if (auto result = co_await buffer_message(msg); !result) {
logger.error("Unable to buffer message: {}", logger.error("Unable to parse header: {}", result.error());
maybe_error.error());
co_return; co_return;
} }
logger.trace("Message: {}", msg); }
awaitable<bool> await_connect() noexcept {
auto reader = msgpack::token_reader(msg); auto reader = msgpack::token_reader(msg);
std::array<msgpack::token, 32> tokens; std::array<msgpack::token, 32> tokens;
auto parsed = reader.read_some(tokens).and_then( auto maybe_user = reader.read_some(tokens)
[this](auto c) { for (auto t : c) logger.trace("{}", t); return handle_auth(c); }) .and_then(handle_connect)
.or_else([](auto const& error) { .map_error([](auto const& error) {
logger.error("Unable to parse msgpack tokens: {}", error); logger.error("Unable to parse msgpack tokens: {}", error);
}); });
if (!parsed) { if (!maybe_user) {
co_return; co_return;
} }
// Authenticate against database. // Authenticate against database.
logger.debug("User {} established connection", maybe_user->user_id);
session_ = server_.establish_session(std::move(*maybe_user));
} }
enum class state { enum class state {
@ -280,27 +354,11 @@ public:
active active
}; };
monolithic_server& server_;
user_session* session_{};
net::ip::tcp::socket socket_; net::ip::tcp::socket socket_;
}; };
class monolithic_server : public server {
public:
monolithic_server(std::string_view address, std::uint16_t user_port,
std::uint16_t websocket_port);
std::error_code run() noexcept override;
private:
awaitable<void> user_listen();
net::io_context io_context_;
net::ip::address addr_;
std::uint16_t user_port_;
std::uint16_t websocket_port_;
};
monolithic_server::monolithic_server(std::string_view address, monolithic_server::monolithic_server(std::string_view address,
std::uint16_t user_port, std::uint16_t websocket_port) std::uint16_t user_port, std::uint16_t websocket_port)
: io_context_{1} : io_context_{1}
@ -315,7 +373,7 @@ awaitable<void> monolithic_server::user_listen() {
auto exec = co_await net::this_coro::executor; auto exec = co_await net::this_coro::executor;
net::ip::tcp::acceptor acceptor{exec, {addr_, user_port_}}; net::ip::tcp::acceptor acceptor{exec, {addr_, user_port_}};
while (true) { while (true) {
std::make_shared<user_session>( std::make_shared<user_connection>(*this,
co_await acceptor.async_accept(use_awaitable))->start(); co_await acceptor.async_accept(use_awaitable))->start();
} }
} }
@ -336,6 +394,15 @@ std::error_code monolithic_server::run() noexcept {
return {}; return {};
} }
user_session* monolithic_server::establish_session(user_session session) {
auto existing_session = active_user_sessions_.find(session.user_id);
if (existing_session == active_user_sessions_.end()) {
// No session exists with that user ID yet.
active_user_sessions_.emplace(session.user_id,
std::move(session.user_id));
}
}
std::unique_ptr<server> parselink::make_server(std::string_view address, std::unique_ptr<server> parselink::make_server(std::string_view address,
std::uint16_t user_port, std::uint16_t websocket_port) { std::uint16_t user_port, std::uint16_t websocket_port) {
using impl = monolithic_server; using impl = monolithic_server;