[arrayfire] 29/248: Added any dimension batching and gfor support for approx1 and approx2

Ghislain Vaillant ghisvail-guest at moszumanska.debian.org
Tue Nov 17 15:53:51 UTC 2015


This is an automated email from the git hooks/post-receive script.

ghisvail-guest pushed a commit to branch dfsg-clean
in repository arrayfire.

commit 871e114039f1e9b68ea62dff9e3df02b1a191003
Author: Shehzan Mohammed <shehzan at arrayfire.com>
Date:   Mon Aug 31 16:52:07 2015 -0400

    Added any dimension batching and gfor support for approx1 and approx2
---
 src/api/c/approx.cpp                 |  28 +++--
 src/backend/cpu/approx.cpp           | 193 +++++++++++++++++++----------------
 src/backend/cuda/kernel/approx.hpp   |  59 +++++++----
 src/backend/opencl/kernel/approx.hpp |  17 ++-
 src/backend/opencl/kernel/approx1.cl |  14 +--
 src/backend/opencl/kernel/approx2.cl |  29 +++---
 test/approx1.cpp                     |  20 ++--
 test/approx2.cpp                     |  24 ++---
 8 files changed, 220 insertions(+), 164 deletions(-)

diff --git a/src/api/c/approx.cpp b/src/api/c/approx.cpp
index c0bb02c..7c2935a 100644
--- a/src/api/c/approx.cpp
+++ b/src/api/c/approx.cpp
@@ -50,7 +50,9 @@ af_err af_approx1(af_array *out, const af_array in, const af_array pos,
         ARG_ASSERT(2, p_info.isRealFloating());                   // Only floating types
         ARG_ASSERT(1, i_info.isSingle() == p_info.isSingle());    // Must have same precision
         ARG_ASSERT(1, i_info.isDouble() == p_info.isDouble());    // Must have same precision
-        DIM_ASSERT(2, p_info.isColumn() || pdims[1] == idims[1]); // Only 1D input allowed or Same no. of cols
+        // POS should either be (x, 1, 1, 1) or (1, idims[1], idims[2], idims[3])
+        DIM_ASSERT(2, p_info.isColumn() ||
+                      (pdims[1] == idims[1] && pdims[2] == idims[2] && pdims[3] == idims[3]));
         ARG_ASSERT(3, (method == AF_INTERP_LINEAR || method == AF_INTERP_NEAREST));
 
         af_array output;
@@ -77,17 +79,23 @@ af_err af_approx2(af_array *out, const af_array in, const af_array pos0, const a
         ArrayInfo p_info = getInfo(pos0);
         ArrayInfo q_info = getInfo(pos1);
 
+        dim4 idims = i_info.dims();
+        dim4 pdims = p_info.dims();
+        dim4 qdims = q_info.dims();
+
         af_dtype itype = i_info.getType();
 
-        ARG_ASSERT(1, i_info.isFloating());                       // Only floating and complex types
-        ARG_ASSERT(2, p_info.isRealFloating());                   // Only floating types
-        ARG_ASSERT(3, q_info.isRealFloating());                   // Only floating types
-        ARG_ASSERT(1, p_info.getType() == q_info.getType());      // Must have same type
-        ARG_ASSERT(1, i_info.isSingle() == p_info.isSingle());    // Must have same precision
-        ARG_ASSERT(1, i_info.isDouble() == p_info.isDouble());    // Must have same precision
-        DIM_ASSERT(2, p_info.dims() == q_info.dims());            // POS0 and POS1 must have same dims
-        DIM_ASSERT(2, p_info.dims()[2] == 1
-                   || p_info.dims()[2] == i_info.dims()[2]);      // Allowing input batch. Output dims = (px, py, iz, iw)
+        ARG_ASSERT(1, i_info.isFloating());                     // Only floating and complex types
+        ARG_ASSERT(2, p_info.isRealFloating());                 // Only floating types
+        ARG_ASSERT(3, q_info.isRealFloating());                 // Only floating types
+        ARG_ASSERT(1, p_info.getType() == q_info.getType());    // Must have same type
+        ARG_ASSERT(1, i_info.isSingle() == p_info.isSingle());  // Must have same precision
+        ARG_ASSERT(1, i_info.isDouble() == p_info.isDouble());  // Must have same precision
+        DIM_ASSERT(2, pdims == qdims);                          // POS0 and POS1 must have same dims
+
+        // POS should either be (x, y, 1, 1) or (x, y, idims[2], idims[3])
+        DIM_ASSERT(2, (pdims[2] == 1        && pdims[3] == 1) ||
+                      (pdims[2] == idims[2] && pdims[3] == idims[3]));
         ARG_ASSERT(3, (method == AF_INTERP_LINEAR || method == AF_INTERP_NEAREST));
 
         af_array output;
diff --git a/src/backend/cpu/approx.cpp b/src/backend/cpu/approx.cpp
index 78d8cf3..f9e8fdd 100644
--- a/src/backend/cpu/approx.cpp
+++ b/src/backend/cpu/approx.cpp
@@ -25,7 +25,8 @@ namespace cpu
                   const Ty *in,  const af::dim4 &idims, const dim_t iElems,
                   const Tp *pos, const af::dim4 &pdims,
                   const af::dim4 &ostrides, const af::dim4 &istrides, const af::dim4 &pstrides,
-                  const float offGrid, const dim_t idx, const dim_t idy)
+                  const float offGrid, const bool pBatch,
+                  const dim_t idx, const dim_t idy, const dim_t idz, const dim_t idw)
         {
             return;
         }
@@ -38,9 +39,11 @@ namespace cpu
                   const Ty *in,  const af::dim4 &idims, const dim_t iElems,
                   const Tp *pos, const af::dim4 &pdims,
                   const af::dim4 &ostrides, const af::dim4 &istrides, const af::dim4 &pstrides,
-                  const float offGrid, const dim_t idx, const dim_t idy)
+                  const float offGrid, const bool pBatch,
+                  const dim_t idx, const dim_t idy, const dim_t idz, const dim_t idw)
         {
-            const dim_t pmId = idx + (pdims[1] == 1 ? 0 : idy * pstrides[1]);
+            dim_t pmId = idx;
+            if(pBatch) pmId += idw * pstrides[3] + idz * pstrides[2] + idy * pstrides[1];
 
             const Tp x = pos[pmId];
             bool gFlag = false;
@@ -48,20 +51,16 @@ namespace cpu
                 gFlag = true;
             }
 
-            for(dim_t idw = 0; idw < odims[3]; idw++) {
-                for(dim_t idz = 0; idz < odims[2]; idz++) {
-                    const dim_t omId = idw * ostrides[3] + idz * ostrides[2]
-                                     + idy * ostrides[1] + idx;
-                    if(gFlag) {
-                        out[omId] = scalar<Ty>(offGrid);
-                    } else {
-                        dim_t ioff = idw * istrides[3] + idz * istrides[2]
-                                   + idy * istrides[1];
-                        const dim_t iMem = round(x) + ioff;
-
-                        out[omId] = in[iMem];
-                    }
-                }
+            const dim_t omId = idw * ostrides[3] + idz * ostrides[2]
+                             + idy * ostrides[1] + idx;
+            if(gFlag) {
+                out[omId] = scalar<Ty>(offGrid);
+            } else {
+                dim_t ioff = idw * istrides[3] + idz * istrides[2]
+                           + idy * istrides[1];
+                const dim_t iMem = round(x) + ioff;
+
+                out[omId] = in[iMem];
             }
         }
     };
@@ -73,9 +72,11 @@ namespace cpu
                   const Ty *in,  const af::dim4 &idims, const dim_t iElems,
                   const Tp *pos, const af::dim4 &pdims,
                   const af::dim4 &ostrides, const af::dim4 &istrides, const af::dim4 &pstrides,
-                  const float offGrid, const dim_t idx, const dim_t idy)
+                  const float offGrid, const bool pBatch,
+                  const dim_t idx, const dim_t idy, const dim_t idz, const dim_t idw)
         {
-            const dim_t pmId = idx + (pdims[1] == 1 ? 0 : idy * pstrides[1]);
+            dim_t pmId = idx;
+            if(pBatch) pmId += idw * pstrides[3] + idz * pstrides[2] + idy * pstrides[1];
 
             const Tp x = pos[pmId];
             bool gFlag = false;
@@ -86,27 +87,23 @@ namespace cpu
             const dim_t grid_x = floor(x);  // nearest grid
             const Tp off_x = x - grid_x; // fractional offset
 
-            for(dim_t idw = 0; idw < odims[3]; idw++) {
-                for(dim_t idz = 0; idz < odims[2]; idz++) {
-                    const dim_t omId = idw * ostrides[3] + idz * ostrides[2]
-                                     + idy * ostrides[1] + idx;
-                    if(gFlag) {
-                        out[omId] = scalar<Ty>(offGrid);
-                    } else {
-                        dim_t ioff = idw * istrides[3] + idz * istrides[2] + idy * istrides[1] + grid_x;
-
-                        // Check if x and x + 1 are both valid indices
-                        bool cond = (x < idims[0] - 1);
-                        // Compute Left and Right Weighted Values
-                        Ty yl = ((Tp)1.0 - off_x) * in[ioff];
-                        Ty yr = cond ? (off_x) * in[ioff + 1] : scalar<Ty>(0);
-                        Ty yo = yl + yr;
-                        // Compute Weight used
-                        Tp wt = cond ? (Tp)1.0 : (Tp)(1.0 - off_x);
-                        // Write final value
-                        out[omId] = (yo / wt);
-                    }
-                }
+            const dim_t omId = idw * ostrides[3] + idz * ostrides[2]
+                             + idy * ostrides[1] + idx;
+            if(gFlag) {
+                out[omId] = scalar<Ty>(offGrid);
+            } else {
+                dim_t ioff = idw * istrides[3] + idz * istrides[2] + idy * istrides[1] + grid_x;
+
+                // Check if x and x + 1 are both valid indices
+                bool cond = (x < idims[0] - 1);
+                // Compute Left and Right Weighted Values
+                Ty yl = ((Tp)1.0 - off_x) * in[ioff];
+                Ty yr = cond ? (off_x) * in[ioff + 1] : scalar<Ty>(0);
+                Ty yo = yl + yr;
+                // Compute Weight used
+                Tp wt = cond ? (Tp)1.0 : (Tp)(1.0 - off_x);
+                // Write final value
+                out[omId] = (yo / wt);
             }
         }
     };
@@ -119,10 +116,18 @@ namespace cpu
             const float offGrid)
     {
         approx1_op<Ty, Tp, method> op;
-        for(dim_t y = 0; y < odims[1]; y++) {
-            for(dim_t x = 0; x < odims[0]; x++) {
-                op(out, odims, oElems, in, idims, iElems, pos, pdims,
-                    ostrides, istrides, pstrides, offGrid, x, y);
+        bool pBatch = false;
+        if(!(pdims[1] == 1 && pdims[2] == 1 && pdims[3] == 1))
+            pBatch = true;
+
+        for(dim_t w = 0; w < odims[3]; w++) {
+            for(dim_t z = 0; z < odims[2]; z++) {
+                for(dim_t y = 0; y < odims[1]; y++) {
+                    for(dim_t x = 0; x < odims[0]; x++) {
+                        op(out, odims, oElems, in, idims, iElems, pos, pdims,
+                           ostrides, istrides, pstrides, offGrid, pBatch, x, y, z, w);
+                    }
+                }
             }
         }
     }
@@ -167,7 +172,8 @@ namespace cpu
                   const Tp *pos, const af::dim4 &pdims, const Tp *qos, const af::dim4 &qdims,
                   const af::dim4 &ostrides, const af::dim4 &istrides,
                   const af::dim4 &pstrides, const af::dim4 &qstrides,
-                  const float offGrid, const dim_t idx, const dim_t idy, const dim_t idz)
+                  const float offGrid, const bool pBatch,
+                  const dim_t idx, const dim_t idy, const dim_t idz, const dim_t idw)
         {
             return;
         }
@@ -181,10 +187,15 @@ namespace cpu
                   const Tp *pos, const af::dim4 &pdims, const Tp *qos, const af::dim4 &qdims,
                   const af::dim4 &ostrides, const af::dim4 &istrides,
                   const af::dim4 &pstrides, const af::dim4 &qstrides,
-                  const float offGrid, const dim_t idx, const dim_t idy, const dim_t idz)
+                  const float offGrid, const bool pBatch,
+                  const dim_t idx, const dim_t idy, const dim_t idz, const dim_t idw)
         {
-            const dim_t pmId = (pdims[2] == 1 ? 0 : idz * pstrides[2]) + idy * pstrides[1] + idx;
-            const dim_t qmId = (qdims[2] == 1 ? 0 : idz * qstrides[2]) + idy * qstrides[1] + idx;
+            dim_t pmId = idy * pstrides[1] + idx;
+            dim_t qmId = idy * qstrides[1] + idx;
+            if(pBatch) {
+                pmId += idw * pstrides[3] + idz * pstrides[2];
+                qmId += idw * qstrides[3] + idz * qstrides[2];
+            }
 
             bool gFlag = false;
             const Tp x = pos[pmId], y = qos[qmId];
@@ -192,18 +203,15 @@ namespace cpu
                 gFlag = true;
             }
 
-            for(dim_t idw = 0; idw < odims[3]; idw++) {
-                const dim_t omId = idw * ostrides[3] + idz * ostrides[2]
-                                 + idy * ostrides[1] + idx;
-                if(gFlag) {
-                    out[omId] = scalar<Ty>(offGrid);
-                } else {
-                    const dim_t grid_x = round(x), grid_y = round(y); // nearest grid
-                    const dim_t imId = idw * istrides[3] +
-                                       idz * istrides[2] +
-                                       grid_y * istrides[1] + grid_x;
-                    out[omId] = in[imId];
-                }
+            const dim_t omId = idw * ostrides[3] + idz * ostrides[2]
+                             + idy * ostrides[1] + idx;
+            if(gFlag) {
+                out[omId] = scalar<Ty>(offGrid);
+            } else {
+                const dim_t grid_x = round(x), grid_y = round(y); // nearest grid
+                const dim_t imId = idw * istrides[3] + idz * istrides[2] +
+                                grid_y * istrides[1] + grid_x;
+                out[omId] = in[imId];
             }
         }
     };
@@ -216,10 +224,15 @@ namespace cpu
                   const Tp *pos, const af::dim4 &pdims, const Tp *qos, const af::dim4 &qdims,
                   const af::dim4 &ostrides, const af::dim4 &istrides,
                   const af::dim4 &pstrides, const af::dim4 &qstrides,
-                  const float offGrid, const dim_t idx, const dim_t idy, const dim_t idz)
+                  const float offGrid, const bool pBatch,
+                  const dim_t idx, const dim_t idy, const dim_t idz, const dim_t idw)
         {
-            const dim_t pmId = (pdims[2] == 1 ? 0 : idz * pstrides[2]) + idy * pstrides[1] + idx;
-            const dim_t qmId = (qdims[2] == 1 ? 0 : idz * qstrides[2]) + idy * qstrides[1] + idx;
+            dim_t pmId = idy * pstrides[1] + idx;
+            dim_t qmId = idy * qstrides[1] + idx;
+            if(pBatch) {
+                pmId += idw * pstrides[3] + idz * pstrides[2];
+                qmId += idw * qstrides[3] + idz * qstrides[2];
+            }
 
             bool gFlag = false;
             const Tp x = pos[pmId], y = qos[qmId];
@@ -243,26 +256,24 @@ namespace cpu
             Tp wt = wt00 + wt10 + wt01 + wt11;
             Ty zero = scalar<Ty>(0);
 
-            for(dim_t idw = 0; idw < odims[3]; idw++) {
-                const dim_t omId = idw * ostrides[3] + idz * ostrides[2]
-                                 + idy * ostrides[1] + idx;
-                if(gFlag) {
-                    out[omId] = scalar<Ty>(offGrid);
-                } else {
-                    dim_t ioff = idw * istrides[3] + idz * istrides[2]
-                            + grid_y * istrides[1] + grid_x;
-
-                    // Compute Weighted Values
-                    Ty y00 =                    wt00 * in[ioff];
-                    Ty y10 = (condY) ?          wt10 * in[ioff + istrides[1]]     : zero;
-                    Ty y01 = (condX) ?          wt01 * in[ioff + 1]                   : zero;
-                    Ty y11 = (condX && condY) ? wt11 * in[ioff + istrides[1] + 1] : zero;
-
-                    Ty yo = y00 + y10 + y01 + y11;
-
-                    // Write Final Value
-                    out[omId] = (yo / wt);
-                }
+            const dim_t omId = idw * ostrides[3] + idz * ostrides[2]
+                             + idy * ostrides[1] + idx;
+            if(gFlag) {
+                out[omId] = scalar<Ty>(offGrid);
+            } else {
+                dim_t ioff = idw * istrides[3] + idz * istrides[2]
+                        + grid_y * istrides[1] + grid_x;
+
+                // Compute Weighted Values
+                Ty y00 =                    wt00 * in[ioff];
+                Ty y10 = (condY) ?          wt10 * in[ioff + istrides[1]]     : zero;
+                Ty y01 = (condX) ?          wt01 * in[ioff + 1]               : zero;
+                Ty y11 = (condX && condY) ? wt11 * in[ioff + istrides[1] + 1] : zero;
+
+                Ty yo = y00 + y10 + y01 + y11;
+
+                // Write Final Value
+                out[omId] = (yo / wt);
             }
         }
     };
@@ -276,11 +287,17 @@ namespace cpu
             const float offGrid)
     {
         approx2_op<Ty, Tp, method> op;
-        for(dim_t z = 0; z < odims[2]; z++) {
-            for(dim_t y = 0; y < odims[1]; y++) {
-                for(dim_t x = 0; x < odims[0]; x++) {
-                    op(out, odims, oElems, in, idims, iElems, pos, pdims, qos, qdims,
-                            ostrides, istrides, pstrides, qstrides, offGrid, x, y, z);
+        bool pBatch = false;
+        if(!(pdims[2] == 1 && pdims[3] == 1))
+            pBatch = true;
+
+        for(dim_t w = 0; w < odims[3]; w++) {
+            for(dim_t z = 0; z < odims[2]; z++) {
+                for(dim_t y = 0; y < odims[1]; y++) {
+                    for(dim_t x = 0; x < odims[0]; x++) {
+                        op(out, odims, oElems, in, idims, iElems, pos, pdims, qos, qdims,
+                           ostrides, istrides, pstrides, qstrides, offGrid, pBatch, x, y, z, w);
+                    }
                 }
             }
         }
diff --git a/src/backend/cuda/kernel/approx.hpp b/src/backend/cuda/kernel/approx.hpp
index fae137a..89d5733 100644
--- a/src/backend/cuda/kernel/approx.hpp
+++ b/src/backend/cuda/kernel/approx.hpp
@@ -29,11 +29,12 @@ namespace cuda
         __device__ inline static
         void core_nearest1(const dim_t idx, const dim_t idy, const dim_t idz, const dim_t idw,
                            Param<Ty> out, CParam<Ty> in, CParam<Tp> pos,
-                           const float offGrid)
+                           const float offGrid, const bool pBatch)
         {
             const dim_t omId = idw * out.strides[3] + idz * out.strides[2]
                              + idy * out.strides[1] + idx;
-            const dim_t pmId = idx + (pos.dims[1] == 1 ? 0 : idy * pos.strides[1]);
+            dim_t pmId = idx;
+            if(pBatch) pmId += idw * pos.strides[3] + idz * pos.strides[2] + idy * pos.strides[1];
 
             const Tp x = pos.ptr[pmId];
             if (x < 0 || in.dims[0] < x+1) {
@@ -52,14 +53,16 @@ namespace cuda
         __device__ inline static
         void core_nearest2(const dim_t idx, const dim_t idy, const dim_t idz, const dim_t idw,
                            Param<Ty> out, CParam<Ty> in,
-                           CParam<Tp> pos, CParam<Tp> qos, const float offGrid)
+                           CParam<Tp> pos, CParam<Tp> qos, const float offGrid, const bool pBatch)
         {
             const dim_t omId = idw * out.strides[3] + idz * out.strides[2]
                              + idy * out.strides[1] + idx;
-            const dim_t pmId = (pos.dims[2] == 1 ? 0 : idz * pos.strides[2])
-                             + idy * pos.strides[1] + idx;
-            const dim_t qmId = (qos.dims[2] == 1 ? 0 : idz * qos.strides[2])
-                             + idy * qos.strides[1] + idx;
+            dim_t pmId = idy * pos.strides[1] + idx;
+            dim_t qmId = idy * qos.strides[1] + idx;
+            if(pBatch) {
+                pmId += idw * pos.strides[3] + idz * pos.strides[2];
+                qmId += idw * qos.strides[3] + idz * qos.strides[2];
+            }
 
             const Tp x = pos.ptr[pmId], y = qos.ptr[qmId];
             if (x < 0 || y < 0 || in.dims[0] < x+1 || in.dims[1] < y+1) {
@@ -82,11 +85,12 @@ namespace cuda
         __device__ inline static
         void core_linear1(const dim_t idx, const dim_t idy, const dim_t idz, const dim_t idw,
                           Param<Ty> out, CParam<Ty> in, CParam<Tp> pos,
-                          const float offGrid)
+                          const float offGrid, const bool pBatch)
         {
             const dim_t omId = idw * out.strides[3] + idz * out.strides[2]
                              + idy * out.strides[1] + idx;
-            const dim_t pmId = idx + (pos.dims[1] == 1 ? 0 : idy * pos.strides[1]);
+            dim_t pmId = idx;
+            if(pBatch) pmId += idw * pos.strides[3] + idz * pos.strides[2] + idy * pos.strides[1];
 
             const Tp pVal = pos.ptr[pmId];
             if (pVal < 0 || in.dims[0] < pVal+1) {
@@ -115,14 +119,17 @@ namespace cuda
         __device__ inline static
         void core_linear2(const dim_t idx, const dim_t idy, const dim_t idz, const dim_t idw,
                            Param<Ty> out, CParam<Ty> in,
-                           CParam<Tp> pos, CParam<Tp> qos, const float offGrid)
+                           CParam<Tp> pos, CParam<Tp> qos, const float offGrid, const bool pBatch)
         {
             const dim_t omId = idw * out.strides[3] + idz * out.strides[2]
                              + idy * out.strides[1] + idx;
-            const dim_t pmId = (pos.dims[2] == 1 ? 0 : idz * pos.strides[2])
-                             + idy * pos.strides[1] + idx;
-            const dim_t qmId = (qos.dims[2] == 1 ? 0 : idz * qos.strides[2])
-                             + idy * qos.strides[1] + idx;
+            dim_t pmId = idy * pos.strides[1] + idx;
+            dim_t qmId = idy * qos.strides[1] + idx;
+            if(pBatch) {
+                pmId += idw * pos.strides[3] + idz * pos.strides[2];
+                qmId += idw * qos.strides[3] + idz * qos.strides[2];
+            }
+
 
             const Tp x = pos.ptr[pmId], y = qos.ptr[qmId];
             if (x < 0 || y < 0 || in.dims[0] < x+1 || in.dims[1] < y+1) {
@@ -165,7 +172,7 @@ namespace cuda
         template<typename Ty, typename Tp, af_interp_type method>
         __global__
         void approx1_kernel(Param<Ty> out, CParam<Ty> in, CParam<Tp> pos,
-                            const float offGrid, const dim_t blocksMatX)
+                            const float offGrid, const dim_t blocksMatX, const bool pBatch)
         {
             const dim_t idw = blockIdx.y / out.dims[2];
             const dim_t idz = blockIdx.y - idw * out.dims[2];
@@ -180,10 +187,10 @@ namespace cuda
 
             switch(method) {
                 case AF_INTERP_NEAREST:
-                    core_nearest1(idx, idy, idz, idw, out, in, pos, offGrid);
+                    core_nearest1(idx, idy, idz, idw, out, in, pos, offGrid, pBatch);
                     break;
                 case AF_INTERP_LINEAR:
-                    core_linear1(idx, idy, idz, idw, out, in, pos, offGrid);
+                    core_linear1(idx, idy, idz, idw, out, in, pos, offGrid, pBatch);
                     break;
                 default:
                     break;
@@ -194,7 +201,7 @@ namespace cuda
         __global__
         void approx2_kernel(Param<Ty> out, CParam<Ty> in,
                       CParam<Tp> pos, CParam<Tp> qos, const float offGrid,
-                      const dim_t blocksMatX, const dim_t blocksMatY)
+                      const dim_t blocksMatX, const dim_t blocksMatY, const bool pBatch)
         {
             const dim_t idz = blockIdx.x / blocksMatX;
             const dim_t idw = blockIdx.y / blocksMatY;
@@ -211,10 +218,10 @@ namespace cuda
 
             switch(method) {
                 case AF_INTERP_NEAREST:
-                    core_nearest2(idx, idy, idz, idw, out, in, pos, qos, offGrid);
+                    core_nearest2(idx, idy, idz, idw, out, in, pos, qos, offGrid, pBatch);
                     break;
                 case AF_INTERP_LINEAR:
-                    core_linear2(idx, idy, idz, idw, out, in, pos, qos, offGrid);
+                    core_linear2(idx, idy, idz, idw, out, in, pos, qos, offGrid, pBatch);
                     break;
                 default:
                     break;
@@ -232,8 +239,12 @@ namespace cuda
             dim_t blocksPerMat = divup(out.dims[0], threads.x);
             dim3 blocks(blocksPerMat * out.dims[1], out.dims[2] * out.dims[3]);
 
+            bool pBatch = false;
+            if(!(pos.dims[1] == 1 && pos.dims[2] == 1 && pos.dims[3] == 1))
+                pBatch = true;
+
             CUDA_LAUNCH((approx1_kernel<Ty, Tp, method>), blocks, threads,
-                    out, in, pos, offGrid, blocksPerMat);
+                         out, in, pos, offGrid, blocksPerMat, pBatch);
             POST_LAUNCH_CHECK();
         }
 
@@ -246,8 +257,12 @@ namespace cuda
             dim_t blocksPerMatY = divup(out.dims[1], threads.y);
             dim3 blocks(blocksPerMatX * out.dims[2], blocksPerMatY * out.dims[3]);
 
+            bool pBatch = false;
+            if(!(pos.dims[2] == 1 && pos.dims[3] == 1))
+                pBatch = true;
+
             CUDA_LAUNCH((approx2_kernel<Ty, Tp, method>), blocks, threads,
-                    out, in, pos, qos, offGrid, blocksPerMatX, blocksPerMatY);
+                         out, in, pos, qos, offGrid, blocksPerMatX, blocksPerMatY, pBatch);
             POST_LAUNCH_CHECK();
         }
     }
diff --git a/src/backend/opencl/kernel/approx.hpp b/src/backend/opencl/kernel/approx.hpp
index f893097..c12f8b2 100644
--- a/src/backend/opencl/kernel/approx.hpp
+++ b/src/backend/opencl/kernel/approx.hpp
@@ -87,7 +87,7 @@ namespace opencl
 
 
                 auto approx1Op = make_kernel<Buffer, const KParam, const Buffer, const KParam,
-                                       const Buffer, const KParam, const float, const dim_t>
+                                       const Buffer, const KParam, const float, const dim_t, const int>
                                       (*approxKernels[device]);
 
                 NDRange local(THREADS, 1, 1);
@@ -96,9 +96,14 @@ namespace opencl
                                out.info.dims[2] * out.info.dims[3] * local[0],
                                1);
 
+                // Passing bools to opencl kernels is not allowed
+                int pBatch = 0;
+                if(!(pos.info.dims[1] == 1 && pos.info.dims[2] == 1 && pos.info.dims[3] == 1))
+                    pBatch = 1;
+
                 approx1Op(EnqueueArgs(getQueue(), global, local),
                           *out.data, out.info, *in.data, in.info,
-                          *pos.data, pos.info, offGrid, blocksPerMat);
+                          *pos.data, pos.info, offGrid, blocksPerMat, pBatch);
 
                 CL_DEBUG_FINISH(getQueue());
             } catch (cl::Error err) {
@@ -152,7 +157,7 @@ namespace opencl
 
                 auto approx2Op = make_kernel<Buffer, const KParam, const Buffer, const KParam,
                                        const Buffer, const KParam, const Buffer, const KParam,
-                                       const float, const dim_t, const dim_t>
+                                       const float, const dim_t, const dim_t, const int>
                                        (*approxKernels[device]);
 
                 NDRange local(TX, TY, 1);
@@ -162,13 +167,17 @@ namespace opencl
                                blocksPerMatY * local[1] * out.info.dims[3],
                                1);
 
+                // Passing bools to opencl kernels is not allowed
+                int pBatch = 0;
+                if(!(pos.info.dims[2] == 1 && pos.info.dims[3] == 1))
+                    pBatch = 1;
 
                 approx2Op(EnqueueArgs(getQueue(), global, local),
                           *out.data, out.info,
                           *in.data, in.info,
                           *pos.data, pos.info,
                           *qos.data, qos.info,
-                          offGrid, blocksPerMatX, blocksPerMatY);
+                          offGrid, blocksPerMatX, blocksPerMatY, pBatch);
                 CL_DEBUG_FINISH(getQueue());
             } catch (cl::Error err) {
                 CL_TO_AF_ERROR(err);
diff --git a/src/backend/opencl/kernel/approx1.cl b/src/backend/opencl/kernel/approx1.cl
index 41acb04..5693fc3 100644
--- a/src/backend/opencl/kernel/approx1.cl
+++ b/src/backend/opencl/kernel/approx1.cl
@@ -36,11 +36,12 @@ void core_nearest1(const dim_t idx, const dim_t idy, const dim_t idz, const dim_
                    __global       Ty *d_out, const KParam out,
                    __global const Ty *d_in,  const KParam in,
                    __global const Tp *d_pos, const KParam pos,
-                   const float offGrid)
+                   const float offGrid, const bool pBatch)
 {
     const dim_t omId = idw * out.strides[3] + idz * out.strides[2]
                      + idy * out.strides[1] + idx;
-    const dim_t pmId = idx + (pos.dims[1] == 1 ? 0 : idy * pos.strides[1]);
+    dim_t pmId = idx;
+    if(pBatch) pmId += idw * pos.strides[3] + idz * pos.strides[2] + idy * pos.strides[1];
 
     const Tp pVal = d_pos[pmId];
     if (pVal < 0 || in.dims[0] < pVal+1) {
@@ -63,11 +64,12 @@ void core_linear1(const dim_t idx, const dim_t idy, const dim_t idz, const dim_t
                    __global       Ty *d_out, const KParam out,
                    __global const Ty *d_in,  const KParam in,
                    __global const Tp *d_pos, const KParam pos,
-                   const float offGrid)
+                   const float offGrid, const bool pBatch)
 {
     const dim_t omId = idw * out.strides[3] + idz * out.strides[2]
                      + idy * out.strides[1] + idx;
-    const dim_t pmId = idx + (pos.dims[1] == 1 ? 0 : idy * pos.strides[1]);
+    dim_t pmId = idx;
+    if(pBatch) pmId += idw * pos.strides[3] + idz * pos.strides[2] + idy * pos.strides[1];
 
     const Tp pVal = d_pos[pmId];
     if (pVal < 0 || in.dims[0] < pVal+1) {
@@ -104,7 +106,7 @@ __kernel
 void approx1_kernel(__global       Ty *d_out, const KParam out,
                     __global const Ty *d_in,  const KParam in,
                     __global const Tp *d_pos, const KParam pos,
-                    const float offGrid, const dim_t blocksMatX)
+                    const float offGrid, const dim_t blocksMatX, const int pBatch)
 {
     const dim_t idw = get_group_id(1) / out.dims[2];
     const dim_t idz = get_group_id(1)  - idw * out.dims[2];
@@ -119,5 +121,5 @@ void approx1_kernel(__global       Ty *d_out, const KParam out,
        idw >= out.dims[3])
         return;
 
-    INTERP(idx, idy, idz, idw, d_out, out, d_in + in.offset, in, d_pos + pos.offset, pos, offGrid);
+    INTERP(idx, idy, idz, idw, d_out, out, d_in + in.offset, in, d_pos + pos.offset, pos, offGrid, pBatch);
 }
diff --git a/src/backend/opencl/kernel/approx2.cl b/src/backend/opencl/kernel/approx2.cl
index 4db2508..1066f55 100644
--- a/src/backend/opencl/kernel/approx2.cl
+++ b/src/backend/opencl/kernel/approx2.cl
@@ -37,14 +37,16 @@ void core_nearest2(const dim_t idx, const dim_t idy, const dim_t idz, const dim_
                    __global const Ty *d_in,  const KParam in,
                    __global const Tp *d_pos, const KParam pos,
                    __global const Tp *d_qos, const KParam qos,
-                   const float offGrid)
+                   const float offGrid, const bool pBatch)
 {
     const dim_t omId = idw * out.strides[3] + idz * out.strides[2]
                      + idy * out.strides[1] + idx;
-    const dim_t pmId = (pos.dims[2] == 1 ? 0 : idz * pos.strides[2])
-                     + idy * pos.strides[1] + idx;
-    const dim_t qmId = (qos.dims[2] == 1 ? 0 : idz * qos.strides[2])
-                     + idy * qos.strides[1] + idx;
+    dim_t pmId = idy * pos.strides[1] + idx;
+    dim_t qmId = idy * qos.strides[1] + idx;
+    if(pBatch) {
+        pmId += idw * pos.strides[3] + idz * pos.strides[2];
+        qmId += idw * qos.strides[3] + idz * qos.strides[2];
+    }
 
     const Tp x = d_pos[pmId], y = d_qos[qmId];
     if (x < 0 || y < 0 || in.dims[0] < x+1 || in.dims[1] < y+1) {
@@ -69,14 +71,16 @@ void core_linear2(const dim_t idx, const dim_t idy, const dim_t idz, const dim_t
                   __global const Ty *d_in,  const KParam in,
                   __global const Tp *d_pos, const KParam pos,
                   __global const Tp *d_qos, const KParam qos,
-                  const float offGrid)
+                  const float offGrid, const bool pBatch)
 {
     const dim_t omId = idw * out.strides[3] + idz * out.strides[2]
                      + idy * out.strides[1] + idx;
-    const dim_t pmId = (pos.dims[2] == 1 ? 0 : idz * pos.strides[2])
-                     + idy * pos.strides[1] + idx;
-    const dim_t qmId = (qos.dims[2] == 1 ? 0 : idz * qos.strides[2])
-                     + idy * qos.strides[1] + idx;
+    dim_t pmId = idy * pos.strides[1] + idx;
+    dim_t qmId = idy * qos.strides[1] + idx;
+    if(pBatch) {
+        pmId += idw * pos.strides[3] + idz * pos.strides[2];
+        qmId += idw * qos.strides[3] + idz * qos.strides[2];
+    }
 
     const Tp x = d_pos[pmId], y = d_qos[qmId];
     if (x < 0 || y < 0 || in.dims[0] < x+1 || in.dims[1] < y+1) {
@@ -122,7 +126,8 @@ void approx2_kernel(__global       Ty *d_out, const KParam out,
                     __global const Ty *d_in,  const KParam in,
                     __global const Tp *d_pos, const KParam pos,
                     __global const Tp *d_qos, const KParam qos,
-                    const float offGrid, const dim_t blocksMatX, const dim_t blocksMatY)
+                    const float offGrid, const dim_t blocksMatX, const dim_t blocksMatY,
+                    const int pBatch)
 {
     const dim_t idz = get_group_id(0) / blocksMatX;
     const dim_t idw = get_group_id(1) / blocksMatY;
@@ -140,5 +145,5 @@ void approx2_kernel(__global       Ty *d_out, const KParam out,
         return;
 
     INTERP(idx, idy, idz, idw, d_out, out, d_in + in.offset, in,
-           d_pos + pos.offset, pos, d_qos + qos.offset, qos, offGrid);
+           d_pos + pos.offset, pos, d_qos + qos.offset, qos, offGrid, pBatch);
 }
diff --git a/test/approx1.cpp b/test/approx1.cpp
index e0350d8..59101b2 100644
--- a/test/approx1.cpp
+++ b/test/approx1.cpp
@@ -248,13 +248,13 @@ TEST(Approx1, CPPNearestBatch)
         outSerial(af::span, i) = af::approx1(input(af::span, i), pos(af::span, i), AF_INTERP_NEAREST);
     }
 
-    //af::array outGFOR(pos.dims());
-    //gfor(af::seq i, 10) {
-    //    outGFOR(af::span, i) = af::approx1(input(af::span, i), pos(af::span, i), AF_INTERP_NEAREST);
-    //}
+    af::array outGFOR(pos.dims());
+    gfor(af::seq i, 10) {
+        outGFOR(af::span, i) = af::approx1(input(af::span, i), pos(af::span, i), AF_INTERP_NEAREST);
+    }
 
     ASSERT_NEAR(0, af::sum<double>(af::abs(outBatch - outSerial)), 1e-3);
-    //ASSERT_NEAR(0, af::sum<double>(af::abs(outBatch - outGFOR)), 1e-3);
+    ASSERT_NEAR(0, af::sum<double>(af::abs(outBatch - outGFOR)), 1e-3);
 }
 
 TEST(Approx1, CPPLinearBatch)
@@ -271,11 +271,11 @@ TEST(Approx1, CPPLinearBatch)
         outSerial(af::span, i) = af::approx1(input(af::span, i), pos(af::span, i), AF_INTERP_LINEAR);
     }
 
-    //af::array outGFOR(pos.dims());
-    //gfor(af::seq i, 10) {
-    //    outGFOR(af::span, i) = af::approx1(input(af::span, i), pos(af::span, i), AF_INTERP_LINEAR);
-    //}
+    af::array outGFOR(pos.dims());
+    gfor(af::seq i, 10) {
+        outGFOR(af::span, i) = af::approx1(input(af::span, i), pos(af::span, i), AF_INTERP_LINEAR);
+    }
 
     ASSERT_NEAR(0, af::sum<double>(af::abs(outBatch - outSerial)), 1e-3);
-    //ASSERT_NEAR(0, af::sum<double>(af::abs(outBatch - outGFOR)), 1e-3);
+    ASSERT_NEAR(0, af::sum<double>(af::abs(outBatch - outGFOR)), 1e-3);
 }
diff --git a/test/approx2.cpp b/test/approx2.cpp
index fc2c87f..3cfd3ea 100644
--- a/test/approx2.cpp
+++ b/test/approx2.cpp
@@ -265,14 +265,14 @@ TEST(Approx2, CPPNearestBatch)
             pos(af::span, af::span, i), qos(af::span, af::span, i), AF_INTERP_NEAREST);
     }
 
-    //af::array outGFOR(pos.dims());
-    //gfor(af::seq i, 10) {
-    //    outGFOR(af::span, af::span, i) = af::approx2(input(af::span, af::span, i),
-    //        pos(af::span, af::span, i), qos(af::span, af::span, i), AF_INTERP_NEAREST);
-    //}
+    af::array outGFOR(pos.dims());
+    gfor(af::seq i, 10) {
+        outGFOR(af::span, af::span, i) = af::approx2(input(af::span, af::span, i),
+            pos(af::span, af::span, i), qos(af::span, af::span, i), AF_INTERP_NEAREST);
+    }
 
     ASSERT_NEAR(0, af::sum<double>(af::abs(outBatch - outSerial)), 1e-3);
-    //ASSERT_NEAR(0, af::sum<double>(af::abs(outBatch - outGFOR)), 1e-3);
+    ASSERT_NEAR(0, af::sum<double>(af::abs(outBatch - outGFOR)), 1e-3);
 }
 
 TEST(Approx2, CPPLinearBatch)
@@ -291,12 +291,12 @@ TEST(Approx2, CPPLinearBatch)
             pos(af::span, af::span, i), qos(af::span, af::span, i), AF_INTERP_LINEAR);
     }
 
-    //af::array outGFOR(pos.dims());
-    //gfor(af::seq i, 10) {
-    //    outGFOR(af::span, af::span, i) = af::approx2(input(af::span, af::span, i),
-    //        pos(af::span, af::span, i), qos(af::span, af::span, i), AF_INTERP_LINEAR);
-    //}
+    af::array outGFOR(pos.dims());
+    gfor(af::seq i, 10) {
+        outGFOR(af::span, af::span, i) = af::approx2(input(af::span, af::span, i),
+            pos(af::span, af::span, i), qos(af::span, af::span, i), AF_INTERP_LINEAR);
+    }
 
     ASSERT_NEAR(0, af::sum<double>(af::abs(outBatch - outSerial)), 1e-3);
-    //ASSERT_NEAR(0, af::sum<double>(af::abs(outBatch - outGFOR)), 1e-3);
+    ASSERT_NEAR(0, af::sum<double>(af::abs(outBatch - outGFOR)), 1e-3);
 }

-- 
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/arrayfire.git



More information about the debian-science-commits mailing list