[python-arrayfire] 128/250: Adding support for replacing nan values for reductions

Ghislain Vaillant ghisvail-guest at moszumanska.debian.org
Mon Mar 28 22:59:40 UTC 2016


This is an automated email from the git hooks/post-receive script.

ghisvail-guest pushed a commit to branch debian/master
in repository python-arrayfire.

commit 3a64cac6e532de137b311475c9dcc1b604112fc4
Author: Pavan Yalamanchili <pavan at arrayfire.com>
Date:   Tue Nov 10 08:20:55 2015 -0500

    Adding support for replacing nan values for reductions
---
 arrayfire/algorithm.py    | 49 +++++++++++++++++++++++++++++++++++++++--------
 tests/simple/algorithm.py |  7 +++++++
 2 files changed, 48 insertions(+), 8 deletions(-)

diff --git a/arrayfire/algorithm.py b/arrayfire/algorithm.py
index ceeaa48..8236637 100644
--- a/arrayfire/algorithm.py
+++ b/arrayfire/algorithm.py
@@ -22,12 +22,29 @@ def _parallel_dim(a, dim, c_func):
 def _reduce_all(a, c_func):
     real = ct.c_double(0)
     imag = ct.c_double(0)
+
     safe_call(c_func(ct.pointer(real), ct.pointer(imag), a.arr))
+
     real = real.value
     imag = imag.value
     return real if imag == 0 else real + imag * 1j
 
-def sum(a, dim=None):
+def _nan_parallel_dim(a, dim, c_func, nan_val):
+    out = Array()
+    safe_call(c_func(ct.pointer(out.arr), a.arr, ct.c_int(dim), ct.c_double(nan_val)))
+    return out
+
+def _nan_reduce_all(a, c_func, nan_val):
+    real = ct.c_double(0)
+    imag = ct.c_double(0)
+
+    safe_call(c_func(ct.pointer(real), ct.pointer(imag), a.arr, ct.c_double(nan_val)))
+
+    real = real.value
+    imag = imag.value
+    return real if imag == 0 else real + imag * 1j
+
+def sum(a, dim=None, nan_val=None):
     """
     Calculate the sum of all the elements along a specified dimension.
 
@@ -37,6 +54,8 @@ def sum(a, dim=None):
          Multi dimensional arrayfire array.
     dim: optional: int. default: None
          Dimension along which the sum is required.
+    nan_val: optional: scalar. default: None
+         The value that replaces NaN in the array
 
     Returns
     -------
@@ -44,12 +63,18 @@ def sum(a, dim=None):
          The sum of all elements in `a` along dimension `dim`.
          If `dim` is `None`, sum of the entire Array is returned.
     """
-    if dim is not None:
-        return _parallel_dim(a, dim, backend.get().af_sum)
+    if (nan_val is not None):
+        if dim is not None:
+            return _nan_parallel_dim(a, dim, backend.get().af_sum_nan, nan_val)
+        else:
+            return _nan_reduce_all(a, backend.get().af_sum_nan_all, nan_val)
     else:
-        return _reduce_all(a, backend.get().af_sum_all)
+        if dim is not None:
+            return _parallel_dim(a, dim, backend.get().af_sum)
+        else:
+            return _reduce_all(a, backend.get().af_sum_all)
 
-def product(a, dim=None):
+def product(a, dim=None, nan_val=None):
     """
     Calculate the product of all the elements along a specified dimension.
 
@@ -59,6 +84,8 @@ def product(a, dim=None):
          Multi dimensional arrayfire array.
     dim: optional: int. default: None
          Dimension along which the product is required.
+    nan_val: optional: scalar. default: None
+         The value that replaces NaN in the array
 
     Returns
     -------
@@ -66,10 +93,16 @@ def product(a, dim=None):
          The product of all elements in `a` along dimension `dim`.
          If `dim` is `None`, product of the entire Array is returned.
     """
-    if dim is not None:
-        return _parallel_dim(a, dim, backend.get().af_product)
+    if (nan_val is not None):
+        if dim is not None:
+            return _nan_parallel_dim(a, dim, backend.get().af_product_nan, nan_val)
+        else:
+            return _nan_reduce_all(a, backend.get().af_product_nan_all, nan_val)
     else:
-        return _reduce_all(a, backend.get().af_product_all)
+        if dim is not None:
+            return _parallel_dim(a, dim, backend.get().af_product)
+        else:
+            return _reduce_all(a, backend.get().af_product_all)
 
 def min(a, dim=None):
     """
diff --git a/tests/simple/algorithm.py b/tests/simple/algorithm.py
index 0d75d35..f68e354 100644
--- a/tests/simple/algorithm.py
+++ b/tests/simple/algorithm.py
@@ -47,6 +47,13 @@ def simple_algorithm(verbose = False):
     display_func(af.sort(a, is_ascending=True))
     display_func(af.sort(a, is_ascending=False))
 
+    b = (a > 0.1) * a
+    c = (a > 0.4) * a
+    d = b / c
+    print_func(af.sum(d));
+    print_func(af.sum(d, nan_val=0.0));
+    display_func(af.sum(d, dim=0, nan_val=0.0));
+
     val,idx = af.sort_index(a, is_ascending=True)
     display_func(val)
     display_func(idx)

-- 
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/python-arrayfire.git



More information about the debian-science-commits mailing list