Skip to content
Snippets Groups Projects
Commit c809feb0 authored by Nils-Arne Dreier's avatar Nils-Arne Dreier
Browse files

refactor: do not implicitly split-up by variables into multiple groups

parent ee901b56
No related branches found
No related tags found
1 merge request!47[hiopy] refactor variable handling and distribution
#!/usr/bin/env python3
import logging
from math import ceil, prod
import numpy as np
logger = logging.getLogger(__name__)
def _map_weighted(a, b, wa=None, wb=None):
"""
......@@ -101,6 +104,13 @@ def distribute_work(grouped_data_vars, size):
group, gsize = max(group_sizes.items(), key=lambda x: x[1])
variables = grouped_data_vars.pop(group)
assert all(
variables[0].shape == v.shape for v in variables
), "All variables in a group need to have the same shape"
assert all(
variables[0].chunks == v.chunks for v in variables
), "All variables in a group need to have the same chunk shape"
if gsize > bytes_per_rank:
# Compute the number of ranks required for this group
nranks = gsize // bytes_per_rank
......@@ -108,31 +118,26 @@ def distribute_work(grouped_data_vars, size):
cell_chunk_size = variables[0].chunks[-1]
ncells = variables[0].shape[-1]
nchunks = ceil(ncells / cell_chunk_size)
if nchunks > nranks:
logger.warning(
f"Not enough chunks to distribute work in group {group}."
" Consider using fewer processes of split up the group."
)
nranks = nchunks
num_var_groups = min(len(variables), ceil(nranks / nchunks))
var_groups = _map_weighted(
range(num_var_groups), variables, None, [_estimate_size(v) for v in variables]
)
ranks_per_group = _map_weighted(range(num_var_groups), range(nranks))
this_group = [
[(group, varidx, sl, mvars)]
for varidx, mvars in var_groups.items()
for sl in _distribute_chunks(variables[0], len(ranks_per_group[varidx]))
]
this_group = [[(group, variables, sl)] for sl in _distribute_chunks(variables[0], nranks)]
return [*this_group, *distribute_work(grouped_data_vars, size - nranks)]
else:
del group_sizes[group]
result = [(group, 0, slice(None), variables)]
result = [(group, variables, slice(None))]
# Add additional groups to this rank until it reaches the byte limit
while gsize < bytes_per_rank:
group, next_gsize = max(group_sizes.items(), key=lambda x: x[1])
if gsize + next_gsize > bytes_per_rank:
break
result.append((group, 0, slice(None), grouped_data_vars.pop(group)))
result.append((group, grouped_data_vars.pop(group), slice(None)))
del group_sizes[group]
gsize += next_gsize
......
......@@ -61,7 +61,7 @@ class LocoServer:
with open("backend_map.conf", "w") as f:
f.write("map $request_uri $backend {\n")
for r, data_vars in enumerate(distributed_data_vars):
for _group, _vgroup, chunk_slice, dvars in data_vars:
for _group, dvars, chunk_slice in data_vars:
for v in dvars:
if chunk_slice == slice(None):
cell_regex = r"\d+"
......
......@@ -11,57 +11,49 @@ except ImportError:
class DummyDataVar:
def __init__(self, shape, chunksize=None):
def __init__(self, name, shape, chunksize=None):
self.name = name
self.shape = shape
self.chunks = (1, chunksize or shape[1])
self.attrs = {"_ARRAY_DIMENSIONS": ["time", "cell"]}
self.dtype = np.dtype("float64")
def __repr__(self):
return self.name
def test_sequential():
print("Test sequential case")
var1 = DummyDataVar([10, 10])
var2 = DummyDataVar([10, 10])
var1 = DummyDataVar("var1", [10, 10])
var2 = DummyDataVar("var2", [10, 10])
result = distribute_work({"g1": [var1], "g2": [var2]}, 1)
pp(result)
assert len(result) == 1
assert result[0] == [("g1", 0, slice(None), [var1]), ("g2", 0, slice(None), [var2])]
assert result[0] == [("g1", [var1], slice(None)), ("g2", [var2], slice(None))]
def test_dist_groups():
print("Test distributing groups")
var1 = DummyDataVar([10, 10])
var2 = DummyDataVar([10, 10])
var1 = DummyDataVar("var1", [10, 10])
var2 = DummyDataVar("var2", [10, 10])
result = distribute_work({"g1": [var1], "g2": [var2]}, 2)
pp(result)
assert len(result) == 2
assert result[0] == [("g1", 0, slice(None), [var1])]
assert result[1] == [("g2", 0, slice(None), [var2])]
def test_dist_vars():
print("Test distributing variables")
var1 = DummyDataVar([10, 10])
var2 = DummyDataVar([10, 10])
result = distribute_work({"g": [var1, var2]}, 2)
pp(result)
assert len(result) == 2
assert result[0] == [("g", 0, slice(None), [var1])]
assert result[1] == [("g", 1, slice(None), [var2])]
assert result[0] == [("g1", [var1], slice(None))]
assert result[1] == [("g2", [var2], slice(None))]
def test_dist_chunks():
print("Test distributing chunks")
var = DummyDataVar([10, 10], 5)
var = DummyDataVar("var", [10, 10], 5)
result = distribute_work({"g": [var]}, 2)
pp(result)
assert len(result) == 2
assert result[0] == [("g", 0, slice(0, 5), [var])]
assert result[1] == [("g", 0, slice(5, 10), [var])]
assert result[0] == [("g", [var], slice(0, 5))]
assert result[1] == [("g", [var], slice(5, 10))]
if __name__ == "__main__":
test_sequential()
test_dist_groups()
test_dist_vars()
test_dist_chunks()
......@@ -107,8 +107,7 @@ def main():
loco_server.write_nginx_config(distributed_data_vars, group_comm_rank())
my_data_vars = distributed_data_vars[group_comm_rank()]
# my_data_vars list of tuples: (gid, vgroup, slice, list of variables)
# the vgroup is the group of variables if they need to be split (for the same grid)
# my_data_vars list of tuples: (gid, data_vars, slice)
if group_comm_rank() == 0:
logging.debug(distributed_data_vars)
......@@ -116,12 +115,11 @@ def main():
frac_masks = {} # the numpy array must keep alive until the end of the program
coyote_instances = {
(gid, vgroup): Coyote(f"{args.process_group}_{gid}_{vgroup}")
for gid, vgroup, chunk_slice, data_vars in my_data_vars
gid: Coyote(f"{args.process_group}_{gid}") for gid, data_vars, chunk_slice in my_data_vars
}
for gid, vgroup, chunk_slice, data_vars in my_data_vars:
coyote = coyote_instances[(gid, vgroup)]
for gid, data_vars, chunk_slice in my_data_vars:
coyote = coyote_instances[gid]
# all vars in data_vars define the same grid
def_grid(coyote, data_vars[0], chunk_slice, data_vars[0].group)
......@@ -153,15 +151,15 @@ def main():
if "hiopy::yac_source" in v.attrs:
src_comp, src_grid = v.attrs["hiopy::yac_source"]
else:
# find the component with that variables
gid, src_vgroup = [
(gid, vg)
for gid, vg, _, gvars in chain(*distributed_data_vars)
for groupvar in gvars
if groupvar.name.split("/")[-2] == v.group.attrs["hiopy::parent"]
and groupvar.name.split("/")[-1] == v.basename
parent_group_name = v.group.attrs["hiopy::parent"]
parent_var_name = parent_group_name + "/" + src_name
parent_var, parent_group = [
(ds[parent_var_name], ds[parent_group_name])
for ds in args.datasets
if parent_var_name in ds
][0]
src_comp = src_grid = f"{args.process_group}_{gid}_{src_vgroup}"
parent_var_gid = grid_id(parent_var, parent_group)
src_comp = src_grid = f"{args.process_group}_{parent_var_gid}"
time_method = v.attrs.get("hiopy::time_method", "point")
nnn = v.attrs.get("hiopy::nnn", 1)
frac_mask_name = v.attrs.get("hiopy::frac_mask", None)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment