[arrayfire] 26/248: FEAT Added batch support for approx1 and approx2

Ghislain Vaillant ghisvail-guest at moszumanska.debian.org
Tue Nov 17 15:53:50 UTC 2015


This is an automated email from the git hooks/post-receive script.

ghisvail-guest pushed a commit to branch dfsg-clean
in repository arrayfire.

commit 8b94ac1e2162a24cc35ae7009abd9ae4fe86d354
Author: Shehzan Mohammed <shehzan at arrayfire.com>
Date:   Mon Aug 31 15:45:21 2015 -0400

    FEAT Added batch support for approx1 and approx2
    
    * Added tests
    * TODO Enable tests with gfor
---
 src/api/c/approx.cpp                 |   8 +-
 src/backend/cpu/approx.cpp           | 163 +++++++++++++++++------------------
 src/backend/cuda/kernel/approx.hpp   |  26 +++---
 src/backend/opencl/kernel/approx1.cl |   6 +-
 src/backend/opencl/kernel/approx2.cl |  14 +--
 test/approx1.cpp                     |  46 ++++++++++
 test/approx2.cpp                     |  52 +++++++++++
 7 files changed, 210 insertions(+), 105 deletions(-)

diff --git a/src/api/c/approx.cpp b/src/api/c/approx.cpp
index 1bc7723..c0bb02c 100644
--- a/src/api/c/approx.cpp
+++ b/src/api/c/approx.cpp
@@ -41,13 +41,16 @@ af_err af_approx1(af_array *out, const af_array in, const af_array pos,
         ArrayInfo i_info = getInfo(in);
         ArrayInfo p_info = getInfo(pos);
 
+        dim4 idims = i_info.dims();
+        dim4 pdims = p_info.dims();
+
         af_dtype itype = i_info.getType();
 
         ARG_ASSERT(1, i_info.isFloating());                       // Only floating and complex types
         ARG_ASSERT(2, p_info.isRealFloating());                   // Only floating types
         ARG_ASSERT(1, i_info.isSingle() == p_info.isSingle());    // Must have same precision
         ARG_ASSERT(1, i_info.isDouble() == p_info.isDouble());    // Must have same precision
-        DIM_ASSERT(2, p_info.isColumn());                         // Only 1D input allowed
+        DIM_ASSERT(2, p_info.isColumn() || pdims[1] == idims[1]); // Only 1D input allowed or Same no. of cols
         ARG_ASSERT(3, (method == AF_INTERP_LINEAR || method == AF_INTERP_NEAREST));
 
         af_array output;
@@ -83,7 +86,8 @@ af_err af_approx2(af_array *out, const af_array in, const af_array pos0, const a
         ARG_ASSERT(1, i_info.isSingle() == p_info.isSingle());    // Must have same precision
         ARG_ASSERT(1, i_info.isDouble() == p_info.isDouble());    // Must have same precision
         DIM_ASSERT(2, p_info.dims() == q_info.dims());            // POS0 and POS1 must have same dims
-        DIM_ASSERT(2, p_info.ndims() < 3);// Allowing input batch but not positions. Output dims = (px, py, iz, iw)
+        DIM_ASSERT(2, p_info.dims()[2] == 1
+                   || p_info.dims()[2] == i_info.dims()[2]);      // Allowing input batch. Output dims = (px, py, iz, iw)
         ARG_ASSERT(3, (method == AF_INTERP_LINEAR || method == AF_INTERP_NEAREST));
 
         af_array output;
diff --git a/src/backend/cpu/approx.cpp b/src/backend/cpu/approx.cpp
index 69b943a..1522341 100644
--- a/src/backend/cpu/approx.cpp
+++ b/src/backend/cpu/approx.cpp
@@ -25,7 +25,7 @@ namespace cpu
                   const Ty *in,  const af::dim4 &idims, const dim_t iElems,
                   const Tp *pos, const af::dim4 &pdims,
                   const af::dim4 &ostrides, const af::dim4 &istrides, const af::dim4 &pstrides,
-                  const float offGrid, const dim_t idx)
+                  const float offGrid, const dim_t idx, const dim_t idy)
         {
             return;
         }
@@ -38,30 +38,28 @@ namespace cpu
                   const Ty *in,  const af::dim4 &idims, const dim_t iElems,
                   const Tp *pos, const af::dim4 &pdims,
                   const af::dim4 &ostrides, const af::dim4 &istrides, const af::dim4 &pstrides,
-                  const float offGrid, const dim_t idx)
+                  const float offGrid, const dim_t idx, const dim_t idy)
         {
-            const dim_t pmId = idx;
+            const dim_t pmId = idx + (pdims[1] == 1 ? 0 : idy * pstrides[1]);
 
             const Tp x = pos[pmId];
             bool gFlag = false;
-            if (x < 0 || idims[0] < x+1) {
+            if (x < 0 || idims[0] < x+1) {  // No need to check y
                 gFlag = true;
             }
 
             for(dim_t idw = 0; idw < odims[3]; idw++) {
                 for(dim_t idz = 0; idz < odims[2]; idz++) {
-                    for(dim_t idy = 0; idy < odims[1]; idy++) {
-                        const dim_t omId = idw * ostrides[3] + idz * ostrides[2]
-                                            + idy * ostrides[1] + idx;
-                        if(gFlag) {
-                            out[omId] = scalar<Ty>(offGrid);
-                        } else {
-                            dim_t ioff = idw * istrides[3] + idz * istrides[2]
-                                          + idy * istrides[1];
-                            const dim_t iMem = round(x) + ioff;
-
-                            out[omId] = in[iMem];
-                        }
+                    const dim_t omId = idw * ostrides[3] + idz * ostrides[2]
+                                     + idy * ostrides[1] + idx;
+                    if(gFlag) {
+                        out[omId] = scalar<Ty>(offGrid);
+                    } else {
+                        dim_t ioff = idw * istrides[3] + idz * istrides[2]
+                                   + idy * istrides[1];
+                        const dim_t iMem = round(x) + ioff;
+
+                        out[omId] = in[iMem];
                     }
                 }
             }
@@ -75,9 +73,9 @@ namespace cpu
                   const Ty *in,  const af::dim4 &idims, const dim_t iElems,
                   const Tp *pos, const af::dim4 &pdims,
                   const af::dim4 &ostrides, const af::dim4 &istrides, const af::dim4 &pstrides,
-                  const float offGrid, const dim_t idx)
+                  const float offGrid, const dim_t idx, const dim_t idy)
         {
-            const dim_t pmId = idx;
+            const dim_t pmId = idx + (pdims[1] == 1 ? 0 : idy * pstrides[1]);
 
             const Tp x = pos[pmId];
             bool gFlag = false;
@@ -90,25 +88,23 @@ namespace cpu
 
             for(dim_t idw = 0; idw < odims[3]; idw++) {
                 for(dim_t idz = 0; idz < odims[2]; idz++) {
-                    for(dim_t idy = 0; idy < odims[1]; idy++) {
-                        const dim_t omId = idw * ostrides[3] + idz * ostrides[2]
-                                            + idy * ostrides[1] + idx;
-                        if(gFlag) {
-                            out[omId] = scalar<Ty>(offGrid);
-                        } else {
-                            dim_t ioff = idw * istrides[3] + idz * istrides[2] + idy * istrides[1] + grid_x;
-
-                            // Check if x and x + 1 are both valid indices
-                            bool cond = (x < idims[0] - 1);
-                            // Compute Left and Right Weighted Values
-                            Ty yl = ((Tp)1.0 - off_x) * in[ioff];
-                            Ty yr = cond ? (off_x) * in[ioff + 1] : scalar<Ty>(0);
-                            Ty yo = yl + yr;
-                            // Compute Weight used
-                            Tp wt = cond ? (Tp)1.0 : (Tp)(1.0 - off_x);
-                            // Write final value
-                            out[omId] = (yo / wt);
-                        }
+                    const dim_t omId = idw * ostrides[3] + idz * ostrides[2]
+                                     + idy * ostrides[1] + idx;
+                    if(gFlag) {
+                        out[omId] = scalar<Ty>(offGrid);
+                    } else {
+                        dim_t ioff = idw * istrides[3] + idz * istrides[2] + idy * istrides[1] + grid_x;
+
+                        // Check if x and x + 1 are both valid indices
+                        bool cond = (x < idims[0] - 1);
+                        // Compute Left and Right Weighted Values
+                        Ty yl = ((Tp)1.0 - off_x) * in[ioff];
+                        Ty yr = cond ? (off_x) * in[ioff + 1] : scalar<Ty>(0);
+                        Ty yo = yl + yr;
+                        // Compute Weight used
+                        Tp wt = cond ? (Tp)1.0 : (Tp)(1.0 - off_x);
+                        // Write final value
+                        out[omId] = (yo / wt);
                     }
                 }
             }
@@ -123,9 +119,11 @@ namespace cpu
             const float offGrid)
     {
         approx1_op<Ty, Tp, method> op;
-        for(dim_t x = 0; x < odims[0]; x++) {
-            op(out, odims, oElems, in, idims, iElems, pos, pdims,
-               ostrides, istrides, pstrides, offGrid, x);
+        for(dim_t y = 0; y < odims[1]; y++) {
+            for(dim_t x = 0; x < odims[0]; x++) {
+                op(out, odims, oElems, in, idims, iElems, pos, pdims,
+                    ostrides, istrides, pstrides, offGrid, x, y);
+            }
         }
     }
 
@@ -169,7 +167,7 @@ namespace cpu
                   const Tp *pos, const af::dim4 &pdims, const Tp *qos, const af::dim4 &qdims,
                   const af::dim4 &ostrides, const af::dim4 &istrides,
                   const af::dim4 &pstrides, const af::dim4 &qstrides,
-                  const float offGrid, const dim_t idx, const dim_t idy)
+                  const float offGrid, const dim_t idx, const dim_t idy, const dim_t idz)
         {
             return;
         }
@@ -183,10 +181,10 @@ namespace cpu
                   const Tp *pos, const af::dim4 &pdims, const Tp *qos, const af::dim4 &qdims,
                   const af::dim4 &ostrides, const af::dim4 &istrides,
                   const af::dim4 &pstrides, const af::dim4 &qstrides,
-                  const float offGrid, const dim_t idx, const dim_t idy)
+                  const float offGrid, const dim_t idx, const dim_t idy, const dim_t idz)
         {
-            const dim_t pmId = idy * pstrides[1] + idx;
-            const dim_t qmId = idy * qstrides[1] + idx;
+            const dim_t pmId = (pdims[2] == 1 ? 0 : idz * pstrides[2]) + idy * pstrides[1] + idx;
+            const dim_t qmId = (qdims[2] == 1 ? 0 : idz * qstrides[2]) + idy * qstrides[1] + idx;
 
             bool gFlag = false;
             const Tp x = pos[pmId], y = qos[qmId];
@@ -195,18 +193,16 @@ namespace cpu
             }
 
             for(dim_t idw = 0; idw < odims[3]; idw++) {
-                for(dim_t idz = 0; idz < odims[2]; idz++) {
-                    const dim_t omId = idw * ostrides[3] + idz * ostrides[2]
-                                        + idy * ostrides[1] + idx;
-                    if(gFlag) {
-                        out[omId] = scalar<Ty>(offGrid);
-                    } else {
-                        const dim_t grid_x = round(x), grid_y = round(y); // nearest grid
-                        const dim_t imId = idw * istrides[3] +
-                                              idz * istrides[2] +
-                                              grid_y * istrides[1] + grid_x;
-                        out[omId] = in[imId];
-                    }
+                const dim_t omId = idw * ostrides[3] + idz * ostrides[2]
+                                 + idy * ostrides[1] + idx;
+                if(gFlag) {
+                    out[omId] = scalar<Ty>(offGrid);
+                } else {
+                    const dim_t grid_x = round(x), grid_y = round(y); // nearest grid
+                    const dim_t imId = idw * istrides[3] +
+                                       idz * istrides[2] +
+                                       grid_y * istrides[1] + grid_x;
+                    out[omId] = in[imId];
                 }
             }
         }
@@ -220,10 +216,10 @@ namespace cpu
                   const Tp *pos, const af::dim4 &pdims, const Tp *qos, const af::dim4 &qdims,
                   const af::dim4 &ostrides, const af::dim4 &istrides,
                   const af::dim4 &pstrides, const af::dim4 &qstrides,
-                  const float offGrid, const dim_t idx, const dim_t idy)
+                  const float offGrid, const dim_t idx, const dim_t idy, const dim_t idz)
         {
-            const dim_t pmId = idy * pstrides[1] + idx;
-            const dim_t qmId = idy * qstrides[1] + idx;
+            const dim_t pmId = (pdims[2] == 1 ? 0 : idz * pstrides[2]) + idy * pstrides[1] + idx;
+            const dim_t qmId = (qdims[2] == 1 ? 0 : idz * qstrides[2]) + idy * qstrides[1] + idx;
 
             bool gFlag = false;
             const Tp x = pos[pmId], y = qos[qmId];
@@ -248,27 +244,24 @@ namespace cpu
             Ty zero = scalar<Ty>(0);
 
             for(dim_t idw = 0; idw < odims[3]; idw++) {
-                for(dim_t idz = 0; idz < odims[2]; idz++) {
-                    const dim_t omId = idw * ostrides[3] + idz * ostrides[2]
-                                        + idy * ostrides[1] + idx;
-                    if(gFlag) {
-                        out[omId] = scalar<Ty>(offGrid);
-                    } else {
-                        dim_t ioff = idw * istrides[3] + idz * istrides[2]
-                                   + grid_y * istrides[1] + grid_x;
-
-                        // Compute Weighted Values
-                        Ty y00 =                    wt00 * in[ioff];
-                        Ty y10 = (condY) ?          wt10 * in[ioff + istrides[1]]     : zero;
-                        Ty y01 = (condX) ?          wt01 * in[ioff + 1]                   : zero;
-                        Ty y11 = (condX && condY) ? wt11 * in[ioff + istrides[1] + 1] : zero;
-
-                        Ty yo = y00 + y10 + y01 + y11;
-
-                        // Write Final Value
-                        out[omId] = (yo / wt);
-
-                    }
+                const dim_t omId = idw * ostrides[3] + idz * ostrides[2]
+                                 + idy * ostrides[1] + idx;
+                if(gFlag) {
+                    out[omId] = scalar<Ty>(offGrid);
+                } else {
+                    dim_t ioff = idw * istrides[3] + idz * istrides[2]
+                            + grid_y * istrides[1] + grid_x;
+
+                    // Compute Weighted Values
+                    Ty y00 =                    wt00 * in[ioff];
+                    Ty y10 = (condY) ?          wt10 * in[ioff + istrides[1]]     : zero;
+                    Ty y01 = (condX) ?          wt01 * in[ioff + 1]                   : zero;
+                    Ty y11 = (condX && condY) ? wt11 * in[ioff + istrides[1] + 1] : zero;
+
+                    Ty yo = y00 + y10 + y01 + y11;
+
+                    // Write Final Value
+                    out[omId] = (yo / wt);
                 }
             }
         }
@@ -283,10 +276,12 @@ namespace cpu
             const float offGrid)
     {
         approx2_op<Ty, Tp, method> op;
-        for(dim_t y = 0; y < odims[1]; y++) {
-            for(dim_t x = 0; x < odims[0]; x++) {
-                op(out, odims, oElems, in, idims, iElems, pos, pdims, qos, qdims,
-                    ostrides, istrides, pstrides, qstrides, offGrid, x, y);
+        for(dim_t z = 0; z < odims[2]; z++) {
+            for(dim_t y = 0; y < odims[1]; y++) {
+                for(dim_t x = 0; x < odims[0]; x++) {
+                    op(out, odims, oElems, in, idims, iElems, pos, pdims, qos, qdims,
+                            ostrides, istrides, pstrides, qstrides, offGrid, x, y, z);
+                }
             }
         }
     }
diff --git a/src/backend/cuda/kernel/approx.hpp b/src/backend/cuda/kernel/approx.hpp
index 6c9dd7d..ced6c4f 100644
--- a/src/backend/cuda/kernel/approx.hpp
+++ b/src/backend/cuda/kernel/approx.hpp
@@ -32,8 +32,8 @@ namespace cuda
                            const float offGrid)
         {
             const int omId = idw * out.strides[3] + idz * out.strides[2]
-                                + idy * out.strides[1] + idx;
-            const int pmId = idx;
+                           + idy * out.strides[1] + idx;
+            const int pmId = idx + (pos.dims[1] == 1 ? 0 : idy * pos.strides[1]);
 
             const Tp x = pos.ptr[pmId];
             if (x < 0 || in.dims[0] < x+1) {
@@ -55,9 +55,11 @@ namespace cuda
                            CParam<Tp> pos, CParam<Tp> qos, const float offGrid)
         {
             const int omId = idw * out.strides[3] + idz * out.strides[2]
-                                + idy * out.strides[1] + idx;
-            const int pmId = idy * pos.strides[1] + idx;
-            const int qmId = idy * qos.strides[1] + idx;
+                           + idy * out.strides[1] + idx;
+            const int pmId = (pos.dims[2] == 1 ? 0 : idz * pos.strides[2])
+                            + idy * pos.strides[1] + idx;
+            const int qmId = (qos.dims[2] == 1 ? 0 : idz * qos.strides[2])
+                            + idy * qos.strides[1] + idx;
 
             const Tp x = pos.ptr[pmId], y = qos.ptr[qmId];
             if (x < 0 || y < 0 || in.dims[0] < x+1 || in.dims[1] < y+1) {
@@ -67,7 +69,7 @@ namespace cuda
 
             const int grid_x = round(x), grid_y = round(y); // nearest grid
             const int imId = idw * in.strides[3] + idz * in.strides[2]
-                             + grid_y * in.strides[1] + grid_x;
+                        + grid_y * in.strides[1] + grid_x;
 
             Ty val = in.ptr[imId];
             out.ptr[omId] = val;
@@ -83,8 +85,8 @@ namespace cuda
                           const float offGrid)
         {
             const int omId = idw * out.strides[3] + idz * out.strides[2]
-                                + idy * out.strides[1] + idx;
-            const int pmId = idx;
+                           + idy * out.strides[1] + idx;
+            const int pmId = idx + (pos.dims[1] == 1 ? 0 : idy * pos.strides[1]);
 
             const Tp pVal = pos.ptr[pmId];
             if (pVal < 0 || in.dims[0] < pVal+1) {
@@ -116,9 +118,11 @@ namespace cuda
                            CParam<Tp> pos, CParam<Tp> qos, const float offGrid)
         {
             const int omId = idw * out.strides[3] + idz * out.strides[2]
-                                + idy * out.strides[1] + idx;
-            const int pmId = idy * pos.strides[1] + idx;
-            const int qmId = idy * qos.strides[1] + idx;
+                           + idy * out.strides[1] + idx;
+            const int pmId = (pos.dims[2] == 1 ? 0 : idz * pos.strides[2])
+                           + idy * pos.strides[1] + idx;
+            const int qmId = (qos.dims[2] == 1 ? 0 : idz * qos.strides[2])
+                           + idy * qos.strides[1] + idx;
 
             const Tp x = pos.ptr[pmId], y = qos.ptr[qmId];
             if (x < 0 || y < 0 || in.dims[0] < x+1 || in.dims[1] < y+1) {
diff --git a/src/backend/opencl/kernel/approx1.cl b/src/backend/opencl/kernel/approx1.cl
index 3531e2f..08a6677 100644
--- a/src/backend/opencl/kernel/approx1.cl
+++ b/src/backend/opencl/kernel/approx1.cl
@@ -40,7 +40,7 @@ void core_nearest1(const int idx, const int idy, const int idz, const int idw,
 {
     const int omId = idw * out.strides[3] + idz * out.strides[2]
                    + idy * out.strides[1] + idx;
-    const int pmId = idx;
+    const int pmId = idx + (pos.dims[1] == 1 ? 0 : idy * pos.strides[1]);
 
     const Tp pVal = d_pos[pmId];
     if (pVal < 0 || in.dims[0] < pVal+1) {
@@ -66,8 +66,8 @@ void core_linear1(const int idx, const int idy, const int idz, const int idw,
                    const float offGrid)
 {
     const int omId = idw * out.strides[3] + idz * out.strides[2]
-                        + idy * out.strides[1] + idx;
-    const int pmId = idx;
+                   + idy * out.strides[1] + idx;
+    const int pmId = idx + (pos.dims[1] == 1 ? 0 : idy * pos.strides[1]);
 
     const Tp pVal = d_pos[pmId];
     if (pVal < 0 || in.dims[0] < pVal+1) {
diff --git a/src/backend/opencl/kernel/approx2.cl b/src/backend/opencl/kernel/approx2.cl
index c540e1b..b6ba02a 100644
--- a/src/backend/opencl/kernel/approx2.cl
+++ b/src/backend/opencl/kernel/approx2.cl
@@ -40,9 +40,11 @@ void core_nearest2(const int idx, const int idy, const int idz, const int idw,
                    const float offGrid)
 {
     const int omId = idw * out.strides[3] + idz * out.strides[2]
-                        + idy * out.strides[1] + idx;
-    const int pmId = idy * pos.strides[1] + idx;
-    const int qmId = idy * qos.strides[1] + idx;
+                   + idy * out.strides[1] + idx;
+    const int pmId = (pos.dims[2] == 1 ? 0 : idz * pos.strides[2])
+                    + idy * pos.strides[1] + idx;
+    const int qmId = (qos.dims[2] == 1 ? 0 : idz * qos.strides[2])
+                    + idy * qos.strides[1] + idx;
 
     const Tp x = d_pos[pmId], y = d_qos[qmId];
     if (x < 0 || y < 0 || in.dims[0] < x+1 || in.dims[1] < y+1) {
@@ -71,8 +73,10 @@ void core_linear2(const int idx, const int idy, const int idz, const int idw,
 {
     const int omId = idw * out.strides[3] + idz * out.strides[2]
                         + idy * out.strides[1] + idx;
-    const int pmId = idy * pos.strides[1] + idx;
-    const int qmId = idy * qos.strides[1] + idx;
+    const int pmId = (pos.dims[2] == 1 ? 0 : idz * pos.strides[2])
+                    + idy * pos.strides[1] + idx;
+    const int qmId = (qos.dims[2] == 1 ? 0 : idz * qos.strides[2])
+                    + idy * qos.strides[1] + idx;
 
     const Tp x = d_pos[pmId], y = d_qos[qmId];
     if (x < 0 || y < 0 || in.dims[0] < x+1 || in.dims[1] < y+1) {
diff --git a/test/approx1.cpp b/test/approx1.cpp
index ad6eb3a..e0350d8 100644
--- a/test/approx1.cpp
+++ b/test/approx1.cpp
@@ -233,3 +233,49 @@ TEST(Approx1, CPP)
 
 #undef BT
 }
+
+TEST(Approx1, CPPNearestBatch)
+{
+    if (noDoubleTests<float>()) return;
+
+    af::array input = af::randu(600, 10);
+    af::array pos   = af::randu(100, 10);
+
+    af::array outBatch = af::approx1(input, pos, AF_INTERP_NEAREST);
+
+    af::array outSerial(pos.dims());
+    for(int i = 0; i < pos.dims()[1]; i++) {
+        outSerial(af::span, i) = af::approx1(input(af::span, i), pos(af::span, i), AF_INTERP_NEAREST);
+    }
+
+    //af::array outGFOR(pos.dims());
+    //gfor(af::seq i, 10) {
+    //    outGFOR(af::span, i) = af::approx1(input(af::span, i), pos(af::span, i), AF_INTERP_NEAREST);
+    //}
+
+    ASSERT_NEAR(0, af::sum<double>(af::abs(outBatch - outSerial)), 1e-3);
+    //ASSERT_NEAR(0, af::sum<double>(af::abs(outBatch - outGFOR)), 1e-3);
+}
+
+TEST(Approx1, CPPLinearBatch)
+{
+    if (noDoubleTests<float>()) return;
+
+    af::array input = af::iota(af::dim4(10, 10));
+    af::array pos   = af::randu(10, 10);
+
+    af::array outBatch = af::approx1(input, pos, AF_INTERP_LINEAR);
+
+    af::array outSerial(pos.dims());
+    for(int i = 0; i < pos.dims()[1]; i++) {
+        outSerial(af::span, i) = af::approx1(input(af::span, i), pos(af::span, i), AF_INTERP_LINEAR);
+    }
+
+    //af::array outGFOR(pos.dims());
+    //gfor(af::seq i, 10) {
+    //    outGFOR(af::span, i) = af::approx1(input(af::span, i), pos(af::span, i), AF_INTERP_LINEAR);
+    //}
+
+    ASSERT_NEAR(0, af::sum<double>(af::abs(outBatch - outSerial)), 1e-3);
+    //ASSERT_NEAR(0, af::sum<double>(af::abs(outBatch - outGFOR)), 1e-3);
+}
diff --git a/test/approx2.cpp b/test/approx2.cpp
index 9c748e2..fc2c87f 100644
--- a/test/approx2.cpp
+++ b/test/approx2.cpp
@@ -248,3 +248,55 @@ TEST(Approx2, CPP)
 
 #undef BT
 }
+
+TEST(Approx2, CPPNearestBatch)
+{
+    if (noDoubleTests<float>()) return;
+
+    af::array input = af::randu(200, 100, 10);
+    af::array pos   = af::randu(100, 100, 10);
+    af::array qos   = af::randu(100, 100, 10);
+
+    af::array outBatch = af::approx2(input, pos, qos, AF_INTERP_NEAREST);
+
+    af::array outSerial(pos.dims());
+    for(int i = 0; i < pos.dims()[2]; i++) {
+        outSerial(af::span, af::span, i) = af::approx2(input(af::span, af::span, i),
+            pos(af::span, af::span, i), qos(af::span, af::span, i), AF_INTERP_NEAREST);
+    }
+
+    //af::array outGFOR(pos.dims());
+    //gfor(af::seq i, 10) {
+    //    outGFOR(af::span, af::span, i) = af::approx2(input(af::span, af::span, i),
+    //        pos(af::span, af::span, i), qos(af::span, af::span, i), AF_INTERP_NEAREST);
+    //}
+
+    ASSERT_NEAR(0, af::sum<double>(af::abs(outBatch - outSerial)), 1e-3);
+    //ASSERT_NEAR(0, af::sum<double>(af::abs(outBatch - outGFOR)), 1e-3);
+}
+
+TEST(Approx2, CPPLinearBatch)
+{
+    if (noDoubleTests<float>()) return;
+
+    af::array input = af::randu(200, 100, 10);
+    af::array pos   = af::randu(100, 100, 10);
+    af::array qos   = af::randu(100, 100, 10);
+
+    af::array outBatch = af::approx2(input, pos, qos, AF_INTERP_LINEAR);
+
+    af::array outSerial(pos.dims());
+    for(int i = 0; i < pos.dims()[2]; i++) {
+        outSerial(af::span, af::span, i) = af::approx2(input(af::span, af::span, i),
+            pos(af::span, af::span, i), qos(af::span, af::span, i), AF_INTERP_LINEAR);
+    }
+
+    //af::array outGFOR(pos.dims());
+    //gfor(af::seq i, 10) {
+    //    outGFOR(af::span, af::span, i) = af::approx2(input(af::span, af::span, i),
+    //        pos(af::span, af::span, i), qos(af::span, af::span, i), AF_INTERP_LINEAR);
+    //}
+
+    ASSERT_NEAR(0, af::sum<double>(af::abs(outBatch - outSerial)), 1e-3);
+    //ASSERT_NEAR(0, af::sum<double>(af::abs(outBatch - outGFOR)), 1e-3);
+}

-- 
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/arrayfire.git



More information about the debian-science-commits mailing list