Skip to content
Snippets Groups Projects
Commit 3cf7c7e1 authored by Moritz Hanke's avatar Moritz Hanke
Browse files

removes intersection from xmap data structure

parent 16027912
No related branches found
No related tags found
No related merge requests found
......@@ -57,8 +57,6 @@ static const struct Xt_xmap_vtable xmap_all2all_vtable = {
.get_max_dst_pos = xmap_all2all_get_max_dst_pos};
struct exchange_data {
// list of global indices to send or receive from partner rank
Xt_idxlist intersect;
// list of relative positions in memory to send or receive
Xt_idxlist xlist;
int rank;
......@@ -130,9 +128,15 @@ static void xmap_all2all_get_source_ranks(Xt_xmap xmap, int * ranks) {
ranks[i] = xmap_all2all->src_msg[i].rank;
}
static void exchange_idxlists(struct Xt_xmap_all2all * xmap,
static void exchange_idxlists(Xt_idxlist ** src_intersections,
int ** src_ranks,
Xt_count * num_src_intersections,
Xt_idxlist ** dst_intersections,
int ** dst_ranks,
Xt_count * num_dst_intersections,
Xt_idxlist src_idxlist_local,
Xt_idxlist dst_idxlist_local)
Xt_idxlist dst_idxlist_local,
MPI_Comm comm)
{
/*
......@@ -149,12 +153,12 @@ static void exchange_idxlists(struct Xt_xmap_all2all * xmap,
unsigned i;
int comm_size;
mpi_err_handler(MPI_Comm_size(xmap->comm, &comm_size), xmap->comm);
mpi_err_handler(MPI_Comm_size(comm, &comm_size), comm);
// compute size of local index lists
long long src_pack_size = xt_idxlist_get_pack_size(src_idxlist_local, xmap->comm);
long long src_pack_size = xt_idxlist_get_pack_size(src_idxlist_local, comm);
if (src_pack_size<0) die("src index list is too large");
long long dst_pack_size = xt_idxlist_get_pack_size(dst_idxlist_local, xmap->comm);
long long dst_pack_size = xt_idxlist_get_pack_size(dst_idxlist_local, comm);
if (dst_pack_size<0) die("dst index list is too large");
if (src_pack_size + dst_pack_size >= INT_MAX) die("local src+dst index lists are too large");
......@@ -167,10 +171,8 @@ static void exchange_idxlists(struct Xt_xmap_all2all * xmap,
// pack local index lists
int position = 0;
xt_idxlist_pack(src_idxlist_local, send_buffer, send_buffer_size, &position,
xmap->comm);
xt_idxlist_pack(dst_idxlist_local, send_buffer, send_buffer_size, &position,
xmap->comm);
xt_idxlist_pack(src_idxlist_local, send_buffer, send_buffer_size, &position, comm);
xt_idxlist_pack(dst_idxlist_local, send_buffer, send_buffer_size, &position, comm);
// exchange buffer sizes
int send_buffer_header = send_buffer_size;
......@@ -179,8 +181,7 @@ static void exchange_idxlists(struct Xt_xmap_all2all * xmap,
recv_buffer_header = xmalloc(comm_size * sizeof(*recv_buffer_header));
mpi_err_handler(MPI_Allgather(&send_buffer_header, 1, MPI_INT,
recv_buffer_header,
1, MPI_INT, xmap->comm), xmap->comm);
recv_buffer_header, 1, MPI_INT, comm), comm);
void * recv_buffer;
int * displ;
......@@ -201,8 +202,7 @@ static void exchange_idxlists(struct Xt_xmap_all2all * xmap,
// exchange buffers
mpi_err_handler(MPI_Allgatherv(send_buffer, send_buffer_size, MPI_PACKED,
recv_buffer, recv_buffer_header, displ,
MPI_PACKED,
xmap->comm), xmap->comm);
MPI_PACKED, comm), comm);
src_idxlists_all = xmalloc(comm_size * sizeof(*src_idxlists_all));
......@@ -211,39 +211,42 @@ static void exchange_idxlists(struct Xt_xmap_all2all * xmap,
// unpack buffers
for (i = 0, position = 0; i < comm_size; ++i) {
src_idxlists_all[i] = xt_idxlist_unpack(recv_buffer, recv_buffer_size,
&position, xmap->comm);
&position, comm);
dst_idxlists_all[i] = xt_idxlist_unpack(recv_buffer, recv_buffer_size,
&position, xmap->comm);
&position, comm);
}
// allocate memory for intersections
xmap->dst_msg = xmalloc(comm_size * sizeof (*xmap->dst_msg));
xmap->ndst = 0;
xmap->src_msg = xmalloc(comm_size * sizeof (*xmap->src_msg));
xmap->nsrc = 0;
if (!xmap->dst_msg || !xmap->src_msg)
abort();
(*dst_intersections) = xmalloc(comm_size * sizeof (**dst_intersections));
(*dst_ranks) = xmalloc(comm_size * sizeof(**dst_ranks));
(*num_dst_intersections) = 0;
(*src_intersections) = xmalloc(comm_size * sizeof (**src_intersections));
(*src_ranks) = xmalloc(comm_size * sizeof(**src_ranks));
(*num_src_intersections) = 0;
// compute intersections
for (i = 0; i < comm_size; ++i)
{
Xt_idxlist intersect
= xt_idxlist_get_intersection(src_idxlists_all[i], dst_idxlist_local);
if (xt_idxlist_get_num_indices(intersect) > 0)
{
xmap->dst_msg[xmap->ndst].intersect = intersect;
xmap->dst_msg[xmap->ndst].rank = i;
xmap->ndst++;
for (i = 0; i < comm_size; ++i) {
Xt_idxlist intersect;
intersect = xt_idxlist_get_intersection(src_idxlists_all[i],
dst_idxlist_local);
if (xt_idxlist_get_num_indices(intersect) > 0) {
(*dst_intersections)[(*num_dst_intersections)] = intersect;
(*dst_ranks)[(*num_dst_intersections)] = i;
(*num_dst_intersections)++;
}
else
xt_idxlist_delete(intersect);
intersect = xt_idxlist_get_intersection(src_idxlist_local,
dst_idxlists_all[i]);
if (xt_idxlist_get_num_indices(intersect) > 0)
{
xmap->src_msg[xmap->nsrc].rank = i;
xmap->src_msg[xmap->nsrc].intersect = intersect;
xmap->nsrc++;
if (xt_idxlist_get_num_indices(intersect) > 0) {
(*src_intersections)[(*num_src_intersections)] = intersect;
(*src_ranks)[(*num_src_intersections)] = i;
(*num_src_intersections)++;
}
else
xt_idxlist_delete(intersect);
......@@ -266,16 +269,23 @@ static void exchange_idxlists(struct Xt_xmap_all2all * xmap,
}
static void generate_src_xlist(struct Xt_xmap_all2all * xmap,
Xt_idxlist * intersections,
int * ranks,
Xt_count num_intersections,
Xt_idxlist mypart_idxlist) {
Xt_count intersection_size[xmap->nsrc];
Xt_count intersection_size[num_intersections];
int max_intersection_size = 0;
const int single_match_only = 0;
xmap->nsrc = num_intersections;
xmap->src_msg = xmalloc(num_intersections * sizeof(*(xmap->src_msg)));
// find max size of intersections:
for (int i = 0; i < xmap->nsrc; ++i) {
intersection_size[i] = xt_idxlist_get_num_indices(xmap->src_msg[i].intersect);
if (intersection_size[i] > max_intersection_size) max_intersection_size = intersection_size[i];
intersection_size[i] = xt_idxlist_get_num_indices(intersections[i]);
if (intersection_size[i] > max_intersection_size)
max_intersection_size = intersection_size[i];
}
Xt_idx *intersection_idxvec = xmalloc(max_intersection_size * sizeof(Xt_idx));
......@@ -283,32 +293,38 @@ static void generate_src_xlist(struct Xt_xmap_all2all * xmap,
int retval;
for (int i = 0; i < xmap->nsrc; ++i) {
xt_idxlist_get_indices(xmap->src_msg[i].intersect, intersection_idxvec);
xt_idxlist_get_indices(intersections[i], intersection_idxvec);
retval = xt_idxlist_get_position_of_indices(mypart_idxlist, intersection_idxvec,
intersection_size[i], intersection_pos, single_match_only);
intersection_size[i], intersection_pos,
single_match_only);
assert(retval != 1);
xmap->src_msg[i].xlist = xt_idxvec_new(intersection_pos, intersection_size[i]);
xmap->src_msg[i].rank = ranks[i];
}
free(intersection_pos);
free(intersection_idxvec);
}
static void generate_dst_xlist(struct Xt_xmap_all2all * xmap,
Xt_idxlist * intersections,
int * ranks,
Xt_count num_intersections,
Xt_idxlist mypart_idxlist) {
Xt_count intersection_size[xmap->ndst];
Xt_count intersection_size[num_intersections];
int max_intersection_size = 0;
const int single_match_only = 1;
xmap->ndst = num_intersections;
xmap->dst_msg = xmalloc(num_intersections * sizeof(*(xmap->dst_msg)));
// find max size of intersections:
for (int i = 0; i < xmap->ndst; ++i) {
intersection_size[i] = xt_idxlist_get_num_indices(xmap->dst_msg[i].intersect);
if (intersection_size[i] > max_intersection_size) max_intersection_size = intersection_size[i];
intersection_size[i] = xt_idxlist_get_num_indices(intersections[i]);
if (intersection_size[i] > max_intersection_size)
max_intersection_size = intersection_size[i];
}
Xt_idx *intersection_idxvec = xmalloc(max_intersection_size * sizeof(Xt_idx));
......@@ -316,26 +332,34 @@ static void generate_dst_xlist(struct Xt_xmap_all2all * xmap,
int retval;
for (int i = 0; i < xmap->ndst; ++i) {
xt_idxlist_get_indices(xmap->dst_msg[i].intersect, intersection_idxvec);
xt_idxlist_get_indices(intersections[i], intersection_idxvec);
retval = xt_idxlist_get_position_of_indices(mypart_idxlist, intersection_idxvec,
intersection_size[i], intersection_pos, single_match_only);
intersection_size[i], intersection_pos,
single_match_only);
assert(retval != 1);
xmap->dst_msg[i].xlist = xt_idxvec_new(intersection_pos, intersection_size[i]);
xmap->dst_msg[i].rank = ranks[i];
}
free(intersection_pos);
free(intersection_idxvec);
}
static void generate_xlist(struct Xt_xmap_all2all * xmap,
Xt_idxlist * src_intersections,
int * src_ranks,
Xt_count num_src_intersections,
Xt_idxlist * dst_intersections,
int * dst_ranks,
Xt_count num_dst_intersections,
Xt_idxlist src_idxlist_local,
Xt_idxlist dst_idxlist_local) {
generate_src_xlist(xmap, src_idxlist_local);
generate_dst_xlist(xmap, dst_idxlist_local);
generate_src_xlist(xmap, src_intersections, src_ranks,
num_src_intersections, src_idxlist_local);
generate_dst_xlist(xmap, dst_intersections, dst_ranks,
num_dst_intersections, dst_idxlist_local);
}
Xt_xmap xt_xmap_all2all_new(Xt_idxlist src_idxlist, Xt_idxlist dst_idxlist, MPI_Comm comm) {
......@@ -348,11 +372,26 @@ Xt_xmap xt_xmap_all2all_new(Xt_idxlist src_idxlist, Xt_idxlist dst_idxlist, MPI_
mpi_err_handler(MPI_Comm_dup(comm, &(xmap->comm)), comm);
Xt_idxlist * src_intersections = NULL, * dst_intersections = NULL;
int * src_ranks = NULL, * dst_ranks = NULL;
Xt_count num_src_intersections = 0, num_dst_intersections = 0;
// exchange index lists between all processes in comm
exchange_idxlists(xmap, src_idxlist, dst_idxlist);
exchange_idxlists(&src_intersections, &src_ranks, &num_src_intersections,
&dst_intersections, &dst_ranks, &num_dst_intersections,
src_idxlist, dst_idxlist, xmap->comm);
// generate exchange lists
generate_xlist(xmap, src_idxlist, dst_idxlist);
generate_xlist(xmap, src_intersections, src_ranks, num_src_intersections,
dst_intersections, dst_ranks, num_dst_intersections, src_idxlist,
dst_idxlist);
for (Xt_count i = 0; i < num_src_intersections; ++i)
xt_idxlist_delete(src_intersections[i]);
for (Xt_count i = 0; i < num_dst_intersections; ++i)
xt_idxlist_delete(dst_intersections[i]);
free(src_intersections), free(dst_intersections);
free(src_ranks), free(dst_ranks);
// we could also calculate the (more precise) max pos using only xmap data
// but using this simple estimate we are still okay for usage checks
......@@ -381,15 +420,11 @@ static void xmap_all2all_delete(Xt_xmap xmap) {
unsigned i;
for (i = 0; i < xmap_all2all->ndst; ++i) {
xt_idxlist_delete(xmap_all2all->dst_msg[i].intersect);
for (i = 0; i < xmap_all2all->ndst; ++i)
xt_idxlist_delete(xmap_all2all->dst_msg[i].xlist);
}
for (i = 0; i < xmap_all2all->nsrc; ++i) {
xt_idxlist_delete(xmap_all2all->src_msg[i].intersect);
for (i = 0; i < xmap_all2all->nsrc; ++i)
xt_idxlist_delete(xmap_all2all->src_msg[i].xlist);
}
free(xmap_all2all->dst_msg);
free(xmap_all2all->src_msg);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment