Commit f31a8392 authored by Moritz Hanke's avatar Moritz Hanke
Browse files

adds asynchronous exchange interface to xt_redist

parent 20c33a2d
......@@ -58,6 +58,7 @@
#include <mpi.h>
#include "xt/xt_core.h"
#include "xt/xt_request.h"
struct Xt_redist_msg {
......@@ -99,6 +100,24 @@ void xt_redist_delete(Xt_redist redist);
void xt_redist_s_exchange(Xt_redist redist, int num_arrays,
const void **src_data, void **dst_data);
/**
* asynchronous redistribution of data
*
* @param[in] redist redistribution structure
* @param[in] num_arrays number of base addresses in src_data and dst_data
* @param[in] src_data array containing the addresses of the first
* elements of the input data
* @param[in,out] dst_data array containing the addresses of the first
* elements of the output data
* @returns a request object that can be used to complete an asynchronous
* exchange
*
* @remark The above implies that NULL or any other invalid pointer
* must not be used in either @a src_data or @a dst_data.
*/
Xt_request xt_redist_a_exchange(Xt_redist redist, int num_arrays,
const void **src_data, void **dst_data);
/**
* synchronous redistribution of data - single array case
*
......@@ -111,6 +130,21 @@ void xt_redist_s_exchange(Xt_redist redist, int num_arrays,
*/
void xt_redist_s_exchange1(Xt_redist redist, const void *src_data, void *dst_data);
/**
* asynchronous redistribution of data - single array case
*
* @param[in] redist redistribution structure
* @param[in] src_data address of the first element of the input data
* @param[in,out] dst_data address of the first element of the output data
* @returns a request object that can be used to complete an asynchronous
* exchange
*
* @remark The above implies that NULL or any other invalid pointer
* must not be used in either @a src_data or @a dst_data.
*/
Xt_request xt_redist_a_exchange1(Xt_redist redist, const void *src_data,
void *dst_data);
/**
* gets a copy of the MPI_Datatype used for the data of the send operation with
* the given rank
......
......@@ -55,6 +55,7 @@
#include "xt/xt_core.h"
#include "xt/xt_redist.h"
#include "xt/xt_mpi.h"
#include "xt/xt_request.h"
#include "core/ppm_xfuncs.h"
#include "xt_redist_internal.h"
......@@ -74,11 +75,23 @@ void xt_redist_s_exchange(Xt_redist redist, int num_arrays,
redist->vtable->s_exchange(redist, num_arrays, src_data, dst_data);
}
Xt_request xt_redist_a_exchange(Xt_redist redist, int num_arrays,
const void **src_data, void **dst_data) {
return redist->vtable->a_exchange(redist, num_arrays, src_data, dst_data);
}
void xt_redist_s_exchange1(Xt_redist redist, const void *src_data, void *dst_data) {
redist->vtable->s_exchange1(redist, src_data, dst_data);
}
Xt_request xt_redist_a_exchange1(Xt_redist redist, const void *src_data,
void *dst_data) {
return redist->vtable->a_exchange1(redist, src_data, dst_data);
}
MPI_Datatype xt_redist_get_send_MPI_Datatype(Xt_redist redist, int rank) {
return redist->vtable->get_send_MPI_Datatype(redist, rank);
......
......@@ -63,6 +63,7 @@
#include "xt/xt_redist_collection.h"
#include "ensure_array_size.h"
#include "xt/xt_redist.h"
#include "xt/xt_request.h"
#include "xt_redist_internal.h"
#include "xt_exchanger.h"
......@@ -78,10 +79,18 @@ static void
redist_collection_s_exchange(Xt_redist redist, int num_src_arrays,
const void **src_data, void **dst_data);
static Xt_request
redist_collection_a_exchange(Xt_redist redist, int num_src_arrays,
const void **src_data, void **dst_data);
static void
redist_collection_s_exchange1(Xt_redist redist,
const void *src_data, void *dst_data);
static Xt_request
redist_collection_a_exchange1(Xt_redist redist,
const void *src_data, void *dst_data);
static MPI_Datatype
redist_collection_get_send_MPI_Datatype(Xt_redist redist, int rank);
......@@ -100,7 +109,9 @@ static const struct xt_redist_vtable redist_collection_vtable = {
.copy = redist_collection_copy,
.delete = redist_collection_delete,
.s_exchange = redist_collection_s_exchange,
.a_exchange = redist_collection_a_exchange,
.s_exchange1 = redist_collection_s_exchange1,
.a_exchange1 = redist_collection_a_exchange1,
.get_send_MPI_Datatype = redist_collection_get_send_MPI_Datatype,
.get_recv_MPI_Datatype = redist_collection_get_recv_MPI_Datatype,
.get_msg_ranks = redist_collection_get_msg_ranks,
......@@ -426,6 +437,37 @@ redist_collection_s_exchange(Xt_redist redist, int num_arrays,
xt_exchanger_delete(exchanger);
}
static Xt_request
redist_collection_a_exchange(Xt_redist redist, int num_arrays,
const void **src_data, void **dst_data) {
Xt_redist_collection redist_coll = xrc(redist);
if (num_arrays != (int)redist_coll->num_redists)
Xt_abort(redist_coll->comm, "ERROR: wrong number of arrays in "
"redist_collection_a_exchange", __FILE__, __LINE__);
Xt_exchanger exchanger = get_exchanger(src_data, dst_data,
redist_coll->send_msgs,
redist_coll->nsrc,
redist_coll->recv_msgs,
redist_coll->ndst,
redist_coll->num_redists,
&(redist_coll->cache),
redist_coll->cache_size,
redist_coll->comm,
redist_coll->tag_offset);
Xt_request request =
xt_exchanger_a_exchange(exchanger, src_data[0], dst_data[0]);
if (redist_coll->cache_size == 0)
xt_exchanger_delete(exchanger);
return request;
}
static void
copy_msgs(size_t num_redists, unsigned nmsgs,
const struct redist_collection_msg *restrict msgs_orig,
......@@ -551,6 +593,19 @@ redist_collection_s_exchange1(Xt_redist redist,
" this xt_redist type (Xt_redist_collection)", __FILE__, __LINE__);
}
static Xt_request
redist_collection_a_exchange1(Xt_redist redist,
const void *src_data, void *dst_data)
{
Xt_redist_collection redist_coll = xrc(redist);
if (redist_coll->num_redists == 1)
return redist_collection_a_exchange(redist, 1, &src_data, &dst_data);
else
Xt_abort(redist_coll->comm, "ERROR: a_exchange1 is not implemented for"
" this xt_redist type (Xt_redist_collection)", __FILE__, __LINE__);
}
static int
redist_collection_get_msg_ranks(Xt_redist redist,
enum xt_msg_direction direction,
......
......@@ -57,6 +57,7 @@
#include <mpi.h>
#include "xt/xt_redist.h"
#include "xt/xt_request.h"
enum xt_msg_direction {SEND, RECV};
......@@ -65,7 +66,9 @@ struct xt_redist_vtable {
Xt_redist (*copy)(Xt_redist);
void (*delete)(Xt_redist);
void (*s_exchange)(Xt_redist, int, const void **, void **);
Xt_request (*a_exchange)(Xt_redist, int, const void **, void **);
void (*s_exchange1)(Xt_redist, const void *, void *);
Xt_request (*a_exchange1)(Xt_redist, const void *, void *);
MPI_Datatype (*get_send_MPI_Datatype)(Xt_redist, int);
MPI_Datatype (*get_recv_MPI_Datatype)(Xt_redist, int);
int (*get_msg_ranks)(Xt_redist, enum xt_msg_direction, int *restrict *);
......
......@@ -58,6 +58,7 @@
#include "xt_redist_internal.h"
#include "xt/xt_xmap.h"
#include "xt/xt_idxlist.h"
#include "xt/xt_request.h"
#include "core/ppm_xfuncs.h"
#include "core/core.h"
#include "xt_exchanger.h"
......@@ -72,9 +73,16 @@ static void
redist_sab_s_exchange(Xt_redist redist, int num_arrays,
const void **src_data, void **dst_data);
static Xt_request
redist_sab_a_exchange(Xt_redist redist, int num_arrays,
const void **src_data, void **dst_data);
static void
redist_sab_s_exchange1(Xt_redist redist, const void *src_data, void *dst_data);
static Xt_request
redist_sab_a_exchange1(Xt_redist redist, const void *src_data, void *dst_data);
static MPI_Datatype
redist_sab_get_send_MPI_Datatype(Xt_redist redist, int rank);
......@@ -93,7 +101,9 @@ static const struct xt_redist_vtable redist_sab_vtable = {
.copy = redist_sab_copy,
.delete = redist_sab_delete,
.s_exchange = redist_sab_s_exchange,
.a_exchange = redist_sab_a_exchange,
.s_exchange1 = redist_sab_s_exchange1,
.a_exchange1 = redist_sab_a_exchange1,
.get_send_MPI_Datatype = redist_sab_get_send_MPI_Datatype,
.get_recv_MPI_Datatype = redist_sab_get_recv_MPI_Datatype,
.get_msg_ranks = redist_collection_get_msg_ranks,
......@@ -183,6 +193,21 @@ redist_sab_s_exchange(Xt_redist redist, int num_arrays,
"(Xt_redist_single_array_base)", __FILE__, __LINE__);
}
static Xt_request
redist_sab_a_exchange(Xt_redist redist, int num_arrays,
const void **src_data, void **dst_data)
{
Xt_redist_sab redist_rep = xrsab(redist);
if (num_arrays == 1)
return redist_sab_a_exchange1(redist, src_data[0], dst_data[0]);
else
Xt_abort(redist_rep->comm, "ERROR: multi-array a_exchange is not"
" implemented for this xt_redist type "
"(Xt_redist_single_array_base)", __FILE__, __LINE__);
return XT_REQUEST_NULL;
}
static void
redist_sab_s_exchange1(Xt_redist redist, const void *src_data, void *dst_data) {
......@@ -191,6 +216,14 @@ redist_sab_s_exchange1(Xt_redist redist, const void *src_data, void *dst_data) {
xt_exchanger_s_exchange(redist_sab->exchanger, src_data, dst_data);
}
static Xt_request
redist_sab_a_exchange1(Xt_redist redist, const void *src_data, void *dst_data) {
Xt_redist_sab redist_sab = xrsab(redist);
return xt_exchanger_a_exchange(redist_sab->exchanger, src_data, dst_data);
}
static MPI_Datatype
redist_sab_get_send_MPI_Datatype(Xt_redist redist, int rank) {
......
......@@ -96,25 +96,45 @@ int main(void) {
// test exchange
static const double src_data[nvalues] = {1,2,3,4,5};
double dst_data[nselect] = {-1,-1,-1};
double dst_data[nselect];
static const double ref_dst_data[nselect] = {1,3,5};
xt_redist_s_exchange1(redist_coll, src_data, dst_data);
for (int j = 0; j < 2; ++j) {
static const double ref_dst_data[nselect] = {1,3,5};
for (size_t i = 0; i < nselect; ++i) dst_data[i] = -1;
for (size_t i = 0; i < nselect; ++i)
if (ref_dst_data[i] != dst_data[i])
PUT_ERR("error in xt_redist_s_exchange\n");
if (j == 0) {
xt_redist_s_exchange1(redist_coll, src_data, dst_data);
} else {
Xt_request request =
xt_redist_a_exchange1(redist_coll, src_data, dst_data);
xt_request_wait(&request);
}
for (size_t i = 0; i < nselect; ++i)
if (ref_dst_data[i] != dst_data[i])
PUT_ERR("error in xt_redist_s_exchange\n");
}
Xt_redist redist_coll_copy = xt_redist_copy(redist_coll);
for (size_t i = 0; i < nselect; ++i)
dst_data[i] = -1;
xt_redist_s_exchange1(redist_coll_copy, src_data, dst_data);
for (int j = 0; j < 2; ++j) {
for (size_t i = 0; i < nselect; ++i) dst_data[i] = -1;
if (j == 0) {
xt_redist_s_exchange1(redist_coll_copy, src_data, dst_data);
} else {
Xt_request request =
xt_redist_a_exchange1(redist_coll_copy, src_data, dst_data);
xt_request_wait(&request);
}
for (size_t i = 0; i < nselect; ++i)
if (ref_dst_data[i] != dst_data[i])
PUT_ERR("error in xt_redist_s_exchange\n");
}
for (size_t i = 0; i < nselect; ++i)
if (ref_dst_data[i] != dst_data[i])
PUT_ERR("error in xt_redist_s_exchange\n");
// clean up
xt_redist_delete(redist_coll_copy);
......@@ -212,6 +232,7 @@ test_repeated_redist(int cache_size)
xt_redist_delete(redist);
// test exchange
for (int j = 0; j < 2; ++j)
{
static const double src_data[3][5]
= {{1,2,3,4,5},{6,7,8,9,10},{11,12,13,14,15}};
......@@ -220,7 +241,13 @@ test_repeated_redist(int cache_size)
const void *src_data_p[3] = {src_data[0],src_data[1],src_data[2]};
void *dst_data_p[3] = {dst_data[0],dst_data[1],dst_data[2]};
xt_redist_s_exchange(redist_coll, 3, src_data_p, dst_data_p);
if (j == 0) {
xt_redist_s_exchange(redist_coll, 3, src_data_p, dst_data_p);
} else {
Xt_request request =
xt_redist_a_exchange(redist_coll, 3, src_data_p, dst_data_p);
xt_request_wait(&request);
}
static const double ref_dst_data[3][3] = {{1,3,5},{6,8,10},{11,13,15}};
......@@ -231,6 +258,7 @@ test_repeated_redist(int cache_size)
}
// test exchange with changed displacements
for (int j = 0; j < 2; ++j)
{
static const double src_data[3][5]
= {{1,2,3,4,5},{6,7,8,9,10},{11,12,13,14,15}};
......@@ -239,7 +267,13 @@ test_repeated_redist(int cache_size)
const void *src_data_p[3] = {src_data[1],src_data[0],src_data[2]};
void *dst_data_p[3] = {dst_data[1],dst_data[0],dst_data[2]};
xt_redist_s_exchange(redist_coll, 3, src_data_p, dst_data_p);
if (j == 0) {
xt_redist_s_exchange(redist_coll, 3, src_data_p, dst_data_p);
} else {
Xt_request request =
xt_redist_a_exchange(redist_coll, 3, src_data_p, dst_data_p);
xt_request_wait(&request);
}
static const double ref_dst_data[3][3] = {{1,3,5},{6,8,10},{11,13,15}};
......@@ -250,6 +284,7 @@ test_repeated_redist(int cache_size)
}
// test exchange with original displacements
for (int j = 0; j < 2; ++j)
{
static const double src_data[3][5]
= {{1,2,3,4,5},{6,7,8,9,10},{11,12,13,14,15}};
......@@ -258,7 +293,13 @@ test_repeated_redist(int cache_size)
const void *src_data_p[3] = {src_data[0],src_data[1],src_data[2]};
void *dst_data_p[3] = {dst_data[0],dst_data[1],dst_data[2]};
xt_redist_s_exchange(redist_coll, 3, src_data_p, dst_data_p);
if (j == 0) {
xt_redist_s_exchange(redist_coll, 3, src_data_p, dst_data_p);
} else {
Xt_request request =
xt_redist_a_exchange(redist_coll, 3, src_data_p, dst_data_p);
xt_request_wait(&request);
}
static const double ref_dst_data[3][3] = {{1,3,5},{6,8,10},{11,13,15}};
......@@ -277,7 +318,7 @@ enum { num_redists = 3 };
enum { nvalues = 5, nselect = nvalues/2+(nvalues&1) };
static void
run_displacement_check(Xt_redist redist_coll)
run_displacement_check(Xt_redist redist_coll, int sync)
{
static const double src_data[num_redists][nvalues]
= {{1,2,3,4,5},{6,7,8,9,10},{11,12,13,14,15}};
......@@ -302,7 +343,13 @@ run_displacement_check(Xt_redist redist_coll)
src_data_p[2] = src_data_+k;
dst_data_p[2] = dst_data_+k;
xt_redist_s_exchange(redist_coll, num_redists, src_data_p, dst_data_p);
if (sync) {
xt_redist_s_exchange(redist_coll, num_redists, src_data_p, dst_data_p);
} else {
Xt_request request =
xt_redist_a_exchange(redist_coll, num_redists, src_data_p, dst_data_p);
xt_request_wait(&request);
}
static const double ref_dst_data[num_redists][nselect]
= {{1,3,5},{6,8,10},{11,13,15}};
......@@ -344,9 +391,11 @@ test_displacement_variations(void)
// test exchange
run_displacement_check(redist_coll);
run_displacement_check(redist_coll, 0);
run_displacement_check(redist_coll, 1);
Xt_redist redist_coll_copy = xt_redist_copy(redist_coll);
run_displacement_check(redist_coll_copy);
run_displacement_check(redist_coll_copy, 0);
run_displacement_check(redist_coll_copy, 1);
// clean up
xt_redist_delete(redist_coll_copy);
......
......@@ -61,7 +61,8 @@
static void
exchange_4redist(Xt_redist redist, MPI_Comm comm,
const Xt_int *index_vector_a, const Xt_int *index_vector_b);
const Xt_int *index_vector_a, const Xt_int *index_vector_b,
int sync);
enum {
list_a = 0,
list_b = 1,
......@@ -71,7 +72,8 @@ enum { num_redists = 4 };
static void
rr_exchange(Xt_redist redist,
const Xt_int src_indices_[5], const Xt_int dst_indices_[2][5]);
const Xt_int src_indices_[5], const Xt_int dst_indices_[2][5],
int sync);
int main(void) {
......@@ -178,11 +180,15 @@ int main(void) {
xt_redist_delete(redists[i]);
exchange_4redist(redist, MPI_COMM_WORLD,
index_vector[list_a], index_vector[list_b]);
index_vector[list_a], index_vector[list_b], 0);
exchange_4redist(redist, MPI_COMM_WORLD,
index_vector[list_a], index_vector[list_b], 1);
Xt_redist redist_copy = xt_redist_copy(redist);
xt_redist_delete(redist);
exchange_4redist(redist_copy, MPI_COMM_WORLD,
index_vector[list_a], index_vector[list_b]);
index_vector[list_a], index_vector[list_b], 0);
exchange_4redist(redist_copy, MPI_COMM_WORLD,
index_vector[list_a], index_vector[list_b], 1);
// clean up
for (size_t i = 0; i < 2; ++i)
......@@ -235,11 +241,13 @@ int main(void) {
xt_redist_delete(redists[0]);
xt_redist_delete(redists[1]);
rr_exchange(redist, src_indices_, (const Xt_int (*)[5])dst_indices_);
rr_exchange(redist, src_indices_, (const Xt_int (*)[5])dst_indices_, 0);
rr_exchange(redist, src_indices_, (const Xt_int (*)[5])dst_indices_, 1);
Xt_redist redist_copy = xt_redist_copy(redist);
xt_redist_delete(redist);
rr_exchange(redist_copy, src_indices_, (const Xt_int (*)[5])dst_indices_);
rr_exchange(redist_copy, src_indices_, (const Xt_int (*)[5])dst_indices_, 0);
rr_exchange(redist_copy, src_indices_, (const Xt_int (*)[5])dst_indices_, 1);
// clean up
xt_redist_delete(redist_copy);
......@@ -259,7 +267,8 @@ check_4redist_result(int size, void *results[4],
static void
exchange_4redist(Xt_redist redist, MPI_Comm comm,
const Xt_int *index_vector_a, const Xt_int *index_vector_b)
const Xt_int *index_vector_a, const Xt_int *index_vector_b,
int sync)
{
int rank, size;
xt_mpi_call(MPI_Comm_rank(comm, &rank), comm);
......@@ -285,7 +294,13 @@ exchange_4redist(Xt_redist redist, MPI_Comm comm,
const void *input[num_redists]
= { index_vector_a, index_vector_b, index_vector_a, index_vector_b };
xt_redist_s_exchange(redist, num_redists, input, results);
if (sync) {
xt_redist_s_exchange(redist, num_redists, input, results);
} else {
Xt_request request =
xt_redist_a_exchange(redist, num_redists, input, results);
xt_request_wait(&request);
}
check_4redist_result(size, results, index_vector_a,
index_vector_b);
/*
......@@ -296,7 +311,13 @@ exchange_4redist(Xt_redist redist, MPI_Comm comm,
results[0] = buf;
/* ...and repeat exchange */
xt_redist_s_exchange(redist, num_redists, input, results);
if (sync) {
xt_redist_s_exchange(redist, num_redists, input, results);
} else {
Xt_request request =
xt_redist_a_exchange(redist, num_redists, input, results);
xt_request_wait(&request);
}
check_4redist_result(size, results, index_vector_a, index_vector_b);
free(buf);
......@@ -327,13 +348,19 @@ check_4redist_result(int size, void *results[4],
static void
rr_exchange(Xt_redist redist,
const Xt_int src_indices_[5], const Xt_int dst_indices_[2][5])
const Xt_int src_indices_[5], const Xt_int dst_indices_[2][5],
int sync)
{
Xt_int results_[2][5] = { {-1,-1,-1,-1,-1}, {-1,-1,-1,-1,-1} };
void *results[2] = {results_[0], results_[1]};
const void *input[2] = {src_indices_, src_indices_};
xt_redist_s_exchange(redist, 2, input, results);
if (sync) {
xt_redist_s_exchange(redist, 2, input, results);
} else {
Xt_request request = xt_redist_a_exchange(redist, 2, input, results);
xt_request_wait(&request);
}
// check results
for (int i = 0; i < 5; ++i) {
......
......@@ -136,13 +136,12 @@ int main(void) {
xt_redist_delete(redist);
// test exchange
{
for (size_t j = 0; j < 3; ++j)
for (size_t i = 0; i < 3; ++i)
dst_data[j][i] = -1.0;
xt_redist_s_exchange1(redist_coll, src_data, dst_data);
static const double ref_dst_data[3][3] = {{1,3,5},{6,8,10},{11,13,15}};
check_redist(redist_coll, src_data,
sizeof (dst_data), dst_data, ref_dst_data);
......
......@@ -233,7 +233,7 @@ int main(void) {
xt_xmap_delete(xmaps[0]);
xt_xmap_delete(xmaps[1]);
Xt_int results_1[5] = {-1,-1,-1,-1,-1}, results_2[5] = {-1,-1,-1,-1,-1};
Xt_int results_1[5], results_2[5];
MPI_Aint src_displacements[2] = {0, 0};
MPI_Aint dst_displacements[2]
......@@ -252,15 +252,29 @@ int main(void) {
xt_redist_delete(redists[0]);
xt_redist_delete(redists[1]);
xt_redist_s_exchange1(redist, (void*)src_indices_, (void*)results_1);
// check results
for (int i = 0; i < 5; ++i) {
if (results_1[i] != dst_indices_[0][i])
PUT_ERR("error on xt_redist_s_exchange\n");
if (results_2[i] != dst_indices_[1][i])
PUT_ERR("error on xt_redist_s_exchange\n");
for (int j = 0; j < 2; ++j) {
for (int i = 0; i < 5; ++i) {
results_1[i] = -1;
results_2[i] = -1;
}
if (j == 0) {
xt_redist_s_exchange1(redist, (void*)src_indices_, (void*)results_1);
} else {
Xt_request request =
xt_redist_a_exchange1(redist, (void*)src_indices_, (void*)results_1);
xt_request_wait(&request);
}
// check results
for (int i = 0; i < 5; ++i) {
if (results_1[i] != dst_indices_[0][i])
PUT_ERR("error on xt_redist_s_exchange\n");
if (results_2[i] != dst_indices_[1][i])
PUT_ERR("error on xt_redist_s_exchange\n");
}