Commit 533163f4 authored by Dion Häfner's avatar Dion Häfner

[diagnostics] fix overturning diagnostic in distributed runs

parent 2bc137e4
...@@ -41,6 +41,11 @@ ISONEUTRAL_VARIABLES = OrderedDict([ ...@@ -41,6 +41,11 @@ ISONEUTRAL_VARIABLES = OrderedDict([
]) ])
@veros_method(inline=True)
def zonal_sum(vs, arr):
return global_sum(vs, np.sum(arr, axis=0), axis=0)
class Overturning(VerosDiagnostic): class Overturning(VerosDiagnostic):
"""Isopycnal overturning diagnostic. Computes and writes vertical streamfunctions """Isopycnal overturning diagnostic. Computes and writes vertical streamfunctions
(zonally averaged). (zonally averaged).
...@@ -81,10 +86,10 @@ class Overturning(VerosDiagnostic): ...@@ -81,10 +86,10 @@ class Overturning(VerosDiagnostic):
self.sigma[...] = self.sigs + self.dsig * np.arange(self.nlevel) self.sigma[...] = self.sigs + self.dsig * np.arange(self.nlevel)
# precalculate area below z levels # precalculate area below z levels
self.zarea[2:-2, :] = np.cumsum(global_sum(vs, np.sum( self.zarea[2:-2, :] = np.cumsum(zonal_sum(vs,
vs.dxt[2:-2, np.newaxis, np.newaxis] vs.dxt[2:-2, np.newaxis, np.newaxis]
* vs.cosu[np.newaxis, 2:-2, np.newaxis] * vs.cosu[np.newaxis, 2:-2, np.newaxis]
* vs.maskV[2:-2, 2:-2, :], axis=0)) * vs.dzt[np.newaxis, :], axis=1) * vs.maskV[2:-2, 2:-2, :]) * vs.dzt[np.newaxis, :], axis=1)
self.initialize_output(vs, self.variables, self.initialize_output(vs, self.variables,
var_data={'sigma': self.sigma}, var_data={'sigma': self.sigma},
...@@ -115,21 +120,18 @@ class Overturning(VerosDiagnostic): ...@@ -115,21 +120,18 @@ class Overturning(VerosDiagnostic):
trans = allocate(vs, ('yu', self.nlevel)) trans = allocate(vs, ('yu', self.nlevel))
z_sig = allocate(vs, ('yu', self.nlevel)) z_sig = allocate(vs, ('yu', self.nlevel))
fac = (vs.dxt[2:-2, np.newaxis, np.newaxis]
* vs.cosu[np.newaxis, 2:-2, np.newaxis]
* vs.dzt[np.newaxis, np.newaxis, :]
* vs.maskV[2:-2, 2:-2, :])
for m in range(self.nlevel): for m in range(self.nlevel):
# NOTE: vectorized version would be O(N^4) in memory # NOTE: vectorized version would be O(N^4) in memory
# consider cythonizing if performance-critical # consider cythonizing if performance-critical
mask = sig_loc_face > self.sigma[m] mask = sig_loc_face > self.sigma[m]
trans[2:-2, m] = global_sum(vs, np.sum( trans[2:-2, m] = zonal_sum(vs, np.sum(vs.v[2:-2, 2:-2, :, vs.tau] * fac * mask, axis=2))
vs.v[2:-2, 2:-2, :, vs.tau] z_sig[2:-2, m] = zonal_sum(vs, np.sum(fac * mask, axis=2))
* vs.dxt[2:-2, np.newaxis, np.newaxis]
* vs.cosu[np.newaxis, 2:-2, np.newaxis]
* vs.dzt[np.newaxis, np.newaxis, :]
* vs.maskV[2:-2, 2:-2, :] * mask, axis=(0, 2)))
z_sig[2:-2, m] = global_sum(vs, np.sum(
vs.dzt[np.newaxis, np.newaxis, :]
* vs.dxt[2:-2, np.newaxis, np.newaxis]
* vs.cosu[np.newaxis, 2:-2, np.newaxis]
* vs.maskV[2:-2, 2:-2, :] * mask, axis=(0, 2)))
self.trans += trans self.trans += trans
if vs.enable_neutral_diffusion and vs.enable_skew_diffusion: if vs.enable_neutral_diffusion and vs.enable_skew_diffusion:
...@@ -138,38 +140,37 @@ class Overturning(VerosDiagnostic): ...@@ -138,38 +140,37 @@ class Overturning(VerosDiagnostic):
for m in range(self.nlevel): for m in range(self.nlevel):
# NOTE: see above # NOTE: see above
mask = sig_loc_face > self.sigma[m] mask = sig_loc_face > self.sigma[m]
bolus_trans[2:-2, m] = global_sum(vs, bolus_trans[2:-2, m] = zonal_sum(vs,
np.sum( np.sum(
(vs.B1_gm[2:-2, 2:-2, 1:] - vs.B1_gm[2:-2, 2:-2, :-1]) (vs.B1_gm[2:-2, 2:-2, 1:] - vs.B1_gm[2:-2, 2:-2, :-1])
* vs.dxt[2:-2, np.newaxis, np.newaxis] * vs.dxt[2:-2, np.newaxis, np.newaxis]
* vs.cosu[np.newaxis, 2:-2, np.newaxis] * vs.cosu[np.newaxis, 2:-2, np.newaxis]
* vs.maskV[2:-2, 2:-2, 1:] * vs.maskV[2:-2, 2:-2, 1:]
* mask[:, :, 1:], * mask[:, :, 1:],
axis=(0, 2) axis=2
)
+ np.sum(
vs.B1_gm[2:-2, 2:-2, 0]
* vs.dxt[2:-2, np.newaxis]
* vs.cosu[np.newaxis, 2:-2]
* vs.maskV[2:-2, 2:-2, 0]
* mask[:, :, 0],
axis=0
) )
+
vs.B1_gm[2:-2, 2:-2, 0]
* vs.dxt[2:-2, np.newaxis]
* vs.cosu[np.newaxis, 2:-2]
* vs.maskV[2:-2, 2:-2, 0]
* mask[:, :, 0]
) )
# streamfunction on geopotentials # streamfunction on geopotentials
self.vsf_depth[2:-2, :] += np.cumsum(global_sum(vs, np.sum( self.vsf_depth[2:-2, :] += np.cumsum(zonal_sum(vs,
vs.dxt[2:-2, np.newaxis, np.newaxis] vs.dxt[2:-2, np.newaxis, np.newaxis]
* vs.cosu[np.newaxis, 2:-2, np.newaxis] * vs.cosu[np.newaxis, 2:-2, np.newaxis]
* vs.v[2:-2, 2:-2, :, vs.tau] * vs.v[2:-2, 2:-2, :, vs.tau]
* vs.maskV[2:-2, 2:-2, :], axis=0)) * vs.dzt[np.newaxis, :], axis=1) * vs.maskV[2:-2, 2:-2, :]) * vs.dzt[np.newaxis, :], axis=1)
if vs.enable_neutral_diffusion and vs.enable_skew_diffusion: if vs.enable_neutral_diffusion and vs.enable_skew_diffusion:
# streamfunction for eddy driven velocity on geopotentials # streamfunction for eddy driven velocity on geopotentials
self.bolus_depth[2:-2, :] += global_sum(vs, np.sum( self.bolus_depth[2:-2, :] += zonal_sum(vs,
vs.dxt[2:-2, np.newaxis, np.newaxis] vs.dxt[2:-2, np.newaxis, np.newaxis]
* vs.cosu[np.newaxis, 2:-2, np.newaxis] * vs.cosu[np.newaxis, 2:-2, np.newaxis]
* vs.B1_gm[2:-2, 2:-2, :], axis=0)) * vs.B1_gm[2:-2, 2:-2, :])
# interpolate from isopycnals to depth # interpolate from isopycnals to depth
self.vsf_iso[2:-2, :] += self._interpolate_along_axis(vs, self.vsf_iso[2:-2, :] += self._interpolate_along_axis(vs,
z_sig[2:-2, :], trans[2:-2, :], z_sig[2:-2, :], trans[2:-2, :],
...@@ -181,7 +182,6 @@ class Overturning(VerosDiagnostic): ...@@ -181,7 +182,6 @@ class Overturning(VerosDiagnostic):
self.nitts += 1 self.nitts += 1
@veros_method @veros_method
def _interpolate_along_axis(self, vs, coords, arr, interp_coords, axis=0): def _interpolate_along_axis(self, vs, coords, arr, interp_coords, axis=0):
# TODO: clean up this mess # TODO: clean up this mess
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment