Skip to content
Snippets Groups Projects
Commit 21ccef72 authored by Nils-Arne Dreier's avatar Nils-Arne Dreier
Browse files

fix: copy_metadata

parent 88a884e4
No related branches found
No related tags found
No related merge requests found
Pipeline #101321 waiting for manual action
......@@ -8,3 +8,10 @@ def get_var_group(v):
return z
else:
return z[parent_group_path]
def get_var_parent_group(v):
var_group = get_var_group(v)
parent_var_path = var_group.attrs["hiopy::parent"]
parent_group = zarr.open(v.store)[parent_var_path]
return parent_group
......@@ -13,7 +13,7 @@ from coyote import (
from ._data_handler import DataHandler
from ._distribute_work import distribute_work
from ._grids import def_grid, grid_id
from ._zarr_utils import get_var_group
from ._zarr_utils import get_var_group, get_var_parent_group
from .loco import LocoServer
from argparse import ArgumentParser
......@@ -86,16 +86,16 @@ def main():
for _name, item in group.groups():
yield from collect_data_vars(item)
data_vars = list(chain(*[collect_data_vars(z) for z in args.datasets]))
logging.info(f"Found {len(data_vars)} variables")
if len(data_vars) == 0:
all_data_vars = list(chain(*[collect_data_vars(z) for z in args.datasets]))
logging.info(f"Found {len(all_data_vars)} variables")
if len(all_data_vars) == 0:
raise RuntimeError("No variables found by the hiopy worker.")
# group the variables by the crs grid_mapping.
# This is used to distribute them through the processes and create the coyote instances
grouped_data_vars = {
gid: list(variables)
for gid, variables in groupby(sorted(data_vars, key=grid_id), key=grid_id)
for gid, variables in groupby(sorted(all_data_vars, key=grid_id), key=grid_id)
}
distributed_data_vars = distribute_work(grouped_data_vars, group_comm_size())
......@@ -148,8 +148,8 @@ def main():
src_comp, src_grid = v.attrs["hiopy::yac_source"]
else:
assert "hiopy::parent" in var_group.attrs, f"No source for field {v.name} specified"
parent_var_path = var_group.attrs["hiopy::parent"] + "/" + v.basename
source_var = zarr.open(store=v.store)[parent_var_path]
parent_group = get_var_parent_group(v)
source_var = parent_group[v.basename]
src_name = source_var.name
source_var_gid = grid_id(source_var)
src_comp = src_grid = f"{args.process_group}_{source_var_gid}"
......@@ -189,7 +189,11 @@ def main():
)
def get_source_triple(v):
if "hiopy::yac_source" in v.attrs:
var_group = get_var_group(v)
if "hiopy::parent" in var_group.attrs:
pgroup = get_var_parent_group(v)
return get_source_triple(pgroup[v.basename])
elif "hiopy::yac_source" in v.attrs:
src_comp, src_grid = v.attrs["hiopy::yac_source"]
src_field = v.attrs.get("hiopy::src_name", v.basename)
return src_comp, src_grid, src_field
......@@ -198,7 +202,7 @@ def main():
ensure_enddef()
if group_comm_rank() == 0:
for v in data_vars:
for v in all_data_vars:
if "hiopy::copy_metadata" in v.attrs:
comp, grid, field = get_source_triple(v)
md_str = get_field_metadata(comp, grid, field)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment