[arrayfire] 168/284: dot in CUDA/OpenCL now uses mul followed by reduction

Ghislain Vaillant ghisvail-guest at moszumanska.debian.org
Sun Feb 7 18:59:30 UTC 2016


This is an automated email from the git hooks/post-receive script.

ghisvail-guest pushed a commit to branch debian/experimental
in repository arrayfire.

commit 507ec929888bf137db79648778701db7b1ca5532
Author: Shehzan Mohammed <shehzan at arrayfire.com>
Date:   Thu Jan 7 18:17:06 2016 -0500

    dot in CUDA/OpenCL now uses mul followed by reduction
---
 src/backend/cuda/blas.cpp   | 58 +++++++++++++++++++-------------------
 src/backend/opencl/blas.cpp | 68 ++++++++++++++++++++++-----------------------
 2 files changed, 63 insertions(+), 63 deletions(-)

diff --git a/src/backend/cuda/blas.cpp b/src/backend/cuda/blas.cpp
index 85f48da..1e5dd5d 100644
--- a/src/backend/cuda/blas.cpp
+++ b/src/backend/cuda/blas.cpp
@@ -18,6 +18,9 @@
 #include <math.hpp>
 #include <err_common.hpp>
 #include <cublasManager.hpp>
+#include <arith.hpp>
+#include <reduce.hpp>
+#include <complex.hpp>
 
 namespace cuda
 {
@@ -197,40 +200,37 @@ Array<T> matmul(const Array<T> &lhs, const Array<T> &rhs,
 
 }
 
-template<typename T, bool conjugate, bool both_conjugate>
-Array<T> dot_(const Array<T> &lhs, const Array<T> &rhs,
-              af_mat_prop optLhs, af_mat_prop optRhs)
-{
-    int N = lhs.dims()[0];
-
-    T out;
-
-    CUBLAS_CHECK((dot_func<T, conjugate>()(
-                 getHandle(),
-                 N,
-                 lhs.get(), lhs.strides()[0],
-                 rhs.get(), rhs.strides()[0],
-                 &out)));
-
-    if(both_conjugate)
-        return createValueArray(af::dim4(1), conj(out));
-    else
-        return createValueArray(af::dim4(1), out);
-}
+// Keeping this around for future reference
+//template<typename T, bool conjugate, bool both_conjugate>
+//Array<T> dot_(const Array<T> &lhs, const Array<T> &rhs,
+//              af_mat_prop optLhs, af_mat_prop optRhs)
+//{
+//    int N = lhs.dims()[0];
+//
+//    T out;
+//
+//    CUBLAS_CHECK((dot_func<T, conjugate>()(
+//                 getHandle(),
+//                 N,
+//                 lhs.get(), lhs.strides()[0],
+//                 rhs.get(), rhs.strides()[0],
+//                 &out)));
+//
+//    if(both_conjugate)
+//        return createValueArray(af::dim4(1), conj(out));
+//    else
+//        return createValueArray(af::dim4(1), out);
+//}
 
 template<typename T>
 Array<T> dot(const Array<T> &lhs, const Array<T> &rhs,
              af_mat_prop optLhs, af_mat_prop optRhs)
 {
-    if(optLhs == AF_MAT_CONJ && optRhs == AF_MAT_CONJ) {
-        return dot_<T, false, true>(lhs, rhs, optLhs, optRhs);
-    } else if (optLhs == AF_MAT_CONJ && optRhs == AF_MAT_NONE) {
-        return dot_<T, true, false>(lhs, rhs, optLhs, optRhs);
-    } else if (optLhs == AF_MAT_NONE && optRhs == AF_MAT_CONJ) {
-        return dot_<T, true, false>(rhs, lhs, optRhs, optLhs);
-    } else {
-        return dot_<T, false, false>(lhs, rhs, optLhs, optRhs);
-    }
+    const Array<T> lhs_ = (optLhs == AF_MAT_NONE ? lhs : conj<T>(lhs));
+    const Array<T> rhs_ = (optRhs == AF_MAT_NONE ? rhs : conj<T>(rhs));
+
+    const Array<T> temp = arithOp<T, af_mul_t>(lhs_, rhs_, lhs_.dims());
+    return reduce<af_add_t, T, T>(temp, 0, false, 0);
 }
 
 template<typename T>
diff --git a/src/backend/opencl/blas.cpp b/src/backend/opencl/blas.cpp
index f9f8af1..15e2373 100644
--- a/src/backend/opencl/blas.cpp
+++ b/src/backend/opencl/blas.cpp
@@ -19,6 +19,9 @@
 #include <err_clblas.hpp>
 #include <math.hpp>
 #include <transpose.hpp>
+#include <arith.hpp>
+#include <reduce.hpp>
+#include <complex.hpp>
 
 #include <cpu/cpu_blas.hpp>
 
@@ -174,45 +177,42 @@ Array<T> matmul(const Array<T> &lhs, const Array<T> &rhs,
     return out;
 }
 
-template<typename T, bool conjugate, bool both_conjugate>
-Array<T> dot_(const Array<T> &lhs, const Array<T> &rhs,
-              af_mat_prop optLhs, af_mat_prop optRhs)
-{
-    initBlas();
-
-    int N = lhs.dims()[0];
-    dot_func<T, conjugate> dot;
-    cl::Event event;
-    Array<T> out = createEmptyArray<T>(af::dim4(1));
-    cl::Buffer scratch(getContext(), CL_MEM_READ_WRITE, sizeof(T) * N);
-    CLBLAS_CHECK(
-        dot(N,
-            (*out.get())(), out.getOffset(),
-            (*lhs.get())(),  lhs.getOffset(), lhs.strides()[0],
-            (*rhs.get())(),  rhs.getOffset(), rhs.strides()[0],
-            scratch(),
-            1, &getQueue()(), 0, nullptr, &event())
-        );
-
-    if(both_conjugate)
-        transpose_inplace<T>(out, true);
-
-    return out;
-}
+// Keeping this around for future reference
+//template<typename T, bool conjugate, bool both_conjugate>
+//Array<T> dot_(const Array<T> &lhs, const Array<T> &rhs,
+//              af_mat_prop optLhs, af_mat_prop optRhs)
+//{
+//    initBlas();
+//
+//    int N = lhs.dims()[0];
+//    dot_func<T, conjugate> dot;
+//    cl::Event event;
+//    Array<T> out = createEmptyArray<T>(af::dim4(1));
+//    cl::Buffer scratch(getContext(), CL_MEM_READ_WRITE, sizeof(T) * N);
+//    CLBLAS_CHECK(
+//        dot(N,
+//            (*out.get())(), out.getOffset(),
+//            (*lhs.get())(),  lhs.getOffset(), lhs.strides()[0],
+//            (*rhs.get())(),  rhs.getOffset(), rhs.strides()[0],
+//            scratch(),
+//            1, &getQueue()(), 0, nullptr, &event())
+//        );
+//
+//    if(both_conjugate)
+//        transpose_inplace<T>(out, true);
+//
+//    return out;
+//}
 
 template<typename T>
 Array<T> dot(const Array<T> &lhs, const Array<T> &rhs,
              af_mat_prop optLhs, af_mat_prop optRhs)
 {
-    if(optLhs == AF_MAT_CONJ && optRhs == AF_MAT_CONJ) {
-        return dot_<T, false, true>(lhs, rhs, optLhs, optRhs);
-    } else if (optLhs == AF_MAT_CONJ && optRhs == AF_MAT_NONE) {
-        return dot_<T, true, false>(lhs, rhs, optLhs, optRhs);
-    } else if (optLhs == AF_MAT_NONE && optRhs == AF_MAT_CONJ) {
-        return dot_<T, true, false>(rhs, lhs, optRhs, optLhs);
-    } else {
-        return dot_<T, false, false>(lhs, rhs, optLhs, optRhs);
-    }
+    const Array<T> lhs_ = (optLhs == AF_MAT_NONE ? lhs : conj<T>(lhs));
+    const Array<T> rhs_ = (optRhs == AF_MAT_NONE ? rhs : conj<T>(rhs));
+
+    const Array<T> temp = arithOp<T, af_mul_t>(lhs_, rhs_, lhs_.dims());
+    return reduce<af_add_t, T, T>(temp, 0, false, 0);
 }
 
 #define INSTANTIATE_BLAS(TYPE)                                                          \

-- 
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/arrayfire.git



More information about the debian-science-commits mailing list