diff --git a/source/include/parselink/msgpack/token/reader.h b/source/include/parselink/msgpack/token/reader.h index ba61494..960cfd4 100644 --- a/source/include/parselink/msgpack/token/reader.h +++ b/source/include/parselink/msgpack/token/reader.h @@ -54,9 +54,13 @@ public: constexpr token_reader(std::span src) noexcept : data_(src), curr_{src.begin()}, end_(src.end()) {} + constexpr auto remaining(auto itr) noexcept { + using dist_type = decltype(std::distance(itr, end_)); + return std::max(dist_type(0), std::distance(itr, end_)); + } + constexpr auto remaining() noexcept { - using dist_type = decltype(std::distance(curr_, end_)); - return std::max(dist_type(0), std::distance(curr_, end_)); + return remaining(curr_); } // Read the next token. If the reader currently points to the end of the @@ -64,45 +68,102 @@ public: // some data present in the buffer, then incomplete_message is returned, // potentially hinting that further buffering is required. constexpr tl::expected read_one() noexcept { + token tok; if (curr_ >= end_) { return tl::make_unexpected(error::end_of_message); } + auto curr = curr_; + // Enumerate the current byte first by switch statement, then by // fix types. long int size = 0; - auto id = *curr_; - ++curr_; + std::size_t var_size = 0; + auto id = *curr; + format::type token_type; + ++curr; switch (id) { case format::uint8::marker: - size = 1; + size = sizeof(format::uint8::value_type); + token_type = format::type::unsigned_int; + break; case format::uint16::marker: - size = 2; + size = sizeof(format::uint16::value_type); + token_type = format::type::unsigned_int; + break; case format::uint32::marker: - size = 3; + size = sizeof(format::uint32::value_type); + token_type = format::type::unsigned_int; + break; case format::uint64::marker: - size = 4; - if (remaining() < size) { - return tl::make_unexpected(error::incomplete_message); - } - return detail::read_value(size, curr_); - + size = sizeof(format::uint64::value_type); + token_type = format::type::unsigned_int; + break; case format::int8::marker: - size = 1; + size = sizeof(format::int8::value_type); + token_type = format::type::signed_int; + break; case format::int16::marker: - size = 2; + size = sizeof(format::int16::value_type); + token_type = format::type::signed_int; + break; case format::int32::marker: - size = 3; + size = sizeof(format::int32::value_type); + token_type = format::type::signed_int; + break; case format::int64::marker: - size = 4; - if (remaining() < size) { + size = sizeof(format::int64::value_type); + token_type = format::type::signed_int; + break; + case format::str8::marker: + size = sizeof(format::str8::first_type); + token_type = format::type::string; + break; + case format::str16::marker: + size = sizeof(format::str16::first_type); + token_type = format::type::string; + break; + case format::str32::marker: + size = sizeof(format::str32::first_type); + token_type = format::type::string; + break; + default: + return tl::make_unexpected(error::not_implemented); + } + + switch (token_type) { + case format::type::unsigned_int: + if (remaining(curr) < size) { + return tl::make_unexpected(error::out_of_space); + } + tok = detail::read_value(size, curr); + break; + case format::type::signed_int: + if (remaining(curr) < size) { + return tl::make_unexpected(error::out_of_space); + } + tok = detail::read_value(size, curr); + break; + case format::type::string: + if (remaining(curr) < size) { return tl::make_unexpected(error::incomplete_message); } - return detail::read_value(size, curr_); + var_size = std::bit_cast(detail::read(size, curr)); + if (std::size_t(remaining(curr)) < var_size) { + return tl::make_unexpected(error::incomplete_message); + } + using type = token_traits::storage_type; + { + auto ptr = reinterpret_cast(&*curr); + tok = token{std::string_view{ptr, var_size}}; + curr += var_size; + } + default: break; } - return tl::make_unexpected(error::not_implemented); - } + curr_ = curr; + return {tok}; + } // Read multiple tokens from the byte buffer. The number of tokens parsed // can be surmised from the returned span of tokens. If the reader diff --git a/tests/msgpack/BUILD b/tests/msgpack/BUILD index 113d3b7..5562a3f 100644 --- a/tests/msgpack/BUILD +++ b/tests/msgpack/BUILD @@ -8,8 +8,9 @@ cc_library( deps = [ "//source:msgpack", "@expected", - "@ut", "@fmt", + "@magic_enum", + "@ut", ], ) diff --git a/tests/msgpack/test_token_reader.cpp b/tests/msgpack/test_token_reader.cpp index 976ab76..e36d13f 100644 --- a/tests/msgpack/test_token_reader.cpp +++ b/tests/msgpack/test_token_reader.cpp @@ -10,6 +10,49 @@ namespace { return {std::byte(std::forward(bytes))...}; } +template struct oversized_array { + std::array data; + std::size_t size; +}; + +constexpr auto to_bytes_array_oversized(auto const &container) { + oversized_array arr; + std::copy(std::begin(container), std::end(container), std::begin(arr.data)); + arr.size = std::distance(std::begin(container), std::end(container)); + return arr; +} + + consteval auto generate_bytes(auto callable) { + constexpr auto oversized = to_bytes_array_oversized(callable()); + std::array out; + std::copy(std::begin(oversized.data), + std::next(std::begin(oversized.data), oversized.size), + std::begin(out)); + return out; + } + +template +consteval auto cat(std::array const &a, + std::array const &b) { + std::array out; + std::copy(std::begin(a), std::next(std::begin(a), std::size(a)), + std::begin(out)); + std::copy(std::begin(b), std::next(std::begin(b), std::size(b)), + std::next(std::begin(out), std::size(a))); + return out; +} + + constexpr auto from_string_view(std::string_view sv) { + std::vector range; + range.resize(sv.size()); + auto itr = range.begin(); + for (auto c : sv) { + *itr = std::byte(c); + ++itr; + } + return range; + } + template constexpr bool wrong_types(auto const& obj) { auto err = tl::make_unexpected(msgpack::error::wrong_type); @@ -33,14 +76,29 @@ namespace { } suite reader = [] { - "construction"_test = [] { - constexpr auto payload = make_bytes(0xce, 0x01, 0x02, 0x03, 0x09, 0xce); + "read uint32"_test = [] { + constexpr auto payload = make_bytes(0xce, 0x01, 0x02, 0x03, 0x09); msgpack::token_reader reader(payload); auto token = reader.read_one(); expect(token && token->type() == msgpack::format::type::unsigned_int); expect(token->get() == tl::make_unexpected(msgpack::error::will_truncate)); + expect(token->get() == tl::make_unexpected(msgpack::error::will_truncate)); expect(token->get() == 0x01020309); + expect(token->get() == 0x01020309); token = reader.read_one(); - expect(token == tl::make_unexpected(msgpack::error::incomplete_message)); + expect(token == tl::make_unexpected(msgpack::error::end_of_message)); + }; + "read str8"_test = [] { + constexpr std::string_view sv = "hello d"; + constexpr auto payload = + cat(make_bytes(0xd9, sv.size()), + generate_bytes([sv] { return from_string_view(sv); })); + + msgpack::token_reader reader(payload); + auto token = reader.read_one(); + expect(token && token->type() == msgpack::format::type::string); + expect(token->get() == sv); + token = reader.read_one(); + expect(token == tl::make_unexpected(msgpack::error::end_of_message)); }; };