[python-arrayfire] 128/250: Adding support for replacing nan values for reductions
Ghislain Vaillant
ghisvail-guest at moszumanska.debian.org
Mon Mar 28 22:59:40 UTC 2016
This is an automated email from the git hooks/post-receive script.
ghisvail-guest pushed a commit to branch debian/master
in repository python-arrayfire.
commit 3a64cac6e532de137b311475c9dcc1b604112fc4
Author: Pavan Yalamanchili <pavan at arrayfire.com>
Date: Tue Nov 10 08:20:55 2015 -0500
Adding support for replacing nan values for reductions
---
arrayfire/algorithm.py | 49 +++++++++++++++++++++++++++++++++++++++--------
tests/simple/algorithm.py | 7 +++++++
2 files changed, 48 insertions(+), 8 deletions(-)
diff --git a/arrayfire/algorithm.py b/arrayfire/algorithm.py
index ceeaa48..8236637 100644
--- a/arrayfire/algorithm.py
+++ b/arrayfire/algorithm.py
@@ -22,12 +22,29 @@ def _parallel_dim(a, dim, c_func):
def _reduce_all(a, c_func):
real = ct.c_double(0)
imag = ct.c_double(0)
+
safe_call(c_func(ct.pointer(real), ct.pointer(imag), a.arr))
+
real = real.value
imag = imag.value
return real if imag == 0 else real + imag * 1j
-def sum(a, dim=None):
+def _nan_parallel_dim(a, dim, c_func, nan_val):
+ out = Array()
+ safe_call(c_func(ct.pointer(out.arr), a.arr, ct.c_int(dim), ct.c_double(nan_val)))
+ return out
+
+def _nan_reduce_all(a, c_func, nan_val):
+ real = ct.c_double(0)
+ imag = ct.c_double(0)
+
+ safe_call(c_func(ct.pointer(real), ct.pointer(imag), a.arr, ct.c_double(nan_val)))
+
+ real = real.value
+ imag = imag.value
+ return real if imag == 0 else real + imag * 1j
+
+def sum(a, dim=None, nan_val=None):
"""
Calculate the sum of all the elements along a specified dimension.
@@ -37,6 +54,8 @@ def sum(a, dim=None):
Multi dimensional arrayfire array.
dim: optional: int. default: None
Dimension along which the sum is required.
+ nan_val: optional: scalar. default: None
+ The value that replaces NaN in the array
Returns
-------
@@ -44,12 +63,18 @@ def sum(a, dim=None):
The sum of all elements in `a` along dimension `dim`.
If `dim` is `None`, sum of the entire Array is returned.
"""
- if dim is not None:
- return _parallel_dim(a, dim, backend.get().af_sum)
+ if (nan_val is not None):
+ if dim is not None:
+ return _nan_parallel_dim(a, dim, backend.get().af_sum_nan, nan_val)
+ else:
+ return _nan_reduce_all(a, backend.get().af_sum_nan_all, nan_val)
else:
- return _reduce_all(a, backend.get().af_sum_all)
+ if dim is not None:
+ return _parallel_dim(a, dim, backend.get().af_sum)
+ else:
+ return _reduce_all(a, backend.get().af_sum_all)
-def product(a, dim=None):
+def product(a, dim=None, nan_val=None):
"""
Calculate the product of all the elements along a specified dimension.
@@ -59,6 +84,8 @@ def product(a, dim=None):
Multi dimensional arrayfire array.
dim: optional: int. default: None
Dimension along which the product is required.
+ nan_val: optional: scalar. default: None
+ The value that replaces NaN in the array
Returns
-------
@@ -66,10 +93,16 @@ def product(a, dim=None):
The product of all elements in `a` along dimension `dim`.
If `dim` is `None`, product of the entire Array is returned.
"""
- if dim is not None:
- return _parallel_dim(a, dim, backend.get().af_product)
+ if (nan_val is not None):
+ if dim is not None:
+ return _nan_parallel_dim(a, dim, backend.get().af_product_nan, nan_val)
+ else:
+ return _nan_reduce_all(a, backend.get().af_product_nan_all, nan_val)
else:
- return _reduce_all(a, backend.get().af_product_all)
+ if dim is not None:
+ return _parallel_dim(a, dim, backend.get().af_product)
+ else:
+ return _reduce_all(a, backend.get().af_product_all)
def min(a, dim=None):
"""
diff --git a/tests/simple/algorithm.py b/tests/simple/algorithm.py
index 0d75d35..f68e354 100644
--- a/tests/simple/algorithm.py
+++ b/tests/simple/algorithm.py
@@ -47,6 +47,13 @@ def simple_algorithm(verbose = False):
display_func(af.sort(a, is_ascending=True))
display_func(af.sort(a, is_ascending=False))
+ b = (a > 0.1) * a
+ c = (a > 0.4) * a
+ d = b / c
+ print_func(af.sum(d));
+ print_func(af.sum(d, nan_val=0.0));
+ display_func(af.sum(d, dim=0, nan_val=0.0));
+
val,idx = af.sort_index(a, is_ascending=True)
display_func(val)
display_func(idx)
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/python-arrayfire.git
More information about the debian-science-commits
mailing list