Commit 7a799ec1 authored by Thomas Jahns's avatar Thomas Jahns 🤸
Browse files

Extend internal interface to save on unneeded dups.

parent 373d7059
......@@ -131,9 +131,9 @@ xt_exchanger_get_msg_ranks(Xt_exchanger exchanger,
}
MPI_Datatype
xt_exchanger_get_MPI_Datatype(Xt_exchanger exchanger,
int rank, enum xt_msg_direction direction) {
return exchanger->vtable->get_MPI_Datatype(exchanger, rank, direction);
xt_exchanger_get_MPI_Datatype(Xt_exchanger exchanger, int rank,
enum xt_msg_direction direction, bool do_dup) {
return exchanger->vtable->get_MPI_Datatype(exchanger, rank, direction, do_dup);
}
/*
......
......@@ -51,6 +51,7 @@
#include <config.h>
#endif
#include <stdbool.h>
#include <mpi.h>
#include "xt/xt_core.h"
......@@ -68,7 +69,8 @@ struct xt_exchanger_vtable {
void (*s_exchange)(Xt_exchanger, const void *, void *);
void (*a_exchange)(Xt_exchanger, const void *, void *, Xt_request *request);
int (*get_msg_ranks)(Xt_exchanger, enum xt_msg_direction, int *restrict *);
MPI_Datatype (*get_MPI_Datatype)(Xt_exchanger, int, enum xt_msg_direction);
MPI_Datatype (*get_MPI_Datatype)(Xt_exchanger, int, enum xt_msg_direction,
bool);
};
struct Xt_exchanger_ {
......@@ -127,13 +129,15 @@ void xt_exchanger_internal_optimize(size_t n, void * msgs, size_t msg_type_size,
* @param[in] rank MPI rank
* @param[in] direction specific whether the datatype of an incoming or outgoing
* message is requested
* @param[in[ do_dup mpi datatype copy will be dup if true
* @return MPI_Datatype for the specificed message
* @remark returns MPI_DATATYPE_NULL if there is no message matching the
* specificed configuration
*/
MPI_Datatype
xt_exchanger_get_MPI_Datatype(Xt_exchanger exchanger,
int rank, enum xt_msg_direction direction);
int rank, enum xt_msg_direction direction,
bool do_dup);
/**
* Gets the ranks of all processes that receive data from/send data to the local
......
......@@ -77,7 +77,8 @@ xt_exchanger_mix_isend_irecv_get_msg_ranks(Xt_exchanger exchanger,
static MPI_Datatype
xt_exchanger_mix_isend_irecv_get_MPI_Datatype(Xt_exchanger exchanger,
int rank,
enum xt_msg_direction direction);
enum xt_msg_direction direction,
bool do_dup);
static const struct xt_exchanger_vtable exchanger_mix_isend_irecv_vtable = {
.copy = xt_exchanger_mix_isend_irecv_copy,
......@@ -292,7 +293,8 @@ xt_exchanger_mix_isend_irecv_get_msg_ranks(Xt_exchanger exchanger,
static MPI_Datatype
xt_exchanger_mix_isend_irecv_get_MPI_Datatype(Xt_exchanger exchanger,
int rank,
enum xt_msg_direction direction)
enum xt_msg_direction direction,
bool do_dup)
{
Xt_exchanger_mix_isend_irecv exchanger_msr =
(Xt_exchanger_mix_isend_irecv)exchanger;
......@@ -301,8 +303,11 @@ xt_exchanger_mix_isend_irecv_get_MPI_Datatype(Xt_exchanger exchanger,
MPI_Datatype datatype_copy = MPI_DATATYPE_NULL;
for (size_t i = 0; i < nmsg; ++i)
if (MSG_DIR(msgs[i]) == direction && msgs[i].data.rank == rank) {
xt_mpi_call(MPI_Type_dup(msgs[i].data.datatype, &datatype_copy),
exchanger_msr->comm);
if (do_dup)
xt_mpi_call(MPI_Type_dup(msgs[i].data.datatype, &datatype_copy),
exchanger_msr->comm);
else
datatype_copy = msgs[i].data.datatype;
break;
}
return datatype_copy;
......
......@@ -85,7 +85,8 @@ xt_exchanger_neigh_alltoall_get_msg_ranks(Xt_exchanger exchanger,
static MPI_Datatype
xt_exchanger_neigh_alltoall_get_MPI_Datatype(Xt_exchanger exchanger,
int rank,
enum xt_msg_direction direction);
enum xt_msg_direction direction,
bool do_dup);
static const struct xt_exchanger_vtable exchanger_neigh_alltoall_vtable = {
......@@ -271,7 +272,8 @@ static void xt_exchanger_neigh_alltoall_a_exchange(Xt_exchanger exchanger,
static MPI_Datatype
xt_exchanger_neigh_alltoall_get_MPI_Datatype(Xt_exchanger exchanger,
int rank,
enum xt_msg_direction direction)
enum xt_msg_direction direction,
bool do_dup)
{
Xt_exchanger_neigh_alltoall exchanger_na =
(Xt_exchanger_neigh_alltoall)exchanger;
......@@ -282,8 +284,11 @@ xt_exchanger_neigh_alltoall_get_MPI_Datatype(Xt_exchanger exchanger,
MPI_Datatype datatype_copy = MPI_DATATYPE_NULL;
for (size_t i = 0; i < nmsg; ++i) {
if (ranks[i] == rank) {
xt_mpi_call(MPI_Type_dup(exchanger_na->datatypes[i+ofs], &datatype_copy),
exchanger_na->comm);
if (do_dup)
xt_mpi_call(MPI_Type_dup(exchanger_na->datatypes[i+ofs], &datatype_copy),
exchanger_na->comm);
else
datatype_copy = exchanger_na->datatypes[i+ofs];
break;
}
}
......
......@@ -49,6 +49,7 @@
#include <assert.h>
#include <mpi.h>
#include <stdbool.h>
#include "core/core.h"
#include "core/ppm_xfuncs.h"
......@@ -77,7 +78,8 @@ xt_exchanger_simple_base_get_msg_ranks(Xt_exchanger exchanger,
static MPI_Datatype
xt_exchanger_simple_base_get_MPI_Datatype(Xt_exchanger exchanger,
int rank,
enum xt_msg_direction direction);
enum xt_msg_direction direction,
bool do_dup);
static const struct xt_exchanger_vtable exchanger_simple_base_vtable = {
......@@ -235,7 +237,8 @@ static void xt_exchanger_simple_base_a_exchange(Xt_exchanger exchanger,
static MPI_Datatype
xt_exchanger_simple_base_get_MPI_Datatype(Xt_exchanger exchanger,
int rank,
enum xt_msg_direction direction)
enum xt_msg_direction direction,
bool do_dup)
{
Xt_exchanger_simple_base exchanger_sb =
(Xt_exchanger_simple_base)exchanger;
......@@ -246,8 +249,11 @@ xt_exchanger_simple_base_get_MPI_Datatype(Xt_exchanger exchanger,
MPI_Datatype datatype_copy = MPI_DATATYPE_NULL;
for (size_t i = 0; i < nmsg; ++i)
if (msgs[i].rank == rank) {
xt_mpi_call(MPI_Type_dup(msgs[i].datatype, &datatype_copy),
exchanger_sb->comm);
if (do_dup)
xt_mpi_call(MPI_Type_dup(msgs[i].datatype, &datatype_copy),
exchanger_sb->comm);
else
datatype_copy = msgs[i].datatype;
break;
}
return datatype_copy;
......
......@@ -106,18 +106,19 @@ int xt_redist_get_num_recv_msg(Xt_redist redist) {
MPI_Datatype xt_redist_get_send_MPI_Datatype(Xt_redist redist, int rank) {
return redist->vtable->get_msg_MPI_Datatype(redist, rank, SEND);
return redist->vtable->get_msg_MPI_Datatype(redist, rank, SEND, true);
}
MPI_Datatype xt_redist_get_recv_MPI_Datatype(Xt_redist redist, int rank) {
return redist->vtable->get_msg_MPI_Datatype(redist, rank, RECV);
return redist->vtable->get_msg_MPI_Datatype(redist, rank, RECV, true);
}
MPI_Datatype xt_redist_get_MPI_Datatype(Xt_redist redist, int rank,
enum xt_msg_direction direction)
enum xt_msg_direction direction,
bool do_dup)
{
return redist->vtable->get_msg_MPI_Datatype(redist, rank, direction);
return redist->vtable->get_msg_MPI_Datatype(redist, rank, direction, do_dup);
}
MPI_Comm xt_redist_get_MPI_Comm(Xt_redist redist) {
......
......@@ -98,7 +98,7 @@ static int redist_collection_get_num_msg(Xt_redist redist,
static MPI_Datatype
redist_collection_get_MPI_Datatype(Xt_redist redist, int rank,
enum xt_msg_direction direction);
enum xt_msg_direction direction, bool do_dup);
static int
redist_collection_get_msg_ranks(Xt_redist redist,
......@@ -173,7 +173,7 @@ static void align_component_dt(unsigned num_redists, unsigned nmsgs,
for (size_t j = 0; j < num_redists; ++j)
component_dt[i * num_redists + j] =
(rank_pos[j] < num_ranks[j] && in_ranks[j][rank_pos[j]] == min_rank)
? xt_redist_get_MPI_Datatype(redists[j], min_rank, direction)
? xt_redist_get_MPI_Datatype(redists[j], min_rank, direction, true)
: MPI_DATATYPE_NULL;
out_ranks[i] = min_rank;
......@@ -550,7 +550,8 @@ static int redist_collection_get_num_msg(Xt_redist redist,
static MPI_Datatype
redist_collection_get_MPI_Datatype(Xt_redist redist, int XT_UNUSED(rank),
enum xt_msg_direction XT_UNUSED(direction))
enum xt_msg_direction XT_UNUSED(direction),
bool XT_UNUSED(do_dup))
{
Xt_redist_collection redist_coll = xrc(redist);
......
......@@ -98,7 +98,7 @@ generate_msg_infos(struct Xt_redist_msg ** msgs,
for (size_t j = 0; j < num_redists; ++j)
datatypes[j] =
(rank_pos[j] < num_ranks[j] && ranks[j][rank_pos[j]] == min_rank)
? xt_redist_get_MPI_Datatype(redists[j], min_rank, direction)
? xt_redist_get_MPI_Datatype(redists[j], min_rank, direction, false)
: MPI_DATATYPE_NULL;
p[i].rank = min_rank;
......@@ -106,8 +106,6 @@ generate_msg_infos(struct Xt_redist_msg ** msgs,
= xt_create_compound_datatype(num_redists, displacements, datatypes,
block_lengths, comm);
for (size_t j = 0; j < num_redists; ++j) {
if (datatypes[j] != MPI_DATATYPE_NULL)
xt_mpi_call(MPI_Type_free(datatypes+j), comm);
rank_pos[j]
+= (rank_pos[j] < num_ranks[j] && ranks[j][rank_pos[j]] == min_rank);
}
......
......@@ -54,6 +54,7 @@
#include <config.h>
#endif
#include <stdbool.h>
#include <stdlib.h>
#include <mpi.h>
......@@ -71,7 +72,8 @@ struct xt_redist_vtable {
void (*a_exchange)(Xt_redist, int, const void **, void **, Xt_request *);
void (*s_exchange1)(Xt_redist, const void *, void *);
void (*a_exchange1)(Xt_redist, const void *, void *, Xt_request *);
MPI_Datatype (*get_msg_MPI_Datatype)(Xt_redist, int, enum xt_msg_direction);
MPI_Datatype (*get_msg_MPI_Datatype)(Xt_redist, int, enum xt_msg_direction,
bool need_dup);
int (*get_num_msg)(Xt_redist, enum xt_msg_direction);
int (*get_msg_ranks)(Xt_redist, enum xt_msg_direction, int *restrict *);
MPI_Comm (*get_MPI_Comm)(Xt_redist);
......@@ -135,12 +137,14 @@ xt_redist_get_msg_ranks(Xt_redist redist, enum xt_msg_direction direction,
* @param[in] rank MPI rank of the communicator partner
* @param[in] direction specifices whether the datatype for an outgoing or
* incoming message is requested
* @param[in] do_dup if true only return MPI_Datatype_dup of
* internally stored datatype
* @return Datatype for the specified message. The return value is
* MPI_DATATYPE_NULL, if no data for the specified message.
*/
PPM_DSO_INTERNAL MPI_Datatype
xt_redist_get_MPI_Datatype(Xt_redist redist, int rank,
enum xt_msg_direction direction);
enum xt_msg_direction direction, bool do_dup);
/**
* Generates a new MPI derived datatype from a number of MPI derived datatypes.
......
......@@ -75,7 +75,7 @@ generate_msg_infos(struct Xt_redist_msg *restrict msgs,
= (size_t)xt_redist_get_msg_ranks(redist, direction, &ranks);
for (size_t i = 0; i < num_ranks; ++i) {
MPI_Datatype datatype
= xt_redist_get_MPI_Datatype(redist, ranks[i], direction);
= xt_redist_get_MPI_Datatype(redist, ranks[i], direction, false);
MPI_Aint curr_lb, curr_extent;
MPI_Datatype datatype_with_extent;
......@@ -89,7 +89,6 @@ generate_msg_infos(struct Xt_redist_msg *restrict msgs,
= xt_mpi_generate_datatype(displacements, num_repetitions,
datatype_with_extent, comm);
MPI_Type_free(&datatype_with_extent);
MPI_Type_free(&datatype);
}
free(ranks);
}
......
......@@ -91,7 +91,8 @@ static int redist_sab_get_num_msg(Xt_redist redist,
static MPI_Datatype
redist_sab_get_MPI_Datatype(Xt_redist redist, int rank,
enum xt_msg_direction direction);
enum xt_msg_direction direction,
bool do_dup);
static int
redist_sab_get_msg_ranks(Xt_redist redist,
......@@ -246,10 +247,11 @@ static int redist_sab_get_num_msg(Xt_redist redist,
static MPI_Datatype
redist_sab_get_MPI_Datatype(Xt_redist redist, int rank,
enum xt_msg_direction direction)
enum xt_msg_direction direction,
bool do_dup)
{
return xt_exchanger_get_MPI_Datatype(xrsab(redist)->exchanger, rank,
direction);
direction, do_dup);
}
static int
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment