[python-arrayfire] 47/250: FEAT: Adding broadcast to arrayfire
Ghislain Vaillant
ghisvail-guest at moszumanska.debian.org
Mon Mar 28 22:59:29 UTC 2016
This is an automated email from the git hooks/post-receive script.
ghisvail-guest pushed a commit to branch debian/master
in repository python-arrayfire.
commit 870ee506075be0c9614a2d8fda20ee9d203c4d87
Author: Pavan Yalamanchili <pavan at arrayfire.com>
Date: Tue Jul 28 16:56:43 2015 -0400
FEAT: Adding broadcast to arrayfire
---
arrayfire/__init__.py | 2 ++
arrayfire/arith.py | 7 ++++---
arrayfire/array.py | 5 +++--
arrayfire/broadcast.py | 26 ++++++++++++++++++++++++++
tests/simple_arith.py | 7 +++++++
5 files changed, 42 insertions(+), 5 deletions(-)
diff --git a/arrayfire/__init__.py b/arrayfire/__init__.py
index 7062420..e1bf098 100644
--- a/arrayfire/__init__.py
+++ b/arrayfire/__init__.py
@@ -22,6 +22,7 @@ from .image import *
from .features import *
from .vision import *
from .graphics import *
+from .broadcast import *
# do not export default modules as part of arrayfire
del ct
@@ -34,6 +35,7 @@ del uidx
del seq
del index
del cell
+del bcast
#do not export internal functions
del binary_func
diff --git a/arrayfire/arith.py b/arrayfire/arith.py
index 24add97..b94bf36 100644
--- a/arrayfire/arith.py
+++ b/arrayfire/arith.py
@@ -9,6 +9,7 @@
from .library import *
from .array import *
+from .broadcast import *
def arith_binary_func(lhs, rhs, c_func):
out = array()
@@ -20,21 +21,21 @@ def arith_binary_func(lhs, rhs, c_func):
TypeError("Atleast one input needs to be of type arrayfire.array")
elif (is_left_array and is_right_array):
- safe_call(c_func(ct.pointer(out.arr), lhs.arr, rhs.arr, False))
+ safe_call(c_func(ct.pointer(out.arr), lhs.arr, rhs.arr, bcast.get()))
elif (is_number(rhs)):
ldims = dim4_tuple(lhs.dims())
lty = lhs.type()
other = array()
other.arr = constant_array(rhs, ldims[0], ldims[1], ldims[2], ldims[3], lty)
- safe_call(c_func(ct.pointer(out.arr), lhs.arr, other.arr, False))
+ safe_call(c_func(ct.pointer(out.arr), lhs.arr, other.arr, bcast.get()))
else:
rdims = dim4_tuple(rhs.dims())
rty = rhs.type()
other = array()
other.arr = constant_array(lhs, rdims[0], rdims[1], rdims[2], rdims[3], rty)
- safe_call(c_func(ct.pointer(out.arr), lhs.arr, other.arr, False))
+ safe_call(c_func(ct.pointer(out.arr), lhs.arr, other.arr, bcast.get()))
return out
diff --git a/arrayfire/array.py b/arrayfire/array.py
index 98dfff8..2253f1e 100644
--- a/arrayfire/array.py
+++ b/arrayfire/array.py
@@ -10,6 +10,7 @@
import inspect
from .library import *
from .util import *
+from .broadcast import *
def create_array(buf, numdims, idims, dtype):
out_arr = ct.c_longlong(0)
@@ -63,7 +64,7 @@ def binary_func(lhs, rhs, c_func):
elif not isinstance(rhs, array):
raise TypeError("Invalid parameter to binary function")
- safe_call(c_func(ct.pointer(out.arr), lhs.arr, other.arr, False))
+ safe_call(c_func(ct.pointer(out.arr), lhs.arr, other.arr, bcast.get()))
return out
@@ -79,7 +80,7 @@ def binary_funcr(lhs, rhs, c_func):
elif not isinstance(lhs, array):
raise TypeError("Invalid parameter to binary function")
- c_func(ct.pointer(out.arr), other.arr, rhs.arr, False)
+ c_func(ct.pointer(out.arr), other.arr, rhs.arr, bcast.get())
return out
diff --git a/arrayfire/broadcast.py b/arrayfire/broadcast.py
new file mode 100644
index 0000000..15e4228
--- /dev/null
+++ b/arrayfire/broadcast.py
@@ -0,0 +1,26 @@
+#######################################################
+# Copyright (c) 2015, ArrayFire
+# All rights reserved.
+#
+# This file is distributed under 3-clause BSD license.
+# The complete license agreement can be obtained at:
+# http://arrayfire.com/licenses/BSD-3-Clause
+########################################################
+
+
+class bcast(object):
+ _flag = False
+ def get():
+ return bcast._flag
+
+ def set(flag):
+ bcast._flag = flag
+
+ def toggle():
+ bcast._flag ^= True
+
+def broadcast(func, *args):
+ bcast.toggle()
+ res = func(*args)
+ bcast.toggle()
+ return res
diff --git a/tests/simple_arith.py b/tests/simple_arith.py
index b0aa017..e6941f1 100755
--- a/tests/simple_arith.py
+++ b/tests/simple_arith.py
@@ -188,3 +188,10 @@ af.display(af.lgamma(a))
af.display(af.iszero(a))
af.display(af.isinf(a/b))
af.display(af.isnan(a/a))
+
+a = af.randu(5, 1)
+b = af.randu(5, 5)
+c = af.broadcast(lambda x,y: x+y, a, b)
+af.display(a)
+af.display(b)
+af.display(c)
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/python-arrayfire.git
More information about the debian-science-commits
mailing list