Commit fa63bdf7 authored by Moritz Hanke's avatar Moritz Hanke

adds new xt_xmap constructor xt_xmap_intersection_pos_new

parent eb3ff7c1
......@@ -61,6 +61,13 @@ struct Xt_com_list {
int rank;
};
struct Xt_com_pos {
// list of relative positions in memory to send or receive
int * transfer_pos;
int num_transfer_pos;
int rank;
};
/**
* constructor for an exchange map \n
* this operation is collective over all processes in comm \n
......@@ -119,6 +126,26 @@ xt_xmap_intersection_ext_new(
Xt_idxlist src_idxlist, Xt_idxlist dst_idxlist,
MPI_Comm comm);
/**
* constructor for an exchange map \n
* this operation is collective over all processes in comm \n
* it uses the provided intersection information to generate the exchange map
*
* @param[in] num_src_msg number of source messages
* @param[in] src_com array containing relative positions for all
* source messages and the destination rank
* @param[in] num_dst_msg number of destination messages
* @param[in] dst_com array containing relative positions for all
* destination messages and the source rank
* @param[in] comm MPI communicator that contains all processes
* that part in the exchange
*/
Xt_xmap
xt_xmap_intersection_pos_new(
int num_src_msg, const struct Xt_com_pos src_com[num_src_msg],
int num_dst_msg, const struct Xt_com_pos dst_com[num_dst_msg],
MPI_Comm comm);
#endif // XT_XMAP_INTERSECTION_H
/*
......
......@@ -956,6 +956,66 @@ xmap_intersection_spread(Xt_xmap xmap, int num_repetitions,
.displacements = dst_displacements });
}
static int get_max_pos_from_com_pos(
int num_msg, const struct Xt_com_pos com[num_msg]) {
int max_pos = 0;
for (int i = 0; i < num_msg; ++i) {
int curr_num_transfer_pos = com[i].num_transfer_pos;
int * curr_transfer_pos = com[i].transfer_pos;
for (int j = 0; j < curr_num_transfer_pos; ++j)
if (curr_transfer_pos[j] > max_pos) max_pos = curr_transfer_pos[j];
}
return max_pos;
}
static void init_exchange_data_from_com_pos(
int count, struct exchange_data *restrict msgs,
const struct Xt_com_pos *restrict com) {
for (int i = 0; i < count; ++i) {
int num_transfer_pos = com[i].num_transfer_pos;
int * transfer_pos =
xmalloc((size_t)num_transfer_pos * sizeof(*transfer_pos));
int rank = com[i].rank;
memcpy(transfer_pos, com[i].transfer_pos,
(size_t)num_transfer_pos * sizeof(*transfer_pos));
msgs[i].transfer_pos = transfer_pos;
msgs[i].transfer_pos_ext_cache = NULL;
msgs[i].num_transfer_pos = num_transfer_pos;
msgs[i].num_transfer_pos_ext =
(int)(count_pos_ext((size_t)num_transfer_pos, transfer_pos));
msgs[i].rank = rank;
}
}
Xt_xmap
xt_xmap_intersection_pos_new(
int num_src_msg, const struct Xt_com_pos src_com[num_src_msg],
int num_dst_msg, const struct Xt_com_pos dst_com[num_dst_msg],
MPI_Comm comm) {
// ensure that yaxt is initialized
assert(xt_initialized());
size_t num_msg = (size_t)num_src_msg + (size_t)num_dst_msg;
Xt_xmap_intersection xmap
= xmalloc(sizeof (*xmap) + num_msg * sizeof (struct exchange_data));
xmap->vtable = &xmap_intersection_vtable;
xmap->n_in = num_dst_msg;
xmap->n_out = num_src_msg;
xmap->max_dst_pos = get_max_pos_from_com_pos(num_dst_msg, dst_com);
xmap->max_src_pos = get_max_pos_from_com_pos(num_src_msg, src_com);
xmap->comm = comm = xt_mpi_comm_smart_dup(comm, &xmap->tag_offset);
init_exchange_data_from_com_pos(
num_dst_msg, xmap->msg, dst_com);
init_exchange_data_from_com_pos(
num_src_msg, xmap->msg + num_dst_msg, src_com);
return (Xt_xmap)xmap;
}
static int xmap_intersection_iterator_next(Xt_xmap_iter iter) {
Xt_xmap_iter_intersection iter_intersection = xmii(iter);
......
......@@ -50,7 +50,7 @@
!! @example test_xmap_intersection_parallel_f.f90
MODULE xt_xmap_intersection
USE iso_c_binding, ONLY: c_int, c_loc, c_null_ptr, c_ptr
USE iso_c_binding, ONLY: c_int, c_size_t, c_loc, c_null_ptr, c_ptr
USE xt_core, ONLY: xt_abort, xt_mpi_fint_kind
USE xt_idxlist_abstract, ONLY: xt_idxlist, xt_idxlist_f2c
USE xt_xmap_abstract, ONLY: xt_xmap, xt_xmap_c2f
......@@ -63,7 +63,19 @@ MODULE xt_xmap_intersection
INTEGER(c_int) :: rank
END TYPE xt_com_list
PUBLIC :: xt_xmap_intersection_new, xt_xmap_intersection_ext_new
TYPE, PUBLIC :: xt_com_pos
INTEGER, POINTER :: transfer_pos(:)
INTEGER :: rank
END TYPE xt_com_pos
TYPE, BIND(c) :: xt_com_pos_c
TYPE(c_ptr) :: transfer_pos
INTEGER(c_int) :: num_transfer_pos
INTEGER(c_int) :: rank
END TYPE xt_com_pos_c
PUBLIC :: xt_xmap_intersection_new, xt_xmap_intersection_ext_new, &
xt_xmap_intersection_pos_new
INTERFACE
FUNCTION xmi_new_f2c(num_src_intersections, src_com, &
......@@ -91,6 +103,16 @@ MODULE xt_xmap_intersection
INTEGER(xt_mpi_fint_kind), VALUE, INTENT(in) :: comm
TYPE(c_ptr) :: xmap
END FUNCTION xmi_ext_new_f2c
FUNCTION xmi_pos_new_f2c(num_src_msg, src_com, num_dst_msg, dst_com, comm) &
RESULT(xmap) &
BIND(c, name='xt_xmap_intersection_pos_new_f2c')
IMPORT :: c_int, c_ptr, xt_mpi_fint_kind
INTEGER(c_int), VALUE, INTENT(in) :: num_src_msg, num_dst_msg
TYPE(c_ptr), VALUE, INTENT(in) :: src_com, dst_com
INTEGER(xt_mpi_fint_kind), VALUE, INTENT(in) :: comm
TYPE(c_ptr) :: xmap
END FUNCTION xmi_pos_new_f2c
END INTERFACE
INTERFACE xt_xmap_intersection_new
......@@ -102,6 +124,12 @@ MODULE xt_xmap_intersection
MODULE PROCEDURE xmi_ext_new_i_a_i_a
MODULE PROCEDURE xmi_ext_new_a_a
END INTERFACE xt_xmap_intersection_ext_new
INTERFACE xt_xmap_intersection_pos_new
MODULE PROCEDURE xmi_pos_new_a_a
MODULE PROCEDURE xmi_pos_new_i_a_i_a
END INTERFACE xt_xmap_intersection_pos_new
CHARACTER(len=*), PARAMETER :: filename = 'xt_xmap_intersection_f.f90'
CONTAINS
FUNCTION xmi_new_i_a_i_a(num_src_intersections, src_com, &
......@@ -180,11 +208,92 @@ CONTAINS
CALL com_p_arg(src_com, src_com_a, src_com_p)
CALL com_p_arg(dst_com, dst_com_a, dst_com_p)
xmap = xt_xmap_c2f(xmi_ext_new_f2c(num_src_intersections_c, src_com_p, &
xmap = xt_xmap_c2f(xmi_ext_new_f2c(&
num_src_intersections_c, src_com_p, &
num_dst_intersections_c, dst_com_p, &
xt_idxlist_f2c(src_idxlist), xt_idxlist_f2c(dst_idxlist), comm))
END FUNCTION xmi_ext_new_a_a
FUNCTION xmi_pos_new_i_a_i_a( &
num_src_msg, src_com, num_dst_msg, dst_com, comm) RESULT(xmap)
TYPE(xt_com_pos), TARGET, INTENT(in) :: src_com(:), dst_com(:)
INTEGER, INTENT(in) :: num_src_msg, num_dst_msg
INTEGER, INTENT(in) :: comm
TYPE(xt_xmap) :: xmap
INTEGER(c_int) :: num_src_msg_c, num_dst_msg_c
TYPE(xt_com_pos_c), ALLOCATABLE :: src_com_c(:), dst_com_c(:)
INTEGER(c_int), ALLOCATABLE :: pos_buffer(:)
INTEGER(c_size_t) :: pos_buffer_offset
ALLOCATE(pos_buffer(get_total_num_transfer_pos(num_src_msg, src_com) + &
get_total_num_transfer_pos(num_dst_msg, dst_com)))
num_src_msg_c = INT(num_src_msg, c_int)
num_dst_msg_c = INT(num_dst_msg, c_int)
pos_buffer_offset = 1
CALL generate_xt_com_pos_c( &
num_src_msg, src_com, src_com_c, pos_buffer, pos_buffer_offset)
CALL generate_xt_com_pos_c( &
num_dst_msg, dst_com, dst_com_c, pos_buffer, pos_buffer_offset)
xmap = &
xt_xmap_c2f( &
xmi_pos_new_f2c(&
num_src_msg_c, C_LOC(src_com_c), num_dst_msg_c, C_LOC(dst_com_c), comm))
CONTAINS
FUNCTION get_total_num_transfer_pos(num_msg, com_pos) &
RESULT(total_num_transfer_pos)
INTEGER, INTENT(in) :: num_msg
TYPE(xt_com_pos), TARGET, INTENT(in) :: com_pos(:)
INTEGER :: i
INTEGER(c_size_t) :: total_num_transfer_pos
total_num_transfer_pos = 0
DO i = 1, num_msg
total_num_transfer_pos = &
total_num_transfer_pos + SIZE(com_pos(i)%transfer_pos, 1, c_size_t)
END DO
END FUNCTION get_total_num_transfer_pos
SUBROUTINE generate_xt_com_pos_c( &
num_msg, com_pos, com_pos_c, pos_buffer, pos_buffer_offset)
INTEGER, INTENT(in) :: num_msg
TYPE(xt_com_pos), TARGET, INTENT(in) :: com_pos(:)
TYPE(xt_com_pos_c), ALLOCATABLE, INTENT(inout) :: com_pos_c(:)
INTEGER(c_int), TARGET, INTENT(inout) :: pos_buffer(:)
INTEGER(c_size_t), INTENT(inout) :: pos_buffer_offset
INTEGER :: i
INTEGER(c_size_t) :: j, curr_num_transfer_pos
ALLOCATE(com_pos_c(num_msg))
DO i = 1, num_msg
curr_num_transfer_pos = SIZE(com_pos(i)%transfer_pos, 1, c_size_t)
DO j = 0, curr_num_transfer_pos - 1
pos_buffer(pos_buffer_offset + j) = &
INT(com_pos(i)%transfer_pos(j+1), c_int)
END DO
com_pos_c(i)%transfer_pos = C_LOC(pos_buffer(pos_buffer_offset))
com_pos_c(i)%num_transfer_pos = INT(curr_num_transfer_pos, c_int)
com_pos_c(i)%rank = INT(com_pos(i)%rank, c_int)
pos_buffer_offset = pos_buffer_offset + curr_num_transfer_pos
END DO
END SUBROUTINE generate_xt_com_pos_c
END FUNCTION xmi_pos_new_i_a_i_a
FUNCTION xmi_pos_new_a_a(src_com, dst_com, comm) RESULT(xmap)
TYPE(xt_com_pos), TARGET, INTENT(in) :: src_com(:), dst_com(:)
INTEGER, INTENT(in) :: comm
TYPE(xt_xmap) :: xmap
xmap = &
xmi_pos_new_i_a_i_a(SIZE(src_com), src_com, SIZE(dst_com), dst_com, comm)
END FUNCTION xmi_pos_new_a_a
SUBROUTINE com_p_arg(com, com_a, com_p)
TYPE(xt_com_list), TARGET, INTENT(in) :: com(:)
TYPE(xt_com_list), TARGET, ALLOCATABLE, INTENT(inout) :: com_a(:)
......
......@@ -87,7 +87,8 @@ MODULE yaxt
xt_reorder_type_kind, XT_REORDER_NONE, XT_REORDER_SEND_UP, &
XT_REORDER_RECV_UP, xt_xmap_update_positions, xt_xmap_spread
USE xt_xmap_intersection, ONLY: xt_xmap_intersection_new, &
xt_xmap_intersection_ext_new, xt_com_list
xt_xmap_intersection_ext_new, xt_com_list, &
xt_xmap_intersection_pos_new, xt_com_pos
USE xt_redist_base, ONLY: xt_redist, xt_redist_c2f, xt_redist_f2c, &
xt_redist_copy, &
xt_redist_delete, xt_redist_s_exchange1, xt_redist_s_exchange, &
......@@ -139,10 +140,11 @@ MODULE yaxt
xt_xmap_get_max_src_pos, xt_xmap_get_max_dst_pos, &
xt_xmap_dist_dir_intercomm_new, &
xt_xmap_intersection_new, xt_xmap_intersection_ext_new, &
xt_xmap_intersection_pos_new, &
xt_xmap_reorder, xt_reorder_type_kind, &
XT_REORDER_NONE, XT_REORDER_SEND_UP, XT_REORDER_RECV_UP, &
xt_xmap_update_positions, xt_xmap_spread, &
xt_com_list, &
xt_com_list, xt_com_pos, &
xt_xmap_iter, xt_xmap_get_in_iterator, xt_xmap_get_out_iterator, &
xt_xmap_iterator_next, xt_xmap_iterator_get_rank, &
xt_xmap_iterator_get_transfer_pos, &
......
......@@ -401,6 +401,15 @@ xt_xmap_intersection_ext_new_f2c(
src_idxlist, dst_idxlist, MPI_Comm_f2c(comm));
}
void *
xt_xmap_intersection_pos_new_f2c(
int num_src_msg, const void *src_com, int num_dst_msg, const void *dst_com,
MPI_Fint comm)
{
return xt_xmap_intersection_pos_new(
num_src_msg, src_com, num_dst_msg, dst_com, MPI_Comm_f2c(comm));
}
#ifndef HAVE_FC_IS_CONTIGUOUS
int
xt_com_list_contiguous(const struct Xt_com_list *p_com_a,
......
......@@ -57,6 +57,7 @@
#include "tests.h"
#include "xt/xt_xmap_intersection.h"
#include "xt/xt_xmap.h"
#include "core/ppm_xfuncs.h"
struct test_message {
......@@ -524,6 +525,47 @@ int main(int argc, char **argv)
xt_idxlist_delete(src_com.list);
}
{ // alltoall using xt_xmap_intersection_pos_new
struct Xt_com_pos * src_com =
xmalloc(2 * (size_t)comm_size * sizeof(*src_com));
struct Xt_com_pos * dst_com = src_com + comm_size;
int * transfer_pos = xmalloc((size_t)comm_size * sizeof(*transfer_pos));
struct test_message * ref_src_msg =
xmalloc(2 * (size_t)comm_size * sizeof(*ref_src_msg));
struct test_message * ref_dst_msg = ref_src_msg + comm_size;
for (int i = 0, offset = my_rank * comm_size; i < comm_size; ++i) {
transfer_pos[i] = i;
src_com[i].transfer_pos = transfer_pos + i;
src_com[i].num_transfer_pos = 1;
src_com[i].rank = i;
dst_com[i].transfer_pos = transfer_pos + i;
dst_com[i].num_transfer_pos = 1;
dst_com[i].rank = i;
ref_src_msg[i].pos = transfer_pos + i;
ref_src_msg[i].num_pos = 1;
ref_src_msg[i].rank = i;
ref_dst_msg[i].pos = transfer_pos + i;
ref_dst_msg[i].num_pos = 1;
ref_dst_msg[i].rank = i;
}
Xt_xmap xmap =
xt_xmap_intersection_pos_new(
comm_size, src_com, comm_size, dst_com, MPI_COMM_WORLD);
test_xmap(xmap, comm_size, ref_src_msg, comm_size, ref_dst_msg);
xt_xmap_delete(xmap);
free(ref_src_msg);
free(transfer_pos);
free(src_com);
}
xt_finalize();
MPI_Finalize();
......
......@@ -50,6 +50,7 @@ PROGRAM test_xmap_intersection_parallel
xt_idxlist, xt_idxvec_new, xt_idxlist_delete, xt_xmap, &
xt_idxempty_new, xi => xt_int_kind, &
xt_xmap_intersection_new, xt_xmap_intersection_ext_new, xt_com_list, &
xt_xmap_intersection_pos_new, xt_com_pos, &
xt_xmap_copy, xt_xmap_delete, xt_xmap_iter, &
xt_xmap_get_in_iterator, xt_xmap_get_out_iterator, &
xt_xmap_get_num_destinations, xt_xmap_get_num_sources, &
......@@ -113,6 +114,7 @@ PROGRAM test_xmap_intersection_parallel
CALL dedup_test
CALL reorder_test
CALL update_positions_and_spread_test
CALL alltoall_pos_test
IF (test_err_count() /= 0) &
CALL test_abort("non-zero error count!", filename, __LINE__)
......@@ -508,6 +510,36 @@ CONTAINS
CALL xt_idxlist_delete(src_com(1)%list)
END SUBROUTINE update_positions_and_spread_test
! checks xt_xmap_intersection_pos_new constructor
SUBROUTINE alltoall_pos_test
TYPE(xt_xmap) :: xmap
TYPE(xt_com_pos), TARGET :: src_com(comm_size), dst_com(comm_size)
INTEGER, TARGET :: transfer_pos(comm_size)
TYPE(test_message) :: recv_messages(comm_size), send_messages(comm_size)
INTEGER :: i
DO i = 1, comm_size
transfer_pos(i) = i
src_com(i)%transfer_pos => transfer_pos(i:i)
src_com(i)%rank = i - 1
dst_com(i)%transfer_pos => transfer_pos(i:i)
dst_com(i)%rank = i - 1
send_messages(i)%rank = i - 1
send_messages(i)%pos => transfer_pos(i:i)
recv_messages(i)%rank = i - 1
recv_messages(i)%pos => transfer_pos(i:i)
END DO
xmap = xt_xmap_intersection_pos_new(src_com, dst_com, mpi_comm_world)
! test
CALL test_xmap(xmap, send_messages, recv_messages)
! cleanup
CALL xt_xmap_delete(xmap)
END SUBROUTINE alltoall_pos_test
SUBROUTINE test_xmap_iter(iter, msgs)
TYPE(xt_xmap_iter), INTENT(inout) :: iter
TYPE(test_message), INTENT(in) :: msgs(:)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment