From 28081276395b98cab705222bb01251cd09413842 Mon Sep 17 00:00:00 2001 From: Pradipta Samanta <samanta@dkrz.de> Date: Fri, 3 Jan 2025 11:08:00 +0100 Subject: [PATCH] added nrows and ncols as arguments to the cpp routine of tdma_solver_vec --- src/support/mo_math_utilities.F90 | 4 ++-- src/support/mo_math_utilities.cpp | 6 +----- test/fortran/test_math_utilities.f90 | 2 +- 3 files changed, 4 insertions(+), 8 deletions(-) diff --git a/src/support/mo_math_utilities.F90 b/src/support/mo_math_utilities.F90 index 5c00192..525d2a1 100644 --- a/src/support/mo_math_utilities.F90 +++ b/src/support/mo_math_utilities.F90 @@ -247,10 +247,10 @@ MODULE mo_math_utilities ! C++ binding for tdma_solver_vec INTERFACE - SUBROUTINE tdma_solver_vec(a, b, c, d, slev, elev, startidx, endidx, varout, opt_acc_queue) BIND(C, NAME="tdma_solver_vec") + SUBROUTINE tdma_solver_vec(a, b, c, d, slev, elev, startidx, endidx, nrows, ncols, varout, opt_acc_queue) BIND(C, NAME="tdma_solver_vec") IMPORT :: C_DOUBLE, C_INT REAL(C_DOUBLE), INTENT(IN) :: a(*), b(*), c(*), d(*) - INTEGER(C_INT), VALUE :: slev, elev, startidx, endidx + INTEGER(C_INT), VALUE :: slev, elev, startidx, endidx, nrows, ncols REAL(C_DOUBLE), INTENT(OUT) :: varout(*) INTEGER(C_INT), OPTIONAL :: opt_acc_queue END SUBROUTINE tdma_solver_vec diff --git a/src/support/mo_math_utilities.cpp b/src/support/mo_math_utilities.cpp index 45430b0..ff94b89 100644 --- a/src/support/mo_math_utilities.cpp +++ b/src/support/mo_math_utilities.cpp @@ -6,17 +6,13 @@ extern "C" { void tdma_solver_vec(double *a, double *b, double *c, double *d, int slev, int elev, int startidx, int endidx, - double* varout, int opt_acc_queue = -1) { + int nrows, int ncols, double *varout, int opt_acc_queue = -1) { // Start timing auto start_time = std::chrono::high_resolution_clock::now(); int acc_queue = (opt_acc_queue == -1) ? 1 : opt_acc_queue; // Use 1 as the default if opt_acc_queue is not provided - // Determine array sizes based on startidx and endidx - int nrows = endidx - startidx; - int ncols = elev - slev; - double* cp = new double[nrows * ncols]; double* dp = new double[nrows * ncols]; diff --git a/test/fortran/test_math_utilities.f90 b/test/fortran/test_math_utilities.f90 index 8dcf5b3..9f95ca7 100644 --- a/test/fortran/test_math_utilities.f90 +++ b/test/fortran/test_math_utilities.f90 @@ -264,7 +264,7 @@ CONTAINS #ifndef __USE_CPP_BINDINGS CALL tdma_solver_vec(a, b, c, d, 1, n, 1, n, x) #else - CALL tdma_solver_vec(a, b, c, d, 0, n, 0, n, x, -1) + CALL tdma_solver_vec(a, b, c, d, 0, n, 0, n, n, n, x, -1) #endif CALL CPU_TIME(end_time) -- GitLab