Forked from
Nils Brüggemann / pyicon
192 commits behind, 91 commits ahead of the upstream repository.
pyic_view.py 19.54 KiB
#!/usr/bin/env python
import tkinter as tk
from tkinter import ttk
import numpy as np
import xarray as xr
import matplotlib.pyplot as plt
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
import pyicon as pyic
import cartopy.crs as ccrs
import glob
from pyicon import params
from ipdb import set_trace as mybreak
import cmocean
from pyproj import Proj, CRS, Transformer
def generate_axes(asp, generate_figure=True):
#figsize = 10,5
figsize = 6,3
if generate_figure:
fig = plt.figure(figsize=figsize)
else:
fig = plt.gcf()
figh = fig.get_figheight()
figw = fig.get_figwidth()
x0, y0 = 0.1, 0.16
axh0 = 0.75
asp0 = 0.5
axw0 = axh0*figh/figw / asp0
ax = fig.add_subplot(position=(x0, y0, axw0, axh0))
#asp = ax.get_position().height*fig.get_figheight() / ( ax.get_position().width*fig.get_figwidth() )
#print(asp)
# colorbar
daxcax = 0.02
caxw = 0.04
cax = fig.add_subplot(position=(x0+axw0+daxcax, y0, caxw, axh0))
cax.set_xticks([])
cax.yaxis.tick_right()
cax.yaxis.set_label_position("right")
if asp<0.5:
axw = 1*axw0
axh = asp/figh * (axw*figw)
x00 = 1*x0
y00 = y0+axh0/2.-axh/2.
elif asp>=0.5:
axh = 1*axh0
axw = axh*figh/figw / asp
x00 = x0+axw0-axw
y00 = 1*y0
ax.set_position([x00, y00, axw, axh])
return fig, ax, cax
def str_to_array(string):
string = string.replace(' ', '')
array = np.array(string.split(','), dtype=float)
return array
def get_fpath_ckdtree(data, res, path_grid, gname='auto', fpath_tgrid='auto'):
if path_grid == 'auto':
path_grid = params['path_grid']
try:
Dgrid = pyic.identify_grid(data, path_grid)
except:
# This doesn't always work, lets try another approach
try:
Dgrid = pyic.identify_grid(
data, path_grid, uuidOfHGrid=data.attrs['uuidOfHGrid']
)
except:
Dgrid = dict()
if gname == "auto":
try:
gname = Dgrid["name"]
except KeyError:
gname = "none"
fpath_ckdtree = f'{path_grid}/{gname}/ckdtree/rectgrids/{gname}_res{res:3.2f}_180W-180E_90S-90N.nc'
print(fpath_ckdtree)
return fpath_ckdtree
def get_data(ds, var_name, it, iz, fpath_ckdtree, lon_reg, lat_reg):
isel_dict = dict(time=it)
if ds[var_name].ndim==3:
depth_name = pyic.identify_depth_name(ds[var_name])
isel_dict[depth_name] = iz
else:
depth_name = 'none'
#else:
# raise ValueError(f"::: Unknown number of dimensions for {var_name}: {ds[var_name].shape}")
dai = pyic.interp_to_rectgrid_xr(
ds[var_name].isel(**isel_dict), fpath_ckdtree,
lon_reg=lon_reg, lat_reg=lat_reg,
)
dai.attrs["depth_name"] = depth_name
return dai.where(dai!=0.)
class view(object):
def __init__(self, flist, path_grid, fig_size_fac=1.0):
# Initialize Tkinter
print('setup TKinter')
root = tk.Tk()
root.title("pyicon view")
self.flist = flist
self.path_grid = path_grid
self.fig_size_fac = fig_size_fac
self.colormaps = [
"inferno", "viridis", "plasma",
"RdYlBu_r", "RdBu_r", "Blues_r",
"cmo.thermal", "cmo.haline", "cmo.curl",
]
self.res_all = [1., 0.3, 0.1, 0.02]
self.proj_all = [
"None",
"+proj=latlong",
"+proj=stere +lat_0=90 +lon_0=0",
"+proj=stere +lat_0=-90 +lon_0=0",
"+proj=eqearth",
"+proj=moll",
]
self.font_size = 6*self.fig_size_fac
self.res = 0.3
self.it = 0
self.iz = 0
self.proj = self.proj_all[0]
if self.proj!="None":
self.transformer = Proj.from_pipeline(self.proj)
# Opean data set
self.load_data()
# Default selections
self.selected_var = tk.StringVar(value=self.var_names[0])
self.selected_cmap = tk.StringVar(value=self.colormaps[0])
self.color_limits = tk.StringVar(value="auto") # Default color limits
self.lon_lat_reg_tk = tk.StringVar(value="-180,180,-90,90")
self.selected_res = tk.StringVar(value="0.3")
self.selected_proj = tk.StringVar(value="None")
# Create figure and axis
self.fig, self.ax, self.cax = generate_axes(asp=0.5)
#asp = 0.5
#proj = ccrs.PlateCarree()
#proj = None
#print(f'Font size: {self.font_size}')
#plt.rcParams.update({'font.size': self.font_size})
#P = pyic.Plot(1, 1,
# fig_size_fac=self.fig_size_fac,
# asp=asp, projection=proj, axlab_kw=None,
#)
#self.ax, self.cax = P.next()
#self.fig = plt.gcf()
#self.pos_ax = self.ax.get_position()
print('------')
print(self.fig.get_size_inches())
print(self.ax.get_position())
self.pos_cax = self.cax.get_position()
#self.make_new_axis()
# TK canvas
self.canvas = FigureCanvasTkAgg(self.fig, master=root)
self.canvas.get_tk_widget().grid(row=0, column=0, columnspan=3)
# Connect the click event to the function
self.canvas.mpl_connect("button_press_event", self.on_click)
# Create sliders
print('Setup sliders')
self.slider_t = tk.Scale(root, from_=0, to=len(self.ds.time)-1,
orient="horizontal", label="time", command=self.update_plot)
self.slider_t.grid(row=1, column=0, columnspan=2, sticky="ew")
self.slider_d = tk.Scale(root, from_=0, to=len(self.ds.depth)-1,
orient="horizontal", label="depth", command=self.update_plot)
self.slider_d.grid(row=1, column=2, columnspan=2, sticky="ew")
# Create dropdown menus
print('Setup var dropdown')
var_menu = ttk.Combobox(
root, textvariable=self.selected_var,
values=list(self.ds.data_vars.keys()), state="readonly"
)
var_menu.grid(row=2, column=0)
var_menu.bind("<<ComboboxSelected>>", self.update_plot)
print('Setup cmap dropdown')
cmap_menu = ttk.Combobox(root, textvariable=self.selected_cmap,
values=self.colormaps, state="readonly")
cmap_menu.grid(row=2, column=1)
cmap_menu.bind("<<ComboboxSelected>>", self.update_cmap)
# Color limit entry
print('Setup color limits')
entry = tk.Entry(root, textvariable=self.color_limits)
entry.grid(row=2, column=2)
entry.insert(0, "") # Default value
entry.bind("<Return>", self.update_clim) # Update when pressing Enter
# lon_lat_reg entry
print('Setup lon_reg')
entry = tk.Entry(root, textvariable=self.lon_lat_reg_tk)
entry.grid(row=3, column=1)
entry.insert(0, "") # Default value
entry.bind("<Return>", self.make_new_axis) # Update when pressing Enter
# res entry
print('Setup res dropdown')
res_menu = ttk.Combobox(root, textvariable=self.selected_res,
values=self.res_all, state="readonly")
res_menu.grid(row=3, column=0)
res_menu.bind("<<ComboboxSelected>>", self.make_new_axis)
# Button to activate zoom mode
self.zoom_button = ttk.Button(root, text="Enable Zoom", command=self.activate_zoom)
self.zoom_button.grid(row=3, column=2)
# Variables to store zoom area
self.press_event = None
self.rect = None
# Variables to store zoom area
self.press_event = None
self.rect = None
# proj entry
print('Setup proj dropdown')
res_menu = ttk.Combobox(root, textvariable=self.selected_proj,
values=self.proj_all, state="readonly")
res_menu.grid(row=3, column=3)
res_menu.bind("<<ComboboxSelected>>", self.make_new_axis)
# initial plot
self.plot_data()
self.canvas.draw()
# Start Tkinter loop
print('Go into mainloop')
root.mainloop()
return
# for zoom
def activate_zoom(self):
"""Activates zooming mode by connecting event handlers."""
self.cid_press = self.canvas.mpl_connect("button_press_event", self.on_press)
self.cid_release = self.canvas.mpl_connect("button_release_event", self.on_release)
self.cid_motion = self.canvas.mpl_connect("motion_notify_event", self.on_motion)
# for zoom
def on_press(self, event):
"""Stores the initial click position."""
if event.xdata is not None and event.ydata is not None:
self.press_event = (event.xdata, event.ydata)
self.rect = self.ax.add_patch(plt.Rectangle(self.press_event, 0, 0, fill=False, color="red", linestyle="dashed"))
self.canvas.draw()
# for zoom
def on_motion(self, event):
"""Updates the rectangle while dragging."""
if self.press_event and event.xdata is not None and event.ydata is not None:
x0, y0 = self.press_event
width = event.xdata - x0
height = event.ydata - y0
self.rect.set_width(width)
self.rect.set_height(height)
self.canvas.draw()
# for zoom
def on_release(self, event):
"""Zooms into the selected rectangle and removes it."""
if self.press_event and event.xdata is not None and event.ydata is not None:
x0, y0 = self.press_event
x1, y1 = event.xdata, event.ydata
# Ensure correct ordering of coordinates
self.ax.set_xlim(min(x0, x1), max(x0, x1))
self.ax.set_ylim(min(y0, y1), max(y0, y1))
# Remove the rectangle and redraw
self.rect.remove()
self.rect = None
self.press_event = None
self.canvas.draw()
# Disable event handlers after zooming
self.canvas.mpl_disconnect(self.cid_press)
self.canvas.mpl_disconnect(self.cid_release)
self.canvas.mpl_disconnect(self.cid_motion)
def load_data(self):
print('opening dataset')
mfdset_kwargs = dict(
combine='nested', concat_dim='time',
data_vars='minimal', coords='minimal',
compat='override', join='override',
parallel=True,
)
# #run = 'nib2704'
# #path_data = f'/work/mh0033/m300602/proj_vmix/icon/icon_27_enbal/feature_momentum_diagnostics_MID_POINT_DUAL_EDGE/build/gcc/experiments/{run}/'
# run = 'nib2703'
# path_data = f'/Users/nbruegge/work/icon_playground/icon_r2b4_new_test_data/'
# flist = glob.glob(f'{path_data}/{run}_P1M_3d_*.nc')
# flist.sort()
# #flist = flist[1:3]
self.ds = xr.open_mfdataset(
self.flist, **mfdset_kwargs,
chunks=dict(time=1, depth=1, depth_2=1)
)
delvars = [
"clon_bnds", "clat_bnds", "elon_bnds", "elat_bnds",
"vlon_bnds", "vlat_bnds",
"clon", "clat", "elon", "elat",
"lev"
]
for var in delvars:
try:
self.ds = self.ds.drop_vars([var ])
except:
pass
self.var_names = list(self.ds)
print(f"variables in data set: {self.var_names}")
self.var_name = self.var_names[0]
def plot_data(self):
# get updated limits
self.update_lon_lat_reg()
# get updated data
self.fpath_ckdtree = get_fpath_ckdtree(self.ds, self.res, self.path_grid)
self.dai = get_data(
self.ds, self.var_name, self.it, self.iz,
self.fpath_ckdtree, self.lon_reg, self.lat_reg
)
self.Lon, self.Lat = np.meshgrid(self.dai.lon.data, self.dai.lat.data)
if self.proj=="None":
self.X, self.Y = self.Lon, self.Lat
else:
self.X, self.Y = self.transformer.transform(self.Lon, self.Lat, direction='FORWARD')
#valid = np.isfinite(self.X) & np.isfinite(self.Y)
#self.dai[valid==False] = np.nan
# make plot
valid = np.isfinite(self.X) & np.isfinite(self.Y)
self.hm = pyic.shade(
#self.X[valid], self.Y[valid], self.dai.data[valid],
self.X, self.Y, self.dai.data,
ax=self.ax, cax=self.cax)
if self.proj=="+proj=stere +lat_0=90 +lon_0=0":
self.ax.set_xlim([-4660515.349812048, 4660515.349812048])
self.ax.set_ylim([-4658959.2511977535, 4658959.2511977535])
elif self.proj=="+proj=stere +lat_0=-90 +lon_0=0":
self.ax.set_xlim([-5965970.154575175, 5965970.154575175])
self.ax.set_ylim([-5963978.177895851, 5963978.177895851])
else:
self.ax.set_xlim(self.X.min(), self.X.max())
self.ax.set_ylim(self.Y.min(), self.Y.max())
#pyic.plot_settings(self.ax, xlim=self.lon_reg, ylim=self.lat_reg)
self.ax.set_facecolor('0.7')
self.update_cmap()
# set titles
#self.ht_var = self.ax.set_xlabel('', fontsize=self.font_size)
self.ht_var = self.cax.set_ylabel('', fontsize=self.font_size)
self.ht_depth = self.ax.set_title('', loc='left', fontsize=self.font_size)
self.ht_time = self.ax.set_title('', loc='right', fontsize=self.font_size)
self.ht_point = self.ax.text(0., -0.15, f'',
transform=self.ax.transAxes, fontsize=self.font_size)
for text in self.fig.findobj(plt.Text):
text.set_fontsize(self.font_size)
self.update_title()
def make_new_axis(self, *args):
self.res = float(self.selected_res.get())
self.proj = self.selected_proj.get()
if self.proj!="None":
self.transformer = Proj.from_pipeline(self.proj)
try:
self.ax.remove()
self.cax.remove()
except:
pass
self.update_lon_lat_reg()
if self.proj=="+proj=stere +lat_0=90 +lon_0=0":
asp = 1.0
elif self.proj=="+proj=stere +lat_0=-90 +lon_0=0":
asp = 1.0
elif self.proj=="+proj=eqearth":
asp = 0.4867169753874043
elif self.proj=="+proj=moll":
asp = 0.5
else:
asp = (self.lat_reg[1]-self.lat_reg[0])/(self.lon_reg[1]-self.lon_reg[0])
self.fig, self.ax, self.cax = generate_axes(asp, generate_figure=False)
##proj = ccrs.PlateCarree()
#proj = None
self.ax = self.fig.add_subplot(
position=self.pos_ax, projection=proj)
self.ax.set_position(self.pos_ax)
print('------')
print(self.fig.get_size_inches())
print(self.ax.get_position())
self.cax = self.fig.add_subplot(
position=self.pos_cax)
self.cax.set_position(self.pos_cax)
self.cax.set_xticks([])
self.cax.yaxis.tick_right()
self.cax.yaxis.set_label_position("right")
self.plot_data()
self.canvas.draw()
# Function to update plot
def update_plot(self, *args):
# Get current slider values
self.it = int(self.slider_t.get())
self.iz = int(self.slider_d.get())
# Get selected variable and colormap
self.var_name = self.selected_var.get()
cmap = self.selected_cmap.get()
self.res = float(self.selected_res.get())
print(f'{self.var_name}: it = {self.it}; iz = {self.iz}')
# Get data and plot
self.update_lon_lat_reg()
self.dai = get_data(
self.ds, self.var_name, self.it, self.iz,
self.fpath_ckdtree, self.lon_reg, self.lat_reg
)
self.Lon, self.Lat = np.meshgrid(self.dai.lon, self.dai.lat)
self.hm[0].set_array(self.dai.data.flatten())
self.update_title()
self.canvas.draw()
def update_title(self):
if self.dai.depth_name!='none':
self.ht_depth.set_text(
f"{self.dai.depth_name} = {self.ds[self.dai.depth_name][self.iz].data}")
self.ht_time.set_text(
f"time = {str(self.ds.time[self.it].data)[:16]}")
try:
var_longname = self.ds[self.var_name].long_name
except:
var_longname = self.var_name
try:
unit = f" / {(self.ds[self.var_name].units)}"
except:
unit = ""
self.ht_var.set_text(f"{var_longname}{unit}")
def update_clim(self, *args):
clim_str = self.color_limits.get()
try:
clim = self.get_clim(clim_str, self.dai)
self.hm[0].set_clim(clim[0], clim[1])
print(f'Updated clim to {clim}')
self.canvas.draw()
except ValueError:
print(f'Invalid value for clim: {clim_str}')
return
def update_cmap(self, *args):
# update cmap
cmap = self.selected_cmap.get()
print(f"Updating cmap to {cmap}")
if cmap.startswith('cmo'):
cmap = cmap.split('.')[-1]
cmap = getattr(cmocean.cm, cmap)
else:
cmap = getattr(plt.cm, cmap)
self.hm[0].set_cmap(cmap)
self.canvas.draw()
def update_lon_lat_reg(self, *args):
lon_lat_reg_str = self.lon_lat_reg_tk.get()
lon_lat_reg = str_to_array(lon_lat_reg_str)
self.lon_reg = [lon_lat_reg[0], lon_lat_reg[1]]
self.lat_reg = [lon_lat_reg[2], lon_lat_reg[3]]
print(f'lon_reg = {self.lon_reg}, lat_reg = {self.lat_reg}')
def get_clim(self, clim, data):
# --- clim
if isinstance(clim, str) and clim=='auto':
clim = np.array([None, None])
elif isinstance(clim, str) and clim=='sym':
clim = np.array([np.abs(data).max().data])
else:
clim = np.array(clim.split(','), dtype=float)
if clim.size==1:
clim = np.array([-1, 1])*clim[0]
if clim[0] is None:
clim[0] = data.min().data
if clim[1] is None:
clim[1] = data.max().data
return clim
# capture mouse click, print coordinates and data
def on_click(self, event):
# Avoid clicking outside the axes
if event.xdata is not None and event.ydata is not None:
if self.proj!="None":
lon_click, lat_click = self.transformer.transform(
event.xdata, event.ydata,
direction='INVERSE',
)
else:
lon_click, lat_click = event.xdata, event.ydata
ind = np.argmin(
(self.Lon.flatten()-lon_click)**2+(self.Lat.flatten()-lat_click)**2
)
data_click = self.dai.data.flatten()[ind]
txt = f"lon:{lon_click:.2f}, lat: {lat_click:.2f}, data: {data_click:.4f}"
self.ht_point.set_text(txt)
print(txt)
self.canvas.draw()
def main():
import argparse
help_text = """
Opens an interactive GUI to visualize horizontal ICON data.
Usage notes:
------------
Basic usage:
pyic_view.py netcdf_file_or_list.nc [options]
Argument list:
--------------
"""
# --- read input arguments
parser = argparse.ArgumentParser(description=help_text, formatter_class=argparse.RawTextHelpFormatter)
# --- necessary arguments
parser.add_argument('fpath_data', nargs='+', metavar='fpath_data', type=str,
help='Path to ICON data file.')
parser.add_argument('--size', type=float, default=1.0,
help='Factor that determines the figure size')
iopts = parser.parse_args()
#flist = glob.glob(iopts.fpath_data)
flist = iopts.fpath_data
flist.sort()
print(flist)
# Initial plot
print('Initialize plot')
path_grid = 'auto'
View = view(flist,
path_grid=path_grid,
fig_size_fac=iopts.size,
)
if __name__ == "__main__":
main()