Commit f3b4431f authored by Thomas Jahns's avatar Thomas Jahns 🤸
Browse files

Use already existing enum to simplify internal interface.

parent 94c5aace
......@@ -95,22 +95,28 @@ void xt_redist_a_exchange1(Xt_redist redist, const void *src_data,
int xt_redist_get_num_send_msg(Xt_redist redist) {
return redist->vtable->get_num_send_msg(redist);
return redist->vtable->get_num_msg(redist, SEND);
}
int xt_redist_get_num_recv_msg(Xt_redist redist) {
return redist->vtable->get_num_recv_msg(redist);
return redist->vtable->get_num_msg(redist, RECV);
}
MPI_Datatype xt_redist_get_send_MPI_Datatype(Xt_redist redist, int rank) {
return redist->vtable->get_send_MPI_Datatype(redist, rank);
return redist->vtable->get_msg_MPI_Datatype(redist, rank, SEND);
}
MPI_Datatype xt_redist_get_recv_MPI_Datatype(Xt_redist redist, int rank) {
return redist->vtable->get_recv_MPI_Datatype(redist, rank);
return redist->vtable->get_msg_MPI_Datatype(redist, rank, RECV);
}
MPI_Datatype xt_redist_get_MPI_Datatype(Xt_redist redist, int rank,
enum xt_msg_direction direction)
{
return redist->vtable->get_msg_MPI_Datatype(redist, rank, direction);
}
MPI_Comm xt_redist_get_MPI_Comm(Xt_redist redist) {
......
......@@ -93,15 +93,12 @@ redist_collection_a_exchange1(Xt_redist redist,
const void *src_data, void *dst_data,
Xt_request *request);
static int redist_collection_get_num_send_msg(Xt_redist redist);
static int redist_collection_get_num_recv_msg(Xt_redist redist);
static MPI_Datatype
redist_collection_get_send_MPI_Datatype(Xt_redist redist, int rank);
static int redist_collection_get_num_msg(Xt_redist redist,
enum xt_msg_direction direction);
static MPI_Datatype
redist_collection_get_recv_MPI_Datatype(Xt_redist redist, int rank);
redist_collection_get_MPI_Datatype(Xt_redist redist, int rank,
enum xt_msg_direction direction);
static int
redist_collection_get_msg_ranks(Xt_redist redist,
......@@ -118,10 +115,8 @@ static const struct xt_redist_vtable redist_collection_vtable = {
.a_exchange = redist_collection_a_exchange,
.s_exchange1 = redist_collection_s_exchange1,
.a_exchange1 = redist_collection_a_exchange1,
.get_num_send_msg = redist_collection_get_num_send_msg,
.get_num_recv_msg = redist_collection_get_num_recv_msg,
.get_send_MPI_Datatype = redist_collection_get_send_MPI_Datatype,
.get_recv_MPI_Datatype = redist_collection_get_recv_MPI_Datatype,
.get_num_msg = redist_collection_get_num_msg,
.get_msg_MPI_Datatype = redist_collection_get_MPI_Datatype,
.get_msg_ranks = redist_collection_get_msg_ranks,
.get_MPI_Comm = redist_collection_get_MPI_Comm
};
......@@ -184,7 +179,7 @@ static void align_component_dt(unsigned num_redists, unsigned nmsgs,
size_t num_ranks[num_redists],
int *out_ranks,
MPI_Datatype *component_dt,
MPI_Datatype (*get_MPI_datatype)(Xt_redist,int))
enum xt_msg_direction direction)
{
size_t rank_pos[num_redists];
for (size_t j = 0; j < num_redists; ++j)
......@@ -200,7 +195,8 @@ static void align_component_dt(unsigned num_redists, unsigned nmsgs,
for (size_t j = 0; j < num_redists; ++j)
component_dt[i * num_redists + j] =
(rank_pos[j] < num_ranks[j] && in_ranks[j][rank_pos[j]] == min_rank)
? get_MPI_datatype(redists[j], min_rank) : MPI_DATATYPE_NULL;
? xt_redist_get_MPI_Datatype(redists[j], min_rank, direction)
: MPI_DATATYPE_NULL;
out_ranks[i] = min_rank;
for (size_t j = 0; j < num_redists; ++j)
......@@ -280,11 +276,10 @@ Xt_redist xt_redist_collection_new(Xt_redist * redists, int num_redists,
MPI_Datatype *all_component_dt = redist_coll->all_component_dt;
align_component_dt(num_redists_, nmsg_send, redists,
ranks[SEND], num_ranks[SEND], redist_coll->send_ranks,
all_component_dt, xt_redist_get_send_MPI_Datatype);
all_component_dt, SEND);
align_component_dt(num_redists_, nmsg_recv, redists,
ranks[RECV], num_ranks[RECV], redist_coll->recv_ranks,
all_component_dt + nmsg_send * num_redists_,
xt_redist_get_recv_MPI_Datatype);
all_component_dt + nmsg_send * num_redists_, RECV);
init_cache(&redist_coll->cache, redist_coll->cache_size, nmsg,
num_redists_);
......@@ -561,34 +556,19 @@ redist_collection_delete(Xt_redist redist) {
free(redist_coll);
}
static int redist_collection_get_num_send_msg(Xt_redist redist) {
return (int)(xrc(redist)->nmsg[SEND]);
}
static int redist_collection_get_num_recv_msg(Xt_redist redist) {
return (int)(xrc(redist)->nmsg[RECV]);
}
static MPI_Datatype
redist_collection_get_send_MPI_Datatype(Xt_redist redist, int XT_UNUSED(rank))
static int redist_collection_get_num_msg(Xt_redist redist,
enum xt_msg_direction direction)
{
Xt_redist_collection redist_coll = xrc(redist);
Xt_abort(redist_coll->comm, "ERROR: get_send_MPI_Datatype is not"
" supported for this xt_redist type (Xt_redist_collection)",
__FILE__, __LINE__);
return MPI_DATATYPE_NULL;
return (int)(xrc(redist)->nmsg[direction]);
}
static MPI_Datatype
redist_collection_get_recv_MPI_Datatype(Xt_redist redist, int XT_UNUSED(rank)) {
redist_collection_get_MPI_Datatype(Xt_redist redist, int XT_UNUSED(rank),
enum xt_msg_direction XT_UNUSED(direction))
{
Xt_redist_collection redist_coll = xrc(redist);
Xt_abort(redist_coll->comm, "ERROR: get_recv_MPI_Datatype is not"
Xt_abort(redist_coll->comm, "ERROR: datatype retrieval is not"
" supported for this xt_redist type (Xt_redist_collection)",
__FILE__, __LINE__);
......
......@@ -70,9 +70,8 @@ static size_t
generate_msg_infos(struct Xt_redist_msg ** msgs,
const MPI_Aint *displacements, Xt_redist *redists,
size_t num_redists, MPI_Comm comm,
enum xt_msg_direction direction,
MPI_Datatype (*get_MPI_datatype)(Xt_redist,int)) {
enum xt_msg_direction direction)
{
int block_lengths[num_redists];
MPI_Datatype datatypes[num_redists];
......@@ -106,7 +105,8 @@ generate_msg_infos(struct Xt_redist_msg ** msgs,
for (size_t j = 0; j < num_redists; ++j)
datatypes[j] =
(rank_pos[j] < num_ranks[j] && ranks[j][rank_pos[j]] == min_rank)
? get_MPI_datatype(redists[j], min_rank) : MPI_DATATYPE_NULL;
? xt_redist_get_MPI_Datatype(redists[j], min_rank, direction)
: MPI_DATATYPE_NULL;
p[i].rank = min_rank;
p[i].datatype
......@@ -144,11 +144,11 @@ xt_redist_collection_static_new(Xt_redist * redists, int num_redists,
size_t num_redists_ = num_redists >= 0 ? (size_t)num_redists : 0;
size_t nsend
= generate_msg_infos(&send_msgs, src_displacements, redists, num_redists_,
new_comm, SEND, xt_redist_get_send_MPI_Datatype);
new_comm, SEND);
size_t nrecv
= generate_msg_infos(&recv_msgs, dst_displacements, redists, num_redists_,
new_comm, RECV, xt_redist_get_recv_MPI_Datatype);
new_comm, RECV);
Xt_redist redist_collection =
xt_redist_single_array_base_new((int)nsend, (int)nrecv,
......
......@@ -70,10 +70,8 @@ struct xt_redist_vtable {
void (*a_exchange)(Xt_redist, int, const void **, void **, Xt_request *);
void (*s_exchange1)(Xt_redist, const void *, void *);
void (*a_exchange1)(Xt_redist, const void *, void *, Xt_request *);
MPI_Datatype (*get_send_MPI_Datatype)(Xt_redist, int);
MPI_Datatype (*get_recv_MPI_Datatype)(Xt_redist, int);
int (*get_num_send_msg)(Xt_redist);
int (*get_num_recv_msg)(Xt_redist);
MPI_Datatype (*get_msg_MPI_Datatype)(Xt_redist, int, enum xt_msg_direction);
int (*get_num_msg)(Xt_redist, enum xt_msg_direction);
int (*get_msg_ranks)(Xt_redist, enum xt_msg_direction, int *restrict *);
MPI_Comm (*get_MPI_Comm)(Xt_redist);
};
......@@ -105,6 +103,9 @@ void xt_redist_check_comms(Xt_redist *redists, int num_redists, MPI_Comm comm);
int xt_redist_get_msg_ranks(Xt_redist redist, enum xt_msg_direction direction,
int *restrict *ranks);
MPI_Datatype xt_redist_get_MPI_Datatype(Xt_redist redist, int rank,
enum xt_msg_direction direction);
size_t
xt_ranks_uniq_count(size_t num_rank_sets,
const size_t num_ranks[num_rank_sets],
......
......@@ -67,9 +67,8 @@ static void
generate_msg_infos(struct Xt_redist_msg **msgs, int *nmsgs,
MPI_Aint extent, const int *displacements, Xt_redist redist,
int num_repetitions, MPI_Comm comm,
enum xt_msg_direction direction,
MPI_Datatype (*get_MPI_datatype)(Xt_redist,int)) {
enum xt_msg_direction direction)
{
assert(*nmsgs >= 0);
size_t num_messages = (size_t)*nmsgs;
int *restrict ranks = NULL;
......@@ -78,7 +77,8 @@ generate_msg_infos(struct Xt_redist_msg **msgs, int *nmsgs,
struct Xt_redist_msg *restrict p
= xrealloc(*msgs, sizeof (*p) * (num_messages + num_ranks));
for (size_t i = 0; i < num_ranks; ++i) {
MPI_Datatype datatype = get_MPI_datatype(redist, ranks[i]);
MPI_Datatype datatype
= xt_redist_get_MPI_Datatype(redist, ranks[i], direction);
MPI_Aint curr_lb, curr_extent;
MPI_Datatype datatype_with_extent;
......@@ -118,12 +118,10 @@ Xt_redist xt_redist_repeat_asym_new(Xt_redist redist, MPI_Aint src_extent,
generate_msg_infos(&send_msgs, &nsend, src_extent,
src_displacements, redist, num_repetitions, comm,
SEND, xt_redist_get_send_MPI_Datatype);
src_displacements, redist, num_repetitions, comm, SEND);
generate_msg_infos(&recv_msgs, &nrecv, dst_extent,
dst_displacements, redist, num_repetitions, comm,
RECV, xt_redist_get_recv_MPI_Datatype);
dst_displacements, redist, num_repetitions, comm, RECV);
Xt_redist result
= xt_redist_single_array_base_new(nsend, nrecv, send_msgs, recv_msgs, comm);
......
......@@ -85,15 +85,12 @@ static void
redist_sab_a_exchange1(Xt_redist redist, const void *src_data, void *dst_data,
Xt_request *request);
static int redist_sab_get_num_send_msg(Xt_redist redist);
static int redist_sab_get_num_recv_msg(Xt_redist redist);
static MPI_Datatype
redist_sab_get_send_MPI_Datatype(Xt_redist redist, int rank);
static int redist_sab_get_num_msg(Xt_redist redist,
enum xt_msg_direction direction);
static MPI_Datatype
redist_sab_get_recv_MPI_Datatype(Xt_redist redist, int rank);
redist_sab_get_MPI_Datatype(Xt_redist redist, int rank,
enum xt_msg_direction direction);
static int
redist_sab_get_msg_ranks(Xt_redist redist,
......@@ -110,10 +107,8 @@ static const struct xt_redist_vtable redist_sab_vtable = {
.a_exchange = redist_sab_a_exchange,
.s_exchange1 = redist_sab_s_exchange1,
.a_exchange1 = redist_sab_a_exchange1,
.get_num_send_msg = redist_sab_get_num_send_msg,
.get_num_recv_msg = redist_sab_get_num_recv_msg,
.get_send_MPI_Datatype = redist_sab_get_send_MPI_Datatype,
.get_recv_MPI_Datatype = redist_sab_get_recv_MPI_Datatype,
.get_num_msg = redist_sab_get_num_msg,
.get_msg_MPI_Datatype = redist_sab_get_MPI_Datatype,
.get_msg_ranks = redist_sab_get_msg_ranks,
.get_MPI_Comm = redist_sab_get_MPI_Comm
};
......@@ -126,7 +121,7 @@ struct Xt_redist_sab_ {
Xt_exchanger exchanger;
int nsend, nrecv;
int nmsg[2];
MPI_Comm comm;
int tag_offset;
......@@ -148,8 +143,8 @@ Xt_redist xt_redist_single_array_base_new(int nsend, int nrecv,
redist->tag_offset);
redist->vtable = &redist_sab_vtable;
redist->nsend = nsend;
redist->nrecv = nrecv;
redist->nmsg[SEND] = nsend;
redist->nmsg[RECV] = nrecv;
return (Xt_redist)redist;
}
......@@ -165,8 +160,8 @@ redist_sab_copy(Xt_redist redist)
Xt_redist_sab redist_sab = xrsab(redist);
Xt_redist_sab redist_sab_new = xmalloc(sizeof *redist_sab_new);
redist_sab_new->vtable = redist_sab->vtable;
redist_sab_new->nsend = redist_sab->nsend;
redist_sab_new->nrecv = redist_sab->nrecv;
for (size_t i = 0; i < 2; ++i)
redist_sab_new->nmsg[i] = redist_sab->nmsg[i];
redist_sab_new->comm = xt_mpi_comm_smart_dup(redist_sab->comm,
&redist_sab_new->tag_offset);
redist_sab_new->exchanger
......@@ -231,26 +226,18 @@ redist_sab_a_exchange1(Xt_redist redist, const void *src_data, void *dst_data,
xt_exchanger_a_exchange(redist_sab->exchanger, src_data, dst_data, request);
}
static int redist_sab_get_num_send_msg(Xt_redist redist) {
return xrsab(redist)->nsend;
}
static int redist_sab_get_num_recv_msg(Xt_redist redist) {
return xrsab(redist)->nrecv;
}
static MPI_Datatype
redist_sab_get_send_MPI_Datatype(Xt_redist redist, int rank) {
return xt_exchanger_get_MPI_Datatype(xrsab(redist)->exchanger, rank, SEND);
static int redist_sab_get_num_msg(Xt_redist redist,
enum xt_msg_direction direction)
{
return xrsab(redist)->nmsg[direction];
}
static MPI_Datatype
redist_sab_get_recv_MPI_Datatype(Xt_redist redist, int rank) {
return xt_exchanger_get_MPI_Datatype(xrsab(redist)->exchanger, rank, RECV);
redist_sab_get_MPI_Datatype(Xt_redist redist, int rank,
enum xt_msg_direction direction)
{
return xt_exchanger_get_MPI_Datatype(xrsab(redist)->exchanger, rank,
direction);
}
static int
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment