Skip to content
Snippets Groups Projects
Commit fdbb4215 authored by Nils-Arne Dreier's avatar Nils-Arne Dreier
Browse files

refactor: put data handler in separate file

parent a7b17b97
No related branches found
No related tags found
No related merge requests found
import logging
import numpy as np
class DataHandler:
def __init__(self, var, t0, cell_slice):
self.var = var
self.cell_slice = cell_slice
shape = {dim: chunk_size for dim, chunk_size in zip(var.attrs["_ARRAY_DIMENSIONS"], var.shape)}
mi, ma, st = cell_slice.indices(shape["cell"])
assert st == 1, "non-contiguous cell slices are not supported"
self.t = t0
self.t_flushed = self.t
chunks = {dim: chunk_size for dim, chunk_size in zip(var.attrs["_ARRAY_DIMENSIONS"], var.chunks)}
self.time_chunk_size = chunks["time"]
self.buf = np.empty([self.time_chunk_size, *var.shape[1:-1], ma-mi], dtype=var.dtype)
def __call__(self, event):
self.buf[self.t % self.time_chunk_size, ...] = event.data
logging.info(f"received {self.var.name} at {self.t % self.time_chunk_size+1}/{self.time_chunk_size}")
self.t += 1
if self.t % self.time_chunk_size == 0: # chunk complete
logging.info(f"writing {self.var.name} for timesteps {self.t_flushed}:{self.t}")
self.var[self.t_flushed:self.t, ..., self.cell_slice] = self.buf[-(self.t-self.t_flushed):, ...]
self.t_flushed = self.t
def __del__(self):
self.var[self.t_flushed:self.t, ..., self.cell_slice] = self.buf[:(self.t-self.t_flushed), ...]
......@@ -7,6 +7,7 @@ from argparse import ArgumentParser
import logging
from itertools import chain, groupby
from _distribute_work import distribute_work
from _data_handler import DataHandler
parser = ArgumentParser()
parser.add_argument("datasets", type=zarr.open_consolidated, nargs="+",
......@@ -52,32 +53,6 @@ for (nside, chunk_slice), data_vars in my_data_vars:
logging.debug(f"{nside=}, {chunk_slice=}: {[dv.name for dv in data_vars]}")
class handler:
def __init__(self, var, t0, cell_slice):
self.var = var
self.cell_slice = cell_slice
shape = {dim: chunk_size for dim, chunk_size in zip(var.attrs["_ARRAY_DIMENSIONS"], var.shape)}
mi, ma, st = cell_slice.indices(shape["cell"])
assert st == 1, "non-contiguous cell slices are not supported"
self.t = t0
self.t_flushed = self.t
chunks = {dim: chunk_size for dim, chunk_size in zip(var.attrs["_ARRAY_DIMENSIONS"], var.chunks)}
self.time_chunk_size = chunks["time"]
self.buf = np.empty([self.time_chunk_size, *var.shape[1:-1], ma-mi], dtype=var.dtype)
def __call__(self, event):
self.buf[self.t % self.time_chunk_size, ...] = event.data
logging.info(f"received {self.var.name} at {self.t % self.time_chunk_size+1}/{self.time_chunk_size}")
self.t += 1
if self.t % self.time_chunk_size == 0: # chunk complete
logging.info(f"writing {self.var.name} for timesteps {self.t_flushed}:{self.t}")
self.var[self.t_flushed:self.t, ..., self.cell_slice] = self.buf[-(self.t-self.t_flushed):, ...]
self.t_flushed = self.t
def __del__(self):
self.var[self.t_flushed:self.t, ..., self.cell_slice] = self.buf[:(self.t-self.t_flushed), ...]
frac_masks = {} # the numpy array must keep alive until the end of the program
coyote_instances = {nside: Coyote(f"hiopy_healpix_{nside.bit_length() - 1}")
......@@ -128,7 +103,7 @@ for (nside, chunk_slice), data_vars in my_data_vars:
coyote.add_field(src_comp,
src_grid,
src_name,
handler(v, t0_idx, chunk_slice),
DataHandler(v, t0_idx, chunk_slice),
f"PT{dt}S",
collection_size,
yac_time_reduction=time_methods2yac[time_method],
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment