From f6ac82117b5ab528b1bebfb8d843b8fb371d9cbd Mon Sep 17 00:00:00 2001
From: Pradipta Samanta <samanta@dkrz.de>
Date: Thu, 2 Jan 2025 22:24:47 +0100
Subject: [PATCH] added the cpp version of tdma_solver_vec

made it compile
---
 CMakeLists.txt                    |   2 +
 src/support/CMakeLists.txt        |   1 +
 src/support/mo_math_utilities.F90 | 164 +++++++++++++++++-------------
 src/support/mo_math_utilities.cpp |  77 ++++++++++++++
 4 files changed, 172 insertions(+), 72 deletions(-)
 create mode 100644 src/support/mo_math_utilities.cpp

diff --git a/CMakeLists.txt b/CMakeLists.txt
index 2f32fcf..8fb4acf 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -16,6 +16,8 @@ project(
   VERSION 1.0.0
   LANGUAGES Fortran CXX)
 
+set(CMAKE_CXX_STANDARD 17)
+
 option(BUILD_SHARED_LIBS "Build shared libraries" ON)
 option(BUILD_TESTING "Build tests" ON)
 option(BUILD_ICONMATH_INTERPOLATION "Build interpolation library" ON)
diff --git a/src/support/CMakeLists.txt b/src/support/CMakeLists.txt
index c0fc287..35e1c71 100644
--- a/src/support/CMakeLists.txt
+++ b/src/support/CMakeLists.txt
@@ -17,6 +17,7 @@ add_library(
   mo_lib_loopindices.F90
   mo_math_constants.f90
   mo_math_types.f90
+  mo_math_utilities.cpp
   mo_math_utilities.F90
   mo_random_number_generators.F90)
 
diff --git a/src/support/mo_math_utilities.F90 b/src/support/mo_math_utilities.F90
index 0add18c..5c00192 100644
--- a/src/support/mo_math_utilities.F90
+++ b/src/support/mo_math_utilities.F90
@@ -22,6 +22,7 @@
 ! #endif
 MODULE mo_math_utilities
 
+  USE, INTRINSIC :: ISO_C_BINDING
   USE mo_iconlib_kind, ONLY: wp, dp, sp
   USE mo_math_constants, ONLY: pi, pi_2, dbl_eps
   USE mo_gridman_constants, ONLY: SUCCESS, TORUS_MAX_LAT
@@ -165,7 +166,98 @@ MODULE mo_math_utilities
 
   CHARACTER(LEN=*), PARAMETER :: modname = 'mo_math_utilities'
 
+  !-------------------------------------------------------------------------
+  !>
+  !! TDMA tridiagonal matrix solver for a_i*x_(i-1) + b_i*x_i + c_i*x_(i+1) = d_i
+  !!
+  !!       a - sub-diagonal (means it is the diagonal below the main diagonal)
+  !!       b - the main diagonal
+  !!       c - sup-diagonal (means it is the diagonal above the main diagonal)
+  !!       d - right part
+  !!  varout - the answer (identical to x in description above)
+  !!    slev - start level (top)
+  !!    elev - end level (bottom)
+
+! Preprocessor directive to conditionally include the tdma_solver_vec implementation
+#ifndef __USE_CPP_BINDINGS
+
+  CONTAINS
+
+  SUBROUTINE tdma_solver_vec(a, b, c, d, slev, elev, startidx, endidx, varout, opt_acc_queue)
+    INTEGER, INTENT(IN) :: slev, elev
+    INTEGER, INTENT(IN) :: startidx, endidx
+    REAL(wp), INTENT(IN) :: a(:, :), b(:, :), c(:, :), d(:, :)
+    REAL(wp), INTENT(OUT) :: varout(:, :)
+    INTEGER, OPTIONAL, INTENT(IN) :: opt_acc_queue
+
+    !
+    ! local
+    REAL(wp):: m, c_p(SIZE(a, 1), SIZE(a, 2)), d_p(SIZE(a, 1), SIZE(a, 2))
+    INTEGER :: i
+    INTEGER :: jc
+    INTEGER :: acc_queue
+
+    IF (PRESENT(opt_acc_queue)) THEN
+      acc_queue = opt_acc_queue
+    ELSE
+      acc_queue = 1
+    END IF
+
+    ! initialize c-prime and d-prime
+    !$ACC PARALLEL DEFAULT(PRESENT) CREATE(c_p, d_p) ASYNC(acc_queue)
+    !$ACC LOOP GANG(STATIC: 1) VECTOR
+    DO jc = startidx, endidx
+      c_p(jc, slev) = c(jc, slev)/b(jc, slev)
+      d_p(jc, slev) = d(jc, slev)/b(jc, slev)
+    END DO
+    ! solve for vectors c-prime and d-prime
+    !$ACC LOOP SEQ
+!NEC$ outerloop_unroll(4)
+    DO i = slev + 1, elev
+      !$ACC LOOP GANG(STATIC: 1) VECTOR PRIVATE(m)
+      DO jc = startidx, endidx
+        m = 1._wp/(b(jc, i) - c_p(jc, i - 1)*a(jc, i))
+        c_p(jc, i) = c(jc, i)*m
+        d_p(jc, i) = (d(jc, i) - d_p(jc, i - 1)*a(jc, i))*m
+      END DO
+    END DO
+    ! initialize varout
+    !$ACC LOOP GANG(STATIC: 1) VECTOR
+    DO jc = startidx, endidx
+      varout(jc, elev) = d_p(jc, elev)
+    END DO
+    ! solve for varout from the vectors c-prime and d-prime
+    !$ACC LOOP SEQ
+!NEC$ outerloop_unroll(4)
+    DO i = elev - 1, slev, -1
+      !$ACC LOOP GANG(STATIC: 1) VECTOR
+      DO jc = startidx, endidx
+        varout(jc, i) = d_p(jc, i) - c_p(jc, i)*varout(jc, i + 1)
+      END DO
+    END DO
+    !$ACC END PARALLEL
+
+    IF (.NOT. PRESENT(opt_acc_queue)) THEN
+      !$ACC WAIT(acc_queue)
+    END IF
+
+  END SUBROUTINE tdma_solver_vec
+
+#else
+
+  ! C++ binding for tdma_solver_vec
+  INTERFACE
+    SUBROUTINE tdma_solver_vec(a, b, c, d, slev, elev, startidx, endidx, varout, opt_acc_queue) BIND(C, NAME="tdma_solver_vec")
+      IMPORT :: C_DOUBLE, C_INT
+      REAL(C_DOUBLE), INTENT(IN) :: a(*), b(*), c(*), d(*)
+      INTEGER(C_INT), VALUE :: slev, elev, startidx, endidx
+      REAL(C_DOUBLE), INTENT(OUT) :: varout(*)
+      INTEGER(C_INT), OPTIONAL :: opt_acc_queue
+    END SUBROUTINE tdma_solver_vec
+  END INTERFACE
+
 CONTAINS
+#endif
 
   !-------------------------------------------------------------------------
   ! Variant for double-precision (or working-precision=dp) lon+lat in ICON
@@ -2041,78 +2133,6 @@ CONTAINS
 
   END SUBROUTINE tdma_solver
 
-  !-------------------------------------------------------------------------
-  !>
-  !! TDMA tridiagonal matrix solver for a_i*x_(i-1) + b_i*x_i + c_i*x_(i+1) = d_i
-  !!
-  !!       a - sub-diagonal (means it is the diagonal below the main diagonal)
-  !!       b - the main diagonal
-  !!       c - sup-diagonal (means it is the diagonal above the main diagonal)
-  !!       d - right part
-  !!  varout - the answer (identical to x in description above)
-  !!    slev - start level (top)
-  !!    elev - end level (bottom)
-  SUBROUTINE tdma_solver_vec(a, b, c, d, slev, elev, startidx, endidx, varout, opt_acc_queue)
-    INTEGER, INTENT(IN) :: slev, elev
-    INTEGER, INTENT(IN) :: startidx, endidx
-    REAL(wp), INTENT(IN) :: a(:, :), b(:, :), c(:, :), d(:, :)
-    REAL(wp), INTENT(OUT) :: varout(:, :)
-    INTEGER, OPTIONAL, INTENT(IN) :: opt_acc_queue
-
-    !
-    ! local
-    REAL(wp):: m, c_p(SIZE(a, 1), SIZE(a, 2)), d_p(SIZE(a, 1), SIZE(a, 2))
-    INTEGER :: i
-    INTEGER :: jc
-    INTEGER :: acc_queue
-
-    IF (PRESENT(opt_acc_queue)) THEN
-      acc_queue = opt_acc_queue
-    ELSE
-      acc_queue = 1
-    END IF
-
-    ! initialize c-prime and d-prime
-    !$ACC PARALLEL DEFAULT(PRESENT) CREATE(c_p, d_p) ASYNC(acc_queue)
-    !$ACC LOOP GANG(STATIC: 1) VECTOR
-    DO jc = startidx, endidx
-      c_p(jc, slev) = c(jc, slev)/b(jc, slev)
-      d_p(jc, slev) = d(jc, slev)/b(jc, slev)
-    END DO
-    ! solve for vectors c-prime and d-prime
-    !$ACC LOOP SEQ
-!NEC$ outerloop_unroll(4)
-    DO i = slev + 1, elev
-      !$ACC LOOP GANG(STATIC: 1) VECTOR PRIVATE(m)
-      DO jc = startidx, endidx
-        m = 1._wp/(b(jc, i) - c_p(jc, i - 1)*a(jc, i))
-        c_p(jc, i) = c(jc, i)*m
-        d_p(jc, i) = (d(jc, i) - d_p(jc, i - 1)*a(jc, i))*m
-      END DO
-    END DO
-    ! initialize varout
-    !$ACC LOOP GANG(STATIC: 1) VECTOR
-    DO jc = startidx, endidx
-      varout(jc, elev) = d_p(jc, elev)
-    END DO
-    ! solve for varout from the vectors c-prime and d-prime
-    !$ACC LOOP SEQ
-!NEC$ outerloop_unroll(4)
-    DO i = elev - 1, slev, -1
-      !$ACC LOOP GANG(STATIC: 1) VECTOR
-      DO jc = startidx, endidx
-        varout(jc, i) = d_p(jc, i) - c_p(jc, i)*varout(jc, i + 1)
-      END DO
-    END DO
-    !$ACC END PARALLEL
-
-    IF (.NOT. PRESENT(opt_acc_queue)) THEN
-      !$ACC WAIT(acc_queue)
-    END IF
-
-  END SUBROUTINE tdma_solver_vec
-  !-------------------------------------------------------------------------
-
   !-------------------------------------------------------------------------
   !
   !> Helper functions for computing the vertical layer structure
diff --git a/src/support/mo_math_utilities.cpp b/src/support/mo_math_utilities.cpp
new file mode 100644
index 0000000..a8ccce4
--- /dev/null
+++ b/src/support/mo_math_utilities.cpp
@@ -0,0 +1,77 @@
+#include <vector>
+#include <iostream>
+#include <chrono> // For timing
+
+extern "C" {
+
+void tdma_solver_vec(double *a, double *b, double *c, double *d,
+                     int slev, int elev, int startidx, int endidx,
+                     double* varout, int opt_acc_queue = -1) {
+
+    int acc_queue = (opt_acc_queue == -1) ? 1 : opt_acc_queue; // Use 1 as the default if opt_acc_queue is not provided
+
+    // Determine array sizes based on startidx and endidx
+    int nrows = endidx - startidx;
+    int ncols = elev - slev;
+
+    // Temporary arrays for c-prime and d-prime
+    std::vector<double> cp(nrows * ncols, 0.0);
+    std::vector<double> dp(nrows * ncols, 0.0);
+
+    // Helper function to access 2D arrays stored as 1D
+    auto idx = [&](int row, int col) { return col * nrows + row; };
+
+    // Start timing
+    auto start_time = std::chrono::high_resolution_clock::now();
+
+    // OpenACC Parallel Region
+    #pragma acc parallel default(present) create(cp[:nrows*ncols], dp[:nrows*ncols]) async(acc_queue)
+    {
+        // Initialize c-prime and d-prime
+        #pragma acc loop gang(static: 1) vector
+        for (int jc = startidx; jc < endidx; ++jc) {
+            cp[idx(jc, slev)] = c[idx(jc, slev)] / b[idx(jc, slev)];
+            dp[idx(jc, slev)] = d[idx(jc, slev)] / b[idx(jc, slev)];
+        }
+
+        // Solve for vectors c-prime and d-prime
+        #pragma acc loop seq
+        for (int i = slev + 1; i < elev; ++i) {
+            #pragma acc loop gang(static: 1) vector
+            for (int jc = startidx; jc < endidx; ++jc) {
+                double m = 1.0 / (b[idx(jc, i)] - cp[idx(jc, i - 1)] * a[idx(jc, i)]);
+                cp[idx(jc, i)] = c[idx(jc, i)] * m;
+                dp[idx(jc, i)] = (d[idx(jc, i)] - dp[idx(jc, i - 1)] * a[idx(jc, i)]) * m;
+            }
+        }
+
+        // Initialize varout
+        #pragma acc loop gang(static: 1) vector
+        for (int jc = startidx; jc < endidx; ++jc) {
+            varout[idx(jc, elev-1)] = dp[idx(jc, elev-1)];
+        }
+
+        // Solve for varout from the vectors c-prime and d-prime
+        #pragma acc loop seq
+        for (int i = elev - 2; i >= slev; --i) {
+            #pragma acc loop gang(static: 1) vector
+            for (int jc = startidx; jc < endidx; ++jc) {
+                varout[idx(jc, i)] = dp[idx(jc, i)] - cp[idx(jc, i)] * varout[idx(jc, i + 1)];
+            }
+        }
+    }
+
+    printf("tdma_solver_vec: completed using C++\n");
+
+    // Wait for OpenACC asynchronous operations to complete if acc_queue is not optional
+    if (opt_acc_queue == -1) {
+        #pragma acc wait(acc_queue)
+    }
+
+    // End timing
+    auto end_time = std::chrono::high_resolution_clock::now();
+    std::chrono::duration<double> elapsed_time = end_time - start_time;
+
+    std::cout << "Elapsed time for tdma_solver_vec (C++): " << elapsed_time.count() << " seconds" << std::endl;
+}
+}
-- 
GitLab