[med-svn] [python-mne] 112/376: improve speed of cluster permutation WIP

Yaroslav Halchenko debian at onerussian.com
Fri Nov 27 17:22:18 UTC 2015


This is an automated email from the git hooks/post-receive script.

yoh pushed a commit to annotated tag v0.1
in repository python-mne.

commit e123ba48feab4d056c3bcdb6dd3f2f0b771c9225
Author: Alexandre Gramfort <alexandre.gramfort at inria.fr>
Date:   Sun Mar 6 22:35:02 2011 -0500

    improve speed of cluster permutation WIP
---
 mne/stats/cluster_level.py | 99 ++++++++++++++++++++++++++++++++++------------
 1 file changed, 74 insertions(+), 25 deletions(-)

diff --git a/mne/stats/cluster_level.py b/mne/stats/cluster_level.py
index b93c3f7..70d9973 100644
--- a/mne/stats/cluster_level.py
+++ b/mne/stats/cluster_level.py
@@ -7,14 +7,24 @@
 # License: Simplified BSD
 
 import numpy as np
-from scipy import stats
+from scipy import stats, ndimage
 from scipy.stats import percentileofscore
-
+from scikits.learn.feature_selection import univariate_selection
 
 def f_oneway(*args):
     """Call scipy.stats.f_oneway, but return only f-value"""
-    return stats.f_oneway(*args)[0]
-
+    return univariate_selection.f_oneway(*args)[0]
+    # return stats.f_oneway(*args)[0]
+
+# def best_component(x, threshold, tail=0):
+#     if tail == -1:
+#         x_in = x < threshold
+#     elif tail == 1:
+#         x_in = x > threshold
+#     else:
+#         x_in = np.abs(x) > threshold
+#     labels, n_labels = ndimage.label(x_in)
+#     return np.max(ndimage.measurements.sum(x, labels, index=range(1, n_labels+1)))
 
 def _find_clusters(x, threshold, tail=0):
     """For a given 1d-array (test statistic), find all clusters which
@@ -52,19 +62,15 @@ def _find_clusters(x, threshold, tail=0):
 
     x = np.asanyarray(x)
 
-    x = np.concatenate([np.array([threshold]), x, np.array([threshold])])
     if tail == -1:
-        x_in = (x < threshold).astype(np.int)
+        x_in = x < threshold
     elif tail == 1:
-        x_in = (x > threshold).astype(np.int)
+        x_in = x > threshold
     else:
-        x_in = (np.abs(x) > threshold).astype(np.int)
+        x_in = np.abs(x) > threshold
 
-    x_switch = np.diff(x_in)
-    in_points = np.where(x_switch > 0)[0]
-    out_points = np.where(x_switch < 0)[0]
-    clusters = zip(in_points, out_points)
-    return clusters
+    labels, n_labels = ndimage.label(x_in)
+    return ndimage.find_objects(labels, n_labels)
 
 
 def permutation_1d_cluster_test(X, stat_fun=f_oneway, threshold=1.67,
@@ -119,25 +125,26 @@ def permutation_1d_cluster_test(X, stat_fun=f_oneway, threshold=1.67,
 
     clusters = _find_clusters(T_obs, threshold, tail)
 
-    splits_idx = np.cumsum(n_samples_per_condition)[:-1]
+    # make list of indices for random data split
+    splits_idx = np.append([0], np.cumsum(n_samples_per_condition))
+    slices = [slice(splits_idx[k], splits_idx[k+1])
+                                                for k in range(len(X))]
+
     # Step 2: If we have some clusters, repeat process on permutated data
     # -------------------------------------------------------------------
     if len(clusters) > 0:
-        cluster_stats = [np.sum(T_obs[c[0]:c[1]]) for c in clusters]
+        cluster_stats = [np.sum(T_obs[c]) for c in clusters]
         cluster_pv = np.ones(len(clusters), dtype=np.float)
         H0 = np.zeros(n_permutations) # histogram
         for i_s in range(n_permutations):
-            # make list of indices for random data split
-            indices_lists = np.split(np.random.permutation(n_samples_total),
-                                     splits_idx)
-
-            X_shuffle_list = [X_full[indices] for indices in indices_lists]
+            np.random.shuffle(X_full)
+            X_shuffle_list = [X_full[s] for s in slices]
             T_obs_surr = stat_fun(*X_shuffle_list)
             clusters_perm = _find_clusters(T_obs_surr, threshold, tail)
 
             if len(clusters_perm) > 0:
-                cluster_stats_perm = [np.sum(T_obs_surr[c[0]:c[1]])
-                                      for c in clusters_perm]
+                cluster_stats_perm = [np.sum(T_obs_surr[c])
+                                              for c in clusters_perm]
                 H0[i_s] = max(cluster_stats_perm)
             else:
                 H0[i_s] = 0
@@ -145,9 +152,8 @@ def permutation_1d_cluster_test(X, stat_fun=f_oneway, threshold=1.67,
         # for each cluster in original data, calculate p-value as percentile
         # of its cluster statistics within all cluster statistics in surrogate
         # data
-        cluster_pv[:] = [percentileofscore(H0,
-                                           cluster_stats[i_cl])
-                          for i_cl in range(len(clusters))]
+        cluster_pv[:] = [percentileofscore(H0, cluster_stats[i_cl])
+                                             for i_cl in range(len(clusters))]
         cluster_pv[:] = (100.0 - cluster_pv[:]) / 100.0 # from pct to fraction
         return T_obs, clusters, cluster_pv, H0
     else:
@@ -155,3 +161,46 @@ def permutation_1d_cluster_test(X, stat_fun=f_oneway, threshold=1.67,
 
 
 permutation_1d_cluster_test.__test__ = False
+
+if __name__ == "__main__":
+    noiselevel = 30
+
+    normfactor = np.hanning(20).sum()
+
+    condition1 = np.random.randn(50, 500) * noiselevel
+    for i in range(50):
+        condition1[i] = np.convolve(condition1[i], np.hanning(20),
+                                      mode="same") / normfactor
+
+    condition2 = np.random.randn(43, 500) * noiselevel
+    for i in range(43):
+        condition2[i] = np.convolve(condition2[i], np.hanning(20),
+                                      mode="same") / normfactor
+
+    pseudoekp = 5 * np.hanning(150)[None,:]
+    condition1[:, 100:250] += pseudoekp
+    condition2[:, 100:250] -= pseudoekp
+
+    fs, cluster_times, cluster_p_values, histogram = permutation_1d_cluster_test(
+                                [condition1, condition2], n_permutations=1000)
+
+    # # Plotting for a better understanding
+    # import pylab as pl
+    # pl.close('all')
+    # pl.subplot(211)
+    # pl.plot(condition1.mean(axis=0), label="Condition 1")
+    # pl.plot(condition2.mean(axis=0), label="Condition 2")
+    # pl.ylabel("signal [a.u.]")
+    # pl.subplot(212)
+    # for i_c, c in enumerate(cluster_times):
+    #     c = c[0]
+    #     if cluster_p_values[i_c] <= 0.05:
+    #         h = pl.axvspan(c.start, c.stop-1, color='r', alpha=0.3)
+    #     else:
+    #         pl.axvspan(c.start, c.stop-1, color=(0.3, 0.3, 0.3), alpha=0.3)
+    # hf = pl.plot(fs, 'g')
+    # pl.legend((h, ), ('cluster p-value < 0.05', ))
+    # pl.xlabel("time (ms)")
+    # pl.ylabel("f-values")
+    # pl.show()
+

-- 
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-med/python-mne.git



More information about the debian-med-commit mailing list