From 25b6cf9f815827bbb841e6bf915b3efea9918a13 Mon Sep 17 00:00:00 2001
From: Fabian Wachsmann <k204210@l40147.lvt.dkrz.de>
Date: Wed, 12 Feb 2025 11:39:20 +0100
Subject: [PATCH] Added an example script to show how to virtualize regridding
 of tco grids for serving daas

---
 scripts/cdo_access.sh       |  11 +++
 scripts/cloudify_tco2reg.py | 188 ++++++++++++++++++++++++++++++++++++
 2 files changed, 199 insertions(+)
 create mode 100755 scripts/cdo_access.sh
 create mode 100644 scripts/cloudify_tco2reg.py

diff --git a/scripts/cdo_access.sh b/scripts/cdo_access.sh
new file mode 100755
index 0000000..9bd3c55
--- /dev/null
+++ b/scripts/cdo_access.sh
@@ -0,0 +1,11 @@
+cwd=$(pwd)
+port="9010"
+dsname="example"
+export SSL_CERT_FILE=${cwd}/cert.pem
+export HDF5_PLUGIN_PATH=/fastdata/k20200/k202186/public/hdf5/plugins/
+cdo="/work/bm0021/cdo_incl_cmor/cdo-test_cmortest_gcc/bin/cdo"
+zarr_prefix="#mode=zarr,s3"
+infile="https://${HOSTNAME}:${port}/datasets/${dsname}/zarr${zarr_prefix}"
+#
+$cdo sinfo $infile
+echo "$cdo sinfo $infile"
diff --git a/scripts/cloudify_tco2reg.py b/scripts/cloudify_tco2reg.py
new file mode 100644
index 0000000..93ece8e
--- /dev/null
+++ b/scripts/cloudify_tco2reg.py
@@ -0,0 +1,188 @@
+from cloudify.plugins.stacer import *
+from cloudify.plugins.geoanimation import *
+from cloudify.utils.daskhelper import *
+import xarray as xr
+import xpublish as xp
+import asyncio
+import nest_asyncio
+import sys
+import os
+import subprocess
+#import hdf5plugin
+import argparse
+
+def parse_args():
+    parser = argparse.ArgumentParser(
+        description=(
+          "This xpublish script serves a virtual zarr dataset.\n"
+          "When a chunk is accessed, it processes data from any TCO gaussian reduced grid and regrids it to the matching regular gaussian grid.\n"
+          "It opens files or references and uses dask to regrid chunks by linear interpolation across longitudes."
+          )
+
+    )
+    
+    parser.add_argument("dataset_name", help="The name of the dataset displayed for users.")
+    parser.add_argument("mode", choices=["refs", "norefs"], help="Specify 'refs' or 'norefs'.")
+    parser.add_argument("paths", nargs='+', help="Paths to dataset files (single for 'refs', multiple for 'norefs').")
+    
+    args = parser.parse_args()
+    
+    # Validation: If 'refs' mode, ensure exactly one path is provided
+    if args.mode == "refs" and len(args.paths) != 1:
+        parser.error("Mode 'refs' requires exactly one path argument.")
+    
+    return args
+
+
+os.environ["HDF5_PLUGIN_PATH"]="/work/ik1017/hdf5plugin/plugins/"
+cwd = os.getcwd()
+ssl_keyfile=f"{cwd}/key.pem"
+ssl_certfile=f"{cwd}/cert.pem"
+
+if not os.path.isfile(ssl_keyfile) or not os.path.isfile(ssl_certfile):
+    cn = os.uname().nodename  # Equivalent to `!echo $HOSTNAME`
+
+    openssl_cmd = [
+            "openssl", "req", "-x509", "-newkey", "rsa:4096",
+            "-keyout", "key.pem", "-out", "cert.pem",
+            "-sha256", "-days", "3650", "-nodes",
+            "-subj", f"/C=XX/ST=Hamburg/L=Hamburg/O=Test/OU=Test/CN={cn}"
+            ]
+
+    subprocess.run(openssl_cmd, check=True)
+
+port=9010
+
+nest_asyncio.apply()
+chunks={}
+for coord in ["lon","lat"]:
+    chunk_size=os.environ.get(f"XPUBLISH_{coord.upper()}_CHUNK_SIZE",None)
+    if chunk_size:
+        chunks[coord]=int(chunk_size)
+
+chunks["time"]=1
+l_lossy=os.environ.get("L_LOSSY",False)
+
+def lossy_compress(partds):
+    import numcodecs
+    rounding = numcodecs.BitRound(keepbits=12)
+    return rounding.decode(rounding.encode(partds))
+
+def unstack(ds):
+    onlydimname=[a for a in ds.dims if a not in ["time","level","lev"]]
+    if not "lat" in ds.coords:
+        if not "latitude" in ds.coords:
+            raise ValueError("No latitude given")
+        else:
+            ds=ds.rename(latitude="lat")
+    if not "lon" in ds.coords:
+        if not "longitude" in ds.coords:
+            raise ValueError("No longitude given")
+        else:
+            ds=ds.rename(longitude="lon")            
+    if len(onlydimname)>1:
+        raise ValueError("More than one dim: "+onlydimname)
+    onlydimname=onlydimname[0]
+    return ds.rename({onlydimname:'latlon'}).set_index(latlon=("lat","lon")).unstack("latlon")
+
+def interp(ds):
+    global equator_lons
+    return ds.interpolate_na(dim="lon",method="linear",period=360.0).reindex(lon=equator_lons)
+
+if __name__ == "__main__":  # This avoids infinite subprocess creation
+    import dask
+    zarrcluster = asyncio.get_event_loop().run_until_complete(get_dask_cluster())
+    os.environ["ZARR_ADDRESS"]=zarrcluster.scheduler._address
+
+    args = parse_args()
+    print(f"Dataset Name: {args.dataset_name}")
+    print(f"Mode: {args.mode}")
+    print(f"Paths: {args.paths}")
+
+    dsname=args.dataset_name
+    refs=args.mode
+    glob_inp=args.paths
+
+    dsdict={}
+
+    if refs == "refs":
+        glob_inp=glob_inp[0]
+        source="reference::/"+glob_inp
+        fsmap = fsspec.get_mapper(
+                source,
+                remote_protocol="file",
+                lazy=True,
+                cache_size=0
+                )
+        ds=xr.open_dataset(
+                fsmap,
+                engine="zarr",
+                chunks=chunks,
+                consolidated=False        
+                )
+    else:
+        ds=xr.open_mfdataset(
+            glob_inp,
+            compat="override",
+            coords="minimal",
+            chunks=chunks,
+        )
+    todel=[
+            a
+            for a in ds.coords
+            if a not in ["lat","latitude","lon","longitude","time","lev","level"]
+            ]
+            
+    if todel:
+        for v in todel:
+            del ds[v]
+    if "height" in ds:
+        del ds["height"]
+    for dv in ds.variables:
+        if "time" in dv:
+            ds[dv]=ds[dv].load()
+            ds[dv].encoding["dtype"] = "float64"
+            ds[dv].encoding["compressor"] = None
+    ds=ds.set_coords([a for a in ds.data_vars if "bnds" in a])
+    if l_lossy:
+        ds = xr.apply_ufunc(
+            lossy_compress,
+            ds,
+            dask="parallelized", 
+            keep_attrs="drop_conflicts"
+        )
+    dvs=[]
+    l_el=False
+    for dv in ds.data_vars:
+        print(dv)
+        template_unstack=unstack(ds[dv].isel(time=0).load())
+        if not l_el:
+            equator_lons=template_unstack.sel(lat=0.0,method="nearest").dropna(dim="lon")["lon"]
+            l_el=True
+            latlonchunks={
+                    a:len(template_unstack[a])
+                    for a in template_unstack.dims
+                    }
+        template_unstack=template_unstack.chunk(**latlonchunks)
+        template_interp=interp(template_unstack)
+        template_unstack=template_unstack.expand_dims(**{"time":ds["time"]}).chunk(time=1)
+        template_interp=template_interp.expand_dims(**{"time":ds["time"]}).chunk(time=1)
+        unstacked=ds[dv].map_blocks(unstack,template=template_unstack)
+        interped=unstacked.map_blocks(interp,template=template_interp)
+        dsv = dask.optimize(interped)[0]
+        print("optimized")
+        del template_unstack, template_interp
+        dvs.append(dsv)
+    ds = xr.combine_by_coords(dvs)
+    print("combined")
+    ds = ds.drop_encoding()
+    dsdict[dsname]=ds
+    collection = xp.Rest(dsdict)
+    collection.register_plugin(Stac())
+    collection.register_plugin(PlotPlugin())
+    collection.serve(
+        host="0.0.0.0",
+        port=port,
+        ssl_keyfile=ssl_keyfile,
+        ssl_certfile=ssl_certfile
+    )
-- 
GitLab