Skip to content
Snippets Groups Projects
Commit 1c56757c authored by Nils-Arne Dreier's avatar Nils-Arne Dreier
Browse files

fix: avoid reopen the zarr again and again

parent 5bff601f
No related branches found
No related tags found
1 merge request!52Migrate to Zarr 3
......@@ -2,12 +2,13 @@ import zarr
def get_var_group(v):
store = zarr.open(v.store)
parent_group_path = "/".join(v.path.split("/")[:-1])
if parent_group_path == "":
return store
if not hasattr(v, "root"):
v.root = zarr.open(v.store)
group_path = "/".join(v.path.split("/")[:-1])
if group_path == "":
return v.root
else:
return store[parent_group_path]
return v.root[group_path]
def get_time_axis(v):
......@@ -25,5 +26,5 @@ def get_time_axis(v):
def get_var_parent_group(v):
var_group = get_var_group(v)
parent_var_path = var_group.attrs["hiopy::parent"]
parent_group = zarr.open(v.store)[parent_var_path]
parent_group = v.root[parent_var_path]
return parent_group
......@@ -15,10 +15,7 @@ def add_time(dataset, startdate, enddate, dt, name="time"):
time_data = (np.arange(startdate + dt, enddate + dt, dt) - startdate) // np.timedelta64(1, "s")
for g in _collect_groups(dataset):
time = g.create_array(
name="time", fill_value=None, shape=time_data.shape, dtype=np.longlong
)
time.append(data=time_data)
time = g.create_dataset(name, data=time_data, fill_value=None, shape=time_data.shape)
time.attrs["_ARRAY_DIMENSIONS"] = (name,)
time.attrs["axis"] = "T"
time.attrs["calendar"] = "proleptic_gregorian"
......
......@@ -37,16 +37,18 @@ def add_coordinates(
lat_list, lon_list = zip(*coordinates)
lon = dataset.create_array(
name=coord_names[0], data=np.array(lon_list), shape=(len(coordinates),)
name=coord_names[0], dtype=np.float32, shape=(len(coordinates),)
)
lon[:] = np.array(lon_list)
lon.attrs["_ARRAY_DIMENSIONS"] = [coord_names[0]]
lon.attrs["long_name"] = "longitude"
lon.attrs["units"] = "degree"
lon.attrs["standard_name"] = "grid_longitude"
lat = dataset.create_array(
name=coord_names[1], data=np.array(lat_list), shape=(len(coordinates),)
name=coord_names[1], dtype=np.float32, shape=(len(coordinates),)
)
lat[:] = np.array(lat_list)
lat.attrs["_ARRAY_DIMENSIONS"] = [coord_names[1]]
lat.attrs["long_name"] = "latitude"
lat.attrs["units"] = "degree"
......
......@@ -13,7 +13,7 @@ from coyote import (
from ._data_handler import DataHandler
from ._distribute_work import distribute_work
from ._grids import def_grid, grid_id
from ._zarr_utils import get_var_group, get_var_parent_group
from ._zarr_utils import get_var_group, get_var_parent_group, get_time_axis
from .loco import LocoServer
from argparse import ArgumentParser
......@@ -79,15 +79,15 @@ def main():
)
# find all variables considered to be written in the input datasets:
def collect_data_vars(group):
def collect_data_vars(group, root):
for _name, item in group.arrays():
if "hiopy::enable" in item.attrs and item.attrs["hiopy::enable"]:
item.root = root
yield item
for _name, item in group.groups():
item.parent = group
yield from collect_data_vars(item)
yield from collect_data_vars(item, root)
all_data_vars = list(chain(*[collect_data_vars(z) for z in args.datasets]))
all_data_vars = list(chain(*[collect_data_vars(z, z) for z in args.datasets]))
logging.info(f"Found {len(all_data_vars)} variables")
if len(all_data_vars) == 0:
raise RuntimeError("No variables found by the hiopy worker.")
......@@ -134,7 +134,7 @@ def main():
- np.datetime64(var_group["time"].attrs["units"][len("seconds since ") :])
) / np.timedelta64(1, "s")
t0_idx = np.searchsorted(var_group["time"], t0)
assert var_group["time"][t0_idx] == t0, "start_datetime not found in time axis"
assert var_group["time"][t0_idx] == t0, f"start_datetime {t0} not found in time axis at index {t0_idx} which has value {var_group['time'][t0_idx]}"
# see YAC_REDUCTION_TIME_NONE etc. (TODO: pass constants through coyote)
time_methods2yac = {"point": 0, "sum": 1, "mean": 2, "min": 3, "max": 4}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment