[arrayfire] 176/284: Add OpenCL-CPU fallback for solve

Ghislain Vaillant ghisvail-guest at moszumanska.debian.org
Sun Feb 7 18:59:31 UTC 2016


This is an automated email from the git hooks/post-receive script.

ghisvail-guest pushed a commit to branch debian/experimental
in repository arrayfire.

commit ffb191cbce56297e391f2f12ec45351dd8ebf1d8
Author: Shehzan Mohammed <shehzan at arrayfire.com>
Date:   Fri Jan 8 13:08:01 2016 -0500

    Add OpenCL-CPU fallback for solve
---
 src/backend/opencl/cpu/cpu_solve.cpp | 187 +++++++++++++++++++++++++++++++++++
 src/backend/opencl/cpu/cpu_solve.hpp |  23 +++++
 src/backend/opencl/solve.cpp         |  11 +++
 3 files changed, 221 insertions(+)

diff --git a/src/backend/opencl/cpu/cpu_solve.cpp b/src/backend/opencl/cpu/cpu_solve.cpp
new file mode 100644
index 0000000..824bce2
--- /dev/null
+++ b/src/backend/opencl/cpu/cpu_solve.cpp
@@ -0,0 +1,187 @@
+/*******************************************************
+ * Copyright (c) 2014, ArrayFire
+ * All rights reserved.
+ *
+ * This file is distributed under 3-clause BSD license.
+ * The complete license agreement can be obtained at:
+ * http://arrayfire.com/licenses/BSD-3-Clause
+ ********************************************************/
+
+#include <cpu/cpu_lapack_helper.hpp>
+#include <cpu/cpu_solve.hpp>
+#include <err_common.hpp>
+#include <copy.hpp>
+#include <math.hpp>
+
+#include <af/dim4.hpp>
+
+namespace opencl
+{
+namespace cpu
+{
+
+template<typename T>
+using gesv_func_def = int (*)(ORDER_TYPE, int, int,
+                              T *, int,
+                              int *,
+                              T *, int);
+
+template<typename T>
+using gels_func_def = int (*)(ORDER_TYPE, char,
+                              int, int, int,
+                              T *, int,
+                              T *, int);
+
+template<typename T>
+using getrs_func_def = int (*)(ORDER_TYPE, char,
+                               int, int,
+                               const T *, int,
+                               const int *,
+                               T *, int);
+
+template<typename T>
+using trtrs_func_def = int (*)(ORDER_TYPE,
+                               char, char, char,
+                               int, int,
+                               const T *, int,
+                               T *, int);
+
+
+#define SOLVE_FUNC_DEF( FUNC )                                      \
+template<typename T> FUNC##_func_def<T> FUNC##_func();
+
+
+#define SOLVE_FUNC( FUNC, TYPE, PREFIX )                            \
+template<> FUNC##_func_def<TYPE>     FUNC##_func<TYPE>()            \
+{ return & LAPACK_NAME(PREFIX##FUNC); }
+
+SOLVE_FUNC_DEF( gesv )
+SOLVE_FUNC(gesv , float  , s)
+SOLVE_FUNC(gesv , double , d)
+SOLVE_FUNC(gesv , cfloat , c)
+SOLVE_FUNC(gesv , cdouble, z)
+
+SOLVE_FUNC_DEF( gels )
+SOLVE_FUNC(gels , float  , s)
+SOLVE_FUNC(gels , double , d)
+SOLVE_FUNC(gels , cfloat , c)
+SOLVE_FUNC(gels , cdouble, z)
+
+SOLVE_FUNC_DEF( getrs )
+SOLVE_FUNC(getrs , float  , s)
+SOLVE_FUNC(getrs , double , d)
+SOLVE_FUNC(getrs , cfloat , c)
+SOLVE_FUNC(getrs , cdouble, z)
+
+SOLVE_FUNC_DEF( trtrs )
+SOLVE_FUNC(trtrs , float  , s)
+SOLVE_FUNC(trtrs , double , d)
+SOLVE_FUNC(trtrs , cfloat , c)
+SOLVE_FUNC(trtrs , cdouble, z)
+
+template<typename T>
+Array<T> solveLU(const Array<T> &A, const Array<int> &pivot,
+                 const Array<T> &b, const af_mat_prop options)
+{
+    int N = A.dims()[0];
+    int NRHS = b.dims()[1];
+
+    Array<T> B = copyArray<T>(b);
+
+    T *aPtr = getMappedPtr<T>(A.get());
+    T *bPtr = getMappedPtr<T>(B.get());
+    int *pPtr = getMappedPtr<int>(pivot.get());
+
+    getrs_func<T>()(AF_LAPACK_COL_MAJOR, 'N',
+                    N, NRHS,
+                    aPtr, A.strides()[1],
+                    pPtr,
+                    bPtr, B.strides()[1]);
+
+    unmapPtr(A.get(), aPtr);
+    unmapPtr(B.get(), bPtr);
+    unmapPtr(pivot.get(), pPtr);
+
+    return B;
+}
+
+template<typename T>
+Array<T> triangleSolve(const Array<T> &A, const Array<T> &b, const af_mat_prop options)
+{
+    Array<T> B = copyArray<T>(b);
+    int N = B.dims()[0];
+    int NRHS = B.dims()[1];
+
+    T *aPtr = getMappedPtr<T>(A.get());
+    T *bPtr = getMappedPtr<T>(B.get());
+
+    trtrs_func<T>()(AF_LAPACK_COL_MAJOR,
+                    options & AF_MAT_UPPER ? 'U' : 'L',
+                    'N', // transpose flag
+                    options & AF_MAT_DIAG_UNIT ? 'U' : 'N',
+                    N, NRHS,
+                    aPtr, A.strides()[1],
+                    bPtr, B.strides()[1]);
+
+    unmapPtr(A.get(), aPtr);
+    unmapPtr(B.get(), bPtr);
+
+    return B;
+}
+
+
+template<typename T>
+Array<T> solve(const Array<T> &a, const Array<T> &b, const af_mat_prop options)
+{
+
+    if (options & AF_MAT_UPPER ||
+        options & AF_MAT_LOWER) {
+        return triangleSolve<T>(a, b, options);
+    }
+
+    int M = a.dims()[0];
+    int N = a.dims()[1];
+    int K = b.dims()[1];
+
+    Array<T> A = copyArray<T>(a);
+    Array<T> B = padArray<T, T>(b, dim4(max(M, N), K), scalar<T>(0));
+
+    T *aPtr = getMappedPtr<T>(A.get());
+    T *bPtr = getMappedPtr<T>(B.get());
+
+    if(M == N) {
+        std::vector<int> pivot(N);
+        gesv_func<T>()(AF_LAPACK_COL_MAJOR, N, K,
+                       aPtr, A.strides()[1],
+                       &pivot.front(),
+                       bPtr, B.strides()[1]);
+    } else {
+        int sM = a.strides()[1];
+        int sN = a.strides()[2] / sM;
+
+        gels_func<T>()(AF_LAPACK_COL_MAJOR, 'N',
+                       M, N, K,
+                       aPtr, A.strides()[1],
+                       bPtr, max(sM, sN));
+        B.resetDims(dim4(N, K));
+    }
+
+    unmapPtr(A.get(), aPtr);
+    unmapPtr(B.get(), bPtr);
+
+    return B;
+}
+
+#define INSTANTIATE_SOLVE(T)                                            \
+    template Array<T> solve<T>(const Array<T> &a, const Array<T> &b,    \
+                               const af_mat_prop options);              \
+    template Array<T> solveLU<T>(const Array<T> &A, const Array<int> &pivot, \
+                                 const Array<T> &b, const af_mat_prop options); \
+
+INSTANTIATE_SOLVE(float)
+INSTANTIATE_SOLVE(cfloat)
+INSTANTIATE_SOLVE(double)
+INSTANTIATE_SOLVE(cdouble)
+
+}
+}
diff --git a/src/backend/opencl/cpu/cpu_solve.hpp b/src/backend/opencl/cpu/cpu_solve.hpp
new file mode 100644
index 0000000..6c3de64
--- /dev/null
+++ b/src/backend/opencl/cpu/cpu_solve.hpp
@@ -0,0 +1,23 @@
+/*******************************************************
+ * Copyright (c) 2014, ArrayFire
+ * All rights reserved.
+ *
+ * This file is distributed under 3-clause BSD license.
+ * The complete license agreement can be obtained at:
+ * http://arrayfire.com/licenses/BSD-3-Clause
+ ********************************************************/
+
+#include <Array.hpp>
+
+namespace opencl
+{
+namespace cpu
+{
+    template<typename T>
+    Array<T> solve(const Array<T> &a, const Array<T> &b, const af_mat_prop options = AF_MAT_NONE);
+
+    template<typename T>
+    Array<T> solveLU(const Array<T> &a, const Array<int> &pivot,
+                     const Array<T> &b, const af_mat_prop options = AF_MAT_NONE);
+}
+}
diff --git a/src/backend/opencl/solve.cpp b/src/backend/opencl/solve.cpp
index 6d2bea4..4fede07 100644
--- a/src/backend/opencl/solve.cpp
+++ b/src/backend/opencl/solve.cpp
@@ -25,6 +25,9 @@
 #include <algorithm>
 #include <string>
 
+#include <platform.hpp>
+#include <cpu/cpu_solve.hpp>
+
 namespace opencl
 {
 
@@ -32,6 +35,10 @@ template<typename T>
 Array<T> solveLU(const Array<T> &A, const Array<int> &pivot,
                  const Array<T> &b, const af_mat_prop options)
 {
+    if(OpenCLCPUOffload()) {
+        return cpu::solveLU(A, pivot, b, options);
+    }
+
     int N = A.dims()[0];
     int NRHS = b.dims()[1];
 
@@ -296,6 +303,10 @@ template<typename T>
 Array<T> solve(const Array<T> &a, const Array<T> &b, const af_mat_prop options)
 {
     try {
+        if(OpenCLCPUOffload()) {
+            return cpu::solve(a, b, options);
+        }
+
         initBlas();
 
         if (options & AF_MAT_UPPER ||

-- 
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/arrayfire.git



More information about the debian-science-commits mailing list