Commit c8fbe3e9 authored by Dion Häfner's avatar Dion Häfner

[distributed] add axis argument to reductions

parent f76d40c3
......@@ -281,7 +281,15 @@ def exchange_cyclic_boundaries(vs, arr):
@dist_context_only
@veros_method(inline=True)
def _reduce(vs, arr, op):
def _reduce(vs, arr, op, axis=None):
if axis is None:
comm = rs.mpi_comm
else:
assert axis in (0, 1)
pi = proc_rank_to_index(rst.proc_rank)
other_axis = 1 - axis
comm = rs.mpi_comm.Split(pi[other_axis], rst.proc_rank)
if np.isscalar(arr):
squeeze = True
arr = np.array([arr])
......@@ -291,7 +299,7 @@ def _reduce(vs, arr, op):
arr = ascontiguousarray(arr)
res = np.empty_like(arr)
rs.mpi_comm.Allreduce(
comm.Allreduce(
get_array_buffer(vs, arr),
get_array_buffer(vs, res),
op=op
......@@ -305,37 +313,37 @@ def _reduce(vs, arr, op):
@dist_context_only
@veros_method
def global_and(vs, arr):
def global_and(vs, arr, axis=None):
from mpi4py import MPI
return _reduce(vs, arr, MPI.LAND)
return _reduce(vs, arr, MPI.LAND, axis=axis)
@dist_context_only
@veros_method
def global_or(vs, arr):
def global_or(vs, arr, axis=None):
from mpi4py import MPI
return _reduce(vs, arr, MPI.LOR)
return _reduce(vs, arr, MPI.LOR, axis=axis)
@dist_context_only
@veros_method
def global_max(vs, arr):
def global_max(vs, arr, axis=None):
from mpi4py import MPI
return _reduce(vs, arr, MPI.MAX)
return _reduce(vs, arr, MPI.MAX, axis=axis)
@dist_context_only
@veros_method
def global_min(vs, arr):
def global_min(vs, arr, axis=None):
from mpi4py import MPI
return _reduce(vs, arr, MPI.MIN)
return _reduce(vs, arr, MPI.MIN, axis=axis)
@dist_context_only
@veros_method
def global_sum(vs, arr):
def global_sum(vs, arr, axis=None):
from mpi4py import MPI
return _reduce(vs, arr, MPI.SUM)
return _reduce(vs, arr, MPI.SUM, axis=axis)
@dist_context_only
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment