From dcdb3783e1d5b6aa72131a775c7c30500756901b Mon Sep 17 00:00:00 2001 From: Giulio Eulisse <10544+ktf@users.noreply.github.com> Date: Thu, 19 Mar 2026 07:16:18 +0100 Subject: [PATCH] DPL: propaedeutic to navigate a MessageSet without caching pairs All this should be fairly straight forward changes while still preserving the old API. If something fails at this level it means that even the counting of dataset changes with this included, which it should not be. --- .../Core/include/Framework/DataModelViews.h | 7 +- Framework/Core/include/Framework/MessageSet.h | 22 +-- Framework/Core/test/test_MessageSet.cxx | 168 +++++++++++------- 3 files changed, 115 insertions(+), 82 deletions(-) diff --git a/Framework/Core/include/Framework/DataModelViews.h b/Framework/Core/include/Framework/DataModelViews.h index b7a334454bb6e..f42ef85ec78e1 100644 --- a/Framework/Core/include/Framework/DataModelViews.h +++ b/Framework/Core/include/Framework/DataModelViews.h @@ -206,15 +206,10 @@ struct get_num_payloads { struct MessageSet; -struct MessageStore { - std::span sets; - size_t inputsPerSlot = 0; -}; - struct inputs_for_slot { TimesliceSlot slot; template - requires requires(R r) { std::ranges::random_access_range; } + requires requires(R r) { requires std::ranges::random_access_range; } friend std::span operator|(R&& r, inputs_for_slot self) { return std::span(r.sets[self.slot.index * r.inputsPerSlot]); diff --git a/Framework/Core/include/Framework/MessageSet.h b/Framework/Core/include/Framework/MessageSet.h index e7ae70e0ea2e5..281f9c42a0773 100644 --- a/Framework/Core/include/Framework/MessageSet.h +++ b/Framework/Core/include/Framework/MessageSet.h @@ -12,13 +12,13 @@ #define FRAMEWORK_MESSAGESET_H #include "Framework/PartRef.h" +#include +#include "Framework/DataModelViews.h" #include #include #include -namespace o2 -{ -namespace framework +namespace o2::framework { /// A set of inflight messages. @@ -83,21 +83,21 @@ struct MessageSet { } /// get number of in-flight O2 messages - size_t size() const + [[nodiscard]] size_t size() const { - return messageMap.size(); + return messages | count_parts{}; } /// get number of header-payload pairs - size_t getNumberOfPairs() const + [[nodiscard]] size_t getNumberOfPairs() const { - return pairMap.size(); + return messages | count_payloads{}; } /// get number of payloads for an in-flight message - size_t getNumberOfPayloads(size_t mi) const + [[nodiscard]] size_t getNumberOfPayloads(size_t mi) const { - return messageMap[mi].size; + return messages | get_num_payloads{mi}; } /// clear the set @@ -179,6 +179,6 @@ struct MessageSet { } }; -} // namespace framework -} // namespace o2 +} // namespace o2::framework + #endif // FRAMEWORK_MESSAGESET_H diff --git a/Framework/Core/test/test_MessageSet.cxx b/Framework/Core/test/test_MessageSet.cxx index d56e32fea1adb..37f823197ef18 100644 --- a/Framework/Core/test/test_MessageSet.cxx +++ b/Framework/Core/test/test_MessageSet.cxx @@ -10,126 +10,164 @@ // or submit itself to any jurisdiction. #include +#include #include "Framework/MessageSet.h" +#include "Framework/DataProcessingHeader.h" +#include "Headers/Stack.h" +#include "MemoryResources/MemoryResources.h" #include using namespace o2::framework; -TEST_CASE("MessageSet") { +TEST_CASE("MessageSet") +{ o2::framework::MessageSet msgSet; - std::vector ptrs; - std::unique_ptr msg(nullptr); + o2::header::DataHeader dh{}; + dh.splitPayloadParts = 0; + dh.splitPayloadIndex = 0; + o2::framework::DataProcessingHeader dph{0, 1}; + auto transport = fair::mq::TransportFactory::CreateTransportFactory("zeromq"); + fair::mq::MessagePtr payload(transport->CreateMessage()); + auto channelAlloc = o2::pmr::getTransportAllocator(transport.get()); + fair::mq::MessagePtr header = o2::pmr::getMessage(o2::header::Stack{channelAlloc, dh, dph}); std::unique_ptr msg2(nullptr); - ptrs.emplace_back(std::move(msg)); + std::vector ptrs; + ptrs.emplace_back(std::move(header)); ptrs.emplace_back(std::move(msg2)); msgSet.add([&ptrs](size_t i) -> fair::mq::MessagePtr& { return ptrs[i]; }, 2); REQUIRE(msgSet.messages.size() == 2); - REQUIRE(msgSet.messageMap.size() == 1); - REQUIRE(msgSet.pairMap.size() == 1); - REQUIRE(msgSet.messageMap[0].position == 0); - REQUIRE(msgSet.messageMap[0].size == 1); - - REQUIRE(msgSet.pairMap[0].partIndex == 0); - REQUIRE(msgSet.pairMap[0].payloadIndex == 0); + REQUIRE((msgSet.messages | count_payloads{}) == 1); + REQUIRE((msgSet.messages | get_dataref_indices{0, 0}).headerIdx == 0); + REQUIRE((msgSet.messages | get_dataref_indices{0, 0}).payloadIdx == 1); + REQUIRE((msgSet.messages | get_pair{0}).headerIdx == 0); + REQUIRE((msgSet.messages | get_pair{0}).payloadIdx == 1); + CHECK_THROWS((msgSet.messages | get_pair{1})); } -TEST_CASE("MessageSetWithFunction") { +TEST_CASE("MessageSetWithFunction") +{ std::vector ptrs; - std::unique_ptr msg(nullptr); + o2::header::DataHeader dh{}; + dh.splitPayloadParts = 0; + dh.splitPayloadIndex = 0; + o2::framework::DataProcessingHeader dph{0, 1}; + auto transport = fair::mq::TransportFactory::CreateTransportFactory("zeromq"); + fair::mq::MessagePtr payload(transport->CreateMessage()); + auto channelAlloc = o2::pmr::getTransportAllocator(transport.get()); + fair::mq::MessagePtr header = o2::pmr::getMessage(o2::header::Stack{channelAlloc, dh, dph}); std::unique_ptr msg2(nullptr); - ptrs.emplace_back(std::move(msg)); + ptrs.emplace_back(std::move(header)); ptrs.emplace_back(std::move(msg2)); o2::framework::MessageSet msgSet([&ptrs](size_t i) -> fair::mq::MessagePtr& { return ptrs[i]; }, 2); REQUIRE(msgSet.messages.size() == 2); - REQUIRE(msgSet.messageMap.size() == 1); - REQUIRE(msgSet.pairMap.size() == 1); - REQUIRE(msgSet.messageMap[0].position == 0); - REQUIRE(msgSet.messageMap[0].size == 1); - - REQUIRE(msgSet.pairMap[0].partIndex == 0); - REQUIRE(msgSet.pairMap[0].payloadIndex == 0); + REQUIRE((msgSet.messages | count_payloads{}) == 1); + REQUIRE((msgSet.messages | get_dataref_indices{0, 0}).headerIdx == 0); + REQUIRE((msgSet.messages | get_dataref_indices{0, 0}).payloadIdx == 1); + REQUIRE((msgSet.messages | get_pair{0}).headerIdx == 0); + REQUIRE((msgSet.messages | get_pair{0}).payloadIdx == 1); + CHECK_THROWS((msgSet.messages | get_pair{1})); } -TEST_CASE("MessageSetWithMultipart") { +TEST_CASE("MessageSetWithMultipart") +{ std::vector ptrs; - std::unique_ptr msg(nullptr); + o2::header::DataHeader dh{}; + dh.splitPayloadParts = 2; + dh.splitPayloadIndex = 2; + o2::framework::DataProcessingHeader dph{0, 1}; + auto transport = fair::mq::TransportFactory::CreateTransportFactory("zeromq"); + fair::mq::MessagePtr payload(transport->CreateMessage()); + auto channelAlloc = o2::pmr::getTransportAllocator(transport.get()); + fair::mq::MessagePtr header = o2::pmr::getMessage(o2::header::Stack{channelAlloc, dh, dph}); std::unique_ptr msg2(nullptr); std::unique_ptr msg3(nullptr); - ptrs.emplace_back(std::move(msg)); + ptrs.emplace_back(std::move(header)); ptrs.emplace_back(std::move(msg2)); ptrs.emplace_back(std::move(msg3)); o2::framework::MessageSet msgSet([&ptrs](size_t i) -> fair::mq::MessagePtr& { return ptrs[i]; }, 3); REQUIRE(msgSet.messages.size() == 3); - REQUIRE(msgSet.messageMap.size() == 1); - REQUIRE(msgSet.pairMap.size() == 2); - REQUIRE(msgSet.messageMap[0].position == 0); - REQUIRE(msgSet.messageMap[0].size == 2); - - REQUIRE(msgSet.pairMap[0].partIndex == 0); - REQUIRE(msgSet.pairMap[0].payloadIndex == 0); - REQUIRE(msgSet.pairMap[1].partIndex == 0); - REQUIRE(msgSet.pairMap[1].payloadIndex == 1); + REQUIRE((msgSet.messages | count_payloads{}) == 2); + REQUIRE((msgSet.messages | get_dataref_indices{0, 0}).headerIdx == 0); + REQUIRE((msgSet.messages | get_dataref_indices{0, 0}).payloadIdx == 1); + REQUIRE((msgSet.messages | get_dataref_indices{0, 1}).headerIdx == 0); + REQUIRE((msgSet.messages | get_dataref_indices{0, 1}).payloadIdx == 2); + REQUIRE((msgSet.messages | get_pair{0}).headerIdx == 0); + REQUIRE((msgSet.messages | get_pair{0}).payloadIdx == 1); + REQUIRE((msgSet.messages | get_pair{1}).headerIdx == 0); + REQUIRE((msgSet.messages | get_pair{1}).payloadIdx == 2); + CHECK_THROWS((msgSet.messages | get_pair{2})); } -TEST_CASE("MessageSetAddPartRef") { +TEST_CASE("MessageSetAddPartRef") +{ std::vector ptrs; std::unique_ptr msg(nullptr); std::unique_ptr msg2(nullptr); ptrs.emplace_back(std::move(msg)); ptrs.emplace_back(std::move(msg2)); - PartRef ref {std::move(msg), std::move(msg2)}; + PartRef ref{std::move(msg), std::move(msg2)}; o2::framework::MessageSet msgSet; msgSet.add(std::move(ref)); REQUIRE(msgSet.messages.size() == 2); - REQUIRE(msgSet.messageMap.size() == 1); - REQUIRE(msgSet.pairMap.size() == 1); - REQUIRE(msgSet.messageMap[0].position == 0); - REQUIRE(msgSet.messageMap[0].size == 1); - - REQUIRE(msgSet.pairMap[0].partIndex == 0); - REQUIRE(msgSet.pairMap[0].payloadIndex == 0); } TEST_CASE("MessageSetAddMultiple") { std::vector ptrs; - std::unique_ptr msg(nullptr); + o2::header::DataHeader dh1{}; + dh1.splitPayloadParts = 0; + dh1.splitPayloadIndex = 0; + o2::header::DataHeader dh2{}; + dh2.splitPayloadParts = 1; + dh2.splitPayloadIndex = 0; + o2::header::DataHeader dh3{}; + dh3.splitPayloadParts = 2; + dh3.splitPayloadIndex = 2; + o2::framework::DataProcessingHeader dph{0, 1}; + auto transport = fair::mq::TransportFactory::CreateTransportFactory("zeromq"); + fair::mq::MessagePtr payload(transport->CreateMessage()); + auto channelAlloc = o2::pmr::getTransportAllocator(transport.get()); + fair::mq::MessagePtr header1 = o2::pmr::getMessage(o2::header::Stack{channelAlloc, dh1, dph}); + fair::mq::MessagePtr header2 = o2::pmr::getMessage(o2::header::Stack{channelAlloc, dh2, dph}); + fair::mq::MessagePtr header3 = o2::pmr::getMessage(o2::header::Stack{channelAlloc, dh3, dph}); + std::unique_ptr msg2(nullptr); - ptrs.emplace_back(std::move(msg)); - ptrs.emplace_back(std::move(msg2)); - PartRef ref{std::move(msg), std::move(msg2)}; + std::unique_ptr msg3(nullptr); + PartRef ref{std::move(header1), std::move(msg2)}; o2::framework::MessageSet msgSet; msgSet.add(std::move(ref)); - PartRef ref2{std::move(msg), std::move(msg2)}; + PartRef ref2{std::move(header2), std::move(msg2)}; msgSet.add(std::move(ref2)); std::vector msgs; - msgs.push_back(std::unique_ptr(nullptr)); + msgs.push_back(std::move(header3)); msgs.push_back(std::unique_ptr(nullptr)); msgs.push_back(std::unique_ptr(nullptr)); msgSet.add([&msgs](size_t i) { return std::move(msgs[i]); - }, 3); + }, + 3); REQUIRE(msgSet.messages.size() == 7); - REQUIRE(msgSet.messageMap.size() == 3); - REQUIRE(msgSet.pairMap.size() == 4); - REQUIRE(msgSet.messageMap[0].position == 0); - REQUIRE(msgSet.messageMap[0].size == 1); - REQUIRE(msgSet.messageMap[1].position == 2); - REQUIRE(msgSet.messageMap[1].size == 1); - REQUIRE(msgSet.messageMap[2].position == 4); - REQUIRE(msgSet.messageMap[2].size == 2); - REQUIRE(msgSet.pairMap[0].partIndex == 0); - REQUIRE(msgSet.pairMap[0].payloadIndex == 0); - REQUIRE(msgSet.pairMap[1].partIndex == 1); - REQUIRE(msgSet.pairMap[1].payloadIndex == 0); - REQUIRE(msgSet.pairMap[2].partIndex == 2); - REQUIRE(msgSet.pairMap[2].payloadIndex == 0); - REQUIRE(msgSet.pairMap[3].partIndex == 2); - REQUIRE(msgSet.pairMap[3].payloadIndex == 1); + REQUIRE((msgSet.messages | count_payloads{}) == 4); + REQUIRE((msgSet.messages | get_dataref_indices{0, 0}).headerIdx == 0); + REQUIRE((msgSet.messages | get_dataref_indices{0, 0}).payloadIdx == 1); + REQUIRE((msgSet.messages | get_dataref_indices{1, 0}).headerIdx == 2); + REQUIRE((msgSet.messages | get_dataref_indices{1, 0}).payloadIdx == 3); + REQUIRE((msgSet.messages | get_dataref_indices{2, 0}).headerIdx == 4); + REQUIRE((msgSet.messages | get_dataref_indices{2, 0}).payloadIdx == 5); + REQUIRE((msgSet.messages | get_dataref_indices{2, 1}).headerIdx == 4); + REQUIRE((msgSet.messages | get_dataref_indices{2, 1}).payloadIdx == 6); + REQUIRE((msgSet.messages | get_pair{0}).headerIdx == 0); + REQUIRE((msgSet.messages | get_pair{0}).payloadIdx == 1); + REQUIRE((msgSet.messages | get_pair{1}).headerIdx == 2); + REQUIRE((msgSet.messages | get_pair{1}).payloadIdx == 3); + REQUIRE((msgSet.messages | get_pair{2}).headerIdx == 4); + REQUIRE((msgSet.messages | get_pair{2}).payloadIdx == 5); + REQUIRE((msgSet.messages | get_pair{3}).headerIdx == 4); + REQUIRE((msgSet.messages | get_pair{3}).payloadIdx == 6); }