Commit 6189eeae authored by Thomas Jahns's avatar Thomas Jahns 🤸
Browse files

Merge allocations.

parent c96c095b
......@@ -126,12 +126,6 @@ static const struct xt_redist_vtable redist_collection_vtable = {
.get_MPI_Comm = redist_collection_get_MPI_Comm
};
struct redist_collection_msg {
int rank;
MPI_Datatype *component_dt; // datatypes of the redists (size == num_redists)
};
struct exchanger_cache
{
size_t token;
......@@ -151,66 +145,71 @@ struct Xt_redist_collection_ {
struct exchanger_cache cache;
unsigned ndst, nsrc;
struct redist_collection_msg * send_msgs;
struct redist_collection_msg * recv_msgs;
int *send_ranks, *recv_ranks;
size_t cache_size;
MPI_Comm comm;
int tag_offset;
MPI_Datatype all_component_dt[];
};
static void copy_component_dt(struct redist_collection_msg **msgs,
unsigned *nmsgs,
const Xt_redist *redists, unsigned num_redists,
enum xt_msg_direction direction,
MPI_Datatype (*get_MPI_datatype)(Xt_redist,int))
static unsigned
get_msg_count(const Xt_redist *redists, unsigned num_redists,
enum xt_msg_direction direction,
size_t num_ranks[num_redists],
int *restrict ranks[num_redists])
{
size_t num_ranks[num_redists], rank_pos[num_redists];
int *restrict ranks[num_redists];
bool ranks_left = false;
/* get lists of ranks to send/receive message to/from */
for (size_t j = 0; j < num_redists; ++j) {
num_ranks[j]
size_t nranks = num_ranks[j]
= (size_t)xt_redist_get_msg_ranks(redists[j], direction, ranks + j);
ranks_left |= (nranks > 0);
/* sort list */
xt_sort_int(ranks[j], num_ranks[j]);
ranks_left |= (num_ranks[j] > 0);
rank_pos[j] = 0;
xt_sort_int(ranks[j], nranks);
}
/* count number of different ranks to send/receive message to/from */
size_t num_messages = ranks_left
? xt_ranks_uniq_count(num_redists, num_ranks, (const int *const *)ranks)
: 0;
/* build messages */
struct redist_collection_msg *restrict p = NULL;
if (num_messages) {
MPI_Datatype *restrict dt
= xmalloc(num_messages * num_redists * sizeof (*dt));
p = xmalloc(num_messages * sizeof (*p));
for (size_t i = 0; i < num_messages; ++i) {
return (unsigned)num_messages;
}
static void align_component_dt(unsigned num_redists, unsigned nmsgs,
const Xt_redist *redists,
int *restrict in_ranks[num_redists],
size_t num_ranks[num_redists],
int *out_ranks,
MPI_Datatype *component_dt,
MPI_Datatype (*get_MPI_datatype)(Xt_redist,int))
{
size_t rank_pos[num_redists];
for (size_t j = 0; j < num_redists; ++j)
rank_pos[j] = 0;
if (nmsgs) {
/* find ranks and corresponding component datatypes */
for (size_t i = 0; i < nmsgs; ++i) {
int min_rank = INT_MAX;
for (size_t j = 0; j < num_redists; ++j)
if (rank_pos[j] < num_ranks[j] && ranks[j][rank_pos[j]] < min_rank)
min_rank = ranks[j][rank_pos[j]];
if (rank_pos[j] < num_ranks[j] && in_ranks[j][rank_pos[j]] < min_rank)
min_rank = in_ranks[j][rank_pos[j]];
MPI_Datatype *dts_rank = dt + (size_t)num_redists * i;
for (size_t j = 0; j < num_redists; ++j)
dts_rank[j] =
(rank_pos[j] < num_ranks[j] && ranks[j][rank_pos[j]] == min_rank)
component_dt[i * num_redists + j] =
(rank_pos[j] < num_ranks[j] && in_ranks[j][rank_pos[j]] == min_rank)
? get_MPI_datatype(redists[j], min_rank) : MPI_DATATYPE_NULL;
p[i].rank = min_rank;
p[i].component_dt = dts_rank;
out_ranks[i] = min_rank;
for (size_t j = 0; j < num_redists; ++j)
rank_pos[j]
+= (rank_pos[j] < num_ranks[j] && ranks[j][rank_pos[j]] == min_rank);
+= (rank_pos[j] < num_ranks[j] && in_ranks[j][rank_pos[j]] == min_rank);
}
}
for (size_t j = 0; j < num_redists; ++j)
free(ranks[j]);
*msgs = p;
*nmsgs = (unsigned)num_messages;
free(in_ranks[j]);
}
/* not yet used cache entries are marked with -1 as first displacement,
......@@ -244,16 +243,29 @@ destruct_cache(struct exchanger_cache *cache,
free(cache->src_displacements);
}
Xt_redist xt_redist_collection_new(Xt_redist * redists, int num_redists,
int cache_size, MPI_Comm comm) {
// ensure that yaxt is initialized
assert(xt_initialized());
Xt_redist_collection redist_coll = xmalloc(sizeof (*redist_coll));
redist_coll->vtable = &redist_collection_vtable;
unsigned num_redists_ = num_redists >= 0 ? (unsigned)num_redists : 0;
size_t num_ranks[2][num_redists_];
int *restrict ranks[2][num_redists_];
unsigned nmsg_send = get_msg_count(redists, num_redists_, SEND,
num_ranks[SEND], ranks[SEND]),
nmsg_recv = get_msg_count(redists, num_redists_, RECV,
num_ranks[RECV], ranks[RECV]);
size_t nmsg = (size_t)nmsg_send + nmsg_recv;
size_t size_all_component_dt = sizeof (MPI_Datatype) * num_redists_ * nmsg;
Xt_redist_collection redist_coll
= xmalloc(sizeof (*redist_coll)
+ size_all_component_dt + nmsg * sizeof (int));
redist_coll->ndst = nmsg_recv;
redist_coll->nsrc = nmsg_send;
redist_coll->send_ranks
= (int *)(redist_coll->all_component_dt + nmsg * num_redists_);
redist_coll->recv_ranks = redist_coll->send_ranks + nmsg_send;
redist_coll->vtable = &redist_collection_vtable;
redist_coll->num_redists = num_redists_;
if (cache_size < -1)
Xt_abort(comm, "ERROR: invalid cache size in xt_redist_collection_new",
......@@ -265,12 +277,15 @@ Xt_redist xt_redist_collection_new(Xt_redist * redists, int num_redists,
xt_redist_check_comms(redists, num_redists, comm);
copy_component_dt(&redist_coll->send_msgs, &redist_coll->nsrc, redists,
num_redists_, SEND, xt_redist_get_send_MPI_Datatype);
copy_component_dt(&redist_coll->recv_msgs, &redist_coll->ndst, redists,
num_redists_, RECV, xt_redist_get_recv_MPI_Datatype);
init_cache(&redist_coll->cache, redist_coll->cache_size,
(size_t)redist_coll->nsrc + (size_t)redist_coll->ndst,
MPI_Datatype *all_component_dt = redist_coll->all_component_dt;
align_component_dt(num_redists_, nmsg_send, redists,
ranks[SEND], num_ranks[SEND], redist_coll->send_ranks,
all_component_dt, xt_redist_get_send_MPI_Datatype);
align_component_dt(num_redists_, nmsg_recv, redists,
ranks[RECV], num_ranks[RECV], redist_coll->recv_ranks,
all_component_dt + nmsg_send * num_redists_,
xt_redist_get_recv_MPI_Datatype);
init_cache(&redist_coll->cache, redist_coll->cache_size, nmsg,
num_redists_);
return (Xt_redist)redist_coll;
......@@ -278,11 +293,13 @@ Xt_redist xt_redist_collection_new(Xt_redist * redists, int num_redists,
static void
create_all_dt_for_dir(struct redist_collection_msg *msgs,
unsigned num_messages, unsigned num_redists,
const MPI_Aint displacements[num_redists],
struct Xt_redist_msg redist_msgs[num_messages],
MPI_Comm comm)
create_all_dt_for_dir(
unsigned num_messages, unsigned num_redists,
const int ranks[num_messages],
const MPI_Datatype *component_dt,
const MPI_Aint displacements[num_redists],
struct Xt_redist_msg redist_msgs[num_messages],
MPI_Comm comm)
{
int block_lengths[num_redists];
......@@ -293,8 +310,9 @@ create_all_dt_for_dir(struct redist_collection_msg *msgs,
xt_mpi_call(MPI_Type_free(&(redist_msgs[i].datatype)), comm);
redist_msgs[i].datatype
= xt_create_compound_datatype(num_redists, displacements,
msgs[i].component_dt, block_lengths, comm);
redist_msgs[i].rank = msgs[i].rank;
component_dt + i * num_redists,
block_lengths, comm);
redist_msgs[i].rank = ranks[i];
}
}
......@@ -334,19 +352,23 @@ lookup_cache_index(unsigned num_redists,
}
static Xt_exchanger
get_exchanger(const void *const * src_data, void *const * dst_data,
struct redist_collection_msg * send_msgs, unsigned num_send_messages,
struct redist_collection_msg * recv_msgs, unsigned num_recv_messages,
unsigned num_redists,
struct exchanger_cache *cache, size_t cache_size,
MPI_Comm comm, int tag_offset)
get_exchanger(struct Xt_redist_collection_ *redist_coll,
const void *const * src_data, void *const * dst_data,
unsigned num_redists)
{
MPI_Aint displacements[2][num_redists];
unsigned num_send_messages = redist_coll->nsrc,
num_recv_messages = redist_coll->ndst;
compute_displ(src_data, num_redists, displacements[0]);
compute_displ((const void *const *)dst_data, num_redists, displacements[1]);
Xt_exchanger exchanger;
const MPI_Datatype *all_component_dt = redist_coll->all_component_dt;
struct exchanger_cache *restrict cache = &redist_coll->cache;
size_t cache_size = redist_coll->cache_size;
MPI_Comm comm = redist_coll->comm;
int tag_offset = redist_coll->tag_offset;
if (cache_size > 0)
{
size_t cache_index
......@@ -358,9 +380,12 @@ get_exchanger(const void *const * src_data, void *const * dst_data,
if (cache_index == cache_size)
{
cache_index = cache->token;
create_all_dt_for_dir(send_msgs, num_send_messages, num_redists,
create_all_dt_for_dir(num_send_messages, num_redists,
redist_coll->send_ranks, all_component_dt,
displacements[0], cache->msgs, comm);
create_all_dt_for_dir(recv_msgs, num_recv_messages, num_redists,
create_all_dt_for_dir(num_recv_messages, num_redists,
redist_coll->recv_ranks,
all_component_dt + num_send_messages * num_redists,
displacements[1], cache->msgs +
(size_t)num_send_messages, comm);
memcpy(cache->src_displacements + cache_index * num_redists,
......@@ -390,9 +415,12 @@ get_exchanger(const void *const * src_data, void *const * dst_data,
for (size_t i = 0; i < nmsg; ++i)
p[i].datatype = MPI_DATATYPE_NULL;
create_all_dt_for_dir(send_msgs, num_send_messages, num_redists,
displacements[0], p, comm);
create_all_dt_for_dir(recv_msgs, num_recv_messages, num_redists,
create_all_dt_for_dir(num_send_messages, num_redists,
redist_coll->send_ranks,
all_component_dt, displacements[0], p, comm);
create_all_dt_for_dir(num_recv_messages, num_redists,
redist_coll->recv_ranks,
all_component_dt + num_send_messages * num_redists,
displacements[1], p + num_send_messages, comm);
exchanger =
......@@ -424,16 +452,9 @@ redist_collection_s_exchange(Xt_redist redist, int num_arrays,
"redist_collection_s_exchange", __FILE__, __LINE__);
Xt_exchanger exchanger = get_exchanger(src_data, dst_data,
redist_coll->send_msgs,
redist_coll->nsrc,
redist_coll->recv_msgs,
redist_coll->ndst,
redist_coll->num_redists,
&(redist_coll->cache),
redist_coll->cache_size,
redist_coll->comm,
redist_coll->tag_offset);
Xt_exchanger exchanger = get_exchanger(redist_coll,
src_data, dst_data,
redist_coll->num_redists);
xt_exchanger_s_exchange(exchanger, src_data[0], dst_data[0]);
......@@ -453,16 +474,9 @@ redist_collection_a_exchange(Xt_redist redist, int num_arrays,
"redist_collection_a_exchange", __FILE__, __LINE__);
Xt_exchanger exchanger = get_exchanger(src_data, dst_data,
redist_coll->send_msgs,
redist_coll->nsrc,
redist_coll->recv_msgs,
redist_coll->ndst,
redist_coll->num_redists,
&(redist_coll->cache),
redist_coll->cache_size,
redist_coll->comm,
redist_coll->tag_offset);
Xt_exchanger exchanger = get_exchanger(redist_coll,
src_data, dst_data,
redist_coll->num_redists);
xt_exchanger_a_exchange(exchanger, src_data[0], dst_data[0], request);
......@@ -472,69 +486,61 @@ redist_collection_a_exchange(Xt_redist redist, int num_arrays,
}
static void
copy_msgs(size_t num_redists, unsigned nmsgs,
const struct redist_collection_msg *restrict msgs_orig,
struct redist_collection_msg **p_msgs_copy,
MPI_Comm comm)
copy_component_dt(size_t num_component_dt,
const MPI_Datatype *component_dt_orig,
MPI_Datatype *component_dt_copy,
MPI_Comm comm)
{
struct redist_collection_msg *restrict msgs_copy =
*p_msgs_copy = nmsgs > 0 ? xmalloc(nmsgs * sizeof (*msgs_copy)) : NULL;
MPI_Datatype *restrict dt_copy
= nmsgs * num_redists > 0
? xmalloc(nmsgs * num_redists * sizeof (*dt_copy)) : NULL;
for (size_t i = 0; i < nmsgs; ++i)
for (size_t i = 0; i < num_component_dt; ++i)
{
msgs_copy[i].rank = msgs_orig[i].rank;
msgs_copy[i].component_dt = dt_copy + i * num_redists;
for (size_t j = 0; j < num_redists; ++j)
if (msgs_orig[i].component_dt[j] != MPI_DATATYPE_NULL)
xt_mpi_call(MPI_Type_dup(msgs_orig[i].component_dt[j],
dt_copy + i * num_redists + j), comm);
else
dt_copy[i * num_redists + j] = MPI_DATATYPE_NULL;
MPI_Datatype orig_dt = component_dt_orig[i];
if (orig_dt != MPI_DATATYPE_NULL)
xt_mpi_call(MPI_Type_dup(orig_dt, component_dt_copy + i), comm);
else
component_dt_copy[i] = MPI_DATATYPE_NULL;
}
}
static Xt_redist
redist_collection_copy(Xt_redist redist)
{
Xt_redist_collection redist_coll = xrc(redist),
redist_copy = xmalloc(sizeof (*redist_copy));
Xt_redist_collection redist_coll = xrc(redist);
unsigned num_redists = redist_coll->num_redists,
nsrc = redist_coll->nsrc,
ndst = redist_coll->ndst;
size_t nmsg = (size_t)ndst + nsrc,
size_all_component_dt = sizeof (MPI_Datatype) * num_redists * nmsg;
Xt_redist_collection redist_copy
= xmalloc(sizeof (*redist_copy)
+ size_all_component_dt + nmsg * sizeof (int));
redist_copy->vtable = redist_coll->vtable;
unsigned num_redists = redist_coll->num_redists;
redist_copy->num_redists = num_redists;
redist_copy->nsrc = nsrc;
redist_copy->ndst = ndst;
redist_copy->send_ranks
= (int *)(redist_copy->all_component_dt + nmsg * num_redists);
redist_copy->recv_ranks = redist_copy->send_ranks + nsrc;
MPI_Comm copy_comm = redist_copy->comm
= xt_mpi_comm_smart_dup(redist_coll->comm, &redist_copy->tag_offset);
unsigned nsrc = redist_coll->nsrc;
redist_copy->nsrc = nsrc;
copy_msgs(num_redists, nsrc, redist_coll->send_msgs, &redist_copy->send_msgs,
copy_comm);
unsigned ndst = redist_coll->ndst;
redist_copy->ndst = ndst;
copy_msgs(num_redists, ndst, redist_coll->recv_msgs, &redist_copy->recv_msgs,
copy_comm);
memcpy(redist_copy->send_ranks, redist_coll->send_ranks,
sizeof (*redist_copy->send_ranks) * nmsg);
copy_component_dt(num_redists * nmsg,
redist_coll->all_component_dt,
redist_copy->all_component_dt, copy_comm);
size_t cache_size = redist_coll->cache_size;
redist_copy->cache_size = cache_size;
init_cache(&redist_copy->cache, cache_size, (size_t)ndst + nsrc, num_redists);
init_cache(&redist_copy->cache, cache_size, nmsg, num_redists);
return (Xt_redist)redist_copy;
}
static void
free_redist_collection_msgs(struct redist_collection_msg * msgs,
unsigned nmsgs, unsigned num_redists,
MPI_Comm comm) {
if (nmsgs) {
size_t ndt = (size_t)nmsgs * num_redists;
MPI_Datatype *all_component_dt = msgs[0].component_dt;
for (size_t i = 0; i < ndt; ++i)
if (all_component_dt[i] != MPI_DATATYPE_NULL)
xt_mpi_call(MPI_Type_free(all_component_dt + i), comm);
free(msgs[0].component_dt);
}
free(msgs);
free_component_dt(size_t num_dt, MPI_Datatype *all_component_dt, MPI_Comm comm)
{
for (size_t i = 0; i < num_dt; ++i)
if (all_component_dt[i] != MPI_DATATYPE_NULL)
xt_mpi_call(MPI_Type_free(all_component_dt + i), comm);
}
static void
......@@ -542,17 +548,13 @@ redist_collection_delete(Xt_redist redist) {
Xt_redist_collection redist_coll = xrc(redist);
free_redist_collection_msgs(redist_coll->send_msgs, redist_coll->nsrc,
redist_coll->num_redists,
redist_coll->comm);
free_redist_collection_msgs(redist_coll->recv_msgs, redist_coll->ndst,
redist_coll->num_redists,
redist_coll->comm);
unsigned num_redists = redist_coll->num_redists;
size_t nmsg = (size_t)redist_coll->ndst + redist_coll->nsrc;
free_component_dt(nmsg * num_redists, redist_coll->all_component_dt,
redist_coll->comm);
destruct_cache(&redist_coll->cache, redist_coll->cache_size,
(size_t)redist_coll->nsrc + (size_t)redist_coll->ndst,
redist_coll->comm);
nmsg, redist_coll->comm);
xt_mpi_comm_smart_dedup(&(redist_coll->comm), redist_coll->tag_offset);
......@@ -627,17 +629,16 @@ redist_collection_get_msg_ranks(Xt_redist redist,
{
Xt_redist_collection redist_coll = xrc(redist);
unsigned nmsg;
struct redist_collection_msg *restrict msg;
int *ranks_orig;
if (direction == SEND) {
nmsg = redist_coll->ndst;
msg = redist_coll->send_msgs;
ranks_orig = redist_coll->send_ranks;
} else {
nmsg = redist_coll->nsrc;
msg = redist_coll->recv_msgs;
ranks_orig = redist_coll->recv_ranks;
}
int *restrict ranks_ = *ranks = xmalloc(nmsg * sizeof (*ranks_));
for (size_t i = 0; i < nmsg; ++i)
ranks_[i] = msg[i].rank;
int *ranks_ = *ranks = xmalloc(nmsg * sizeof (*ranks_));
memcpy(ranks_, ranks_orig, nmsg * sizeof (*ranks));
return (int)nmsg;
}
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment