[med-svn] [python-mne] 282/376: ENH : adding cluster level stats with connectivity matrix (for mesh)

Yaroslav Halchenko debian at onerussian.com
Fri Nov 27 17:23:06 UTC 2015


This is an automated email from the git hooks/post-receive script.

yoh pushed a commit to annotated tag v0.1
in repository python-mne.

commit 6f1c7d63f0f3a565be0e6b038eead3d0b2091b0a
Author: Alexandre Gramfort <alexandre.gramfort at inria.fr>
Date:   Tue May 31 20:08:48 2011 -0400

    ENH : adding cluster level stats with connectivity matrix (for mesh)
---
 mne/stats/cluster_level.py            | 68 +++++++++++++++++++++++++----------
 mne/stats/tests/test_cluster_level.py | 29 ++++++++++-----
 2 files changed, 70 insertions(+), 27 deletions(-)

diff --git a/mne/stats/cluster_level.py b/mne/stats/cluster_level.py
index aa82aec..5bd3e49 100644
--- a/mne/stats/cluster_level.py
+++ b/mne/stats/cluster_level.py
@@ -7,13 +7,12 @@
 # License: Simplified BSD
 
 import numpy as np
-from scipy import ndimage
-from scipy import stats
+from scipy import stats, sparse, ndimage
 
 from .parametric import f_oneway
 
 
-def _find_clusters(x, threshold, tail=0):
+def _find_clusters(x, threshold, tail=0, connectivity=None):
     """For a given 1d-array (test statistic), find all clusters which
     are above/below a certain threshold. Returns a list of 2-tuples.
 
@@ -49,18 +48,39 @@ def _find_clusters(x, threshold, tail=0):
 
     labels, n_labels = ndimage.label(x_in)
 
-    if x.ndim == 1:
-        clusters = ndimage.find_objects(labels, n_labels)
-        sums = ndimage.measurements.sum(x, labels,
-                                        index=range(1, n_labels + 1))
+    if connectivity is None:
+        if x.ndim == 1:
+            clusters = ndimage.find_objects(labels, n_labels)
+            sums = ndimage.measurements.sum(x, labels,
+                                            index=range(1, n_labels + 1))
+        else:
+            clusters = list()
+            sums = np.empty(n_labels)
+            for l in range(1, n_labels + 1):
+                c = labels == l
+                clusters.append(c)
+                sums[l - 1] = np.sum(x[c])
     else:
+        if x.ndim > 1:
+            raise Exception("Data should be 1D when using a connectivity "
+                            "to define clusters.")
+        from scikits.learn.utils._csgraph import cs_graph_components
+        mask = np.logical_and(x_in[connectivity.row], x_in[connectivity.col])
+        if np.sum(mask) == 0:
+            return [], np.empty(0)
+        connectivity = sparse.coo_matrix((connectivity.data[mask],
+                                         (connectivity.row[mask],
+                                          connectivity.col[mask])))
+        _, components = cs_graph_components(connectivity)
+        labels = np.unique(components)
         clusters = list()
-        sums = np.empty(n_labels)
-        for l in range(1, n_labels + 1):
-            c = labels == l
-            clusters.append(c)
-            sums[l - 1] = np.sum(x[c])
-
+        sums = list()
+        for l in labels:
+            c = (components == l)
+            if np.any(x_in[c]):
+                clusters.append(c)
+                sums.append(np.sum(x[c]))
+        sums = np.array(sums)
     return clusters, sums
 
 
@@ -86,7 +106,8 @@ def _pval_from_histogram(T, H0, tail):
 
 
 def permutation_cluster_test(X, stat_fun=f_oneway, threshold=1.67,
-                             n_permutations=1000, tail=0):
+                             n_permutations=1000, tail=0,
+                             connectivity=None):
     """Cluster-level statistical permutation test
 
     For a list of 2d-arrays of data, e.g. power values, calculate some
@@ -111,6 +132,10 @@ def permutation_cluster_test(X, stat_fun=f_oneway, threshold=1.67,
         If tail is -1, the statistic is thresholded below threshold.
         If tail is 0, the statistic is thresholded on both sides of
         the distribution.
+    connectivity : sparse matrix.
+        Defines connectivity between features. The matrix is assumed to
+        be symmetric and only the upper triangular half is used.
+        Defaut is None, i.e, no connectivity.
 
     Returns
     -------
@@ -139,7 +164,8 @@ def permutation_cluster_test(X, stat_fun=f_oneway, threshold=1.67,
     # -------------------------------------------------------------
     T_obs = stat_fun(*X)
 
-    clusters, cluster_stats = _find_clusters(T_obs, threshold, tail)
+    clusters, cluster_stats = _find_clusters(T_obs, threshold, tail,
+                                             connectivity)
 
     # make list of indices for random data split
     splits_idx = np.append([0], np.cumsum(n_samples_per_condition))
@@ -154,7 +180,8 @@ def permutation_cluster_test(X, stat_fun=f_oneway, threshold=1.67,
             np.random.shuffle(X_full)
             X_shuffle_list = [X_full[s] for s in slices]
             T_obs_surr = stat_fun(*X_shuffle_list)
-            _, perm_clusters_sums = _find_clusters(T_obs_surr, threshold, tail)
+            _, perm_clusters_sums = _find_clusters(T_obs_surr, threshold, tail,
+                                                   connectivity)
 
             if len(perm_clusters_sums) > 0:
                 H0[i_s] = np.max(perm_clusters_sums)
@@ -178,7 +205,8 @@ def ttest_1samp(X):
 
 
 def permutation_cluster_1samp_test(X, threshold=1.67, n_permutations=1000,
-                                   tail=0, stat_fun=ttest_1samp):
+                                   tail=0, stat_fun=ttest_1samp,
+                                   connectivity=None):
     """Non-parametric cluster-level 1 sample T-test
 
     From a array of observations, e.g. signal amplitudes or power spectrum
@@ -230,7 +258,8 @@ def permutation_cluster_1samp_test(X, threshold=1.67, n_permutations=1000,
     # -------------------------------------------------------------
     T_obs = stat_fun(X)
 
-    clusters, cluster_stats = _find_clusters(T_obs, threshold, tail)
+    clusters, cluster_stats = _find_clusters(T_obs, threshold, tail,
+                                             connectivity)
 
     # Step 2: If we have some clusters, repeat process on permuted data
     # -------------------------------------------------------------------
@@ -243,7 +272,8 @@ def permutation_cluster_1samp_test(X, threshold=1.67, n_permutations=1000,
 
             # Recompute statistic on randomized data
             T_obs_surr = stat_fun(X_copy)
-            _, perm_clusters_sums = _find_clusters(T_obs_surr, threshold, tail)
+            _, perm_clusters_sums = _find_clusters(T_obs_surr, threshold, tail,
+                                                   connectivity)
 
             if len(perm_clusters_sums) > 0:
                 idx_max = np.argmax(np.abs(perm_clusters_sums))
diff --git a/mne/stats/tests/test_cluster_level.py b/mne/stats/tests/test_cluster_level.py
index a92130f..f747325 100644
--- a/mne/stats/tests/test_cluster_level.py
+++ b/mne/stats/tests/test_cluster_level.py
@@ -16,15 +16,13 @@ condition2 = np.random.randn(33, 350) * noiselevel
 for c in condition2:
     c[:] = np.convolve(c, np.hanning(20), mode="same") / normfactor
 
-pseudoekp = 5 * np.hanning(150)[None,:]
+pseudoekp = 5 * np.hanning(150)[None, :]
 condition1[:, 100:250] += pseudoekp
 condition2[:, 100:250] -= pseudoekp
 
 
 def test_cluster_permutation_test():
-    """Test cluster level permutations tests.
-    """
-
+    """Test cluster level permutations tests."""
     T_obs, clusters, cluster_p_values, hist = permutation_cluster_test(
                                 [condition1, condition2], n_permutations=500,
                                 tail=1)
@@ -37,10 +35,8 @@ def test_cluster_permutation_test():
 
 
 def test_cluster_permutation_t_test():
-    """Test cluster level permutations T-test.
-    """
-
-    my_condition1 = condition1[:,:,None] # to test 2D also
+    """Test cluster level permutations T-test."""
+    my_condition1 = condition1[:, :, None]  # to test 2D also
     T_obs, clusters, cluster_p_values, hist = permutation_cluster_1samp_test(
                                 my_condition1, n_permutations=500, tail=0)
     assert_equal(np.sum(cluster_p_values < 0.05), 1)
@@ -55,3 +51,20 @@ def test_cluster_permutation_t_test():
     assert_array_equal(T_obs_pos, -T_obs_neg)
     assert_array_equal(cluster_p_values_pos < 0.05,
                        cluster_p_values_neg < 0.05)
+
+
+def test_cluster_permutation_t_test_with_connectivity():
+    """Test cluster level permutations T-test with connectivity matrix."""
+    try:
+        from scikits.learn.feature_extraction.image import grid_to_graph
+    except ImportError:
+        pass
+    else:
+        out = permutation_cluster_1samp_test(condition1, n_permutations=500)
+        connectivity = grid_to_graph(1, condition1.shape[1])
+        out_connectivity = permutation_cluster_1samp_test(condition1,
+                                 n_permutations=500, connectivity=connectivity)
+        assert_array_equal(out[0], out_connectivity[0])
+        for a, b in zip(out_connectivity[1], out[1]):
+            assert np.sum(out[0][a]) == np.sum(out[0][b])
+            assert np.all(a[b])

-- 
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-med/python-mne.git



More information about the debian-med-commit mailing list