[python-arrayfire] 12/250: Adding operator overloading to the array class

Ghislain Vaillant ghisvail-guest at moszumanska.debian.org
Mon Mar 28 22:59:25 UTC 2016


This is an automated email from the git hooks/post-receive script.

ghisvail-guest pushed a commit to branch debian/master
in repository python-arrayfire.

commit be27746a7f681b2f5e029a70188420315c98277c
Author: Pavan Yalamanchili <pavan at arrayfire.com>
Date:   Fri Jun 19 18:26:15 2015 -0400

    Adding operator overloading to the array class
---
 arrayfire/array.py | 211 ++++++++++++++++++++++++++++++++++++++++++++++++++++-
 arrayfire/data.py  |  25 +------
 arrayfire/util.py  |  12 +++
 3 files changed, 223 insertions(+), 25 deletions(-)

diff --git a/arrayfire/array.py b/arrayfire/array.py
index 61dfdd6..f3d68f4 100644
--- a/arrayfire/array.py
+++ b/arrayfire/array.py
@@ -1,6 +1,7 @@
 import array as host
 from .library import *
-from .util import dim4
+from .util import *
+from .data import *
 
 def create_array(buf, numdims, idims, dtype):
     out_arr = c_longlong(0)
@@ -8,6 +9,66 @@ def create_array(buf, numdims, idims, dtype):
     clib.af_create_array(pointer(out_arr), c_longlong(buf), numdims, pointer(c_dims), dtype)
     return out_arr
 
+def constant_array(val, d0, d1=None, d2=None, d3=None, dtype=f32):
+    if not isinstance(dtype, c_int):
+        raise TypeError("Invalid dtype")
+
+    out = c_longlong(0)
+    dims = dim4(d0, d1, d2, d3)
+
+    if isinstance(val, complex):
+        c_real = c_double(val.real)
+        c_imag = c_double(val.imag)
+
+        if (dtype != c32 and dtype != c64):
+            dtype = c32
+
+        clib.af_constant_complex(pointer(out), c_real, c_imag, 4, pointer(dims), dtype)
+    elif dtype == s64:
+        c_val = c_longlong(val.real)
+        clib.af_constant_long(pointer(out), c_val, 4, pointer(dims))
+    elif dtype == u64:
+        c_val = c_ulonglong(val.real)
+        clib.af_constant_ulong(pointer(out), c_val, 4, pointer(dims))
+    else:
+        c_val = c_double(val)
+        clib.af_constant(pointer(out), c_val, 4, pointer(dims), dtype)
+
+    return out
+
+
+def binary_func(lhs, rhs, c_func):
+    out = array()
+    other = rhs
+
+    if (is_valid_scalar(rhs)):
+        ldims = dim4_tuple(lhs.dims())
+        lty = lhs.type()
+        other = array()
+        other.arr = constant_array(rhs, ldims[0], ldims[1], ldims[2], ldims[3], lty)
+    elif not isinstance(rhs, array):
+        TypeError("Invalid parameter to binary function")
+
+    c_func(pointer(out.arr), lhs.arr, other.arr, False)
+
+    return out
+
+def binary_funcr(lhs, rhs, c_func):
+    out = array()
+    other = lhs
+
+    if (is_valid_scalar(lhs)):
+        rdims = dim4_tuple(rhs.dims())
+        rty = rhs.type()
+        other = array()
+        other.arr = constant_array(lhs, rdims[0], rdims[1], rdims[2], rdims[3], rty)
+    elif not isinstance(lhs, array):
+        TypeError("Invalid parameter to binary function")
+
+    c_func(pointer(out.arr), other.arr, rhs.arr, False)
+
+    return out
+
 class array(object):
 
     def __init__(self, src=None, dims=(0,)):
@@ -65,3 +126,151 @@ class array(object):
     def __del__(self):
         if (self.arr.value != 0):
             clib.af_release_array(self.arr)
+
+    def numdims(self):
+        nd = c_uint(0)
+        clib.af_get_numdims(pointer(nd), self.arr)
+        return nd.value
+
+    def dims(self):
+        d0 = c_longlong(0)
+        d1 = c_longlong(0)
+        d2 = c_longlong(0)
+        d3 = c_longlong(0)
+        clib.af_get_dims(pointer(d0), pointer(d1), pointer(d2), pointer(d3), self.arr)
+        dims = (d0.value,d1.value,d2.value,d3.value)
+        return dims[:self.numdims()]
+
+    def type(self):
+        dty = f32
+        clib.af_get_type(pointer(dty), self.arr)
+        return dty
+
+    def __add__(self, other):
+        return binary_func(self, other, clib.af_add)
+
+    def __iadd__(self, other):
+        self = binary_func(self, other, clib.af_add)
+        return self
+
+    def __radd__(self, other):
+        return binary_funcr(other, self, clib.af_add)
+
+    def __sub__(self, other):
+        return binary_func(self, other, clib.af_sub)
+
+    def __isub__(self, other):
+        self = binary_func(self, other, clib.af_sub)
+        return self
+
+    def __rsub__(self, other):
+        return binary_funcr(other, self, clib.af_sub)
+
+    def __mul__(self, other):
+        return binary_func(self, other, clib.af_mul)
+
+    def __imul__(self, other):
+        self = binary_func(self, other, clib.af_mul)
+        return self
+
+    def __rmul__(self, other):
+        return binary_funcr(other, self, clib.af_mul)
+
+    def __truediv__(self, other):
+        return binary_func(self, other, clib.af_div)
+
+    def __itruediv__(self, other):
+        self =  binary_func(self, other, clib.af_div)
+        return self
+
+    def __rtruediv__(self, other):
+        return binary_funcr(other, self, clib.af_div)
+
+    def __mod__(self, other):
+        return binary_func(self, other, clib.af_mod)
+
+    def __imod__(self, other):
+        self =  binary_func(self, other, clib.af_mod)
+        return self
+
+    def __rmod__(self, other):
+        return binary_funcr(other, self, clib.af_mod)
+
+    def __pow__(self, other):
+        return binary_func(self, other, clib.af_pow)
+
+    def __ipow__(self, other):
+        self =  binary_func(self, other, clib.af_pow)
+        return self
+
+    def __rpow__(self, other):
+        return binary_funcr(other, self, clib.af_pow)
+
+    def __lt__(self, other):
+        return binary_func(self, other, clib.af_lt)
+
+    def __gt__(self, other):
+        return binary_func(self, other, clib.af_gt)
+
+    def __le__(self, other):
+        return binary_func(self, other, clib.af_le)
+
+    def __ge__(self, other):
+        return binary_func(self, other, clib.af_ge)
+
+    def __eq__(self, other):
+        return binary_func(self, other, clib.af_eq)
+
+    def __ne__(self, other):
+        return binary_func(self, other, clib.af_neq)
+
+    def __and__(self, other):
+        return binary_func(self, other, clib.af_bitand)
+
+    def __iand__(self, other):
+        self = binary_func(self, other, clib.af_bitand)
+        return self
+
+    def __or__(self, other):
+        return binary_func(self, other, clib.af_bitor)
+
+    def __ior__(self, other):
+        self = binary_func(self, other, clib.af_bitor)
+        return self
+
+    def __xor__(self, other):
+        return binary_func(self, other, clib.af_bitxor)
+
+    def __ixor__(self, other):
+        self = binary_func(self, other, clib.af_bitxor)
+        return self
+
+    def __lshift__(self, other):
+        return binary_func(self, other, clib.af_bitshiftl)
+
+    def __ilshift__(self, other):
+        self = binary_func(self, other, clib.af_bitshiftl)
+        return self
+
+    def __rshift__(self, other):
+        return binary_func(self, other, clib.af_bitshiftr)
+
+    def __irshift__(self, other):
+        self = binary_func(self, other, clib.af_bitshiftr)
+        return self
+
+    def __neg__(self):
+        return 0 - self
+
+    def __pos__(self):
+        return self
+
+    def __invert__(self):
+        return self == 0
+
+    def __nonzero__(self):
+        return self != 0
+
+    # TODO:
+    # def __abs__(self):
+    #     return self
diff --git a/arrayfire/data.py b/arrayfire/data.py
index b462992..b4ad706 100644
--- a/arrayfire/data.py
+++ b/arrayfire/data.py
@@ -3,32 +3,9 @@ from .library import *
 from .array import *
 from .util import *
 
-
 def constant(val, d0, d1=None, d2=None, d3=None, dtype=f32):
-
-    if not isinstance(dtype, c_int):
-        raise TypeError("Invalid dtype")
-
     out = array()
-    dims = dim4(d0, d1, d2, d3)
-
-    if isinstance(val, complex):
-        c_real = c_double(val.real)
-        c_imag = c_double(val.imag)
-
-        if (dtype != c32 and dtype != c64):
-            dtype = c32
-
-        clib.af_constant_complex(pointer(out.arr), c_real, c_imag, 4, pointer(dims), dtype)
-    elif dtype == s64:
-        c_val = c_longlong(val.real)
-        clib.af_constant_long(pointer(out.arr), c_val, 4, pointer(dims))
-    elif dtype == u64:
-        c_val = c_ulonglong(val.real)
-        clib.af_constant_ulong(pointer(out.arr), c_val, 4, pointer(dims))
-    else:
-        c_val = c_double(val)
-        clib.af_constant(pointer(out.arr), c_val, 4, pointer(dims), dtype)
+    out.arr = constant_array(val, d0, d1, d2, d3, dtype)
     return out
 
 # Store builtin range function to be used later
diff --git a/arrayfire/util.py b/arrayfire/util.py
index cf0321d..0f2ed8f 100644
--- a/arrayfire/util.py
+++ b/arrayfire/util.py
@@ -10,5 +10,17 @@ def dim4(d0=1, d1=1, d2=1, d3=1):
 
     return out
 
+def dim4_tuple(dims):
+    assert(isinstance(dims, tuple))
+    out = [1]*4
+
+    for i, dim in enumerate(dims):
+        out[i] = dim
+
+    return tuple(out)
+
 def print_array(a):
     clib.af_print_array(a.arr)
+
+def is_valid_scalar(a):
+    return isinstance(a, float) or isinstance(a, int) or isinstance(a, complex)

-- 
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/python-arrayfire.git



More information about the debian-science-commits mailing list