diff --git a/Framework/Core/include/Framework/DataModelViews.h b/Framework/Core/include/Framework/DataModelViews.h index b7a334454bb6e..f42ef85ec78e1 100644 --- a/Framework/Core/include/Framework/DataModelViews.h +++ b/Framework/Core/include/Framework/DataModelViews.h @@ -206,15 +206,10 @@ struct get_num_payloads { struct MessageSet; -struct MessageStore { - std::span sets; - size_t inputsPerSlot = 0; -}; - struct inputs_for_slot { TimesliceSlot slot; template - requires requires(R r) { std::ranges::random_access_range; } + requires requires(R r) { requires std::ranges::random_access_range; } friend std::span operator|(R&& r, inputs_for_slot self) { return std::span(r.sets[self.slot.index * r.inputsPerSlot]); diff --git a/Framework/Core/include/Framework/MessageSet.h b/Framework/Core/include/Framework/MessageSet.h index e7ae70e0ea2e5..281f9c42a0773 100644 --- a/Framework/Core/include/Framework/MessageSet.h +++ b/Framework/Core/include/Framework/MessageSet.h @@ -12,13 +12,13 @@ #define FRAMEWORK_MESSAGESET_H #include "Framework/PartRef.h" +#include +#include "Framework/DataModelViews.h" #include #include #include -namespace o2 -{ -namespace framework +namespace o2::framework { /// A set of inflight messages. @@ -83,21 +83,21 @@ struct MessageSet { } /// get number of in-flight O2 messages - size_t size() const + [[nodiscard]] size_t size() const { - return messageMap.size(); + return messages | count_parts{}; } /// get number of header-payload pairs - size_t getNumberOfPairs() const + [[nodiscard]] size_t getNumberOfPairs() const { - return pairMap.size(); + return messages | count_payloads{}; } /// get number of payloads for an in-flight message - size_t getNumberOfPayloads(size_t mi) const + [[nodiscard]] size_t getNumberOfPayloads(size_t mi) const { - return messageMap[mi].size; + return messages | get_num_payloads{mi}; } /// clear the set @@ -179,6 +179,6 @@ struct MessageSet { } }; -} // namespace framework -} // namespace o2 +} // namespace o2::framework + #endif // FRAMEWORK_MESSAGESET_H diff --git a/Framework/Core/test/test_MessageSet.cxx b/Framework/Core/test/test_MessageSet.cxx index d56e32fea1adb..37f823197ef18 100644 --- a/Framework/Core/test/test_MessageSet.cxx +++ b/Framework/Core/test/test_MessageSet.cxx @@ -10,126 +10,164 @@ // or submit itself to any jurisdiction. #include +#include #include "Framework/MessageSet.h" +#include "Framework/DataProcessingHeader.h" +#include "Headers/Stack.h" +#include "MemoryResources/MemoryResources.h" #include using namespace o2::framework; -TEST_CASE("MessageSet") { +TEST_CASE("MessageSet") +{ o2::framework::MessageSet msgSet; - std::vector ptrs; - std::unique_ptr msg(nullptr); + o2::header::DataHeader dh{}; + dh.splitPayloadParts = 0; + dh.splitPayloadIndex = 0; + o2::framework::DataProcessingHeader dph{0, 1}; + auto transport = fair::mq::TransportFactory::CreateTransportFactory("zeromq"); + fair::mq::MessagePtr payload(transport->CreateMessage()); + auto channelAlloc = o2::pmr::getTransportAllocator(transport.get()); + fair::mq::MessagePtr header = o2::pmr::getMessage(o2::header::Stack{channelAlloc, dh, dph}); std::unique_ptr msg2(nullptr); - ptrs.emplace_back(std::move(msg)); + std::vector ptrs; + ptrs.emplace_back(std::move(header)); ptrs.emplace_back(std::move(msg2)); msgSet.add([&ptrs](size_t i) -> fair::mq::MessagePtr& { return ptrs[i]; }, 2); REQUIRE(msgSet.messages.size() == 2); - REQUIRE(msgSet.messageMap.size() == 1); - REQUIRE(msgSet.pairMap.size() == 1); - REQUIRE(msgSet.messageMap[0].position == 0); - REQUIRE(msgSet.messageMap[0].size == 1); - - REQUIRE(msgSet.pairMap[0].partIndex == 0); - REQUIRE(msgSet.pairMap[0].payloadIndex == 0); + REQUIRE((msgSet.messages | count_payloads{}) == 1); + REQUIRE((msgSet.messages | get_dataref_indices{0, 0}).headerIdx == 0); + REQUIRE((msgSet.messages | get_dataref_indices{0, 0}).payloadIdx == 1); + REQUIRE((msgSet.messages | get_pair{0}).headerIdx == 0); + REQUIRE((msgSet.messages | get_pair{0}).payloadIdx == 1); + CHECK_THROWS((msgSet.messages | get_pair{1})); } -TEST_CASE("MessageSetWithFunction") { +TEST_CASE("MessageSetWithFunction") +{ std::vector ptrs; - std::unique_ptr msg(nullptr); + o2::header::DataHeader dh{}; + dh.splitPayloadParts = 0; + dh.splitPayloadIndex = 0; + o2::framework::DataProcessingHeader dph{0, 1}; + auto transport = fair::mq::TransportFactory::CreateTransportFactory("zeromq"); + fair::mq::MessagePtr payload(transport->CreateMessage()); + auto channelAlloc = o2::pmr::getTransportAllocator(transport.get()); + fair::mq::MessagePtr header = o2::pmr::getMessage(o2::header::Stack{channelAlloc, dh, dph}); std::unique_ptr msg2(nullptr); - ptrs.emplace_back(std::move(msg)); + ptrs.emplace_back(std::move(header)); ptrs.emplace_back(std::move(msg2)); o2::framework::MessageSet msgSet([&ptrs](size_t i) -> fair::mq::MessagePtr& { return ptrs[i]; }, 2); REQUIRE(msgSet.messages.size() == 2); - REQUIRE(msgSet.messageMap.size() == 1); - REQUIRE(msgSet.pairMap.size() == 1); - REQUIRE(msgSet.messageMap[0].position == 0); - REQUIRE(msgSet.messageMap[0].size == 1); - - REQUIRE(msgSet.pairMap[0].partIndex == 0); - REQUIRE(msgSet.pairMap[0].payloadIndex == 0); + REQUIRE((msgSet.messages | count_payloads{}) == 1); + REQUIRE((msgSet.messages | get_dataref_indices{0, 0}).headerIdx == 0); + REQUIRE((msgSet.messages | get_dataref_indices{0, 0}).payloadIdx == 1); + REQUIRE((msgSet.messages | get_pair{0}).headerIdx == 0); + REQUIRE((msgSet.messages | get_pair{0}).payloadIdx == 1); + CHECK_THROWS((msgSet.messages | get_pair{1})); } -TEST_CASE("MessageSetWithMultipart") { +TEST_CASE("MessageSetWithMultipart") +{ std::vector ptrs; - std::unique_ptr msg(nullptr); + o2::header::DataHeader dh{}; + dh.splitPayloadParts = 2; + dh.splitPayloadIndex = 2; + o2::framework::DataProcessingHeader dph{0, 1}; + auto transport = fair::mq::TransportFactory::CreateTransportFactory("zeromq"); + fair::mq::MessagePtr payload(transport->CreateMessage()); + auto channelAlloc = o2::pmr::getTransportAllocator(transport.get()); + fair::mq::MessagePtr header = o2::pmr::getMessage(o2::header::Stack{channelAlloc, dh, dph}); std::unique_ptr msg2(nullptr); std::unique_ptr msg3(nullptr); - ptrs.emplace_back(std::move(msg)); + ptrs.emplace_back(std::move(header)); ptrs.emplace_back(std::move(msg2)); ptrs.emplace_back(std::move(msg3)); o2::framework::MessageSet msgSet([&ptrs](size_t i) -> fair::mq::MessagePtr& { return ptrs[i]; }, 3); REQUIRE(msgSet.messages.size() == 3); - REQUIRE(msgSet.messageMap.size() == 1); - REQUIRE(msgSet.pairMap.size() == 2); - REQUIRE(msgSet.messageMap[0].position == 0); - REQUIRE(msgSet.messageMap[0].size == 2); - - REQUIRE(msgSet.pairMap[0].partIndex == 0); - REQUIRE(msgSet.pairMap[0].payloadIndex == 0); - REQUIRE(msgSet.pairMap[1].partIndex == 0); - REQUIRE(msgSet.pairMap[1].payloadIndex == 1); + REQUIRE((msgSet.messages | count_payloads{}) == 2); + REQUIRE((msgSet.messages | get_dataref_indices{0, 0}).headerIdx == 0); + REQUIRE((msgSet.messages | get_dataref_indices{0, 0}).payloadIdx == 1); + REQUIRE((msgSet.messages | get_dataref_indices{0, 1}).headerIdx == 0); + REQUIRE((msgSet.messages | get_dataref_indices{0, 1}).payloadIdx == 2); + REQUIRE((msgSet.messages | get_pair{0}).headerIdx == 0); + REQUIRE((msgSet.messages | get_pair{0}).payloadIdx == 1); + REQUIRE((msgSet.messages | get_pair{1}).headerIdx == 0); + REQUIRE((msgSet.messages | get_pair{1}).payloadIdx == 2); + CHECK_THROWS((msgSet.messages | get_pair{2})); } -TEST_CASE("MessageSetAddPartRef") { +TEST_CASE("MessageSetAddPartRef") +{ std::vector ptrs; std::unique_ptr msg(nullptr); std::unique_ptr msg2(nullptr); ptrs.emplace_back(std::move(msg)); ptrs.emplace_back(std::move(msg2)); - PartRef ref {std::move(msg), std::move(msg2)}; + PartRef ref{std::move(msg), std::move(msg2)}; o2::framework::MessageSet msgSet; msgSet.add(std::move(ref)); REQUIRE(msgSet.messages.size() == 2); - REQUIRE(msgSet.messageMap.size() == 1); - REQUIRE(msgSet.pairMap.size() == 1); - REQUIRE(msgSet.messageMap[0].position == 0); - REQUIRE(msgSet.messageMap[0].size == 1); - - REQUIRE(msgSet.pairMap[0].partIndex == 0); - REQUIRE(msgSet.pairMap[0].payloadIndex == 0); } TEST_CASE("MessageSetAddMultiple") { std::vector ptrs; - std::unique_ptr msg(nullptr); + o2::header::DataHeader dh1{}; + dh1.splitPayloadParts = 0; + dh1.splitPayloadIndex = 0; + o2::header::DataHeader dh2{}; + dh2.splitPayloadParts = 1; + dh2.splitPayloadIndex = 0; + o2::header::DataHeader dh3{}; + dh3.splitPayloadParts = 2; + dh3.splitPayloadIndex = 2; + o2::framework::DataProcessingHeader dph{0, 1}; + auto transport = fair::mq::TransportFactory::CreateTransportFactory("zeromq"); + fair::mq::MessagePtr payload(transport->CreateMessage()); + auto channelAlloc = o2::pmr::getTransportAllocator(transport.get()); + fair::mq::MessagePtr header1 = o2::pmr::getMessage(o2::header::Stack{channelAlloc, dh1, dph}); + fair::mq::MessagePtr header2 = o2::pmr::getMessage(o2::header::Stack{channelAlloc, dh2, dph}); + fair::mq::MessagePtr header3 = o2::pmr::getMessage(o2::header::Stack{channelAlloc, dh3, dph}); + std::unique_ptr msg2(nullptr); - ptrs.emplace_back(std::move(msg)); - ptrs.emplace_back(std::move(msg2)); - PartRef ref{std::move(msg), std::move(msg2)}; + std::unique_ptr msg3(nullptr); + PartRef ref{std::move(header1), std::move(msg2)}; o2::framework::MessageSet msgSet; msgSet.add(std::move(ref)); - PartRef ref2{std::move(msg), std::move(msg2)}; + PartRef ref2{std::move(header2), std::move(msg2)}; msgSet.add(std::move(ref2)); std::vector msgs; - msgs.push_back(std::unique_ptr(nullptr)); + msgs.push_back(std::move(header3)); msgs.push_back(std::unique_ptr(nullptr)); msgs.push_back(std::unique_ptr(nullptr)); msgSet.add([&msgs](size_t i) { return std::move(msgs[i]); - }, 3); + }, + 3); REQUIRE(msgSet.messages.size() == 7); - REQUIRE(msgSet.messageMap.size() == 3); - REQUIRE(msgSet.pairMap.size() == 4); - REQUIRE(msgSet.messageMap[0].position == 0); - REQUIRE(msgSet.messageMap[0].size == 1); - REQUIRE(msgSet.messageMap[1].position == 2); - REQUIRE(msgSet.messageMap[1].size == 1); - REQUIRE(msgSet.messageMap[2].position == 4); - REQUIRE(msgSet.messageMap[2].size == 2); - REQUIRE(msgSet.pairMap[0].partIndex == 0); - REQUIRE(msgSet.pairMap[0].payloadIndex == 0); - REQUIRE(msgSet.pairMap[1].partIndex == 1); - REQUIRE(msgSet.pairMap[1].payloadIndex == 0); - REQUIRE(msgSet.pairMap[2].partIndex == 2); - REQUIRE(msgSet.pairMap[2].payloadIndex == 0); - REQUIRE(msgSet.pairMap[3].partIndex == 2); - REQUIRE(msgSet.pairMap[3].payloadIndex == 1); + REQUIRE((msgSet.messages | count_payloads{}) == 4); + REQUIRE((msgSet.messages | get_dataref_indices{0, 0}).headerIdx == 0); + REQUIRE((msgSet.messages | get_dataref_indices{0, 0}).payloadIdx == 1); + REQUIRE((msgSet.messages | get_dataref_indices{1, 0}).headerIdx == 2); + REQUIRE((msgSet.messages | get_dataref_indices{1, 0}).payloadIdx == 3); + REQUIRE((msgSet.messages | get_dataref_indices{2, 0}).headerIdx == 4); + REQUIRE((msgSet.messages | get_dataref_indices{2, 0}).payloadIdx == 5); + REQUIRE((msgSet.messages | get_dataref_indices{2, 1}).headerIdx == 4); + REQUIRE((msgSet.messages | get_dataref_indices{2, 1}).payloadIdx == 6); + REQUIRE((msgSet.messages | get_pair{0}).headerIdx == 0); + REQUIRE((msgSet.messages | get_pair{0}).payloadIdx == 1); + REQUIRE((msgSet.messages | get_pair{1}).headerIdx == 2); + REQUIRE((msgSet.messages | get_pair{1}).payloadIdx == 3); + REQUIRE((msgSet.messages | get_pair{2}).headerIdx == 4); + REQUIRE((msgSet.messages | get_pair{2}).payloadIdx == 5); + REQUIRE((msgSet.messages | get_pair{3}).headerIdx == 4); + REQUIRE((msgSet.messages | get_pair{3}).payloadIdx == 6); }