Program Listing for File mpi_utils.h
↰ Return to documentation for file (framework/mpi/mpi_utils.h
)
// SPDX-FileCopyrightText: 2024 The OpenSn Authors <https://open-sn.github.io/opensn/>
// SPDX-License-Identifier: MIT
#pragma once
#include "framework/runtime.h"
#include <map>
#include <vector>
#include <type_traits>
#include <set>
namespace opensn
{
/**
* Given each location's local size (of items), builds a vector (dimension comm-size plus 1) of
* where each location's global indices start and end. Example: location i starts at extents[i] and
* ends at extents[i+1]
*/
std::vector<uint64_t> BuildLocationExtents(uint64_t local_size, const mpi::Communicator& comm);
/**
* Given a map with keys indicating the destination process-ids and the values for each key a list
* of values of type T (T must have an MPI_Datatype). Returns a map with the keys indicating the
* source process-ids and the values for each key a list of values of type T (sent by the respective
* process).
*
* The keys must be "castable" to `int`.
*
* Also expects the MPI_Datatype of T.
*/
template <typename K, class T>
std::map<K, std::vector<T>>
MapAllToAll(const std::map<K, std::vector<T>>& pid_data_pairs,
const mpi::Communicator& comm = opensn::mpi_comm)
{
static_assert(std::is_integral<K>::value, "Integral datatype required.");
// Make sendcounts and senddispls
std::vector<int> sendcounts(opensn::mpi_comm.size(), 0);
std::vector<int> senddispls(opensn::mpi_comm.size(), 0);
{
size_t accumulated_displ = 0;
for (const auto& [pid, data] : pid_data_pairs)
{
sendcounts[pid] = static_cast<int>(data.size());
senddispls[pid] = static_cast<int>(accumulated_displ);
accumulated_displ += data.size();
}
}
// Communicate sendcounts to get recvcounts
std::vector<int> recvcounts(opensn::mpi_comm.size(), 0);
comm.all_to_all(sendcounts, recvcounts);
// Populate recvdispls, sender_pids_set, and total_recv_count
// All three these quantities are constructed from recvcounts.
std::vector<int> recvdispls(opensn::mpi_comm.size(), 0);
/// set of neighbor-partitions sending data
std::set<K> sender_pids_set;
{
int displacement = 0;
for (int pid = 0; pid < opensn::mpi_comm.size(); ++pid)
{
recvdispls[pid] = displacement;
displacement += recvcounts[pid];
if (recvcounts[pid] > 0)
sender_pids_set.insert(static_cast<K>(pid));
} // for pid
}
// Make sendbuf
// The data for each partition is now loaded into a single buffer
std::vector<T> sendbuf;
for (const auto& pid_data_pair : pid_data_pairs)
sendbuf.insert(sendbuf.end(), pid_data_pair.second.begin(), pid_data_pair.second.end());
// Make recvbuf
std::vector<T> recvbuf;
// Communicate serial data
comm.all_to_all(sendbuf, sendcounts, senddispls, recvbuf, recvcounts, recvdispls);
std::map<K, std::vector<T>> output_data;
{
for (K pid : sender_pids_set)
{
const int data_count = recvcounts.at(pid);
const int data_displ = recvdispls.at(pid);
auto& data = output_data[pid];
data.resize(data_count);
for (int i = 0; i < data_count; ++i)
data.at(i) = recvbuf.at(data_displ + i);
}
}
return output_data;
}
} // namespace opensn