From 7fc05075628116f972ebb9784f7cbd5a67fc5017 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Nils=20Br=C3=BCggemann?= <nils.brueggemann@mpimet.mpg.de>
Date: Tue, 11 Mar 2025 17:14:14 +0100
Subject: [PATCH] pyic_view: Slightly improved chunking that is switched off
 when zarr files are opened since this is a large performance loss.

---
 scripts/pyic_view.py | 26 +++++++++++++++++---------
 1 file changed, 17 insertions(+), 9 deletions(-)

diff --git a/scripts/pyic_view.py b/scripts/pyic_view.py
index 6595e9c..6d5ace4 100755
--- a/scripts/pyic_view.py
+++ b/scripts/pyic_view.py
@@ -106,7 +106,7 @@ def str_to_array(string):
 #  return fpath_ckdtree
 
 
-def get_data(ds, var_name, it, iz, res, lon_reg, lat_reg):
+def get_data(ds, var_name, it, iz, res, lon_reg, lat_reg, do_chunking=True):
     isel_dict = dict(time=it)
     if ds[var_name].ndim==3:
       depth_name = pyic.identify_depth_name(ds[var_name])
@@ -121,7 +121,8 @@ def get_data(ds, var_name, it, iz, res, lon_reg, lat_reg):
         chunks = dict(time=1)
     if depth_name!='none':
         chunks[depth_name] = 1
-    da = da.chunk(**chunks)
+    if do_chunking:
+      da = da.chunk(**chunks)
     dai = pyic.interp_to_rectgrid_xr(
         da.isel(**isel_dict), res=res,
         lon_reg=lon_reg, lat_reg=lat_reg,
@@ -142,6 +143,8 @@ class view(object):
         self.flist = flist
         self.path_grid = path_grid
         self.fig_size_fac = fig_size_fac
+
+        self.do_chunking = True
         self.colormaps = [
             "inferno", "viridis", "plasma", 
             "RdYlBu_r", "RdBu_r", "Blues_r", 
@@ -182,6 +185,7 @@ class view(object):
             self.transformer = "None"
 
         # Opean data set
+        self.message("Opening data set")
         self.load_data()
 
         # Default selections
@@ -198,6 +202,7 @@ class view(object):
         self.rect = None
 
         # Create figure and axis
+        self.message("Generating axes")
         self.fig, self.ax, self.cax = generate_axes(asp=0.5)
 
         #print('------')
@@ -205,6 +210,7 @@ class view(object):
         #print(self.ax.get_position())
 
         # TK canvas
+        self.message("Setting up TK")
         frame_plot = tk.Frame(root)
         frame_plot.pack(fill="both", expand=True)
         self.canvas = FigureCanvasTkAgg(self.fig, master=frame_plot)
@@ -234,7 +240,7 @@ class view(object):
         spacer = tk.Frame(frame1, width=30)  # 10 pixels wide
         spacer.pack(side="left")
         
-        # depth slide
+        # depth slider
         btn_dec_d = tk.Button(frame1, text="-", command=lambda: self.decrease_slider(self.slider_d))
         btn_dec_d.pack(side="left")
 
@@ -402,6 +408,9 @@ class view(object):
             self.canvas.mpl_disconnect(self.cid_motion)
 
     def load_data(self):
+        if self.flist[0].endswith('zarr'):
+            self.message('Detected zarr file and switching off chunking.')
+            self.do_chunking = False
         self.message('opening dataset')
         mfdset_kwargs = dict(
             combine='nested', concat_dim='time',
@@ -416,6 +425,7 @@ class view(object):
         self.ds = xr.open_mfdataset(
             self.flist, **mfdset_kwargs, 
         )
+        self.message('Done opening data files')
         delvars = [
             "clon_bnds", "clat_bnds", "elon_bnds", "elat_bnds",
             "vlon_bnds", "vlat_bnds",
@@ -464,7 +474,8 @@ class view(object):
         #self.fpath_ckdtree = get_fpath_ckdtree(self.ds, self.res, self.path_grid)
         self.dai = get_data(
             self.ds, self.var_name, self.it, self.iz, 
-            self.res, self.lon_reg, self.lat_reg
+            self.res, self.lon_reg, self.lat_reg,
+            do_chunking=self.do_chunking,
         )
         self.Lon, self.Lat = np.meshgrid(self.dai.lon.data, self.dai.lat.data)
         if self.proj=="None":
@@ -562,8 +573,6 @@ class view(object):
         self.lon_lat_reg_tk.set(f"{self.lon_reg[0]:.3g},{self.lon_reg[1]:.3g},{self.lat_reg[0]:.3g},{self.lat_reg[1]:.3g}")
         self.xlim, self.ylim = get_xlim_ylim(
             self.lon_reg, lat_reg_axlim, self.proj, self.transformer)
-        print(f'set_default_lon_lat_reg:')
-        print(self.lon_reg, self.lat_reg, self.xlim, self.ylim)
         
     def update_projection(self, *args):
         self.proj = self.proj_dict[self.selected_proj.get()]
@@ -601,11 +610,9 @@ class view(object):
 
     def increase_slider(self, slider):
         slider.set(slider.get() + 1)
-        self.update_data()
     
     def decrease_slider(self, slider):
         slider.set(slider.get() - 1)
-        self.update_data()
     
     # Function to update plot
     def update_data(self, *args):
@@ -628,7 +635,8 @@ class view(object):
         self.update_lon_lat_reg()
         self.dai = get_data(
             self.ds, self.var_name, self.it, self.iz, 
-            self.res, self.lon_reg, self.lat_reg
+            self.res, self.lon_reg, self.lat_reg,
+            do_chunking=self.do_chunking,
         )
         self.Lon, self.Lat = np.meshgrid(self.dai.lon, self.dai.lat)
         self.hm[0].set_array(self.dai.data.flatten())
-- 
GitLab