[med-svn] [python-mne] 294/376: factoring joblib parallel code

Yaroslav Halchenko debian at onerussian.com
Fri Nov 27 17:23:08 UTC 2015


This is an automated email from the git hooks/post-receive script.

yoh pushed a commit to annotated tag v0.1
in repository python-mne.

commit bda9bd53e3382d342c66ad5d6548e98e9cd82515
Author: Alexandre Gramfort <alexandre.gramfort at inria.fr>
Date:   Wed Jun 8 12:16:05 2011 -0400

    factoring joblib parallel code
---
 examples/stats/plot_cluster_stats_evoked.py |  3 +-
 mne/minimum_norm/time_frequency.py          | 19 +-----
 mne/parallel.py                             | 50 +++++++++++++++
 mne/stats/cluster_level.py                  | 95 ++++++++++++++++++-----------
 mne/stats/permutations.py                   | 20 +-----
 mne/time_frequency/tfr.py                   | 27 +++-----
 6 files changed, 123 insertions(+), 91 deletions(-)

diff --git a/examples/stats/plot_cluster_stats_evoked.py b/examples/stats/plot_cluster_stats_evoked.py
index 34a4ad1..63d344c 100644
--- a/examples/stats/plot_cluster_stats_evoked.py
+++ b/examples/stats/plot_cluster_stats_evoked.py
@@ -58,7 +58,8 @@ condition2 = condition2[:, 0, :]  # take only one channel to get a 2D array
 threshold = 6.0
 T_obs, clusters, cluster_p_values, H0 = \
                 permutation_cluster_test([condition1, condition2],
-                            n_permutations=1000, threshold=threshold, tail=1)
+                            n_permutations=1000, threshold=threshold, tail=1,
+                            n_jobs=2)
 
 ###############################################################################
 # Plot
diff --git a/mne/minimum_norm/time_frequency.py b/mne/minimum_norm/time_frequency.py
index 1262669..b36adce 100644
--- a/mne/minimum_norm/time_frequency.py
+++ b/mne/minimum_norm/time_frequency.py
@@ -10,6 +10,7 @@ from ..source_estimate import SourceEstimate
 from ..time_frequency.tfr import cwt, morlet
 from ..baseline import rescale
 from .inverse import combine_xyz, prepare_inverse_operator
+from ..parallel import parallel_func
 
 
 def _compute_power(data, K, sel, Ws, source_ori, use_fft, Vh):
@@ -89,23 +90,7 @@ def source_induced_power(epochs, inverse_operator, bands, lambda2=1.0 / 9.0,
         Number of jobs to run in parallel
     """
 
-    if n_jobs == -1:
-        try:
-            import multiprocessing
-            n_jobs = multiprocessing.cpu_count()
-        except ImportError:
-            print "multiprocessing not installed. Cannot run in parallel."
-            n_jobs = 1
-
-    try:
-        from scikits.learn.externals.joblib import Parallel, delayed
-        parallel = Parallel(n_jobs)
-        my_compute_power = delayed(_compute_power)
-    except ImportError:
-        print "joblib not installed. Cannot run in parallel."
-        n_jobs = 1
-        my_compute_power = _compute_power
-        parallel = list
+    parallel, my_compute_power, n_jobs = parallel_func(_compute_power, n_jobs)
 
     #
     #   Set up the inverse according to the parameters
diff --git a/mne/parallel.py b/mne/parallel.py
new file mode 100644
index 0000000..12e4164
--- /dev/null
+++ b/mne/parallel.py
@@ -0,0 +1,50 @@
+"""Parralle util function
+"""
+
+# Author: Alexandre Gramfort <gramfort at nmr.mgh.harvard.edu>
+#
+# License: Simplified BSD
+
+
+def parallel_func(func, n_jobs, verbose=5):
+    """Return parallel instance with delayed function
+
+    Util function to use joblib only if available
+
+    Parameters
+    ----------
+    func: callable
+        A function
+    n_jobs: int
+        Number of jobs to run in parallel
+    verbose: int
+        Verbosity level
+
+    Returns
+    -------
+    parallel: instance of joblib.Parallel or list
+        The parallel object
+    my_func: callable
+        func if not parallel or delayed(func)
+    n_jobs: int
+        Number of jobs >= 0
+    """
+    try:
+        from scikits.learn.externals.joblib import Parallel, delayed
+        parallel = Parallel(n_jobs, verbose=verbose)
+        my_func = delayed(func)
+
+        if n_jobs == -1:
+            try:
+                import multiprocessing
+                n_jobs = multiprocessing.cpu_count()
+            except ImportError:
+                print "multiprocessing not installed. Cannot run in parallel."
+                n_jobs = 1
+
+    except ImportError:
+        print "joblib not installed. Cannot run in parallel."
+        n_jobs = 1
+        my_func = func
+        parallel = list
+    return parallel, my_func, n_jobs
diff --git a/mne/stats/cluster_level.py b/mne/stats/cluster_level.py
index c1bda6c..eea5e10 100644
--- a/mne/stats/cluster_level.py
+++ b/mne/stats/cluster_level.py
@@ -10,6 +10,7 @@ import numpy as np
 from scipy import stats, sparse, ndimage
 
 from .parametric import f_oneway
+from ..parallel import parallel_func
 
 
 def _get_components(x_in, connectivity):
@@ -123,9 +124,23 @@ def _pval_from_histogram(T, H0, tail):
     return pval
 
 
+def _one_permutation(X_full, slices, stat_fun, tail, threshold, connectivity):
+    np.random.shuffle(X_full)
+    X_shuffle_list = [X_full[s] for s in slices]
+    T_obs_surr = stat_fun(*X_shuffle_list)
+    _, perm_clusters_sums = _find_clusters(T_obs_surr, threshold, tail,
+                                           connectivity)
+
+    if len(perm_clusters_sums) > 0:
+        return np.max(perm_clusters_sums)
+    else:
+        return 0
+
+
 def permutation_cluster_test(X, stat_fun=f_oneway, threshold=1.67,
                              n_permutations=1000, tail=0,
-                             connectivity=None, verbose=True):
+                             connectivity=None, n_jobs=1,
+                             verbose=5):
     """Cluster-level statistical permutation test
 
     For a list of 2d-arrays of data, e.g. power values, calculate some
@@ -154,8 +169,10 @@ def permutation_cluster_test(X, stat_fun=f_oneway, threshold=1.67,
         Defines connectivity between features. The matrix is assumed to
         be symmetric and only the upper triangular half is used.
         Defaut is None, i.e, no connectivity.
-    verbose: boolean
-        If True print some text.
+    verbose : int
+        If > 0, print some text during computation.
+    n_jobs : int
+        Number of permutations to run in parallel (requires joblib package.)
 
     Returns
     -------
@@ -195,24 +212,16 @@ def permutation_cluster_test(X, stat_fun=f_oneway, threshold=1.67,
     slices = [slice(splits_idx[k], splits_idx[k + 1])
                                                 for k in range(len(X))]
 
+    parallel, my_one_permutation, _ = parallel_func(_one_permutation, n_jobs,
+                                                 verbose)
+
     # Step 2: If we have some clusters, repeat process on permuted data
     # -------------------------------------------------------------------
     if len(clusters) > 0:
-        H0 = np.zeros(n_permutations)  # histogram
-        for i_s in range(n_permutations):
-            if verbose:
-                print "Permutation %d / %d" % (i_s + 1, n_permutations) 
-            np.random.shuffle(X_full)
-            X_shuffle_list = [X_full[s] for s in slices]
-            T_obs_surr = stat_fun(*X_shuffle_list)
-            _, perm_clusters_sums = _find_clusters(T_obs_surr, threshold, tail,
-                                                   connectivity)
-
-            if len(perm_clusters_sums) > 0:
-                H0[i_s] = np.max(perm_clusters_sums)
-            else:
-                H0[i_s] = 0
-
+        H0 = parallel(my_one_permutation(X_full, slices, stat_fun, tail,
+                                threshold, connectivity)
+                                for _ in range(n_permutations))
+        H0 = np.array(H0)
         cluster_pv = _pval_from_histogram(cluster_stats, H0, tail)
         return T_obs, clusters, cluster_pv, H0
     else:
@@ -229,9 +238,28 @@ def ttest_1samp(X):
     return T
 
 
+def _one_1samp_permutation(n_samples, shape_ones, X_copy, threshold, tail,
+                           connectivity, stat_fun):
+    # new surrogate data with random sign flip
+    signs = np.sign(0.5 - np.random.rand(n_samples, *shape_ones))
+    X_copy *= signs
+
+    # Recompute statistic on randomized data
+    T_obs_surr = stat_fun(X_copy)
+    _, perm_clusters_sums = _find_clusters(T_obs_surr, threshold, tail,
+                                           connectivity)
+
+    if len(perm_clusters_sums) > 0:
+        idx_max = np.argmax(np.abs(perm_clusters_sums))
+        return perm_clusters_sums[idx_max]  # get max with sign info
+    else:
+        return 0.0
+
+
 def permutation_cluster_1samp_test(X, threshold=1.67, n_permutations=1000,
                                    tail=0, stat_fun=ttest_1samp,
-                                   connectivity=None):
+                                   connectivity=None, n_jobs=1,
+                                   verbose=5):
     """Non-parametric cluster-level 1 sample T-test
 
     From a array of observations, e.g. signal amplitudes or power spectrum
@@ -259,6 +287,11 @@ def permutation_cluster_1samp_test(X, threshold=1.67, n_permutations=1000,
         Defines connectivity between features. The matrix is assumed to
         be symmetric and only the upper triangular half is used.
         Defaut is None, i.e, no connectivity.
+    verbose : int
+        If > 0, print some text during computation.
+    n_jobs : int
+        Number of permutations to run in parallel (requires joblib package.)
+
 
     Returns
     -------
@@ -294,26 +327,16 @@ def permutation_cluster_1samp_test(X, threshold=1.67, n_permutations=1000,
     clusters, cluster_stats = _find_clusters(T_obs, threshold, tail,
                                              connectivity)
 
+    parallel, my_one_1samp_permutation, _ = parallel_func(_one_1samp_permutation,
+                                                       n_jobs, verbose)
+
     # Step 2: If we have some clusters, repeat process on permuted data
     # -------------------------------------------------------------------
     if len(clusters) > 0:
-        H0 = np.empty(n_permutations)  # histogram
-        for i_s in range(n_permutations):
-            # new surrogate data with random sign flip
-            signs = np.sign(0.5 - np.random.rand(n_samples, *shape_ones))
-            X_copy *= signs
-
-            # Recompute statistic on randomized data
-            T_obs_surr = stat_fun(X_copy)
-            _, perm_clusters_sums = _find_clusters(T_obs_surr, threshold, tail,
-                                                   connectivity)
-
-            if len(perm_clusters_sums) > 0:
-                idx_max = np.argmax(np.abs(perm_clusters_sums))
-                H0[i_s] = perm_clusters_sums[idx_max]  # get max with sign info
-            else:
-                H0[i_s] = 0
-
+        H0 = parallel(my_one_1samp_permutation(n_samples, shape_ones, X_copy,
+                                    threshold, tail, connectivity, stat_fun)
+                                    for _ in range(n_permutations))
+        H0 = np.array(H0)
         cluster_pv = _pval_from_histogram(cluster_stats, H0, tail)
 
         return T_obs, clusters, cluster_pv, H0
diff --git a/mne/stats/permutations.py b/mne/stats/permutations.py
index 59c9e74..39e141c 100644
--- a/mne/stats/permutations.py
+++ b/mne/stats/permutations.py
@@ -9,6 +9,8 @@
 from math import sqrt
 import numpy as np
 
+from ..parallel import parallel_func
+
 
 def bin_perm_rep(ndim, a=0, b=1):
     """bin_perm_rep(ndim) -> ndim permutations with repetitions of (a,b).
@@ -128,23 +130,7 @@ def permutation_t_test(X, n_permutations=10000, tail=0, n_jobs=1):
     else:
         perms = np.sign(0.5 - np.random.rand(n_permutations, n_samples))
 
-    try:
-        from scikits.learn.externals.joblib import Parallel, delayed
-        parallel = Parallel(n_jobs)
-        my_max_stat = delayed(_max_stat)
-    except ImportError:
-        print "joblib not installed. Cannot run in parallel."
-        n_jobs = 1
-        my_max_stat = _max_stat
-        parallel = list
-
-    if n_jobs == -1:
-        try:
-            import multiprocessing
-            n_jobs = multiprocessing.cpu_count()
-        except ImportError:
-            print "multiprocessing not installed. Cannot run in parallel."
-            n_jobs = 1
+    parallel, my_max_stat, n_jobs = parallel_func(_max_stat, n_jobs)
 
     max_abs = np.concatenate(parallel(my_max_stat(X, X2, p, dof_scaling)
                                       for p in np.array_split(perms, n_jobs)))
diff --git a/mne/time_frequency/tfr.py b/mne/time_frequency/tfr.py
index ff97ed4..6af02ec 100644
--- a/mne/time_frequency/tfr.py
+++ b/mne/time_frequency/tfr.py
@@ -12,6 +12,7 @@ import numpy as np
 from scipy import linalg
 from scipy.fftpack import fftn, ifftn
 from ..baseline import rescale
+from ..parallel import parallel_func
 
 
 def morlet(Fs, freqs, n_cycles=7, sigma=None):
@@ -276,15 +277,7 @@ def single_trial_power(epochs, Fs, frequencies, use_fft=True, n_cycles=7,
     # Precompute wavelets for given frequency range to save time
     Ws = morlet(Fs, frequencies, n_cycles=n_cycles)
 
-    try:
-        from scikits.learn.externals.joblib import Parallel, delayed
-        parallel = Parallel(n_jobs)
-        my_cwt = delayed(cwt)
-    except ImportError:
-        print "joblib not installed. Cannot run in parallel."
-        n_jobs = 1
-        my_cwt = cwt
-        parallel = list
+    parallel, my_cwt, _ = parallel_func(cwt, n_jobs)
 
     print "Computing time-frequency power on single epochs..."
 
@@ -347,13 +340,9 @@ def induced_power(epochs, Fs, frequencies, use_fft=True, n_cycles=7,
     # Precompute wavelets for given frequency range to save time
     Ws = morlet(Fs, frequencies, n_cycles=n_cycles)
 
-    try:
-        import joblib
-    except ImportError:
-        print "joblib not installed. Cannot run in parallel."
-        n_jobs = 1
+    parallel, my_time_frequency, _ = parallel_func(_time_frequency, n_jobs)
 
-    if n_jobs == 1:
+    if my_time_frequency is _time_frequency:  # not parallel
         psd = np.empty((n_channels, n_frequencies, n_times))
         plf = np.empty((n_channels, n_frequencies, n_times), dtype=np.complex)
 
@@ -362,11 +351,9 @@ def induced_power(epochs, Fs, frequencies, use_fft=True, n_cycles=7,
             psd[c], plf[c] = _time_frequency(X, Ws, use_fft)
 
     else:
-        from joblib import Parallel, delayed
-        psd_plf = Parallel(n_jobs=n_jobs)(
-                    delayed(_time_frequency)(
-                            np.squeeze(epochs[:, c, :]), Ws, use_fft)
-                    for c in range(n_channels))
+        psd_plf = parallel(my_time_frequency(np.squeeze(epochs[:, c, :]),
+                                             Ws, use_fft)
+                           for c in range(n_channels))
 
         psd = np.zeros((n_channels, n_frequencies, n_times))
         plf = np.zeros((n_channels, n_frequencies, n_times), dtype=np.complex)

-- 
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-med/python-mne.git



More information about the debian-med-commit mailing list