Skip to content
Snippets Groups Projects
Commit d267d9b9 authored by Nils-Arne Dreier's avatar Nils-Arne Dreier
Browse files

fix: distribute work

parent 68a8045e
No related branches found
No related tags found
1 merge request!16Fix and test distribute work
......@@ -25,7 +25,7 @@ def _map_weighted(a, b, wa=None, wb=None):
partial_sum_b = np.cumsum(wb)
factor = partial_sum_a[-1] / partial_sum_b[-1] # Normalization factor
indices = np.searchsorted(factor * partial_sum_b, partial_sum_a, side="left")
indices = [0, *np.searchsorted(factor * partial_sum_b, partial_sum_a, side="right")]
return {a[ai]: b[i:j] for ai, (i, j) in enumerate(zip(indices[:-1], indices[1:]))}
......@@ -88,6 +88,7 @@ def distribute_work(grouped_data_vars, size):
A list of tuples, each representing the distribution of work to a specific rank.
"""
if size == 0:
assert len(grouped_data_vars) == 0, "Cannot distribute non-empty dataset on 0 processes"
return []
# Estimate sizes for each group of variables
......@@ -99,7 +100,7 @@ def distribute_work(grouped_data_vars, size):
if gsize > bytes_per_rank:
# Compute the number of ranks required for this group
nranks = min(size - 1, gsize // bytes_per_rank)
nranks = gsize // bytes_per_rank
cell_chunk_size = dict(zip(variables[0].attrs["_ARRAY_DIMENSIONS"], variables[0].chunks))["cell"]
ncells = dict(zip(variables[0].attrs["_ARRAY_DIMENSIONS"], variables[0].shape))["cell"]
......@@ -127,6 +128,7 @@ def distribute_work(grouped_data_vars, size):
if gsize + next_gsize > bytes_per_rank:
break
result.append((group, 0, slice(None), grouped_data_vars.pop(group)))
del group_sizes[group]
gsize += next_gsize
return [result, *distribute_work(grouped_data_vars, size - 1)]
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment