Skip to content
Snippets Groups Projects
Commit 22a0e5d4 authored by Aaron Wienkers's avatar Aaron Wienkers
Browse files

adds new nn and linear interpolator functions from arbitrary grid onto gridICON object grid

parent 0251869c
Branches master
No related tags found
No related merge requests found
Pipeline #85727 failed
......@@ -2,6 +2,8 @@ import warnings
import numpy as np
import xarray as xr
from itertools import product
from scipy.spatial import cKDTree
from scipy.spatial import Delaunay
import os
......@@ -40,7 +42,92 @@ class daskicon:
return
## Interpolator Functions:
# These Interpolators take arbitrary data and put it onto the present ICON grid of the daskicon object
def make_nn_interpolator(self, da):
points_icon_source = np.vstack((da.lon, da.lat)).T
points_icon_target = np.vstack((self.grid2d.clon*180./np.pi, self.grid2d.clat*180./np.pi)).T
tree = cKDTree(points_icon_source)
_, nn_indices_ao = tree.query(points_icon_target)
return nn_indices_ao
def interp_nn(self, da_source, interpolator):
indices = interpolator
da_on_target = da_source.isel(ncells=indices)
return da_on_target
def make_linear_interpolator(self, da):
points_icon_source = np.vstack((da.lon, da.lat)).T
points_icon_target = np.vstack((self.grid2d.clon*180./np.pi, self.grid2d.clat*180./np.pi)).T
tri = Delaunay(points_icon_source)
# Find simplices containing target points and get valid indices
simplices = tri.find_simplex(points_icon_target)
indices = np.where(simplices >= 0)[0]
valid_simplices = simplices[indices]
# Get the vertices and compute barycentric coordinates
vertices = tri.simplices[valid_simplices]
delta = points_icon_source[vertices] - points_icon_target[indices, None, :]
# Compute barycentric coordinates
weights = np.einsum('ijk,ik->ij',
delta[:, :-1] - delta[:, -1:],
-delta[:, -1]) / np.einsum('ijk,ijk->ij',
delta[:, :-1] - delta[:, -1:],
delta[:, :-1] - delta[:, -1:])
weights = np.c_[weights, 1 - weights.sum(axis=1)]
return indices, vertices, weights
def interp_linear(self, da_source, interpolator):
indices, vertices, weights = interpolator
n_points = max(indices.max() + 1, len(da_source.ncells))
def interp_timestep(values):
n_times = values.shape[0]
result = np.full((n_times, n_points), np.nan)
# Get values at vertices: shape becomes (n_times, n_target_points, n_vertices)
vertex_values = values[:, vertices]
# Multiply by weights and sum: (n_times, n_target_points)
result[:, indices] = np.sum(vertex_values * weights, axis=2)
return result
result = xr.apply_ufunc(
interp_timestep,
da_source.transpose('time','ncells'),
input_core_dims=[['ncells']],
output_core_dims=[['ncells_target']],
vectorize=False,
dask='parallelized',
output_dtypes=[np.float32],
dask_gufunc_kwargs={'output_sizes': {'ncells_target': n_points}}
)
result = result.rename({'ncells_target':'ncells'})
result = result.assign_coords({
'lat': ('ncells', self.grid2d.clat.data*180./np.pi),
'lon': ('ncells', self.grid2d.clon.data*180./np.pi)
})
return result
## High-level Vector Operations
def compute_curl_cells(self, ds_uv):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment