[pyoperators] 01/04: Imported Upstream version 0.13.5

Ghislain Vaillant ghisvail-guest at moszumanska.debian.org
Thu Feb 12 19:16:07 UTC 2015


This is an automated email from the git hooks/post-receive script.

ghisvail-guest pushed a commit to branch master
in repository pyoperators.

commit 47d7871ef4e0174ec3a3af85b9579df2a096b70b
Author: Ghislain Antony Vaillant <ghisvail at gmail.com>
Date:   Thu Feb 12 19:06:27 2015 +0000

    Imported Upstream version 0.13.5
---
 hooks.py                            |  38 +++++++------
 pyoperators/config.py               |  35 ++++++------
 pyoperators/core.py                 |  84 +++++++++++++++++-----------
 pyoperators/iterative/algorithms.py |  24 ++++----
 pyoperators/linear.py               |   4 +-
 pyoperators/utils/misc.py           |  61 +++++++++++++-------
 pyoperators/utils/testing.py        |   6 +-
 test/test_utils.py                  | 107 ++++++++++++++++++++----------------
 8 files changed, 209 insertions(+), 150 deletions(-)

diff --git a/hooks.py b/hooks.py
index 31bcb6e..8b23583 100644
--- a/hooks.py
+++ b/hooks.py
@@ -52,7 +52,6 @@ import shutil
 import sys
 from distutils.command.clean import clean
 from numpy.distutils.command.build import build
-from numpy.distutils.command.build_ext import build_ext
 from numpy.distutils.command.build_src import build_src
 from numpy.distutils.command.sdist import sdist
 from numpy.distutils.core import Command
@@ -82,7 +81,11 @@ def get_cmdclass():
                            self.distribution.get_version())
             build.run(self)
 
-    class BuildExtCommand(build_ext):
+    class BuildSrcCommand(build_src):
+        def initialize_options(self):
+            build_src.initialize_options(self)
+            self.f2py_opts = '--quiet'
+
         def run(self):
             has_fortran = False
             has_cython = False
@@ -90,21 +93,20 @@ def get_cmdclass():
                 has_fortran = has_fortran or has_f_sources(ext.sources)
                 for isource, source in enumerate(ext.sources):
                     if source.endswith('.pyx'):
-                        if USE_CYTHON:
-                            has_cython = True
-                        else:
+                        if not USE_CYTHON:
                             ext.sources[isource] = source[:-3] + 'c'
-            if has_cython:
-                self.extensions = cythonize(self.extensions, force=True)
+                        else:
+                            has_cython = True
             if has_fortran:
                 with open(os.path.join(root, '.f2py_f2cmap'), 'w') as f:
                     f.write(repr(F2PY_TABLE))
-            build_ext.run(self)
-
-    class BuildSrcCommand(build_src):
-        def initialize_options(self):
-            build_src.initialize_options(self)
-            self.f2py_opts = '--quiet'
+            if has_cython:
+                build_dir = None if self.inplace else self.build_src
+                new_extensions = cythonize(self.extensions, force=True,
+                                           build_dir=build_dir)
+                for i in range(len(self.extensions)):
+                    self.extensions[i] = new_extensions[i]
+            build_src.run(self)
 
     class SDistCommand(sdist):
         def make_release_tree(self, base_dir, files):
@@ -112,8 +114,13 @@ def get_cmdclass():
                            self.distribution.get_version())
             initfile = os.path.join(self.distribution.get_name(),
                                     '__init__.py')
+            new_files = []
+            for f in files:
+                if f.endswith('.pyx'):
+                    new_files.append(f[:-3] + 'c')
             if initfile not in files:
-                files.append(initfile)
+                new_files.append(initfile)
+            files.extend(new_files)
             sdist.make_release_tree(self, base_dir, files)
 
     class CleanCommand(clean):
@@ -190,7 +197,6 @@ def get_cmdclass():
             pass
 
     return {'build': BuildCommand,
-            'build_ext': BuildExtCommand,
             'build_src': BuildSrcCommand,
             'clean': CleanCommand,
             'coverage': CoverageCommand,
@@ -340,7 +346,7 @@ def _get_version_git(default):
             suffix = 'pre'
     if name != '':
         name += '.'
-    return '{}{}{:02}-g{}{}'.format(name, suffix, rev, commit, dirty)
+    return '{}{}{:02}{}'.format(name, suffix, rev, dirty)
 
 
 def _get_version_init_file(name):
diff --git a/pyoperators/config.py b/pyoperators/config.py
index 2f38535..e509d03 100644
--- a/pyoperators/config.py
+++ b/pyoperators/config.py
@@ -1,10 +1,10 @@
-import os
-import site
-from .warnings import warn, PyOperatorsWarning
+import os as _os
+import site as _site
+from .warnings import warn as _warn, PyOperatorsWarning as _PyOperatorsWarning
 
 
 def getenv(key, default, cls):
-    val = os.getenv(key, '').strip()
+    val = _os.getenv(key, '').strip()
     if len(val) == 0:
         return cls(default)
     try:
@@ -12,25 +12,25 @@ def getenv(key, default, cls):
             val = int(val)
         val = cls(val)
     except ValueError:
-        warn("Invalid environment variable {0}='{1}'".format(key, val),
-             PyOperatorsWarning)
+        _warn("Invalid environment variable {0}='{1}'".format(key, val),
+              _PyOperatorsWarning)
         return cls(default)
     return val
 
 
 # PyOperators local path, used for example to store the FFTW wisdom files
-LOCAL_PATH = os.getenv('PYOPERATORS_PATH')
+LOCAL_PATH = _os.getenv('PYOPERATORS_PATH')
 if LOCAL_PATH is None:
-    LOCAL_PATH = os.path.join(site.USER_BASE, 'share', 'pyoperators')
-if not os.path.exists(LOCAL_PATH):
+    LOCAL_PATH = _os.path.join(_site.USER_BASE, 'share', 'pyoperators')
+if not _os.path.exists(LOCAL_PATH):
     try:
-        os.makedirs(LOCAL_PATH)
+        _os.makedirs(LOCAL_PATH)
     except IOError:
-        warn("User path '{0}' cannot be created.".format(LOCAL_PATH),
-             PyOperatorsWarning)
-elif not os.access(LOCAL_PATH, os.W_OK):
-    warn("User path '{0}' is not writable.".format(LOCAL_PATH),
-         PyOperatorsWarning)
+        _warn("User path '{0}' cannot be created.".format(LOCAL_PATH),
+              _PyOperatorsWarning)
+elif not _os.access(LOCAL_PATH, _os.W_OK):
+    _warn("User path '{0}' is not writable.".format(LOCAL_PATH),
+          _PyOperatorsWarning)
 
 # force garbage collection when deleted operators' nbytes exceed this
 # threshold.
@@ -42,7 +42,10 @@ MEMORY_ALIGNMENT = getenv('PYOPERATORS_MEMORY_ALIGNMENT', 32, int)
 # the requested size
 MEMORY_TOLERANCE = getenv('PYOPERATORS_MEMORY_TOLERANCE', 1.2, float)
 
+# on some supercomputers, importing mpi4py on a login node exits python without
+# raising an ImportError.
 NO_MPI = getenv('PYOPERATORS_NO_MPI', False, bool)
+
 VERBOSE = getenv('PYOPERATORS_VERBOSE', False, bool)
 
-#del os, site, PyOperatorsWarning, warn, getenv
+del getenv
diff --git a/pyoperators/core.py b/pyoperators/core.py
index 22ffa17..4525704 100644
--- a/pyoperators/core.py
+++ b/pyoperators/core.py
@@ -635,6 +635,8 @@ class Operator(object):
             target = Target()
         target.__class__ = self.__class__
         for k, v in self.__dict__.items():
+            if k in ('_C', '_T', '_H', '_I'):
+                continue
             if isinstance(v, types.MethodType) and v.__self__ is self:
                 target.__dict__[k] = types.MethodType(v.__func__, target)
             else:
@@ -663,7 +665,7 @@ class Operator(object):
         elif 'C' in rules:
             C = _copy_direct(self, rules['C'](self))
         else:
-            C = _copy_direct(
+            C = _copy_direct_all(
                 self, Operator(direct=self.conjugate,
                                name=self.__name__ + '.C',
                                flags={'linear': flags.linear,
@@ -689,7 +691,7 @@ class Operator(object):
         elif flags.orthogonal and 'I' in rules:
             T = _copy_reverse(self, rules['I'](self))
         elif self.transpose is not None:
-            T = _copy_reverse(
+            T = _copy_reverse_all(
                 self, Operator(direct=self.transpose,
                                name=self.__name__ + '.T', flags=new_flags))
         else:
@@ -706,7 +708,7 @@ class Operator(object):
         elif flags.unitary and 'I' in rules:
             H = _copy_reverse(self, rules['I'](self))
         elif self.adjoint is not None:
-            H = _copy_reverse(
+            H = _copy_reverse_all(
                 self, Operator(direct=self.adjoint,
                                name=self.__name__ + '.H', flags=new_flags))
         else:
@@ -717,17 +719,17 @@ class Operator(object):
                 if flags.real:
                     T = H
                 else:
-                    T = _copy_reverse(
+                    T = _copy_reverse_all(
                         self, Operator(direct=H.conjugate, name=
                                        self.__name__ + '.T', flags=new_flags))
             else:
-                T = _copy_reverse(
+                T = _copy_reverse_all(
                     self, Operator(name=self.__name__ + '.T', flags=new_flags))
                 if flags.real:
                     H = T
 
         if H is None:
-            H = _copy_reverse(
+            H = _copy_reverse_all(
                 self, Operator(direct=T.conjugate if T is not None else None,
                                name=self.__name__ + '.H', flags=new_flags))
 
@@ -740,7 +742,7 @@ class Operator(object):
         elif 'I' in rules:
             I = _copy_reverse(self, rules['I'](self))
         else:
-            I = _copy_reverse(
+            I = _copy_reverse_all(
                 self, Operator(direct=self.inverse,
                                name=self.__name__ + '.I',
                                flags={'linear': flags.linear,
@@ -771,7 +773,7 @@ class Operator(object):
                 func = I.conjugate
             else:
                 func = None
-            IC = _copy_reverse(
+            IC = _copy_reverse_all(
                 self, Operator(direct=func, name=self.__name__ + '.I.C',
                                flags=new_flags))
 
@@ -786,7 +788,7 @@ class Operator(object):
         elif 'IT' in rules:
             IT = _copy_direct(self, rules['IT'](self))
         elif self.inverse_transpose is not None:
-            IT = _copy_direct(
+            IT = _copy_direct_all(
                 self, Operator(direct=self.inverse_transpose,
                                name=self.__name__ + '.I.T', flags=new_flags))
         else:
@@ -807,7 +809,7 @@ class Operator(object):
         elif 'IH' in rules:
             IH = _copy_direct(self, rules['IH'](self))
         elif self.inverse_adjoint is not None:
-            IH = _copy_direct(
+            IH = _copy_direct_all(
                 self, Operator(direct=self.inverse_adjoint,
                                name=self.__name__ + '.I.H', flags=new_flags))
         else:
@@ -818,19 +820,19 @@ class Operator(object):
                 if flags.real:
                     IT = IH
                 else:
-                    IT = _copy_direct(
+                    IT = _copy_direct_all(
                         self, Operator(direct=IH.conjugate,
                                        name=self.__name__ + '.I.T',
                                        flags=new_flags))
             else:
-                IT = _copy_direct(
+                IT = _copy_direct_all(
                     self, Operator(name=self.__name__ + '.I.T',
                                    flags=new_flags))
                 if flags.real:
                     IH = IT
 
         if IH is None:
-            IH = _copy_direct(
+            IH = _copy_direct_all(
                 self, Operator(direct=IT.conjugate if IT is not None else None,
                                name=self.__name__ + '.I.H', flags=new_flags))
 
@@ -1263,7 +1265,7 @@ class Operator(object):
             not isinstance(other, np.matrix)):
                 return self(other)
         try:
-            other = asoperator(other)
+            other = asoperator(other, constant=not self.flags.linear)
         except TypeError:
             return NotImplemented
         if not self.flags.linear or not other.flags.linear:
@@ -1280,7 +1282,7 @@ class Operator(object):
             not isinstance(other, np.matrix)):
                 return self.T(other)
         try:
-            other = asoperator(other)
+            other = asoperator(other, constant=not self.flags.linear)
         except TypeError:
             return NotImplemented
         if not self.flags.linear or not other.flags.linear:
@@ -4160,28 +4162,48 @@ class VariableTranspose(Operator):
 
 def _copy_direct(source, target):
     keywords = {}
-    for attr in OPERATOR_ATTRIBUTES:
-        if attr != 'flags':
-            v = getattr(source, attr)
-            if attr in ('reshapein', 'reshapeout', 'toshapein', 'toshapeout',
-                        'validatein', 'validateout'):
-                if v == getattr(Operator, attr).__get__(source, type(source)):
-                    continue
-            keywords[attr] = v
+    for attr in set(OPERATOR_ATTRIBUTES) - {
+            'flags', 'reshapein', 'reshapeout', 'toshapein', 'toshapeout',
+            'validatein', 'validateout'}:
+        v = getattr(source, attr)
+        keywords[attr] = v
+    Operator.__init__(target, **keywords)
+    return target
+
+
+def _copy_direct_all(source, target):
+    keywords = {}
+    for attr in set(OPERATOR_ATTRIBUTES) - {'flags'}:
+        v = getattr(source, attr)
+        if attr in ('reshapein', 'reshapeout', 'toshapein', 'toshapeout',
+                    'validatein', 'validateout'):
+            if v == getattr(Operator, attr).__get__(source, type(source)):
+                continue
+        keywords[attr] = v
     Operator.__init__(target, **keywords)
     return target
 
 
 def _copy_reverse(source, target):
     keywords = {}
-    for attr in OPERATOR_ATTRIBUTES:
-        if attr != 'flags':
-            v = getattr(source, attr)
-            if attr in ('reshapein', 'reshapeout', 'toshapein', 'toshapeout',
-                        'validatein', 'validateout'):
-                if v == getattr(Operator, attr).__get__(source, type(source)):
-                    continue
-            keywords[_swap_inout(attr)] = v
+    for attr in set(OPERATOR_ATTRIBUTES) - {
+            'flags', 'reshapein', 'reshapeout', 'toshapein', 'toshapeout',
+            'validatein', 'validateout'}:
+        v = getattr(source, attr)
+        keywords[_swap_inout(attr)] = v
+    Operator.__init__(target, **keywords)
+    return target
+
+
+def _copy_reverse_all(source, target):
+    keywords = {}
+    for attr in set(OPERATOR_ATTRIBUTES) - {'flags'}:
+        v = getattr(source, attr)
+        if attr in ('reshapein', 'reshapeout', 'toshapein', 'toshapeout',
+                    'validatein', 'validateout'):
+            if v == getattr(Operator, attr).__get__(source, type(source)):
+                continue
+        keywords[_swap_inout(attr)] = v
     Operator.__init__(target, **keywords)
     return target
 
diff --git a/pyoperators/iterative/algorithms.py b/pyoperators/iterative/algorithms.py
index 8317285..97afd98 100644
--- a/pyoperators/iterative/algorithms.py
+++ b/pyoperators/iterative/algorithms.py
@@ -3,11 +3,6 @@ Implements iterative algorithm class.
 """
 import numpy as np
 from copy import copy
-try:
-    import pylab
-except:
-    pass
-
 from .linesearch import optimal_step
 from .criterions import norm2, quadratic_criterion, huber_criterion
 
@@ -196,7 +191,7 @@ class Callback(object):
             are stored with numpy savez function.
         shape: 2-tuple
             Shape of the solution.
-            If not empty tuple, pylab plot or imshow are called to display
+            If not empty tuple, plot or imshow are called to display
             current solution (solution should be 1D or 2D).
 
         Returns
@@ -224,22 +219,23 @@ class Callback(object):
                 }
             np.savez(self.savefile, **var_dict)
     def imshow(self, algo):
+        import matplotlib.pyplot as mp
         if algo.iter_ == 1:
-            self.im = pylab.imshow(algo.current_solution.reshape(self.shape))
+            self.im = mp.imshow(algo.current_solution.reshape(self.shape))
         else:
             self.im.set_data(algo.current_solution.reshape(self.shape))
-        pylab.draw()
-        pylab.show()
+        mp.draw()
+        mp.show()
     def plot(self, algo):
-        import pylab
+        import matplotlib.pyplot as mp
         if algo.iter_ == 1:
-            self.im = pylab.plot(algo.current_solution)[0]
+            self.im = mp.plot(algo.current_solution)[0]
         else:
             y = algo.current_solution
             self.im.set_ydata(y)
-            pylab.ylim((y.min(), y.max()))
-        pylab.draw()
-        pylab.show()
+            mp.ylim((y.min(), y.max()))
+        mp.draw()
+        mp.show()
     def __call__(self, algo):
         if self.verbose:
             self.print_status(algo)
diff --git a/pyoperators/linear.py b/pyoperators/linear.py
index 88a0dbf..910d175 100644
--- a/pyoperators/linear.py
+++ b/pyoperators/linear.py
@@ -22,7 +22,7 @@ from .flags import (
 from .memory import empty
 from .utils import (
     broadcast_shapes, cast, complex_dtype, float_dtype, inspect_special_values,
-    isalias, izip_broadcast, pi, product, strshape, tointtuple, ufuncs)
+    isalias, pi, product, strshape, tointtuple, ufuncs, zip_broadcast)
 from .warnings import warn, PyOperatorsWarning
 
 __all__ = [
@@ -1435,7 +1435,7 @@ class SymmetricBandToeplitzOperator(Operator):
                 self.fplan.update_arrays(rbuffer, cbuffer)
                 self.bplan.update_arrays(cbuffer, rbuffer)
 
-                for x_, out_, kernel in izip_broadcast(x, out, self.kernel):
+                for x_, out_, kernel in zip_broadcast(x, out, self.kernel):
                     rbuffer[:lpad] = 0
                     rbuffer[lpad:lpad+self.nsamples] = x_
                     rbuffer[lpad+self.nsamples:] = 0
diff --git a/pyoperators/utils/misc.py b/pyoperators/utils/misc.py
index 74ca54a..049f012 100644
--- a/pyoperators/utils/misc.py
+++ b/pyoperators/utils/misc.py
@@ -67,7 +67,8 @@ __all__ = ['all_eq',
            'tointtuple',
            'uninterruptible',
            'uninterruptible_if',
-           'zero']
+           'zero',
+           'zip_broadcast']
 
 
 # decorators
@@ -505,21 +506,6 @@ def isscalarlike(x):
     return np.isscalar(x) or isinstance(x, np.ndarray) and x.ndim == 0
 
 
-def izip_broadcast(*args):
-    """
-    Like izip, except that arguments which are containers of length 1 are
-    repeated.
-
-    """
-    def wrap(a):
-        if hasattr(a, '__len__') and len(a) == 1:
-            return itertools.repeat(a[0])
-        return a
-    if any(not hasattr(a, '__len__') or len(a) != 1 for a in args):
-        args = [wrap(arg) for arg in args]
-    return zip(*args)
-
-
 def last(l, f):
     """
     Return last item in list that verifies a certain condition, or raise
@@ -775,13 +761,13 @@ def settingerr(*args, **keywords):
 
 def split(n, m, rank=None):
     """
-    Return an iterator through the slices that partition a list of n elements
-    in m almost same-size groups. If a rank is provided, only the slice
-    for the rank is returned.
+    Iterate through the slices that partition a list of n elements in m almost
+    same-size groups. If a rank is provided, only the slice for the rank
+    is returned.
 
     Example
     -------
-    >>> split(1000, 2)
+    >>> tuple(split(1000, 2))
     (slice(0, 500, None), slice(500, 1000, None))
     >>> split(1000, 2, 1)
     slice(500, 1000, None)
@@ -801,7 +787,7 @@ def split(n, m, rank=None):
             start += work
             rank += 1
 
-    return tuple(generator())
+    return generator()
 
 
 def strelapsed(t0, msg='Elapsed time'):
@@ -1085,3 +1071,36 @@ def uninterruptible_if(condition):
 def zero(dtype):
     """ Return 0 with a given dtype. """
     return np.zeros((), dtype=dtype)[()]
+
+
+ at deprecated("use 'zip_broadcast' instead.")
+def izip_broadcast(*args):
+    return zip_broadcast(*args)
+
+
+def zip_broadcast(*args, **keywords):
+    """
+    zip_broadcast(seq1 [, seq2 [...], iter_str=False|True]) ->
+        [(seq1[0], seq2[0] ...), (...)]
+
+    Like zip, except that arguments which are non iterable or containers
+    of length 1 are repeated. If the keyword iter_str is False, string
+    arguments are, unlike zip, not considered as iterable (default is True).
+
+    """
+    if len(keywords) > 1 or len(keywords) == 1 and 'iter_str' not in keywords:
+        raise TypeError('Invalid keyword(s).')
+    iter_str = keywords.get('iter_str', True)
+    n = max(1 if not isinstance(_, collections.Iterable) or
+            isinstance(_, str) and not iter_str
+            else len(_) if hasattr(_, '__len__') else sys.maxint for _ in args)
+
+    def wrap(a):
+        if not isinstance(a, collections.Iterable) or \
+           isinstance(a, str) and not iter_str:
+            return itertools.repeat(a, n)
+        if hasattr(a, '__len__') and len(a) == 1:
+            return itertools.repeat(a[0], n)
+        return a
+    args = [wrap(arg) for arg in args]
+    return zip(*args)
diff --git a/pyoperators/utils/testing.py b/pyoperators/utils/testing.py
index ad4a379..49899eb 100644
--- a/pyoperators/utils/testing.py
+++ b/pyoperators/utils/testing.py
@@ -1,9 +1,6 @@
-import collections
 import functools
 import numpy as np
 from collections import Container, Mapping
-
-from nose.plugins.skip import SkipTest
 from numpy.testing import assert_equal, assert_allclose
 
 from .misc import settingerr, strenum
@@ -227,6 +224,7 @@ assert_raises.__doc__ = np.testing.assert_raises.__doc__
 
 
 def skiptest(func):
+    from nose.plugins.skip import SkipTest
     @functools.wraps(func)
     def _():
         raise SkipTest()
@@ -234,6 +232,7 @@ def skiptest(func):
 
 
 def skiptest_if(condition):
+    from nose.plugins.skip import SkipTest
     def decorator(func):
         @functools.wraps(func)
         def _():
@@ -245,6 +244,7 @@ def skiptest_if(condition):
 
 
 def skiptest_unless_module(module):
+    from nose.plugins.skip import SkipTest
     def decorator(func):
         @functools.wraps(func)
         def _():
diff --git a/test/test_utils.py b/test/test_utils.py
index 79f1a08..70d0df1 100644
--- a/test/test_utils.py
+++ b/test/test_utils.py
@@ -9,7 +9,7 @@ from pyoperators import Operator
 from pyoperators.utils import (
     broadcast_shapes, cast, complex_dtype, first, first_is_not, float_dtype,
     ifirst, ifirst_is_not, ilast, ilast_is_not, groupbykey,
-    inspect_special_values, interruptible, isscalarlike, izip_broadcast, last,
+    inspect_special_values, interruptible, isscalarlike, zip_broadcast, last,
     last_is_not, least_greater_multiple, one, omp_num_threads, pi,
     pool_threading, product, reshape_broadcast, setting, settingerr, split,
     strenum, strplural, strshape, Timer, uninterruptible, zero)
@@ -239,51 +239,6 @@ def test_is_not_scalar():
         yield func, x
 
 
-def test_izip_broadcast1():
-    def g():
-        i = 0
-        while True:
-            yield i
-            i += 1
-    a = [1]
-    b = (np.sin,)
-    c = np.arange(3).reshape((1, 3))
-    d = ('x', 'y', [])
-    e = ['a', 'b', 'c']
-    f = np.arange(6).reshape((3, 2))
-
-    aa = []; bb = []; cc = []; dd = []; ee = []; ff = []; gg = []
-    for a_, b_, c_, d_, e_, f_, g_ in izip_broadcast(a, b, c, d, e, f, g()):
-        aa.append(a_)
-        bb.append(b_)
-        cc.append(c_)
-        dd.append(d_)
-        ee.append(e_)
-        ff.append(f_)
-        gg.append(g_)
-    assert_eq(aa, 3 * a)
-    assert_eq(bb, list(3 * b))
-    assert_eq(cc, [[0, 1, 2], [0, 1, 2], [0, 1, 2]])
-    assert_eq(dd, list(_ for _ in d))
-    assert_eq(ee, list(_ for _ in e))
-    assert_eq(ff, list(_ for _ in f))
-    assert_eq(gg, [0, 1, 2])
-
-
-def test_izip_broadcast2():
-    a = [1]
-    b = (np.sin,)
-    c = np.arange(3).reshape((1, 3))
-    aa = []; bb = []; cc = []
-    for a_, b_, c_ in izip_broadcast(a, b, c):
-        aa.append(a_)
-        bb.append(b_)
-        cc.append(c_)
-    assert_eq(aa, a)
-    assert_eq(tuple(bb), b)
-    assert_eq(cc, c)
-
-
 def test_least_greater_multiple():
     def func(lgm, expected):
         assert_eq(lgm, expected)
@@ -417,7 +372,7 @@ def test_settingerr():
 
 def test_split():
     def func(n, m):
-        slices = split(n, m)
+        slices = tuple(split(n, m))
         assert_eq(len(slices), m)
         x = np.zeros(n, int)
         for s in slices:
@@ -527,3 +482,61 @@ def test_timer3():
     time.sleep(0.01)
     assert_equal(t._level, 0)
     assert abs(t.elapsed - 0.01) < 0.001
+
+
+def test_zip_broadcast1():
+    def g():
+        i = 0
+        while True:
+            yield i
+            i += 1
+    a = [1]
+    b = (np.sin,)
+    c = np.arange(3).reshape((1, 3))
+    d = ('x', 'y', [])
+    e = ['a', 'b', 'c']
+    f = np.arange(6).reshape((3, 2))
+    h = 3
+
+    aa = []; bb = []; cc = []; dd = []; ee = []; ff = []; gg = []; hh = []
+    for a_, b_, c_, d_, e_, f_, g_, h_ in zip_broadcast(
+            a, b, c, d, e, f, g(), h):
+        aa.append(a_)
+        bb.append(b_)
+        cc.append(c_)
+        dd.append(d_)
+        ee.append(e_)
+        ff.append(f_)
+        gg.append(g_)
+        hh.append(h_)
+    assert_eq(aa, 3 * a)
+    assert_eq(bb, list(3 * b))
+    assert_eq(cc, [[0, 1, 2], [0, 1, 2], [0, 1, 2]])
+    assert_eq(dd, list(_ for _ in d))
+    assert_eq(ee, list(_ for _ in e))
+    assert_eq(ff, list(_ for _ in f))
+    assert_eq(gg, [0, 1, 2])
+    assert_eq(hh, [3, 3, 3])
+
+
+def test_zip_broadcast2():
+    a = [1]
+    b = (np.sin,)
+    c = np.arange(3).reshape((1, 3))
+    aa = []; bb = []; cc = []
+    for a_, b_, c_ in zip_broadcast(a, b, c):
+        aa.append(a_)
+        bb.append(b_)
+        cc.append(c_)
+    assert_eq(aa, a)
+    assert_eq(tuple(bb), b)
+    assert_eq(cc, c)
+
+
+def test_zip_broadcast3():
+    a = 'abc'
+    b = [1, 2, 3]
+    assert_eq(tuple(zip_broadcast(a, b)), tuple(zip(a, b)))
+    assert_eq(tuple(zip_broadcast(a, b, iter_str=True)), tuple(zip(a, b)))
+    assert_eq(tuple(zip_broadcast(a, b, iter_str=False)),
+              (('abc', 1), ('abc', 2), ('abc', 3)))

-- 
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/pyoperators.git



More information about the debian-science-commits mailing list