[python-arrayfire] 24/250: Adding tests for array class

Ghislain Vaillant ghisvail-guest at moszumanska.debian.org
Mon Mar 28 22:59:27 UTC 2016


This is an automated email from the git hooks/post-receive script.

ghisvail-guest pushed a commit to branch debian/master
in repository python-arrayfire.

commit 8d9b0b956e9e430f7730b6e95c52aedd7554f51a
Author: Pavan Yalamanchili <pavan at arrayfire.com>
Date:   Mon Jun 22 13:26:03 2015 -0400

    Adding tests for array class
    
    - Also changed checks for data types in data creation
---
 arrayfire/array.py    |  8 ++++++--
 arrayfire/data.py     | 26 +++++++++++++++++++++-----
 tests/simple_array.py | 15 +++++++++++++++
 3 files changed, 42 insertions(+), 7 deletions(-)

diff --git a/arrayfire/array.py b/arrayfire/array.py
index 7ba83dc..fe4407c 100644
--- a/arrayfire/array.py
+++ b/arrayfire/array.py
@@ -11,8 +11,12 @@ def create_array(buf, numdims, idims, dtype):
     return out_arr
 
 def constant_array(val, d0, d1=None, d2=None, d3=None, dtype=f32):
+
     if not isinstance(dtype, c_int):
-        raise TypeError("Invalid dtype")
+        if isinstance(dtype, int):
+            dtype = c_int(dtype)
+        else:
+            raise TypeError("Invalid dtype")
 
     out = c_longlong(0)
     dims = dim4(d0, d1, d2, d3)
@@ -130,7 +134,7 @@ class array(object):
     def type(self):
         dty = c_int(f32.value)
         safe_call(clib.af_get_type(pointer(dty), self.arr))
-        return dty
+        return dty.value
 
     def __add__(self, other):
         return binary_func(self, other, clib.af_add)
diff --git a/arrayfire/data.py b/arrayfire/data.py
index a1e91fe..b480839 100644
--- a/arrayfire/data.py
+++ b/arrayfire/data.py
@@ -12,8 +12,12 @@ def constant(val, d0, d1=None, d2=None, d3=None, dtype=f32):
 brange = range
 
 def range(d0, d1=None, d2=None, d3=None, dim=-1, dtype=f32):
+
     if not isinstance(dtype, c_int):
-        raise TypeError("Invalid dtype")
+        if isinstance(dtype, int):
+            dtype = c_int(dtype)
+        else:
+            raise TypeError("Invalid dtype")
 
     out = array()
     dims = dim4(d0, d1, d2, d3)
@@ -24,7 +28,10 @@ def range(d0, d1=None, d2=None, d3=None, dim=-1, dtype=f32):
 
 def iota(d0, d1=None, d2=None, d3=None, dim=-1, tile_dims=None, dtype=f32):
     if not isinstance(dtype, c_int):
-        raise TypeError("Invalid dtype")
+        if isinstance(dtype, int):
+            dtype = c_int(dtype)
+        else:
+            raise TypeError("Invalid dtype")
 
     out = array()
     dims = dim4(d0, d1, d2, d3)
@@ -42,7 +49,10 @@ def iota(d0, d1=None, d2=None, d3=None, dim=-1, tile_dims=None, dtype=f32):
 def randu(d0, d1=None, d2=None, d3=None, dtype=f32):
 
     if not isinstance(dtype, c_int):
-        raise TypeError("Invalid dtype")
+        if isinstance(dtype, int):
+            dtype = c_int(dtype)
+        else:
+            raise TypeError("Invalid dtype")
 
     out = array()
     dims = dim4(d0, d1, d2, d3)
@@ -53,7 +63,10 @@ def randu(d0, d1=None, d2=None, d3=None, dtype=f32):
 def randn(d0, d1=None, d2=None, d3=None, dtype=f32):
 
     if not isinstance(dtype, c_int):
-        raise TypeError("Invalid dtype")
+        if isinstance(dtype, int):
+            dtype = c_int(dtype)
+        else:
+            raise TypeError("Invalid dtype")
 
     out = array()
     dims = dim4(d0, d1, d2, d3)
@@ -72,7 +85,10 @@ def get_seed():
 def identity(d0, d1=None, d2=None, d3=None, dtype=f32):
 
     if not isinstance(dtype, c_int):
-        raise TypeError("Invalid dtype")
+        if isinstance(dtype, int):
+            dtype = c_int(dtype)
+        else:
+            raise TypeError("Invalid dtype")
 
     out = array()
     dims = dim4(d0, d1, d2, d3)
diff --git a/tests/simple_array.py b/tests/simple_array.py
new file mode 100644
index 0000000..8d5001b
--- /dev/null
+++ b/tests/simple_array.py
@@ -0,0 +1,15 @@
+#!/usr/bin/python
+import arrayfire as af
+import array as host
+
+a = af.array([1, 2, 3])
+af.print_array(a)
+print(a.numdims(), a.dims(), a.type())
+
+a = af.array(host.array('d', [4, 5, 6]))
+af.print_array(a)
+print(a.numdims(), a.dims(), a.type())
+
+a = af.array(host.array('l', [7, 8, 9] * 4), (2, 5))
+af.print_array(a)
+print(a.numdims(), a.dims(), a.type())

-- 
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/python-arrayfire.git



More information about the debian-science-commits mailing list