From 0251869c73c383fc5bbb096ea1d69514e00f3684 Mon Sep 17 00:00:00 2001
From: Aaron Wienkers <aaron.wienkers@gmail.com>
Date: Thu, 31 Oct 2024 16:31:53 +0100
Subject: [PATCH] another optimisation to curl operator to reduce number of
 tasks in compute_curl and prevent memory overload

---
 pyicon/pyicon_calc_aw.py | 39 +++++++++++++++------------------------
 1 file changed, 15 insertions(+), 24 deletions(-)

diff --git a/pyicon/pyicon_calc_aw.py b/pyicon/pyicon_calc_aw.py
index e07999a..1946f84 100644
--- a/pyicon/pyicon_calc_aw.py
+++ b/pyicon/pyicon_calc_aw.py
@@ -203,26 +203,16 @@ class daskicon:
             vector p_vn_c remapped to edges
         """
         
-        edge2cell_coeff_cc_t = (self.ds_aux.edge2cell_coeff_cc_t
-                                        .drop_vars({'clat','clon'})                                         # N.B.: These additional coordinates confuse dask !
-                                        .transpose('cart','edge','nc_e').chunk({'nc_e':-1, 'cart':-1}))     # Reorder for better performance
-                         
-        ic = self.ds_IcD.adjacent_cell_of_edge 
-        
-        p_vn_c = p_vn_c.chunk({'cart':-1}).transpose('time','cart','cell')     # Reorder for better performance
-        
-        ptp_vn = xr.apply_ufunc(
-            lambda v, ind, coeff: (v[:,ind] * coeff).sum(axis=(0,2)),
-            p_vn_c,
-            ic, 
-            edge2cell_coeff_cc_t,
-            input_core_dims=[['cart','cell'], ['edge', 'nc_e'], ['cart', 'edge', 'nc_e']],
-            output_core_dims=[['edge']],
-            vectorize=True,
-            dask='parallelized',
-            output_dtypes=[p_vn_c.dtype],
-            dask_gufunc_kwargs={'output_sizes':{'edge': ic.edge.size}}
-        )
+        edge2cell_coeff_cc_t = self.ds_aux.edge2cell_coeff_cc_t.drop_vars({'clat','clon'})  # N.B.: These additional coordinates confuse dask !
+        
+        ic0 = self.ds_IcD.adjacent_cell_of_edge.isel(nc_e=0).compute()
+        ic1 = self.ds_IcD.adjacent_cell_of_edge.isel(nc_e=1).compute()
+        
+        p_vn_c = p_vn_c.drop_vars(p_vn_c.coords)   # The additional coordinates create a lot of overhead....
+        
+        ptp_vn = (p_vn_c.isel(cell=ic0) * edge2cell_coeff_cc_t.isel(nc_e=0) + 
+                  p_vn_c.isel(cell=ic1) * edge2cell_coeff_cc_t.isel(nc_e=1))  .sum(dim='cart')
+        
         
         return ptp_vn.chunk({'edge':-1}) # Ensure still in a single chunk...
 
@@ -373,15 +363,16 @@ class daskicon:
         """
         assert "edge" in vector.dims
 
-        rot_coeff = self.ds_aux.rot_coeff
+        rot_coeff = self.ds_aux.rot_coeff.chunk({'ne_v':-1})
         
-        eov = self.ds_IcD["edges_of_vertex"]
+        eov = self.ds_IcD.edges_of_vertex.compute()
         
         curl_vec = xr.apply_ufunc(
-            lambda v, ind: (v[ind] * rot_coeff).sum(dim='ne_v'),  # NOTE: This gives very large/wrong values when the vertex has _any_ neighbouring undefined cell....
+            lambda v, ind, coeff: (v[ind] * coeff).sum(axis=1),  # NOTE: This gives very large/wrong values when the vertex has _any_ neighbouring undefined cell....
             vector,
             eov,
-            input_core_dims=[['edge'], ['vertex','ne_v']],
+            rot_coeff,
+            input_core_dims=[['edge'], ['vertex','ne_v'], ['vertex','ne_v']],
             output_core_dims=[['vertex']],
             vectorize=True,
             dask='parallelized',
-- 
GitLab