Commit 56ba5be7 authored by Thomas Jahns's avatar Thomas Jahns 🤸
Browse files

Fix incorrect handling of position extents.

* This bug was introduced in 5862f0a3.

* Extend intersection xmap test, so similar bugs cannot creep in as
  easily again.
parent 0f474404
...@@ -285,11 +285,14 @@ generate_ext_msg_infos(int num_msgs, Xt_xmap_iter iter, ...@@ -285,11 +285,14 @@ generate_ext_msg_infos(int num_msgs, Xt_xmap_iter iter,
{ {
if (num_msgs > 0) { if (num_msgs > 0) {
/* partial sums of ext sizes */ /* partial sums of ext sizes */
int *psum_ext_size int *restrict psum_ext_size
= xmalloc(((size_t)num_ext + 1) * sizeof (psum_ext_size[0])); = xmalloc(((size_t)num_ext + 1) * sizeof (psum_ext_size[0]));
psum_ext_size[0] = 0; int accum = 0;
for (size_t i = 0; i < (size_t)num_ext; ++i) for (size_t i = 0; i < (size_t)num_ext; ++i) {
psum_ext_size[i + 1] = psum_ext_size[i] + extents[i].size; psum_ext_size[i] = accum;
accum += extents[i].size;
}
psum_ext_size[num_ext] = accum;
struct Xt_redist_msg *curr_msg = msgs; struct Xt_redist_msg *curr_msg = msgs;
do { do {
......
...@@ -338,13 +338,13 @@ generate_dir_transfer_pos_dst( ...@@ -338,13 +338,13 @@ generate_dir_transfer_pos_dst(
.all_dst_covered = all_bits_set == ~0UL }; .all_dst_covered = all_bits_set == ~0UL };
} }
struct tps_result {
int resCount; struct pos_count_max {
int max_pos; int count, max_pos;
}; };
/* compute list positions for send direction */ /* compute list positions for send direction */
static struct tps_result static struct pos_count_max
generate_dir_transfer_pos_src(int num_intersections, generate_dir_transfer_pos_src(int num_intersections,
const struct Xt_com_list const struct Xt_com_list
intersections[num_intersections], intersections[num_intersections],
...@@ -423,8 +423,8 @@ generate_dir_transfer_pos_src(int num_intersections, ...@@ -423,8 +423,8 @@ generate_dir_transfer_pos_src(int num_intersections,
free(new_intersection_idxvec); free(new_intersection_idxvec);
free(intersection_pos); free(intersection_pos);
return (struct tps_result){ .max_pos = max_pos_, return (struct pos_count_max){ .max_pos = max_pos_,
.resCount = new_num_intersections }; .count = new_num_intersections };
} }
static Xt_int * static Xt_int *
...@@ -565,12 +565,12 @@ generate_transfer_pos(struct Xt_xmap_intersection_ *xmap, ...@@ -565,12 +565,12 @@ generate_transfer_pos(struct Xt_xmap_intersection_ *xmap,
= xrealloc(num_src_indices_to_remove_per_intersection, = xrealloc(num_src_indices_to_remove_per_intersection,
(size_t)num_src_intersections * sizeof(int)); (size_t)num_src_intersections * sizeof(int));
struct tps_result tpsr struct pos_count_max tpsr
= generate_dir_transfer_pos_src( = generate_dir_transfer_pos_src(
num_src_intersections, src_com, src_idxlist_local, xmap->msg + xmap->n_in, num_src_intersections, src_com, src_idxlist_local, xmap->msg + xmap->n_in,
src_indices_to_remove, num_src_indices_to_remove_per_intersection); src_indices_to_remove, num_src_indices_to_remove_per_intersection);
xmap->max_src_pos = tpsr.max_pos; xmap->max_src_pos = tpsr.max_pos;
xmap->n_out = tpsr.resCount; xmap->n_out = tpsr.count;
free(src_indices_to_remove); free(src_indices_to_remove);
free(num_src_indices_to_remove_per_intersection); free(num_src_indices_to_remove_per_intersection);
...@@ -975,6 +975,45 @@ xmap_intersection_spread(Xt_xmap xmap, int num_repetitions, ...@@ -975,6 +975,45 @@ xmap_intersection_spread(Xt_xmap xmap, int num_repetitions,
.displacements = dst_displacements }); .displacements = dst_displacements });
} }
/* how many pos values have monotonically either positively or
* negatively consecutive values and copy to pos_copy */
static inline struct pos_run copy_get_pos_run_len(
size_t num_pos, const int *restrict pos,
int *restrict pos_copy)
{
size_t i = 0, j = 1;
int direction = 0;
int start = pos_copy[0] = pos[0];
if (j < num_pos) {
direction = isign_mask(pos[1] - pos[0]);
while (j < num_pos
&& (pos_copy[j] = pos[j]) == start + (~direction & (int)(j - i)) +
(direction & -(int)(j - i))) {
pos_copy[j] = pos[j];
++j;
}
direction = direction & ((j == 1) - 1);
}
return (struct pos_run){ .start = start, .len = j, .direction = direction };
}
/* compute number of position extents that would be required
to represent positions array and copy to pos_copy */
static struct pos_count_max
max_count_pos_ext_and_copy(int max_pos, size_t num_pos, const int *restrict pos,
int *restrict pos_copy)
{
size_t i = 0, num_pos_ext = 0;
while (i < num_pos) {
struct pos_run run = copy_get_pos_run_len(num_pos - i, pos + i, pos_copy + i);
i += run.len;
int max_of_run = (run.start & run.direction) | ((run.start + (int)run.len - 1) & ~run.direction);
if (max_of_run > max_pos) max_pos = max_of_run;
++num_pos_ext;
}
return (struct pos_count_max){ .count = (int)num_pos_ext, .max_pos = max_pos };
}
static void init_exchange_data_from_com_pos( static void init_exchange_data_from_com_pos(
int count, struct exchange_data *restrict msgs, int count, struct exchange_data *restrict msgs,
const struct Xt_com_pos *restrict com, int *max_pos) { const struct Xt_com_pos *restrict com, int *max_pos) {
...@@ -984,18 +1023,15 @@ static void init_exchange_data_from_com_pos( ...@@ -984,18 +1023,15 @@ static void init_exchange_data_from_com_pos(
int num_transfer_pos = com[i].num_transfer_pos; int num_transfer_pos = com[i].num_transfer_pos;
int *restrict transfer_pos = int *restrict transfer_pos =
xmalloc((size_t)num_transfer_pos * sizeof(*transfer_pos)); xmalloc((size_t)num_transfer_pos * sizeof(*transfer_pos));
int rank = com[i].rank;
const int *restrict com_transfer_pos = com[i].transfer_pos;
for (int j = 0; j < num_transfer_pos; ++j)
if (com_transfer_pos[j] > max_pos_) max_pos_ = com_transfer_pos[j];
msgs[i].transfer_pos = transfer_pos; msgs[i].transfer_pos = transfer_pos;
msgs[i].transfer_pos_ext_cache = NULL; msgs[i].transfer_pos_ext_cache = NULL;
msgs[i].num_transfer_pos = num_transfer_pos; msgs[i].num_transfer_pos = num_transfer_pos;
msgs[i].num_transfer_pos_ext = msgs[i].rank = com[i].rank;
(int)(count_pos_ext((size_t)num_transfer_pos, transfer_pos)); struct pos_count_max max_count
msgs[i].rank = rank; = max_count_pos_ext_and_copy(max_pos_, (size_t)num_transfer_pos,
memcpy(transfer_pos, com_transfer_pos, com[i].transfer_pos, transfer_pos);
(size_t)num_transfer_pos * sizeof(*transfer_pos)); msgs[i].num_transfer_pos_ext = max_count.count;
if (max_count.max_pos > max_pos_) max_pos_ = max_count.max_pos;
} }
*max_pos = max_pos_; *max_pos = max_pos_;
} }
......
...@@ -76,6 +76,9 @@ static Xt_xmap (*xmi_new)( ...@@ -76,6 +76,9 @@ static Xt_xmap (*xmi_new)(
int ndst_com, const struct Xt_com_list dst_com[ndst_com], int ndst_com, const struct Xt_com_list dst_com[ndst_com],
Xt_idxlist src_idxlist, Xt_idxlist dst_idxlist, MPI_Comm comm); Xt_idxlist src_idxlist, Xt_idxlist dst_idxlist, MPI_Comm comm);
static void
test_strided_block_pos_alltoall(MPI_Comm comm, int nblk, int blksz);
static void static void
parse_options(int *argc, char ***argv); parse_options(int *argc, char ***argv);
...@@ -525,51 +528,58 @@ int main(int argc, char **argv) ...@@ -525,51 +528,58 @@ int main(int argc, char **argv)
xt_idxlist_delete(src_com.list); xt_idxlist_delete(src_com.list);
} }
{ // alltoall using xt_xmap_intersection_pos_new test_strided_block_pos_alltoall(MPI_COMM_WORLD, 1, 1);
test_strided_block_pos_alltoall(MPI_COMM_WORLD, 2, 1);
struct Xt_com_pos * src_com = test_strided_block_pos_alltoall(MPI_COMM_WORLD, 5, 2000);
xmalloc(2 * (size_t)comm_size * sizeof(*src_com));
struct Xt_com_pos * dst_com = src_com + comm_size;
int * transfer_pos = xmalloc((size_t)comm_size * sizeof(*transfer_pos));
struct test_message * ref_src_msg =
xmalloc(2 * (size_t)comm_size * sizeof(*ref_src_msg));
struct test_message * ref_dst_msg = ref_src_msg + comm_size;
for (int i = 0; i < comm_size; ++i) {
transfer_pos[i] = i;
src_com[i].transfer_pos = transfer_pos + i;
src_com[i].num_transfer_pos = 1;
src_com[i].rank = i;
dst_com[i].transfer_pos = transfer_pos + i;
dst_com[i].num_transfer_pos = 1;
dst_com[i].rank = i;
ref_src_msg[i].pos = transfer_pos + i;
ref_src_msg[i].num_pos = 1;
ref_src_msg[i].rank = i;
ref_dst_msg[i].pos = transfer_pos + i;
ref_dst_msg[i].num_pos = 1;
ref_dst_msg[i].rank = i;
}
Xt_xmap xmap = xt_finalize();
xt_xmap_intersection_pos_new( MPI_Finalize();
comm_size, src_com, comm_size, dst_com, MPI_COMM_WORLD);
test_xmap(xmap, comm_size, ref_src_msg, comm_size, ref_dst_msg);
xt_xmap_delete(xmap); return TEST_EXIT_CODE;
}
free(ref_src_msg); static void
free(transfer_pos); test_strided_block_pos_alltoall(MPI_Comm comm, int nblk, int blksz)
free(src_com); {
int comm_rank, comm_size;
xt_mpi_call(MPI_Comm_rank(comm, &comm_rank), comm);
xt_mpi_call(MPI_Comm_size(comm, &comm_size), comm);
struct Xt_com_pos *src_com =
xmalloc(2 * (size_t)comm_size * sizeof(*src_com)),
*dst_com = src_com + comm_size;
int (*transfer_pos)[nblk][blksz]
= xmalloc((size_t)nblk * (size_t)blksz
* (size_t)comm_size * sizeof(*transfer_pos));
struct test_message *ref_src_msg =
xmalloc(2 * (size_t)comm_size * sizeof(*ref_src_msg)),
*ref_dst_msg = ref_src_msg + comm_size;
for (int rank = 0; rank < comm_size; ++rank) {
for (int j = 0; j < nblk; ++j)
for (int i = 0; i < blksz; ++i)
transfer_pos[rank][j][i] = i + j*blksz*comm_size;
ref_src_msg[rank].pos = src_com[rank].transfer_pos
= (int *)(transfer_pos + rank);
ref_src_msg[rank].num_pos = src_com[rank].num_transfer_pos = nblk*blksz;
ref_src_msg[rank].rank = src_com[rank].rank = rank;
ref_dst_msg[rank].pos = dst_com[rank].transfer_pos
= (int *)(transfer_pos + rank);
ref_dst_msg[rank].num_pos = dst_com[rank].num_transfer_pos = nblk*blksz;
ref_dst_msg[rank].rank = dst_com[rank].rank = rank;
} }
xt_finalize(); Xt_xmap xmap =
MPI_Finalize(); xt_xmap_intersection_pos_new(comm_size, src_com, comm_size, dst_com, comm);
return TEST_EXIT_CODE; test_xmap(xmap, comm_size, ref_src_msg, comm_size, ref_dst_msg);
xt_xmap_delete(xmap);
free(ref_src_msg);
free(transfer_pos);
free(src_com);
} }
static void test_xmap_iter(Xt_xmap_iter iter, int num_msgs, static void test_xmap_iter(Xt_xmap_iter iter, int num_msgs,
...@@ -603,6 +613,27 @@ static void test_xmap_iter(Xt_xmap_iter iter, int num_msgs, ...@@ -603,6 +613,27 @@ static void test_xmap_iter(Xt_xmap_iter iter, int num_msgs,
mismatch |= (pos[j] != msgs[i].pos[j]); mismatch |= (pos[j] != msgs[i].pos[j]);
if (mismatch) if (mismatch)
PUT_ERR("ERROR: xt_xmap_iterator_get_transfer_pos\n"); PUT_ERR("ERROR: xt_xmap_iterator_get_transfer_pos\n");
int num_transfer_pos_ext
= xt_xmap_iterator_get_num_transfer_pos_ext(iter);
const struct Xt_pos_ext *restrict pos_ext
= xt_xmap_iterator_get_transfer_pos_ext(iter);
mismatch = false;
size_t ofs = 0;
for (size_t pe = 0; pe < (size_t)num_transfer_pos_ext; ++pe) {
int pos_ext_size = abs(pos_ext[pe].size);
if (pos_ext[pe].size > 0)
for (int j = 0; j < pos_ext_size; ++j)
mismatch |= (pos[ofs+(size_t)j] != pos_ext[pe].start + j);
else
for (int j = 0; j < pos_ext_size; ++j)
mismatch |= (pos[ofs+(size_t)j] != pos_ext[pe].start - j);
ofs += (size_t)pos_ext_size;
}
if (mismatch || (int)ofs != msgs[i].num_pos)
PUT_ERR("ERROR: xt_xmap_iterator_get_transfer_pos_ext\n");
++i; ++i;
} while (xt_xmap_iterator_next(iter)); } while (xt_xmap_iterator_next(iter));
......
...@@ -57,6 +57,8 @@ PROGRAM test_xmap_intersection_parallel ...@@ -57,6 +57,8 @@ PROGRAM test_xmap_intersection_parallel
xt_xmap_get_num_destinations, xt_xmap_get_num_sources, & xt_xmap_get_num_destinations, xt_xmap_get_num_sources, &
xt_xmap_iterator_get_rank, xt_xmap_iterator_get_num_transfer_pos, & xt_xmap_iterator_get_rank, xt_xmap_iterator_get_num_transfer_pos, &
xt_xmap_iterator_get_transfer_pos, xt_xmap_iterator_next, & xt_xmap_iterator_get_transfer_pos, xt_xmap_iterator_next, &
xt_pos_ext, xt_xmap_iterator_get_num_transfer_pos_ext, &
xt_xmap_iterator_get_transfer_pos_ext, &
xt_xmap_iterator_delete, xt_xmap_reorder, xt_reorder_type_kind, & xt_xmap_iterator_delete, xt_xmap_reorder, xt_reorder_type_kind, &
xt_reorder_none, xt_reorder_send_up, xt_reorder_recv_up, & xt_reorder_none, xt_reorder_send_up, xt_reorder_recv_up, &
xt_sort_permutation, xt_xmap_update_positions, xt_xmap_spread xt_sort_permutation, xt_xmap_update_positions, xt_xmap_spread
...@@ -539,9 +541,10 @@ CONTAINS ...@@ -539,9 +541,10 @@ CONTAINS
TYPE(xt_xmap_iter), INTENT(inout) :: iter TYPE(xt_xmap_iter), INTENT(inout) :: iter
TYPE(test_message), INTENT(in) :: msgs(:) TYPE(test_message), INTENT(in) :: msgs(:)
INTEGER :: num_msgs, num_pos, i, j INTEGER :: num_msgs, num_pos, num_pos_ext, i, j, ofs, pe, pos_ext_size
INTEGER, POINTER :: pos(:) INTEGER, POINTER :: pos(:)
LOGICAL :: iter_is_null TYPE(xt_pos_ext), POINTER :: pos_ext(:)
LOGICAL :: iter_is_null, mismatch
num_msgs = SIZE(msgs) num_msgs = SIZE(msgs)
iter_is_null = xt_is_null(iter) iter_is_null = xt_is_null(iter)
...@@ -566,11 +569,35 @@ CONTAINS ...@@ -566,11 +569,35 @@ CONTAINS
filename, __LINE__) filename, __LINE__)
END IF END IF
pos => xt_xmap_iterator_get_transfer_pos(iter) pos => xt_xmap_iterator_get_transfer_pos(iter)
mismatch = .FALSE.
DO j = 1, num_pos DO j = 1, num_pos
IF (pos(j) /= msgs(i)%pos(j)) & mismatch = mismatch .OR. pos(j) /= msgs(i)%pos(j)
CALL test_abort('ERROR: xt_xmap_iterator_get_transfer_pos', & END DO
filename, __LINE__) IF (mismatch) &
CALL test_abort('ERROR: xt_xmap_iterator_get_transfer_pos', &
filename, __LINE__)
num_pos_ext = xt_xmap_iterator_get_num_transfer_pos_ext(iter)
pos_ext => xt_xmap_iterator_get_transfer_pos_ext(iter)
ofs = 0
mismatch = .FALSE.
DO pe = 1, num_pos_ext
pos_ext_size = ABS(pos_ext(pe)%size)
IF (pos_ext(pe)%size > 0) THEN
DO j = 1, pos_ext_size
mismatch = mismatch .OR. pos(ofs+j) /= pos_ext(pe)%start + j - 1
END DO
ELSE
DO j = 1, pos_ext_size
mismatch = mismatch .OR. pos(ofs+j) /= pos_ext(pe)%start - j + 1
END DO
END IF
ofs = ofs + pos_ext_size
END DO END DO
IF (mismatch .OR. ofs /= num_pos) &
CALL test_abort('ERROR: xt_xmap_iterator_get_transfer_pos_ext', &
filename, __LINE__)
IF (.NOT. xt_xmap_iterator_next(iter)) EXIT IF (.NOT. xt_xmap_iterator_next(iter)) EXIT
i = i + 1 i = i + 1
END DO END DO
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment