[arrayfire] 26/248: FEAT Added batch support for approx1 and approx2
Ghislain Vaillant
ghisvail-guest at moszumanska.debian.org
Tue Nov 17 15:53:50 UTC 2015
This is an automated email from the git hooks/post-receive script.
ghisvail-guest pushed a commit to branch dfsg-clean
in repository arrayfire.
commit 8b94ac1e2162a24cc35ae7009abd9ae4fe86d354
Author: Shehzan Mohammed <shehzan at arrayfire.com>
Date: Mon Aug 31 15:45:21 2015 -0400
FEAT Added batch support for approx1 and approx2
* Added tests
* TODO Enable tests with gfor
---
src/api/c/approx.cpp | 8 +-
src/backend/cpu/approx.cpp | 163 +++++++++++++++++------------------
src/backend/cuda/kernel/approx.hpp | 26 +++---
src/backend/opencl/kernel/approx1.cl | 6 +-
src/backend/opencl/kernel/approx2.cl | 14 +--
test/approx1.cpp | 46 ++++++++++
test/approx2.cpp | 52 +++++++++++
7 files changed, 210 insertions(+), 105 deletions(-)
diff --git a/src/api/c/approx.cpp b/src/api/c/approx.cpp
index 1bc7723..c0bb02c 100644
--- a/src/api/c/approx.cpp
+++ b/src/api/c/approx.cpp
@@ -41,13 +41,16 @@ af_err af_approx1(af_array *out, const af_array in, const af_array pos,
ArrayInfo i_info = getInfo(in);
ArrayInfo p_info = getInfo(pos);
+ dim4 idims = i_info.dims();
+ dim4 pdims = p_info.dims();
+
af_dtype itype = i_info.getType();
ARG_ASSERT(1, i_info.isFloating()); // Only floating and complex types
ARG_ASSERT(2, p_info.isRealFloating()); // Only floating types
ARG_ASSERT(1, i_info.isSingle() == p_info.isSingle()); // Must have same precision
ARG_ASSERT(1, i_info.isDouble() == p_info.isDouble()); // Must have same precision
- DIM_ASSERT(2, p_info.isColumn()); // Only 1D input allowed
+ DIM_ASSERT(2, p_info.isColumn() || pdims[1] == idims[1]); // Only 1D input allowed or Same no. of cols
ARG_ASSERT(3, (method == AF_INTERP_LINEAR || method == AF_INTERP_NEAREST));
af_array output;
@@ -83,7 +86,8 @@ af_err af_approx2(af_array *out, const af_array in, const af_array pos0, const a
ARG_ASSERT(1, i_info.isSingle() == p_info.isSingle()); // Must have same precision
ARG_ASSERT(1, i_info.isDouble() == p_info.isDouble()); // Must have same precision
DIM_ASSERT(2, p_info.dims() == q_info.dims()); // POS0 and POS1 must have same dims
- DIM_ASSERT(2, p_info.ndims() < 3);// Allowing input batch but not positions. Output dims = (px, py, iz, iw)
+ DIM_ASSERT(2, p_info.dims()[2] == 1
+ || p_info.dims()[2] == i_info.dims()[2]); // Allowing input batch. Output dims = (px, py, iz, iw)
ARG_ASSERT(3, (method == AF_INTERP_LINEAR || method == AF_INTERP_NEAREST));
af_array output;
diff --git a/src/backend/cpu/approx.cpp b/src/backend/cpu/approx.cpp
index 69b943a..1522341 100644
--- a/src/backend/cpu/approx.cpp
+++ b/src/backend/cpu/approx.cpp
@@ -25,7 +25,7 @@ namespace cpu
const Ty *in, const af::dim4 &idims, const dim_t iElems,
const Tp *pos, const af::dim4 &pdims,
const af::dim4 &ostrides, const af::dim4 &istrides, const af::dim4 &pstrides,
- const float offGrid, const dim_t idx)
+ const float offGrid, const dim_t idx, const dim_t idy)
{
return;
}
@@ -38,30 +38,28 @@ namespace cpu
const Ty *in, const af::dim4 &idims, const dim_t iElems,
const Tp *pos, const af::dim4 &pdims,
const af::dim4 &ostrides, const af::dim4 &istrides, const af::dim4 &pstrides,
- const float offGrid, const dim_t idx)
+ const float offGrid, const dim_t idx, const dim_t idy)
{
- const dim_t pmId = idx;
+ const dim_t pmId = idx + (pdims[1] == 1 ? 0 : idy * pstrides[1]);
const Tp x = pos[pmId];
bool gFlag = false;
- if (x < 0 || idims[0] < x+1) {
+ if (x < 0 || idims[0] < x+1) { // No need to check y
gFlag = true;
}
for(dim_t idw = 0; idw < odims[3]; idw++) {
for(dim_t idz = 0; idz < odims[2]; idz++) {
- for(dim_t idy = 0; idy < odims[1]; idy++) {
- const dim_t omId = idw * ostrides[3] + idz * ostrides[2]
- + idy * ostrides[1] + idx;
- if(gFlag) {
- out[omId] = scalar<Ty>(offGrid);
- } else {
- dim_t ioff = idw * istrides[3] + idz * istrides[2]
- + idy * istrides[1];
- const dim_t iMem = round(x) + ioff;
-
- out[omId] = in[iMem];
- }
+ const dim_t omId = idw * ostrides[3] + idz * ostrides[2]
+ + idy * ostrides[1] + idx;
+ if(gFlag) {
+ out[omId] = scalar<Ty>(offGrid);
+ } else {
+ dim_t ioff = idw * istrides[3] + idz * istrides[2]
+ + idy * istrides[1];
+ const dim_t iMem = round(x) + ioff;
+
+ out[omId] = in[iMem];
}
}
}
@@ -75,9 +73,9 @@ namespace cpu
const Ty *in, const af::dim4 &idims, const dim_t iElems,
const Tp *pos, const af::dim4 &pdims,
const af::dim4 &ostrides, const af::dim4 &istrides, const af::dim4 &pstrides,
- const float offGrid, const dim_t idx)
+ const float offGrid, const dim_t idx, const dim_t idy)
{
- const dim_t pmId = idx;
+ const dim_t pmId = idx + (pdims[1] == 1 ? 0 : idy * pstrides[1]);
const Tp x = pos[pmId];
bool gFlag = false;
@@ -90,25 +88,23 @@ namespace cpu
for(dim_t idw = 0; idw < odims[3]; idw++) {
for(dim_t idz = 0; idz < odims[2]; idz++) {
- for(dim_t idy = 0; idy < odims[1]; idy++) {
- const dim_t omId = idw * ostrides[3] + idz * ostrides[2]
- + idy * ostrides[1] + idx;
- if(gFlag) {
- out[omId] = scalar<Ty>(offGrid);
- } else {
- dim_t ioff = idw * istrides[3] + idz * istrides[2] + idy * istrides[1] + grid_x;
-
- // Check if x and x + 1 are both valid indices
- bool cond = (x < idims[0] - 1);
- // Compute Left and Right Weighted Values
- Ty yl = ((Tp)1.0 - off_x) * in[ioff];
- Ty yr = cond ? (off_x) * in[ioff + 1] : scalar<Ty>(0);
- Ty yo = yl + yr;
- // Compute Weight used
- Tp wt = cond ? (Tp)1.0 : (Tp)(1.0 - off_x);
- // Write final value
- out[omId] = (yo / wt);
- }
+ const dim_t omId = idw * ostrides[3] + idz * ostrides[2]
+ + idy * ostrides[1] + idx;
+ if(gFlag) {
+ out[omId] = scalar<Ty>(offGrid);
+ } else {
+ dim_t ioff = idw * istrides[3] + idz * istrides[2] + idy * istrides[1] + grid_x;
+
+ // Check if x and x + 1 are both valid indices
+ bool cond = (x < idims[0] - 1);
+ // Compute Left and Right Weighted Values
+ Ty yl = ((Tp)1.0 - off_x) * in[ioff];
+ Ty yr = cond ? (off_x) * in[ioff + 1] : scalar<Ty>(0);
+ Ty yo = yl + yr;
+ // Compute Weight used
+ Tp wt = cond ? (Tp)1.0 : (Tp)(1.0 - off_x);
+ // Write final value
+ out[omId] = (yo / wt);
}
}
}
@@ -123,9 +119,11 @@ namespace cpu
const float offGrid)
{
approx1_op<Ty, Tp, method> op;
- for(dim_t x = 0; x < odims[0]; x++) {
- op(out, odims, oElems, in, idims, iElems, pos, pdims,
- ostrides, istrides, pstrides, offGrid, x);
+ for(dim_t y = 0; y < odims[1]; y++) {
+ for(dim_t x = 0; x < odims[0]; x++) {
+ op(out, odims, oElems, in, idims, iElems, pos, pdims,
+ ostrides, istrides, pstrides, offGrid, x, y);
+ }
}
}
@@ -169,7 +167,7 @@ namespace cpu
const Tp *pos, const af::dim4 &pdims, const Tp *qos, const af::dim4 &qdims,
const af::dim4 &ostrides, const af::dim4 &istrides,
const af::dim4 &pstrides, const af::dim4 &qstrides,
- const float offGrid, const dim_t idx, const dim_t idy)
+ const float offGrid, const dim_t idx, const dim_t idy, const dim_t idz)
{
return;
}
@@ -183,10 +181,10 @@ namespace cpu
const Tp *pos, const af::dim4 &pdims, const Tp *qos, const af::dim4 &qdims,
const af::dim4 &ostrides, const af::dim4 &istrides,
const af::dim4 &pstrides, const af::dim4 &qstrides,
- const float offGrid, const dim_t idx, const dim_t idy)
+ const float offGrid, const dim_t idx, const dim_t idy, const dim_t idz)
{
- const dim_t pmId = idy * pstrides[1] + idx;
- const dim_t qmId = idy * qstrides[1] + idx;
+ const dim_t pmId = (pdims[2] == 1 ? 0 : idz * pstrides[2]) + idy * pstrides[1] + idx;
+ const dim_t qmId = (qdims[2] == 1 ? 0 : idz * qstrides[2]) + idy * qstrides[1] + idx;
bool gFlag = false;
const Tp x = pos[pmId], y = qos[qmId];
@@ -195,18 +193,16 @@ namespace cpu
}
for(dim_t idw = 0; idw < odims[3]; idw++) {
- for(dim_t idz = 0; idz < odims[2]; idz++) {
- const dim_t omId = idw * ostrides[3] + idz * ostrides[2]
- + idy * ostrides[1] + idx;
- if(gFlag) {
- out[omId] = scalar<Ty>(offGrid);
- } else {
- const dim_t grid_x = round(x), grid_y = round(y); // nearest grid
- const dim_t imId = idw * istrides[3] +
- idz * istrides[2] +
- grid_y * istrides[1] + grid_x;
- out[omId] = in[imId];
- }
+ const dim_t omId = idw * ostrides[3] + idz * ostrides[2]
+ + idy * ostrides[1] + idx;
+ if(gFlag) {
+ out[omId] = scalar<Ty>(offGrid);
+ } else {
+ const dim_t grid_x = round(x), grid_y = round(y); // nearest grid
+ const dim_t imId = idw * istrides[3] +
+ idz * istrides[2] +
+ grid_y * istrides[1] + grid_x;
+ out[omId] = in[imId];
}
}
}
@@ -220,10 +216,10 @@ namespace cpu
const Tp *pos, const af::dim4 &pdims, const Tp *qos, const af::dim4 &qdims,
const af::dim4 &ostrides, const af::dim4 &istrides,
const af::dim4 &pstrides, const af::dim4 &qstrides,
- const float offGrid, const dim_t idx, const dim_t idy)
+ const float offGrid, const dim_t idx, const dim_t idy, const dim_t idz)
{
- const dim_t pmId = idy * pstrides[1] + idx;
- const dim_t qmId = idy * qstrides[1] + idx;
+ const dim_t pmId = (pdims[2] == 1 ? 0 : idz * pstrides[2]) + idy * pstrides[1] + idx;
+ const dim_t qmId = (qdims[2] == 1 ? 0 : idz * qstrides[2]) + idy * qstrides[1] + idx;
bool gFlag = false;
const Tp x = pos[pmId], y = qos[qmId];
@@ -248,27 +244,24 @@ namespace cpu
Ty zero = scalar<Ty>(0);
for(dim_t idw = 0; idw < odims[3]; idw++) {
- for(dim_t idz = 0; idz < odims[2]; idz++) {
- const dim_t omId = idw * ostrides[3] + idz * ostrides[2]
- + idy * ostrides[1] + idx;
- if(gFlag) {
- out[omId] = scalar<Ty>(offGrid);
- } else {
- dim_t ioff = idw * istrides[3] + idz * istrides[2]
- + grid_y * istrides[1] + grid_x;
-
- // Compute Weighted Values
- Ty y00 = wt00 * in[ioff];
- Ty y10 = (condY) ? wt10 * in[ioff + istrides[1]] : zero;
- Ty y01 = (condX) ? wt01 * in[ioff + 1] : zero;
- Ty y11 = (condX && condY) ? wt11 * in[ioff + istrides[1] + 1] : zero;
-
- Ty yo = y00 + y10 + y01 + y11;
-
- // Write Final Value
- out[omId] = (yo / wt);
-
- }
+ const dim_t omId = idw * ostrides[3] + idz * ostrides[2]
+ + idy * ostrides[1] + idx;
+ if(gFlag) {
+ out[omId] = scalar<Ty>(offGrid);
+ } else {
+ dim_t ioff = idw * istrides[3] + idz * istrides[2]
+ + grid_y * istrides[1] + grid_x;
+
+ // Compute Weighted Values
+ Ty y00 = wt00 * in[ioff];
+ Ty y10 = (condY) ? wt10 * in[ioff + istrides[1]] : zero;
+ Ty y01 = (condX) ? wt01 * in[ioff + 1] : zero;
+ Ty y11 = (condX && condY) ? wt11 * in[ioff + istrides[1] + 1] : zero;
+
+ Ty yo = y00 + y10 + y01 + y11;
+
+ // Write Final Value
+ out[omId] = (yo / wt);
}
}
}
@@ -283,10 +276,12 @@ namespace cpu
const float offGrid)
{
approx2_op<Ty, Tp, method> op;
- for(dim_t y = 0; y < odims[1]; y++) {
- for(dim_t x = 0; x < odims[0]; x++) {
- op(out, odims, oElems, in, idims, iElems, pos, pdims, qos, qdims,
- ostrides, istrides, pstrides, qstrides, offGrid, x, y);
+ for(dim_t z = 0; z < odims[2]; z++) {
+ for(dim_t y = 0; y < odims[1]; y++) {
+ for(dim_t x = 0; x < odims[0]; x++) {
+ op(out, odims, oElems, in, idims, iElems, pos, pdims, qos, qdims,
+ ostrides, istrides, pstrides, qstrides, offGrid, x, y, z);
+ }
}
}
}
diff --git a/src/backend/cuda/kernel/approx.hpp b/src/backend/cuda/kernel/approx.hpp
index 6c9dd7d..ced6c4f 100644
--- a/src/backend/cuda/kernel/approx.hpp
+++ b/src/backend/cuda/kernel/approx.hpp
@@ -32,8 +32,8 @@ namespace cuda
const float offGrid)
{
const int omId = idw * out.strides[3] + idz * out.strides[2]
- + idy * out.strides[1] + idx;
- const int pmId = idx;
+ + idy * out.strides[1] + idx;
+ const int pmId = idx + (pos.dims[1] == 1 ? 0 : idy * pos.strides[1]);
const Tp x = pos.ptr[pmId];
if (x < 0 || in.dims[0] < x+1) {
@@ -55,9 +55,11 @@ namespace cuda
CParam<Tp> pos, CParam<Tp> qos, const float offGrid)
{
const int omId = idw * out.strides[3] + idz * out.strides[2]
- + idy * out.strides[1] + idx;
- const int pmId = idy * pos.strides[1] + idx;
- const int qmId = idy * qos.strides[1] + idx;
+ + idy * out.strides[1] + idx;
+ const int pmId = (pos.dims[2] == 1 ? 0 : idz * pos.strides[2])
+ + idy * pos.strides[1] + idx;
+ const int qmId = (qos.dims[2] == 1 ? 0 : idz * qos.strides[2])
+ + idy * qos.strides[1] + idx;
const Tp x = pos.ptr[pmId], y = qos.ptr[qmId];
if (x < 0 || y < 0 || in.dims[0] < x+1 || in.dims[1] < y+1) {
@@ -67,7 +69,7 @@ namespace cuda
const int grid_x = round(x), grid_y = round(y); // nearest grid
const int imId = idw * in.strides[3] + idz * in.strides[2]
- + grid_y * in.strides[1] + grid_x;
+ + grid_y * in.strides[1] + grid_x;
Ty val = in.ptr[imId];
out.ptr[omId] = val;
@@ -83,8 +85,8 @@ namespace cuda
const float offGrid)
{
const int omId = idw * out.strides[3] + idz * out.strides[2]
- + idy * out.strides[1] + idx;
- const int pmId = idx;
+ + idy * out.strides[1] + idx;
+ const int pmId = idx + (pos.dims[1] == 1 ? 0 : idy * pos.strides[1]);
const Tp pVal = pos.ptr[pmId];
if (pVal < 0 || in.dims[0] < pVal+1) {
@@ -116,9 +118,11 @@ namespace cuda
CParam<Tp> pos, CParam<Tp> qos, const float offGrid)
{
const int omId = idw * out.strides[3] + idz * out.strides[2]
- + idy * out.strides[1] + idx;
- const int pmId = idy * pos.strides[1] + idx;
- const int qmId = idy * qos.strides[1] + idx;
+ + idy * out.strides[1] + idx;
+ const int pmId = (pos.dims[2] == 1 ? 0 : idz * pos.strides[2])
+ + idy * pos.strides[1] + idx;
+ const int qmId = (qos.dims[2] == 1 ? 0 : idz * qos.strides[2])
+ + idy * qos.strides[1] + idx;
const Tp x = pos.ptr[pmId], y = qos.ptr[qmId];
if (x < 0 || y < 0 || in.dims[0] < x+1 || in.dims[1] < y+1) {
diff --git a/src/backend/opencl/kernel/approx1.cl b/src/backend/opencl/kernel/approx1.cl
index 3531e2f..08a6677 100644
--- a/src/backend/opencl/kernel/approx1.cl
+++ b/src/backend/opencl/kernel/approx1.cl
@@ -40,7 +40,7 @@ void core_nearest1(const int idx, const int idy, const int idz, const int idw,
{
const int omId = idw * out.strides[3] + idz * out.strides[2]
+ idy * out.strides[1] + idx;
- const int pmId = idx;
+ const int pmId = idx + (pos.dims[1] == 1 ? 0 : idy * pos.strides[1]);
const Tp pVal = d_pos[pmId];
if (pVal < 0 || in.dims[0] < pVal+1) {
@@ -66,8 +66,8 @@ void core_linear1(const int idx, const int idy, const int idz, const int idw,
const float offGrid)
{
const int omId = idw * out.strides[3] + idz * out.strides[2]
- + idy * out.strides[1] + idx;
- const int pmId = idx;
+ + idy * out.strides[1] + idx;
+ const int pmId = idx + (pos.dims[1] == 1 ? 0 : idy * pos.strides[1]);
const Tp pVal = d_pos[pmId];
if (pVal < 0 || in.dims[0] < pVal+1) {
diff --git a/src/backend/opencl/kernel/approx2.cl b/src/backend/opencl/kernel/approx2.cl
index c540e1b..b6ba02a 100644
--- a/src/backend/opencl/kernel/approx2.cl
+++ b/src/backend/opencl/kernel/approx2.cl
@@ -40,9 +40,11 @@ void core_nearest2(const int idx, const int idy, const int idz, const int idw,
const float offGrid)
{
const int omId = idw * out.strides[3] + idz * out.strides[2]
- + idy * out.strides[1] + idx;
- const int pmId = idy * pos.strides[1] + idx;
- const int qmId = idy * qos.strides[1] + idx;
+ + idy * out.strides[1] + idx;
+ const int pmId = (pos.dims[2] == 1 ? 0 : idz * pos.strides[2])
+ + idy * pos.strides[1] + idx;
+ const int qmId = (qos.dims[2] == 1 ? 0 : idz * qos.strides[2])
+ + idy * qos.strides[1] + idx;
const Tp x = d_pos[pmId], y = d_qos[qmId];
if (x < 0 || y < 0 || in.dims[0] < x+1 || in.dims[1] < y+1) {
@@ -71,8 +73,10 @@ void core_linear2(const int idx, const int idy, const int idz, const int idw,
{
const int omId = idw * out.strides[3] + idz * out.strides[2]
+ idy * out.strides[1] + idx;
- const int pmId = idy * pos.strides[1] + idx;
- const int qmId = idy * qos.strides[1] + idx;
+ const int pmId = (pos.dims[2] == 1 ? 0 : idz * pos.strides[2])
+ + idy * pos.strides[1] + idx;
+ const int qmId = (qos.dims[2] == 1 ? 0 : idz * qos.strides[2])
+ + idy * qos.strides[1] + idx;
const Tp x = d_pos[pmId], y = d_qos[qmId];
if (x < 0 || y < 0 || in.dims[0] < x+1 || in.dims[1] < y+1) {
diff --git a/test/approx1.cpp b/test/approx1.cpp
index ad6eb3a..e0350d8 100644
--- a/test/approx1.cpp
+++ b/test/approx1.cpp
@@ -233,3 +233,49 @@ TEST(Approx1, CPP)
#undef BT
}
+
+TEST(Approx1, CPPNearestBatch)
+{
+ if (noDoubleTests<float>()) return;
+
+ af::array input = af::randu(600, 10);
+ af::array pos = af::randu(100, 10);
+
+ af::array outBatch = af::approx1(input, pos, AF_INTERP_NEAREST);
+
+ af::array outSerial(pos.dims());
+ for(int i = 0; i < pos.dims()[1]; i++) {
+ outSerial(af::span, i) = af::approx1(input(af::span, i), pos(af::span, i), AF_INTERP_NEAREST);
+ }
+
+ //af::array outGFOR(pos.dims());
+ //gfor(af::seq i, 10) {
+ // outGFOR(af::span, i) = af::approx1(input(af::span, i), pos(af::span, i), AF_INTERP_NEAREST);
+ //}
+
+ ASSERT_NEAR(0, af::sum<double>(af::abs(outBatch - outSerial)), 1e-3);
+ //ASSERT_NEAR(0, af::sum<double>(af::abs(outBatch - outGFOR)), 1e-3);
+}
+
+TEST(Approx1, CPPLinearBatch)
+{
+ if (noDoubleTests<float>()) return;
+
+ af::array input = af::iota(af::dim4(10, 10));
+ af::array pos = af::randu(10, 10);
+
+ af::array outBatch = af::approx1(input, pos, AF_INTERP_LINEAR);
+
+ af::array outSerial(pos.dims());
+ for(int i = 0; i < pos.dims()[1]; i++) {
+ outSerial(af::span, i) = af::approx1(input(af::span, i), pos(af::span, i), AF_INTERP_LINEAR);
+ }
+
+ //af::array outGFOR(pos.dims());
+ //gfor(af::seq i, 10) {
+ // outGFOR(af::span, i) = af::approx1(input(af::span, i), pos(af::span, i), AF_INTERP_LINEAR);
+ //}
+
+ ASSERT_NEAR(0, af::sum<double>(af::abs(outBatch - outSerial)), 1e-3);
+ //ASSERT_NEAR(0, af::sum<double>(af::abs(outBatch - outGFOR)), 1e-3);
+}
diff --git a/test/approx2.cpp b/test/approx2.cpp
index 9c748e2..fc2c87f 100644
--- a/test/approx2.cpp
+++ b/test/approx2.cpp
@@ -248,3 +248,55 @@ TEST(Approx2, CPP)
#undef BT
}
+
+TEST(Approx2, CPPNearestBatch)
+{
+ if (noDoubleTests<float>()) return;
+
+ af::array input = af::randu(200, 100, 10);
+ af::array pos = af::randu(100, 100, 10);
+ af::array qos = af::randu(100, 100, 10);
+
+ af::array outBatch = af::approx2(input, pos, qos, AF_INTERP_NEAREST);
+
+ af::array outSerial(pos.dims());
+ for(int i = 0; i < pos.dims()[2]; i++) {
+ outSerial(af::span, af::span, i) = af::approx2(input(af::span, af::span, i),
+ pos(af::span, af::span, i), qos(af::span, af::span, i), AF_INTERP_NEAREST);
+ }
+
+ //af::array outGFOR(pos.dims());
+ //gfor(af::seq i, 10) {
+ // outGFOR(af::span, af::span, i) = af::approx2(input(af::span, af::span, i),
+ // pos(af::span, af::span, i), qos(af::span, af::span, i), AF_INTERP_NEAREST);
+ //}
+
+ ASSERT_NEAR(0, af::sum<double>(af::abs(outBatch - outSerial)), 1e-3);
+ //ASSERT_NEAR(0, af::sum<double>(af::abs(outBatch - outGFOR)), 1e-3);
+}
+
+TEST(Approx2, CPPLinearBatch)
+{
+ if (noDoubleTests<float>()) return;
+
+ af::array input = af::randu(200, 100, 10);
+ af::array pos = af::randu(100, 100, 10);
+ af::array qos = af::randu(100, 100, 10);
+
+ af::array outBatch = af::approx2(input, pos, qos, AF_INTERP_LINEAR);
+
+ af::array outSerial(pos.dims());
+ for(int i = 0; i < pos.dims()[2]; i++) {
+ outSerial(af::span, af::span, i) = af::approx2(input(af::span, af::span, i),
+ pos(af::span, af::span, i), qos(af::span, af::span, i), AF_INTERP_LINEAR);
+ }
+
+ //af::array outGFOR(pos.dims());
+ //gfor(af::seq i, 10) {
+ // outGFOR(af::span, af::span, i) = af::approx2(input(af::span, af::span, i),
+ // pos(af::span, af::span, i), qos(af::span, af::span, i), AF_INTERP_LINEAR);
+ //}
+
+ ASSERT_NEAR(0, af::sum<double>(af::abs(outBatch - outSerial)), 1e-3);
+ //ASSERT_NEAR(0, af::sum<double>(af::abs(outBatch - outGFOR)), 1e-3);
+}
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/arrayfire.git
More information about the debian-science-commits
mailing list