Skip to content
Snippets Groups Projects
Commit 285fc919 authored by Nils-Arne Dreier's avatar Nils-Arne Dreier
Browse files

refactor: get zarr group from store instead of monkey patch the Array

This improves also how the parent variable is found within the same store
parent ac4dac03
No related branches found
No related tags found
No related merge requests found
import numpy as np
from ._zarr_utils import get_var_group
# computes an id that identifies the grid of a variable
def grid_id(var, zgroup):
def grid_id(var):
"""
var: dtype zarr.array
zgroup: dtype zarr.group
"""
assert "grid_mapping" in var.attrs
zgroup = get_var_group(var)
crs = zgroup[var.attrs["grid_mapping"]]
var_group = var.attrs.get("hiopy::var_group", "")
spatial_chunk_shape = var.chunk_shape[-1]
spatial_chunk_shape = var.chunks[-1]
# healpix
if "grid_mapping_name" in crs.attrs and crs.attrs["grid_mapping_name"] == "healpix":
......@@ -41,7 +45,8 @@ def grid_id(var, zgroup):
raise RuntimeError("Unknown grid")
def def_grid(coyote, var, chunk_slice, zgroup):
def def_grid(coyote, var, chunk_slice):
zgroup = get_var_group(var)
crs = zgroup[var.attrs["grid_mapping"]]
# healpix
if "grid_mapping_name" in crs.attrs and crs.attrs["grid_mapping_name"] == "healpix":
......
......@@ -11,6 +11,7 @@ from coyote import Coyote, group_comm_rank, group_comm_size, init, run, start_da
from ._data_handler import DataHandler
from ._distribute_work import distribute_work
from ._grids import def_grid, grid_id
from ._zarr_utils import get_var_group
from .loco import LocoServer
......@@ -71,7 +72,6 @@ def main():
def collect_data_vars(group):
for _name, item in group.arrays():
if "hiopy::enable" in item.attrs and item.attrs["hiopy::enable"]:
item.group = group # amend the variable object with the group
yield item
for _name, item in group.groups():
yield from collect_data_vars(item)
......@@ -85,9 +85,7 @@ def main():
# This is used to distribute them through the processes and create the coyote instances
grouped_data_vars = {
gid: list(variables)
for gid, variables in groupby(
sorted(data_vars, key=lambda v: grid_id(v, v.group)), key=lambda v: grid_id(v, v.group)
)
for gid, variables in groupby(sorted(data_vars, key=grid_id), key=grid_id)
}
distributed_data_vars = distribute_work(grouped_data_vars, group_comm_size())
......@@ -109,14 +107,15 @@ def main():
for gid, data_vars, chunk_slice in my_data_vars:
coyote = coyote_instances[gid]
# all vars in data_vars define the same grid
def_grid(coyote, data_vars[0], chunk_slice, data_vars[0].group)
def_grid(coyote, data_vars[0], chunk_slice)
data_handlers = []
for v in data_vars:
# compute timestep
var_group = get_var_group(v)
time_dim_name = v.attrs["_ARRAY_DIMENSIONS"][0]
time_coordinate = v.group[time_dim_name]
time_coordinate = var_group[time_dim_name]
assert (
"seconds since " in time_coordinate.attrs["units"]
), "Currently the time must be given in seconds"
......@@ -124,10 +123,10 @@ def main():
# compute time start index
t0 = (
np.datetime64(start_datetime())
- np.datetime64(v.group.time.attrs["units"][len("seconds since ") :])
- np.datetime64(var_group.time.attrs["units"][len("seconds since ") :])
) / np.timedelta64(1, "s")
t0_idx = np.searchsorted(v.group.time, t0)
assert v.group.time[t0_idx] == t0, "start_datetime not found in time axis"
t0_idx = np.searchsorted(var_group.time, t0)
assert var_group.time[t0_idx] == t0, "start_datetime not found in time axis"
dt = time_coordinate[t0_idx + 1] - time_coordinate[t0_idx]
......@@ -139,22 +138,18 @@ def main():
if "hiopy::yac_source" in v.attrs:
src_comp, src_grid = v.attrs["hiopy::yac_source"]
else:
parent_group_name = v.group.attrs["hiopy::parent"]
parent_var_name = parent_group_name + "/" + src_name
parent_var, parent_group = [
(ds[parent_var_name], ds[parent_group_name])
for ds in args.datasets
if parent_var_name in ds
][0]
parent_var_gid = grid_id(parent_var, parent_group)
src_comp = src_grid = f"{args.process_group}_{parent_var_gid}"
assert "hiopy::parent" in var_group.attrs, f"No source for field {v.name} specified"
parent_var_name = var_group.attrs["hiopy::parent"] + "/" + v.name.split("/")[-1]
source_var = zarr.Group(store=v.store)[parent_var_name]
source_var_gid = grid_id(source_var)
src_comp = src_grid = f"{args.process_group}_{source_var_gid}"
time_method = v.attrs.get("hiopy::time_method", "point")
nnn = v.attrs.get("hiopy::nnn", 1)
frac_mask_name = v.attrs.get("hiopy::frac_mask", None)
frac_mask = []
if frac_mask_name is not None:
if frac_mask_name not in frac_masks:
frac_masks[frac_mask_name] = np.array(v.group[frac_mask_name][chunk_slice])
frac_masks[frac_mask_name] = np.array(var_group[frac_mask_name][chunk_slice])
frac_mask = frac_masks[frac_mask_name]
logging.info(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment