[arrayfire] 02/284: Convert CPU blas to use async queues

Ghislain Vaillant ghisvail-guest at moszumanska.debian.org
Sun Feb 7 18:59:12 UTC 2016


This is an automated email from the git hooks/post-receive script.

ghisvail-guest pushed a commit to branch debian/experimental
in repository arrayfire.

commit b94c3df4e1c5288eb1bb985f372e3f8e5910fe3d
Author: Umar Arshad <umar at arrayfire.com>
Date:   Sun Aug 9 01:19:34 2015 -0400

    Convert CPU blas to use async queues
---
 src/backend/cpu/blas.cpp | 73 ++++++++++++++++++++++++++----------------------
 1 file changed, 40 insertions(+), 33 deletions(-)

diff --git a/src/backend/cpu/blas.cpp b/src/backend/cpu/blas.cpp
index 0bbd399..8887202 100644
--- a/src/backend/cpu/blas.cpp
+++ b/src/backend/cpu/blas.cpp
@@ -13,6 +13,8 @@
 #include <cassert>
 #include <err_cpu.hpp>
 #include <err_common.hpp>
+#include <platform.hpp>
+#include <async_queue.hpp>
 
 namespace cpu
 {
@@ -131,36 +133,38 @@ Array<T> matmul(const Array<T> &lhs, const Array<T> &rhs,
     int N = rDims[bColDim];
     int K = lDims[aColDim];
 
-    //FIXME: Leaks on errors.
-    Array<T> out = createEmptyArray<T>(af::dim4(M, N, 1, 1));
-    auto alpha = getScale<T, 1>();
-    auto beta  = getScale<T, 0>();
-
-    dim4 lStrides = lhs.strides();
-    dim4 rStrides = rhs.strides();
     using BT  =       typename blas_base<T>::type;
     using CBT = const typename blas_base<T>::type;
 
-    if(rDims[bColDim] == 1) {
-        N = lDims[aColDim];
-        gemv_func<T>()(
-            CblasColMajor, lOpts,
-            lDims[0], lDims[1],
-            alpha,
-            reinterpret_cast<CBT*>(lhs.get()), lStrides[1],
-            reinterpret_cast<CBT*>(rhs.get()), rStrides[0],
-            beta,
-            reinterpret_cast<BT*>(out.get()), 1);
-    } else {
-        gemm_func<T>()(
-            CblasColMajor, lOpts, rOpts,
-            M, N, K,
-            alpha,
-            reinterpret_cast<CBT*>(lhs.get()), lStrides[1],
-            reinterpret_cast<CBT*>(rhs.get()), rStrides[1],
-            beta,
-            reinterpret_cast<BT*>(out.get()), out.dims()[0]);
-    }
+    Array<T> out = createEmptyArray<T>(af::dim4(M, N, 1, 1));
+    auto func = [=] (Array<T> output, const Array<T> left, const Array<T> right) {
+        auto alpha = getScale<T, 1>();
+        auto beta  = getScale<T, 0>();
+
+        dim4 lStrides = left.strides();
+        dim4 rStrides = right.strides();
+
+        if(rDims[bColDim] == 1) {
+            gemv_func<T>()(
+                CblasColMajor, lOpts,
+                lDims[0], lDims[1],
+                alpha,
+                reinterpret_cast<CBT*>(left.get()), lStrides[1],
+                reinterpret_cast<CBT*>(right.get()), rStrides[0],
+                beta,
+                reinterpret_cast<BT*>(output.get()), 1);
+        } else {
+            gemm_func<T>()(
+                CblasColMajor, lOpts, rOpts,
+                M, N, K,
+                alpha,
+                reinterpret_cast<CBT*>(left.get()), lStrides[1],
+                reinterpret_cast<CBT*>(right.get()), rStrides[1],
+                beta,
+                reinterpret_cast<BT*>(output.get()), output.dims()[0]);
+        }
+    };
+    getQueue().enqueue(func, out, lhs, rhs);
 
     return out;
 }
@@ -172,7 +176,7 @@ template<> cfloat  conj<cfloat> (cfloat  c) { return std::conj(c); }
 template<> cdouble conj<cdouble>(cdouble c) { return std::conj(c); }
 
 template<typename T, bool conjugate, bool both_conjugate>
-Array<T> dot_(const Array<T> &lhs, const Array<T> &rhs,
+void dot_(Array<T> output, const Array<T> &lhs, const Array<T> &rhs,
               af_mat_prop optLhs, af_mat_prop optRhs)
 {
     int N = lhs.dims()[0];
@@ -186,22 +190,25 @@ Array<T> dot_(const Array<T> &lhs, const Array<T> &rhs,
 
     if(both_conjugate) out = cpu::conj(out);
 
-    return createValueArray(af::dim4(1), out);
+    *output.get() = out;
+
 }
 
 template<typename T>
 Array<T> dot(const Array<T> &lhs, const Array<T> &rhs,
              af_mat_prop optLhs, af_mat_prop optRhs)
 {
+    Array<T> out = createEmptyArray<T>(af::dim4(1));
     if(optLhs == AF_MAT_CONJ && optRhs == AF_MAT_CONJ) {
-        return dot_<T, false, true>(lhs, rhs, optLhs, optRhs);
+        getQueue().enqueue(dot_<T, false, true>, out, lhs, rhs, optLhs, optRhs);
     } else if (optLhs == AF_MAT_CONJ && optRhs == AF_MAT_NONE) {
-        return dot_<T, true, false>(lhs, rhs, optLhs, optRhs);
+        getQueue().enqueue(dot_<T, true, false>,out, lhs, rhs, optLhs, optRhs);
     } else if (optLhs == AF_MAT_NONE && optRhs == AF_MAT_CONJ) {
-        return dot_<T, true, false>(rhs, lhs, optRhs, optLhs);
+        getQueue().enqueue(dot_<T, true, false>,out, rhs, lhs, optRhs, optLhs);
     } else {
-        return dot_<T, false, false>(lhs, rhs, optLhs, optRhs);
+        getQueue().enqueue(dot_<T, false, false>,out, lhs, rhs, optLhs, optRhs);
     }
+    return out;
 }
 
 #undef BT

-- 
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/arrayfire.git



More information about the debian-science-commits mailing list