Commit 56ba5be7 authored by Thomas Jahns's avatar Thomas Jahns 🤸
Browse files

Fix incorrect handling of position extents.

* This bug was introduced in 5862f0a3.

* Extend intersection xmap test, so similar bugs cannot creep in as
  easily again.
parent 0f474404
......@@ -285,11 +285,14 @@ generate_ext_msg_infos(int num_msgs, Xt_xmap_iter iter,
{
if (num_msgs > 0) {
/* partial sums of ext sizes */
int *psum_ext_size
int *restrict psum_ext_size
= xmalloc(((size_t)num_ext + 1) * sizeof (psum_ext_size[0]));
psum_ext_size[0] = 0;
for (size_t i = 0; i < (size_t)num_ext; ++i)
psum_ext_size[i + 1] = psum_ext_size[i] + extents[i].size;
int accum = 0;
for (size_t i = 0; i < (size_t)num_ext; ++i) {
psum_ext_size[i] = accum;
accum += extents[i].size;
}
psum_ext_size[num_ext] = accum;
struct Xt_redist_msg *curr_msg = msgs;
do {
......
......@@ -338,13 +338,13 @@ generate_dir_transfer_pos_dst(
.all_dst_covered = all_bits_set == ~0UL };
}
struct tps_result {
int resCount;
int max_pos;
struct pos_count_max {
int count, max_pos;
};
/* compute list positions for send direction */
static struct tps_result
static struct pos_count_max
generate_dir_transfer_pos_src(int num_intersections,
const struct Xt_com_list
intersections[num_intersections],
......@@ -423,8 +423,8 @@ generate_dir_transfer_pos_src(int num_intersections,
free(new_intersection_idxvec);
free(intersection_pos);
return (struct tps_result){ .max_pos = max_pos_,
.resCount = new_num_intersections };
return (struct pos_count_max){ .max_pos = max_pos_,
.count = new_num_intersections };
}
static Xt_int *
......@@ -565,12 +565,12 @@ generate_transfer_pos(struct Xt_xmap_intersection_ *xmap,
= xrealloc(num_src_indices_to_remove_per_intersection,
(size_t)num_src_intersections * sizeof(int));
struct tps_result tpsr
struct pos_count_max tpsr
= generate_dir_transfer_pos_src(
num_src_intersections, src_com, src_idxlist_local, xmap->msg + xmap->n_in,
src_indices_to_remove, num_src_indices_to_remove_per_intersection);
xmap->max_src_pos = tpsr.max_pos;
xmap->n_out = tpsr.resCount;
xmap->n_out = tpsr.count;
free(src_indices_to_remove);
free(num_src_indices_to_remove_per_intersection);
......@@ -975,6 +975,45 @@ xmap_intersection_spread(Xt_xmap xmap, int num_repetitions,
.displacements = dst_displacements });
}
/* how many pos values have monotonically either positively or
* negatively consecutive values and copy to pos_copy */
static inline struct pos_run copy_get_pos_run_len(
size_t num_pos, const int *restrict pos,
int *restrict pos_copy)
{
size_t i = 0, j = 1;
int direction = 0;
int start = pos_copy[0] = pos[0];
if (j < num_pos) {
direction = isign_mask(pos[1] - pos[0]);
while (j < num_pos
&& (pos_copy[j] = pos[j]) == start + (~direction & (int)(j - i)) +
(direction & -(int)(j - i))) {
pos_copy[j] = pos[j];
++j;
}
direction = direction & ((j == 1) - 1);
}
return (struct pos_run){ .start = start, .len = j, .direction = direction };
}
/* compute number of position extents that would be required
to represent positions array and copy to pos_copy */
static struct pos_count_max
max_count_pos_ext_and_copy(int max_pos, size_t num_pos, const int *restrict pos,
int *restrict pos_copy)
{
size_t i = 0, num_pos_ext = 0;
while (i < num_pos) {
struct pos_run run = copy_get_pos_run_len(num_pos - i, pos + i, pos_copy + i);
i += run.len;
int max_of_run = (run.start & run.direction) | ((run.start + (int)run.len - 1) & ~run.direction);
if (max_of_run > max_pos) max_pos = max_of_run;
++num_pos_ext;
}
return (struct pos_count_max){ .count = (int)num_pos_ext, .max_pos = max_pos };
}
static void init_exchange_data_from_com_pos(
int count, struct exchange_data *restrict msgs,
const struct Xt_com_pos *restrict com, int *max_pos) {
......@@ -984,18 +1023,15 @@ static void init_exchange_data_from_com_pos(
int num_transfer_pos = com[i].num_transfer_pos;
int *restrict transfer_pos =
xmalloc((size_t)num_transfer_pos * sizeof(*transfer_pos));
int rank = com[i].rank;
const int *restrict com_transfer_pos = com[i].transfer_pos;
for (int j = 0; j < num_transfer_pos; ++j)
if (com_transfer_pos[j] > max_pos_) max_pos_ = com_transfer_pos[j];
msgs[i].transfer_pos = transfer_pos;
msgs[i].transfer_pos_ext_cache = NULL;
msgs[i].num_transfer_pos = num_transfer_pos;
msgs[i].num_transfer_pos_ext =
(int)(count_pos_ext((size_t)num_transfer_pos, transfer_pos));
msgs[i].rank = rank;
memcpy(transfer_pos, com_transfer_pos,
(size_t)num_transfer_pos * sizeof(*transfer_pos));
msgs[i].rank = com[i].rank;
struct pos_count_max max_count
= max_count_pos_ext_and_copy(max_pos_, (size_t)num_transfer_pos,
com[i].transfer_pos, transfer_pos);
msgs[i].num_transfer_pos_ext = max_count.count;
if (max_count.max_pos > max_pos_) max_pos_ = max_count.max_pos;
}
*max_pos = max_pos_;
}
......
......@@ -76,6 +76,9 @@ static Xt_xmap (*xmi_new)(
int ndst_com, const struct Xt_com_list dst_com[ndst_com],
Xt_idxlist src_idxlist, Xt_idxlist dst_idxlist, MPI_Comm comm);
static void
test_strided_block_pos_alltoall(MPI_Comm comm, int nblk, int blksz);
static void
parse_options(int *argc, char ***argv);
......@@ -525,51 +528,58 @@ int main(int argc, char **argv)
xt_idxlist_delete(src_com.list);
}
{ // alltoall using xt_xmap_intersection_pos_new
struct Xt_com_pos * src_com =
xmalloc(2 * (size_t)comm_size * sizeof(*src_com));
struct Xt_com_pos * dst_com = src_com + comm_size;
int * transfer_pos = xmalloc((size_t)comm_size * sizeof(*transfer_pos));
struct test_message * ref_src_msg =
xmalloc(2 * (size_t)comm_size * sizeof(*ref_src_msg));
struct test_message * ref_dst_msg = ref_src_msg + comm_size;
for (int i = 0; i < comm_size; ++i) {
transfer_pos[i] = i;
src_com[i].transfer_pos = transfer_pos + i;
src_com[i].num_transfer_pos = 1;
src_com[i].rank = i;
dst_com[i].transfer_pos = transfer_pos + i;
dst_com[i].num_transfer_pos = 1;
dst_com[i].rank = i;
ref_src_msg[i].pos = transfer_pos + i;
ref_src_msg[i].num_pos = 1;
ref_src_msg[i].rank = i;
ref_dst_msg[i].pos = transfer_pos + i;
ref_dst_msg[i].num_pos = 1;
ref_dst_msg[i].rank = i;
}
test_strided_block_pos_alltoall(MPI_COMM_WORLD, 1, 1);
test_strided_block_pos_alltoall(MPI_COMM_WORLD, 2, 1);
test_strided_block_pos_alltoall(MPI_COMM_WORLD, 5, 2000);
Xt_xmap xmap =
xt_xmap_intersection_pos_new(
comm_size, src_com, comm_size, dst_com, MPI_COMM_WORLD);
test_xmap(xmap, comm_size, ref_src_msg, comm_size, ref_dst_msg);
xt_finalize();
MPI_Finalize();
xt_xmap_delete(xmap);
return TEST_EXIT_CODE;
}
free(ref_src_msg);
free(transfer_pos);
free(src_com);
static void
test_strided_block_pos_alltoall(MPI_Comm comm, int nblk, int blksz)
{
int comm_rank, comm_size;
xt_mpi_call(MPI_Comm_rank(comm, &comm_rank), comm);
xt_mpi_call(MPI_Comm_size(comm, &comm_size), comm);
struct Xt_com_pos *src_com =
xmalloc(2 * (size_t)comm_size * sizeof(*src_com)),
*dst_com = src_com + comm_size;
int (*transfer_pos)[nblk][blksz]
= xmalloc((size_t)nblk * (size_t)blksz
* (size_t)comm_size * sizeof(*transfer_pos));
struct test_message *ref_src_msg =
xmalloc(2 * (size_t)comm_size * sizeof(*ref_src_msg)),
*ref_dst_msg = ref_src_msg + comm_size;
for (int rank = 0; rank < comm_size; ++rank) {
for (int j = 0; j < nblk; ++j)
for (int i = 0; i < blksz; ++i)
transfer_pos[rank][j][i] = i + j*blksz*comm_size;
ref_src_msg[rank].pos = src_com[rank].transfer_pos
= (int *)(transfer_pos + rank);
ref_src_msg[rank].num_pos = src_com[rank].num_transfer_pos = nblk*blksz;
ref_src_msg[rank].rank = src_com[rank].rank = rank;
ref_dst_msg[rank].pos = dst_com[rank].transfer_pos
= (int *)(transfer_pos + rank);
ref_dst_msg[rank].num_pos = dst_com[rank].num_transfer_pos = nblk*blksz;
ref_dst_msg[rank].rank = dst_com[rank].rank = rank;
}
xt_finalize();
MPI_Finalize();
Xt_xmap xmap =
xt_xmap_intersection_pos_new(comm_size, src_com, comm_size, dst_com, comm);
return TEST_EXIT_CODE;
test_xmap(xmap, comm_size, ref_src_msg, comm_size, ref_dst_msg);
xt_xmap_delete(xmap);
free(ref_src_msg);
free(transfer_pos);
free(src_com);
}
static void test_xmap_iter(Xt_xmap_iter iter, int num_msgs,
......@@ -603,6 +613,27 @@ static void test_xmap_iter(Xt_xmap_iter iter, int num_msgs,
mismatch |= (pos[j] != msgs[i].pos[j]);
if (mismatch)
PUT_ERR("ERROR: xt_xmap_iterator_get_transfer_pos\n");
int num_transfer_pos_ext
= xt_xmap_iterator_get_num_transfer_pos_ext(iter);
const struct Xt_pos_ext *restrict pos_ext
= xt_xmap_iterator_get_transfer_pos_ext(iter);
mismatch = false;
size_t ofs = 0;
for (size_t pe = 0; pe < (size_t)num_transfer_pos_ext; ++pe) {
int pos_ext_size = abs(pos_ext[pe].size);
if (pos_ext[pe].size > 0)
for (int j = 0; j < pos_ext_size; ++j)
mismatch |= (pos[ofs+(size_t)j] != pos_ext[pe].start + j);
else
for (int j = 0; j < pos_ext_size; ++j)
mismatch |= (pos[ofs+(size_t)j] != pos_ext[pe].start - j);
ofs += (size_t)pos_ext_size;
}
if (mismatch || (int)ofs != msgs[i].num_pos)
PUT_ERR("ERROR: xt_xmap_iterator_get_transfer_pos_ext\n");
++i;
} while (xt_xmap_iterator_next(iter));
......
......@@ -57,6 +57,8 @@ PROGRAM test_xmap_intersection_parallel
xt_xmap_get_num_destinations, xt_xmap_get_num_sources, &
xt_xmap_iterator_get_rank, xt_xmap_iterator_get_num_transfer_pos, &
xt_xmap_iterator_get_transfer_pos, xt_xmap_iterator_next, &
xt_pos_ext, xt_xmap_iterator_get_num_transfer_pos_ext, &
xt_xmap_iterator_get_transfer_pos_ext, &
xt_xmap_iterator_delete, xt_xmap_reorder, xt_reorder_type_kind, &
xt_reorder_none, xt_reorder_send_up, xt_reorder_recv_up, &
xt_sort_permutation, xt_xmap_update_positions, xt_xmap_spread
......@@ -539,9 +541,10 @@ CONTAINS
TYPE(xt_xmap_iter), INTENT(inout) :: iter
TYPE(test_message), INTENT(in) :: msgs(:)
INTEGER :: num_msgs, num_pos, i, j
INTEGER :: num_msgs, num_pos, num_pos_ext, i, j, ofs, pe, pos_ext_size
INTEGER, POINTER :: pos(:)
LOGICAL :: iter_is_null
TYPE(xt_pos_ext), POINTER :: pos_ext(:)
LOGICAL :: iter_is_null, mismatch
num_msgs = SIZE(msgs)
iter_is_null = xt_is_null(iter)
......@@ -566,11 +569,35 @@ CONTAINS
filename, __LINE__)
END IF
pos => xt_xmap_iterator_get_transfer_pos(iter)
mismatch = .FALSE.
DO j = 1, num_pos
IF (pos(j) /= msgs(i)%pos(j)) &
CALL test_abort('ERROR: xt_xmap_iterator_get_transfer_pos', &
filename, __LINE__)
mismatch = mismatch .OR. pos(j) /= msgs(i)%pos(j)
END DO
IF (mismatch) &
CALL test_abort('ERROR: xt_xmap_iterator_get_transfer_pos', &
filename, __LINE__)
num_pos_ext = xt_xmap_iterator_get_num_transfer_pos_ext(iter)
pos_ext => xt_xmap_iterator_get_transfer_pos_ext(iter)
ofs = 0
mismatch = .FALSE.
DO pe = 1, num_pos_ext
pos_ext_size = ABS(pos_ext(pe)%size)
IF (pos_ext(pe)%size > 0) THEN
DO j = 1, pos_ext_size
mismatch = mismatch .OR. pos(ofs+j) /= pos_ext(pe)%start + j - 1
END DO
ELSE
DO j = 1, pos_ext_size
mismatch = mismatch .OR. pos(ofs+j) /= pos_ext(pe)%start - j + 1
END DO
END IF
ofs = ofs + pos_ext_size
END DO
IF (mismatch .OR. ofs /= num_pos) &
CALL test_abort('ERROR: xt_xmap_iterator_get_transfer_pos_ext', &
filename, __LINE__)
IF (.NOT. xt_xmap_iterator_next(iter)) EXIT
i = i + 1
END DO
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment