diff --git a/pyicon/pyicon_calc_aw.py b/pyicon/pyicon_calc_aw.py index 1946f84df0b77c82706f2f0adac3e8f8c69a6a0b..6247398623af84cfc53411a88f997d3de4e0f3c1 100644 --- a/pyicon/pyicon_calc_aw.py +++ b/pyicon/pyicon_calc_aw.py @@ -2,6 +2,8 @@ import warnings import numpy as np import xarray as xr from itertools import product +from scipy.spatial import cKDTree +from scipy.spatial import Delaunay import os @@ -40,7 +42,92 @@ class daskicon: return + ## Interpolator Functions: + # These Interpolators take arbitrary data and put it onto the present ICON grid of the daskicon object + def make_nn_interpolator(self, da): + + points_icon_source = np.vstack((da.lon, da.lat)).T + points_icon_target = np.vstack((self.grid2d.clon*180./np.pi, self.grid2d.clat*180./np.pi)).T + + tree = cKDTree(points_icon_source) + _, nn_indices_ao = tree.query(points_icon_target) + + return nn_indices_ao + + def interp_nn(self, da_source, interpolator): + + indices = interpolator + da_on_target = da_source.isel(ncells=indices) + return da_on_target + + def make_linear_interpolator(self, da): + + points_icon_source = np.vstack((da.lon, da.lat)).T + points_icon_target = np.vstack((self.grid2d.clon*180./np.pi, self.grid2d.clat*180./np.pi)).T + + tri = Delaunay(points_icon_source) + + # Find simplices containing target points and get valid indices + simplices = tri.find_simplex(points_icon_target) + indices = np.where(simplices >= 0)[0] + valid_simplices = simplices[indices] + + # Get the vertices and compute barycentric coordinates + vertices = tri.simplices[valid_simplices] + delta = points_icon_source[vertices] - points_icon_target[indices, None, :] + + # Compute barycentric coordinates + weights = np.einsum('ijk,ik->ij', + delta[:, :-1] - delta[:, -1:], + -delta[:, -1]) / np.einsum('ijk,ijk->ij', + delta[:, :-1] - delta[:, -1:], + delta[:, :-1] - delta[:, -1:]) + weights = np.c_[weights, 1 - weights.sum(axis=1)] + + return indices, vertices, weights + + def interp_linear(self, da_source, interpolator): + + indices, vertices, weights = interpolator + n_points = max(indices.max() + 1, len(da_source.ncells)) + + def interp_timestep(values): + + n_times = values.shape[0] + result = np.full((n_times, n_points), np.nan) + + # Get values at vertices: shape becomes (n_times, n_target_points, n_vertices) + vertex_values = values[:, vertices] + + # Multiply by weights and sum: (n_times, n_target_points) + result[:, indices] = np.sum(vertex_values * weights, axis=2) + + return result + + + result = xr.apply_ufunc( + interp_timestep, + da_source.transpose('time','ncells'), + input_core_dims=[['ncells']], + output_core_dims=[['ncells_target']], + vectorize=False, + dask='parallelized', + output_dtypes=[np.float32], + dask_gufunc_kwargs={'output_sizes': {'ncells_target': n_points}} + ) + + result = result.rename({'ncells_target':'ncells'}) + result = result.assign_coords({ + 'lat': ('ncells', self.grid2d.clat.data*180./np.pi), + 'lon': ('ncells', self.grid2d.clon.data*180./np.pi) + }) + + return result + + + + ## High-level Vector Operations def compute_curl_cells(self, ds_uv):