From 08280b575efe5a0a6b21b1e0001b4f7524d78acf Mon Sep 17 00:00:00 2001
From: Davide Ori <dori@uni-koeln.de>
Date: Mon, 31 Mar 2025 09:07:02 +0200
Subject: [PATCH] modified send trajectory for profiles, trying to get
 arguments, and open trajectory file

---
 comin/mpi_send_trajectory.py | 79 ++++++++++++++++++++++++++++++------
 1 file changed, 66 insertions(+), 13 deletions(-)

diff --git a/comin/mpi_send_trajectory.py b/comin/mpi_send_trajectory.py
index b9fc4ba..91b62b3 100644
--- a/comin/mpi_send_trajectory.py
+++ b/comin/mpi_send_trajectory.py
@@ -1,10 +1,41 @@
 import comin
 import numpy as np
+import xarray as xr
 from mpi4py import MPI
 from scipy.spatial import KDTree
+import yaml
+import sys
 
 jg = 1
-variables = ["temp"]
+
+#from argparse import ArgumentParser
+#parser = ArgumentParser()
+#parser.add_argument("--num", type=str, help="this is an integer")
+#parser.add_argument("--conf", type=str, help="this is the required yaml configuration file")
+#args=parser.parse_args()
+ymlfile="/home/b/b381492/warmworld/icon-model/build/run/pamtra-insitu/comin/pamtra.yml"
+
+with open(ymlfile, 'r') as conf:
+    pamcfg = yaml.safe_load(conf)
+
+Number_of_platforms = len(pamcfg['platforms'])
+platform = pamcfg['platforms'][0] # TODO need to make it flexible
+platform_name = platform['name']
+trajfile = platform['trajectoryFile']
+print(f"{np.__version__} {np.__path__}", file=sys.stderr)
+#traj = xr.open_dataset(trajfile)
+#print(traj, file=sys.stderr)
+
+mic = pamcfg['microphysics']
+if mic == 'icon_2':
+    print('icon single moment microphysics with graupel', file=sys.stderr)
+    vars3D = ["temp"]#, "pres", "qv", "qi", "qc", "qs", "qr", "qg", "z_ifc", "u", "v", "w", "clc"]
+elif mic == 'icon_4':
+    print('icon double moment microphysics', file=sys.stderr)
+    vars3D = ["temp"]#, "pres", "qv", "qi", "qc", "qs", "qr", "qg", "qh", "qni", "qnc", "qns", "qnr", "qng", "qnh", "z_ifc", "u", "v", "w"]
+else:
+    raise ValueError("Unrecognized microphysical scheme {}".format(mic))
+vars2D = ["fr_land"]
 
 
 def trajectory(t):
@@ -22,7 +53,7 @@ domain = comin.descrdata_get_domain(jg)
 clon = np.asarray(domain.cells.clon)
 clat = np.asarray(domain.cells.clat)
 decomp_domain = np.asarray(domain.cells.decomp_domain)
-owner = decomp_domain.ravel() == 0
+owner = decomp_domain.ravel() == 0 # this is a mask nproma*numcells????
 xyz = np.c_[lonlat2xyz(clon.ravel()[owner], clat.ravel()[owner])]
 decomp_domain = np.asarray(domain.cells.decomp_domain)
 tree = KDTree(xyz)
@@ -30,17 +61,28 @@ tree = KDTree(xyz)
 plugin_comm = MPI.Comm.f2py(comin.parallel_get_plugin_mpi_comm())
 
 
-@comin.EP_SECONDARY_CONSTRUCTOR
+@comin.register_callback(comin.EP_SECONDARY_CONSTRUCTOR)
 def secondary_constructor():
-    global comin_vars
-    comin_vars = {v: comin.var_get([comin.EP_ATM_WRITE_OUTPUT_BEFORE],
+    global comin_vars3D
+    global comin_vars2D
+    comin_vars3D = {v: comin.var_get([comin.EP_ATM_WRITE_OUTPUT_BEFORE],
+                                   (v, jg), comin.COMIN_FLAG_READ)
+                    for v in vars3D}
+    comin_vars2D = {v: comin.var_get([comin.EP_ATM_WRITE_OUTPUT_BEFORE],
                                    (v, jg), comin.COMIN_FLAG_READ)
-                  for v in variables}
+                    for v in vars2D}
+
 
+def ravel_not_height(arr3D):
+    arr_sort = arr3D.transpose((1, 0, 2, 3, 4))
+    return arr_sort.reshape(arr_sort.shape[0], -1)
 
-@comin.EP_ATM_WRITE_OUTPUT_BEFORE
+
+@comin.register_callback(comin.EP_ATM_WRITE_OUTPUT_BEFORE)
 def write_output_before():
-    global comin_vars
+    global comin_vars2D
+    global comin_vars3D
+
     time = comin.current_get_datetime()
     lon, lat = trajectory(time)
 
@@ -48,12 +90,23 @@ def write_output_before():
 
     ## check whether is inside domain
 
-    data = {v: np.asarray(cv)[:, 0, ...].ravel()[owner][ii]
-            for v, cv in comin_vars.items()}
-    print(f"sending {data}")
-    plugin_comm.send(data, dest=0)  # assume that the destination is always 0
+    data = np.asarray(comin_vars2D['fr_land']).ravel()[owner]
+    #print(f"fr_land {data.shape}  {ii} ", file=sys.stderr)
+    data = np.asarray(comin_vars3D['temp'])#[:, 0, ...].ravel()[owner]
+    dato = ravel_not_height(data)[:, owner][:, ii]
+    #print(f"temp {data.shape}     {dato.shape}  {owner.shape} ", file=sys.stderr)
+
+    #data3D = {v: np.asarray(cv)[:, 0, ...].ravel()[owner][ii]
+    #          for v, cv in comin_vars3D.items()}
+    data3D = {v: ravel_not_height(np.asarray(cv))[:, owner][:, ii]
+              for v, cv in comin_vars3D.items()}
+    data2D = {v: np.asarray(cv).ravel()[owner][ii]
+              for v, cv in comin_vars2D.items()}
+    data3D.update(data2D)
+    #print(f"sending {data3D}", file=sys.stderr)
+    plugin_comm.send(data3D, dest=0)  # assume that the destination is always 0
 
 
-@comin.EP_DESTRUCTOR
+@comin.register_callback(comin.EP_DESTRUCTOR)
 def finalize():
     plugin_comm.send(None, dest=0)
-- 
GitLab