Skip to content
Snippets Groups Projects
Commit 5ffec2ed authored by Nils-Arne Dreier's avatar Nils-Arne Dreier
Browse files

refactor: distribute work by throughput not by data size

parent 5e084e33
No related branches found
No related tags found
1 merge request!47[hiopy] refactor variable handling and distribution
......@@ -5,6 +5,8 @@ from math import ceil, prod
import numpy as np
from ._zarr_utils import get_time_axis
logger = logging.getLogger(__name__)
......@@ -45,11 +47,12 @@ def _estimate_size(*data_vars):
Total estimated size in bytes of the data variables.
"""
size = 0
size = 0.0
for dv in data_vars:
time_chunk = dv.chunks[0]
time_shape = dv.shape[0]
size += time_chunk * (prod(dv.shape) // time_shape) * dv.dtype.itemsize
time_axis = get_time_axis(dv)
if time_axis is not None:
dt = time_axis[1] - time_axis[0]
size += prod(dv.shape[1:]) * dv.dtype.itemsize / dt
return size
......@@ -75,7 +78,7 @@ def _distribute_chunks(var, size):
assert nchunks // size > 0, "Cannot distribute chunks (too many ranks)"
# Map ranks to chunks
ranks_to_chunks = _map_weighted(range(size), range(nchunks))
ranks_to_chunks = _map_weighted(range(int(size)), range(nchunks))
return [
slice(cell_chunk_size * c[0], min(ncells, cell_chunk_size * (c[-1] + 1)))
for i, c in sorted(ranks_to_chunks.items())
......
......@@ -5,3 +5,16 @@ def get_var_group(v):
root = zarr.Group(store=v.store)
last_slash_idx = v.name.rindex("/")
return root[v.name[:last_slash_idx]]
def get_time_axis(v):
group = get_var_group(v)
time_axis_name = v.attrs["_ARRAY_DIMENSIONS"][0]
time_axis = group[time_axis_name]
if time_axis.attrs["axis"] == "T":
assert (
"seconds since " in time_axis.attrs["units"]
), "Currently the time must be given in seconds"
return time_axis
else:
return None
#!/usr/bin/env python3
import numpy as np
import zarr
from hiopy._distribute_work import distribute_work
try:
......@@ -10,22 +11,24 @@ except ImportError:
pp = print
class DummyDataVar:
def __init__(self, name, shape, chunksize=None):
self.name = name
self.shape = shape
self.chunks = (1, chunksize or shape[1])
self.attrs = {"_ARRAY_DIMENSIONS": ["time", "cell"]}
self.dtype = np.dtype("float64")
def _create_test_vars():
store = {}
time = zarr.array(data=np.arange(10, dtype=int), path="time", store=store)
time.attrs["axis"] = "T"
time.attrs["units"] = "seconds since 2000-01-01"
def __repr__(self):
return self.name
var1 = zarr.zeros(shape=[10, 10], path="var1", store=store, chunks=(2, 5))
var2 = zarr.zeros(shape=[10, 10], path="var2", store=store, chunks=(2, 5))
for v in (var1, var2):
v.attrs["_ARRAY_DIMENSIONS"] = ("time", "cell")
pp(store)
return var1, var2
def test_sequential():
print("Test sequential case")
var1 = DummyDataVar("var1", [10, 10])
var2 = DummyDataVar("var2", [10, 10])
var1, var2 = _create_test_vars()
result = distribute_work({"g1": [var1], "g2": [var2]}, 1)
pp(result)
assert len(result) == 1
......@@ -34,8 +37,7 @@ def test_sequential():
def test_dist_groups():
print("Test distributing groups")
var1 = DummyDataVar("var1", [10, 10])
var2 = DummyDataVar("var2", [10, 10])
var1, var2 = _create_test_vars()
result = distribute_work({"g1": [var1], "g2": [var2]}, 2)
pp(result)
assert len(result) == 2
......@@ -45,12 +47,12 @@ def test_dist_groups():
def test_dist_chunks():
print("Test distributing chunks")
var = DummyDataVar("var", [10, 10], 5)
result = distribute_work({"g": [var]}, 2)
var1, var2 = _create_test_vars()
result = distribute_work({"g": [var1, var2]}, 2)
pp(result)
assert len(result) == 2
assert result[0] == [("g", [var], slice(0, 5))]
assert result[1] == [("g", [var], slice(5, 10))]
assert result[0] == [("g", [var1, var2], slice(0, 5))]
assert result[1] == [("g", [var1, var2], slice(5, 10))]
if __name__ == "__main__":
......
......@@ -21,7 +21,7 @@ from coyote import (
from ._data_handler import DataHandler
from ._distribute_work import distribute_work
from ._grids import def_grid, grid_id
from ._zarr_utils import get_var_group
from ._zarr_utils import get_time_axis, get_var_group
from .loco import LocoServer
......@@ -125,11 +125,7 @@ def main():
for v in data_vars:
# compute timestep
var_group = get_var_group(v)
time_dim_name = v.attrs["_ARRAY_DIMENSIONS"][0]
time_coordinate = var_group[time_dim_name]
assert (
"seconds since " in time_coordinate.attrs["units"]
), "Currently the time must be given in seconds"
time_coordinate = get_time_axis(v)
dt = time_coordinate[1] - time_coordinate[0]
# compute time start index
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment