Commit 533163f4 authored by Dion Häfner's avatar Dion Häfner

[diagnostics] fix overturning diagnostic in distributed runs

parent 2bc137e4
......@@ -41,6 +41,11 @@ ISONEUTRAL_VARIABLES = OrderedDict([
])
@veros_method(inline=True)
def zonal_sum(vs, arr):
return global_sum(vs, np.sum(arr, axis=0), axis=0)
class Overturning(VerosDiagnostic):
"""Isopycnal overturning diagnostic. Computes and writes vertical streamfunctions
(zonally averaged).
......@@ -81,10 +86,10 @@ class Overturning(VerosDiagnostic):
self.sigma[...] = self.sigs + self.dsig * np.arange(self.nlevel)
# precalculate area below z levels
self.zarea[2:-2, :] = np.cumsum(global_sum(vs, np.sum(
self.zarea[2:-2, :] = np.cumsum(zonal_sum(vs,
vs.dxt[2:-2, np.newaxis, np.newaxis]
* vs.cosu[np.newaxis, 2:-2, np.newaxis]
* vs.maskV[2:-2, 2:-2, :], axis=0)) * vs.dzt[np.newaxis, :], axis=1)
* vs.maskV[2:-2, 2:-2, :]) * vs.dzt[np.newaxis, :], axis=1)
self.initialize_output(vs, self.variables,
var_data={'sigma': self.sigma},
......@@ -115,21 +120,18 @@ class Overturning(VerosDiagnostic):
trans = allocate(vs, ('yu', self.nlevel))
z_sig = allocate(vs, ('yu', self.nlevel))
fac = (vs.dxt[2:-2, np.newaxis, np.newaxis]
* vs.cosu[np.newaxis, 2:-2, np.newaxis]
* vs.dzt[np.newaxis, np.newaxis, :]
* vs.maskV[2:-2, 2:-2, :])
for m in range(self.nlevel):
# NOTE: vectorized version would be O(N^4) in memory
# consider cythonizing if performance-critical
mask = sig_loc_face > self.sigma[m]
trans[2:-2, m] = global_sum(vs, np.sum(
vs.v[2:-2, 2:-2, :, vs.tau]
* vs.dxt[2:-2, np.newaxis, np.newaxis]
* vs.cosu[np.newaxis, 2:-2, np.newaxis]
* vs.dzt[np.newaxis, np.newaxis, :]
* vs.maskV[2:-2, 2:-2, :] * mask, axis=(0, 2)))
z_sig[2:-2, m] = global_sum(vs, np.sum(
vs.dzt[np.newaxis, np.newaxis, :]
* vs.dxt[2:-2, np.newaxis, np.newaxis]
* vs.cosu[np.newaxis, 2:-2, np.newaxis]
* vs.maskV[2:-2, 2:-2, :] * mask, axis=(0, 2)))
trans[2:-2, m] = zonal_sum(vs, np.sum(vs.v[2:-2, 2:-2, :, vs.tau] * fac * mask, axis=2))
z_sig[2:-2, m] = zonal_sum(vs, np.sum(fac * mask, axis=2))
self.trans += trans
if vs.enable_neutral_diffusion and vs.enable_skew_diffusion:
......@@ -138,38 +140,37 @@ class Overturning(VerosDiagnostic):
for m in range(self.nlevel):
# NOTE: see above
mask = sig_loc_face > self.sigma[m]
bolus_trans[2:-2, m] = global_sum(vs,
bolus_trans[2:-2, m] = zonal_sum(vs,
np.sum(
(vs.B1_gm[2:-2, 2:-2, 1:] - vs.B1_gm[2:-2, 2:-2, :-1])
* vs.dxt[2:-2, np.newaxis, np.newaxis]
* vs.cosu[np.newaxis, 2:-2, np.newaxis]
* vs.maskV[2:-2, 2:-2, 1:]
* mask[:, :, 1:],
axis=(0, 2)
axis=2
)
+ np.sum(
+
vs.B1_gm[2:-2, 2:-2, 0]
* vs.dxt[2:-2, np.newaxis]
* vs.cosu[np.newaxis, 2:-2]
* vs.maskV[2:-2, 2:-2, 0]
* mask[:, :, 0],
axis=0
)
* mask[:, :, 0]
)
# streamfunction on geopotentials
self.vsf_depth[2:-2, :] += np.cumsum(global_sum(vs, np.sum(
self.vsf_depth[2:-2, :] += np.cumsum(zonal_sum(vs,
vs.dxt[2:-2, np.newaxis, np.newaxis]
* vs.cosu[np.newaxis, 2:-2, np.newaxis]
* vs.v[2:-2, 2:-2, :, vs.tau]
* vs.maskV[2:-2, 2:-2, :], axis=0)) * vs.dzt[np.newaxis, :], axis=1)
* vs.maskV[2:-2, 2:-2, :]) * vs.dzt[np.newaxis, :], axis=1)
if vs.enable_neutral_diffusion and vs.enable_skew_diffusion:
# streamfunction for eddy driven velocity on geopotentials
self.bolus_depth[2:-2, :] += global_sum(vs, np.sum(
self.bolus_depth[2:-2, :] += zonal_sum(vs,
vs.dxt[2:-2, np.newaxis, np.newaxis]
* vs.cosu[np.newaxis, 2:-2, np.newaxis]
* vs.B1_gm[2:-2, 2:-2, :], axis=0))
* vs.B1_gm[2:-2, 2:-2, :])
# interpolate from isopycnals to depth
self.vsf_iso[2:-2, :] += self._interpolate_along_axis(vs,
z_sig[2:-2, :], trans[2:-2, :],
......@@ -181,7 +182,6 @@ class Overturning(VerosDiagnostic):
self.nitts += 1
@veros_method
def _interpolate_along_axis(self, vs, coords, arr, interp_coords, axis=0):
# TODO: clean up this mess
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment