[python-arrayfire] 47/250: FEAT: Adding broadcast to arrayfire

Ghislain Vaillant ghisvail-guest at moszumanska.debian.org
Mon Mar 28 22:59:29 UTC 2016


This is an automated email from the git hooks/post-receive script.

ghisvail-guest pushed a commit to branch debian/master
in repository python-arrayfire.

commit 870ee506075be0c9614a2d8fda20ee9d203c4d87
Author: Pavan Yalamanchili <pavan at arrayfire.com>
Date:   Tue Jul 28 16:56:43 2015 -0400

    FEAT: Adding broadcast to arrayfire
---
 arrayfire/__init__.py  |  2 ++
 arrayfire/arith.py     |  7 ++++---
 arrayfire/array.py     |  5 +++--
 arrayfire/broadcast.py | 26 ++++++++++++++++++++++++++
 tests/simple_arith.py  |  7 +++++++
 5 files changed, 42 insertions(+), 5 deletions(-)

diff --git a/arrayfire/__init__.py b/arrayfire/__init__.py
index 7062420..e1bf098 100644
--- a/arrayfire/__init__.py
+++ b/arrayfire/__init__.py
@@ -22,6 +22,7 @@ from .image      import *
 from .features   import *
 from .vision     import *
 from .graphics   import *
+from .broadcast  import *
 
 # do not export default modules as part of arrayfire
 del ct
@@ -34,6 +35,7 @@ del uidx
 del seq
 del index
 del cell
+del bcast
 
 #do not export internal functions
 del binary_func
diff --git a/arrayfire/arith.py b/arrayfire/arith.py
index 24add97..b94bf36 100644
--- a/arrayfire/arith.py
+++ b/arrayfire/arith.py
@@ -9,6 +9,7 @@
 
 from .library import *
 from .array import *
+from .broadcast import *
 
 def arith_binary_func(lhs, rhs, c_func):
     out = array()
@@ -20,21 +21,21 @@ def arith_binary_func(lhs, rhs, c_func):
         TypeError("Atleast one input needs to be of type arrayfire.array")
 
     elif (is_left_array and is_right_array):
-        safe_call(c_func(ct.pointer(out.arr), lhs.arr, rhs.arr, False))
+        safe_call(c_func(ct.pointer(out.arr), lhs.arr, rhs.arr, bcast.get()))
 
     elif (is_number(rhs)):
         ldims = dim4_tuple(lhs.dims())
         lty = lhs.type()
         other = array()
         other.arr = constant_array(rhs, ldims[0], ldims[1], ldims[2], ldims[3], lty)
-        safe_call(c_func(ct.pointer(out.arr), lhs.arr, other.arr, False))
+        safe_call(c_func(ct.pointer(out.arr), lhs.arr, other.arr, bcast.get()))
 
     else:
         rdims = dim4_tuple(rhs.dims())
         rty = rhs.type()
         other = array()
         other.arr = constant_array(lhs, rdims[0], rdims[1], rdims[2], rdims[3], rty)
-        safe_call(c_func(ct.pointer(out.arr), lhs.arr, other.arr, False))
+        safe_call(c_func(ct.pointer(out.arr), lhs.arr, other.arr, bcast.get()))
 
     return out
 
diff --git a/arrayfire/array.py b/arrayfire/array.py
index 98dfff8..2253f1e 100644
--- a/arrayfire/array.py
+++ b/arrayfire/array.py
@@ -10,6 +10,7 @@
 import inspect
 from .library import *
 from .util import *
+from .broadcast import *
 
 def create_array(buf, numdims, idims, dtype):
     out_arr = ct.c_longlong(0)
@@ -63,7 +64,7 @@ def binary_func(lhs, rhs, c_func):
     elif not isinstance(rhs, array):
         raise TypeError("Invalid parameter to binary function")
 
-    safe_call(c_func(ct.pointer(out.arr), lhs.arr, other.arr, False))
+    safe_call(c_func(ct.pointer(out.arr), lhs.arr, other.arr, bcast.get()))
 
     return out
 
@@ -79,7 +80,7 @@ def binary_funcr(lhs, rhs, c_func):
     elif not isinstance(lhs, array):
         raise TypeError("Invalid parameter to binary function")
 
-    c_func(ct.pointer(out.arr), other.arr, rhs.arr, False)
+    c_func(ct.pointer(out.arr), other.arr, rhs.arr, bcast.get())
 
     return out
 
diff --git a/arrayfire/broadcast.py b/arrayfire/broadcast.py
new file mode 100644
index 0000000..15e4228
--- /dev/null
+++ b/arrayfire/broadcast.py
@@ -0,0 +1,26 @@
+#######################################################
+# Copyright (c) 2015, ArrayFire
+# All rights reserved.
+#
+# This file is distributed under 3-clause BSD license.
+# The complete license agreement can be obtained at:
+# http://arrayfire.com/licenses/BSD-3-Clause
+########################################################
+
+
+class bcast(object):
+    _flag = False
+    def get():
+        return bcast._flag
+
+    def set(flag):
+        bcast._flag = flag
+
+    def toggle():
+        bcast._flag ^= True
+
+def broadcast(func, *args):
+    bcast.toggle()
+    res = func(*args)
+    bcast.toggle()
+    return res
diff --git a/tests/simple_arith.py b/tests/simple_arith.py
index b0aa017..e6941f1 100755
--- a/tests/simple_arith.py
+++ b/tests/simple_arith.py
@@ -188,3 +188,10 @@ af.display(af.lgamma(a))
 af.display(af.iszero(a))
 af.display(af.isinf(a/b))
 af.display(af.isnan(a/a))
+
+a = af.randu(5, 1)
+b = af.randu(5, 5)
+c = af.broadcast(lambda x,y: x+y, a, b)
+af.display(a)
+af.display(b)
+af.display(c)

-- 
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/python-arrayfire.git



More information about the debian-science-commits mailing list