Commit c12ed89d authored by Thomas Jahns's avatar Thomas Jahns 🤸
Browse files

Merge buffers that form collinear vectors.

parent 98da572d
......@@ -372,7 +372,8 @@ scan_stripe(const int *disp, size_t disp_len, struct Xt_offset_ext *restrict v)
static int
match_simple_vec(size_t *pstart_, const struct Xt_offset_ext *v, size_t vlen,
MPI_Datatype old_type, int *disp, MPI_Datatype *dt,
MPI_Datatype old_type, MPI_Aint old_type_extent,
MPI_Aint *disp, MPI_Datatype *dt,
MPI_Comm comm) {
// we only accept non-trivial matches (nsteps>2) with stride /= 1
// using only one vector from v
......@@ -384,21 +385,18 @@ match_simple_vec(size_t *pstart_, const struct Xt_offset_ext *v, size_t vlen,
*pstart_ = p + 1;
*disp = vlen > 1 ? v[p].start : 0;
int disp_ = vlen > 1 ? v[p].start : 0;
*disp = disp_ * old_type_extent;
MPI_Datatype dt1;
xt_mpi_call(MPI_Type_vector(nstrides, 1, stride, old_type, &dt1), comm);
int start = v[p].start - *disp;
int start = v[p].start - disp_;
if (!start) {
*dt = dt1;
} else {
// (start != 0) => add offset:
MPI_Aint old_type_size, old_type_lb;
xt_mpi_call(MPI_Type_get_extent(old_type, &old_type_lb,
&old_type_size), comm);
MPI_Aint displacement = start * old_type_size;
MPI_Aint displacement = start * old_type_extent;
int bl2 = 1;
MPI_Datatype dt2;
xt_mpi_call(MPI_Type_create_hindexed(1, &bl2, &displacement, dt1, &dt2),
......@@ -416,7 +414,8 @@ match_simple_vec(size_t *pstart_, const struct Xt_offset_ext *v, size_t vlen,
*/
static bool
match_block_vec(size_t *pstart_, const struct Xt_offset_ext *v, size_t vlen,
MPI_Datatype old_type, int *disp, MPI_Datatype *dt,
MPI_Datatype old_type, MPI_Aint old_type_extent,
MPI_Aint *disp, MPI_Datatype *dt,
MPI_Comm comm) {
// using at least 3 vectors
size_t p = *pstart_, pstart = p;
......@@ -434,23 +433,20 @@ match_block_vec(size_t *pstart_, const struct Xt_offset_ext *v, size_t vlen,
size_t n = p - pstart;
if (n<3) return false;
*disp = n == vlen ? 0 : v[pstart].start;
int disp_ = n == vlen ? 0 : v[pstart].start;
*disp = disp_ * old_type_extent;
MPI_Datatype dt1;
xt_mpi_call(MPI_Type_vector((int)n, bl, vstride, old_type, &dt1), comm);
int start = v[pstart].start - *disp;
int start = v[pstart].start - disp_;
*pstart_ = p;
if (!start) {
*dt = dt1;
} else {
// (start != 0) => add offset:
MPI_Aint old_type_size, old_type_lb;
xt_mpi_call(MPI_Type_get_extent(old_type, &old_type_lb,
&old_type_size), comm);
MPI_Aint displacement = start * old_type_size;
MPI_Aint displacement = start * old_type_extent;
int bl2 = 1;
MPI_Datatype dt2;
xt_mpi_call(MPI_Type_create_hindexed(1, &bl2, &displacement, dt1, &dt2),
......@@ -463,12 +459,15 @@ match_block_vec(size_t *pstart_, const struct Xt_offset_ext *v, size_t vlen,
static bool
match_contiguous(size_t *pstart_, const struct Xt_offset_ext *v, size_t vlen,
MPI_Datatype old_type, int *restrict disp, MPI_Datatype *dt,
MPI_Datatype old_type, MPI_Aint old_type_extent,
MPI_Aint *restrict disp, MPI_Datatype *dt,
MPI_Comm comm) {
size_t p = *pstart_;
if (p >= vlen || v[p].stride != 1 || v[p].size < 2) return 0;
int d = v[p].start - (*disp = vlen > 1 ? v[p].start : 0);
int disp_ = vlen > 1 ? v[p].start : 0;
*disp = disp_ * old_type_extent;
int d = v[p].start - disp_;
if (!d)
xt_mpi_call(MPI_Type_contiguous(v[p].size, old_type, dt), comm) ;
......@@ -482,7 +481,8 @@ match_contiguous(size_t *pstart_, const struct Xt_offset_ext *v, size_t vlen,
static bool
match_indexed(size_t *pstart_, const struct Xt_offset_ext *v, size_t vlen,
MPI_Datatype old_type, int *disp, MPI_Datatype *dt,
MPI_Datatype old_type, MPI_Aint old_type_extent,
MPI_Aint *disp, MPI_Datatype *dt,
MPI_Comm comm) {
// we only accept non-trivial matches
size_t p = *pstart_, pstart = p;
......@@ -496,8 +496,8 @@ match_indexed(size_t *pstart_, const struct Xt_offset_ext *v, size_t vlen,
if (n < 2) return false;
int start = (*disp = n == vlen ? 0 : v[pstart].start);
int start = n == vlen ? 0 : v[pstart].start;
*disp = start * old_type_extent;
int *restrict bl = xmalloc(2 * n * sizeof (*bl)),
*restrict d = bl + n;
bool hom_bl = true;
......@@ -525,8 +525,9 @@ match_indexed(size_t *pstart_, const struct Xt_offset_ext *v, size_t vlen,
static int
gen_fallback_type(size_t set_start, size_t set_end,
const struct Xt_offset_ext *v,
size_t vlen, MPI_Datatype old_type, int *offset,
const struct Xt_offset_ext *v, size_t vlen,
MPI_Datatype old_type, MPI_Aint old_type_extent,
MPI_Aint *disp,
MPI_Datatype *dt, MPI_Comm comm) {
size_t ia = set_start;
size_t ib = set_end;
......@@ -538,16 +539,11 @@ gen_fallback_type(size_t set_start, size_t set_end,
if (n<1) return 0;
int start;
if (ia == 0 && ib == vlen) {
// generate absolute datatype
start = 0;
} else {
// generate relative datatype that gets embedded by the caller
start = v[ia].start;
}
// generate absolute datatype if ia == 0 && ib == vlen,
// else generate relative datatype that gets embedded by the caller
int start = (ia == 0 && ib == vlen) ? 0 : v[ia].start;
*offset = start;
*disp = start * old_type_extent;
int *restrict d = xmalloc(sizeof (*d) * (size_t)n);
size_t p=0;
......@@ -578,9 +574,13 @@ parse_stripe(const struct Xt_offset_ext *v, size_t vlen, MPI_Datatype old_type,
/* [set_start,set_end) describes the prefix of non-matching
* elements in v that then need to be handled with gen_fallback_type */
size_t set_start = 0, set_end = 0;
MPI_Datatype *restrict wdt = xmalloc(sizeof(*wdt) * (size_t)vlen
+ sizeof (int) * (size_t)vlen);
int *restrict wdisp = (int *)(wdt + vlen);
MPI_Aint old_type_lb, old_type_extent;
xt_mpi_call(MPI_Type_get_extent(old_type, &old_type_lb,
&old_type_extent), comm);
MPI_Aint *restrict wdisp
= xmalloc(sizeof(MPI_Datatype) * (size_t)vlen
+ sizeof (MPI_Aint) * (size_t)vlen);
MPI_Datatype *restrict wdt = (MPI_Datatype *)(wdisp + vlen);
/* [p,vlen) is the part of v that still needs matching performed */
/* m is the index of the next datatype and displacements to write
* to wdt and wdisp respectively */
......@@ -590,15 +590,19 @@ parse_stripe(const struct Xt_offset_ext *v, size_t vlen, MPI_Datatype old_type,
* and displacement corresponding to a match need to be written
* to wdt[m+1] and wdisp[m+1] or wdt[m] and wdisp[m] respectively */
size_t mm = m + (set_start < set_end);
if ( match_block_vec(&p, v, vlen, old_type, wdisp+mm, wdt+mm, comm) ||
match_indexed(&p, v, vlen, old_type, wdisp+mm, wdt+mm, comm) ||
match_simple_vec(&p, v, vlen, old_type, wdisp+mm, wdt+mm, comm) ||
match_contiguous(&p, v, vlen, old_type, wdisp+mm, wdt+mm, comm) ) {
if (match_block_vec(&p, v, vlen, old_type, old_type_extent,
wdisp+mm, wdt+mm, comm)
|| match_indexed(&p, v, vlen, old_type, old_type_extent,
wdisp+mm, wdt+mm, comm)
|| match_simple_vec(&p, v, vlen, old_type, old_type_extent,
wdisp+mm, wdt+mm, comm)
|| match_contiguous(&p, v, vlen, old_type, old_type_extent,
wdisp+mm, wdt+mm, comm) ) {
/* in case a match is found generate fallback datatype for
* non-matching, preceding extents */
if (set_start < set_end) {
gen_fallback_type(set_start, set_end, v, vlen, old_type, wdisp+m, wdt+m,
comm);
gen_fallback_type(set_start, set_end, v, vlen, old_type,
old_type_extent, wdisp+m, wdt+m, comm);
m++;
}
m++;
......@@ -609,8 +613,8 @@ parse_stripe(const struct Xt_offset_ext *v, size_t vlen, MPI_Datatype old_type,
}
}
if (set_start < set_end) {
gen_fallback_type(set_start, set_end, v, vlen, old_type, wdisp+m, wdt+m,
comm);
gen_fallback_type(set_start, set_end, v, vlen, old_type, old_type_extent,
wdisp+m, wdt+m, comm);
m++;
}
size_t wlen = m;
......@@ -621,25 +625,20 @@ parse_stripe(const struct Xt_offset_ext *v, size_t vlen, MPI_Datatype old_type,
xt_mpi_call(MPI_Type_dup(old_type, wdt), comm);
result_dt = wdt[0];
} else {
MPI_Aint old_type_lb, old_type_extent;
MPI_Aint *restrict wbdisp
= xmalloc((size_t)wlen * (sizeof (*wbdisp) + sizeof (int)));
int *restrict wblocklength = (int *)(wbdisp + wlen);
xt_mpi_call(MPI_Type_get_extent(old_type, &old_type_lb,
&old_type_extent), comm);
int *restrict wblocklength
= xmalloc((size_t)wlen * sizeof (*wblocklength));
for(size_t i=0; i<wlen; i++) {
wbdisp[i] = wdisp[i] * old_type_extent;
wblocklength[i] = 1;
}
xt_mpi_call(MPI_Type_create_struct((int)wlen, wblocklength, wbdisp,
xt_mpi_call(MPI_Type_create_struct((int)wlen, wblocklength, wdisp,
wdt, &result_dt), comm);
free(wbdisp);
free(wblocklength);
for (size_t i = 0; i < wlen; i++)
if (wdt[i] != old_type)
xt_mpi_call(MPI_Type_free(wdt+i), comm);
}
xt_mpi_call(MPI_Type_commit(&result_dt), comm);
free(wdt);
free(wdisp);
return result_dt;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment