[med-svn] [python-mne] 113/376: towards 2D permutation cluster stats WIP

Yaroslav Halchenko debian at onerussian.com
Fri Nov 27 17:22:18 UTC 2015


This is an automated email from the git hooks/post-receive script.

yoh pushed a commit to annotated tag v0.1
in repository python-mne.

commit b9017b33da5fb95dca197d2e8cddb05493ec03dd
Author: Alexandre Gramfort <alexandre.gramfort at inria.fr>
Date:   Mon Mar 7 11:10:10 2011 -0500

    towards 2D permutation cluster stats WIP
---
 examples/stats/plot_cluster_stats_evoked.py |   4 +-
 mne/stats/__init__.py                       |   2 +-
 mne/stats/cluster_level.py                  | 180 +++++++++++++++-------------
 mne/stats/parametric.py                     |  79 ++++++++++++
 mne/stats/tests/test_cluster_level.py       |   8 +-
 mne/tfr.py                                  |  24 ++++
 6 files changed, 205 insertions(+), 92 deletions(-)

diff --git a/examples/stats/plot_cluster_stats_evoked.py b/examples/stats/plot_cluster_stats_evoked.py
index c6dd8fd..53dbe64 100644
--- a/examples/stats/plot_cluster_stats_evoked.py
+++ b/examples/stats/plot_cluster_stats_evoked.py
@@ -19,7 +19,7 @@ import numpy as np
 
 import mne
 from mne import fiff
-from mne.stats import permutation_1d_cluster_test
+from mne.stats import permutation_cluster_test
 from mne.datasets import sample
 
 ###############################################################################
@@ -55,7 +55,7 @@ condition2 = np.squeeze(np.array([d['epoch'] for d in data2])) # as 3D matrix
 # Compute statistic
 threshold = 6.0
 T_obs, clusters, cluster_p_values, H0 = \
-                permutation_1d_cluster_test([condition1, condition2],
+                permutation_cluster_test([condition1, condition2],
                             n_permutations=1000, threshold=threshold, tail=1)
 
 ###############################################################################
diff --git a/mne/stats/__init__.py b/mne/stats/__init__.py
index 4b782c8..2d810de 100644
--- a/mne/stats/__init__.py
+++ b/mne/stats/__init__.py
@@ -1,2 +1,2 @@
 from .permutations import permutation_t_test
-from .cluster_level import permutation_1d_cluster_test
+from .cluster_level import permutation_cluster_test
diff --git a/mne/stats/cluster_level.py b/mne/stats/cluster_level.py
index 70d9973..f1a5bbb 100644
--- a/mne/stats/cluster_level.py
+++ b/mne/stats/cluster_level.py
@@ -7,24 +7,11 @@
 # License: Simplified BSD
 
 import numpy as np
-from scipy import stats, ndimage
+from scipy import ndimage
 from scipy.stats import percentileofscore
-from scikits.learn.feature_selection import univariate_selection
-
-def f_oneway(*args):
-    """Call scipy.stats.f_oneway, but return only f-value"""
-    return univariate_selection.f_oneway(*args)[0]
-    # return stats.f_oneway(*args)[0]
-
-# def best_component(x, threshold, tail=0):
-#     if tail == -1:
-#         x_in = x < threshold
-#     elif tail == 1:
-#         x_in = x > threshold
-#     else:
-#         x_in = np.abs(x) > threshold
-#     labels, n_labels = ndimage.label(x_in)
-#     return np.max(ndimage.measurements.sum(x, labels, index=range(1, n_labels+1)))
+
+from .parametric import f_oneway
+
 
 def _find_clusters(x, threshold, tail=0):
     """For a given 1d-array (test statistic), find all clusters which
@@ -41,21 +28,12 @@ def _find_clusters(x, threshold, tail=0):
 
     Returns
     -------
-    clusters: list of tuples
-        Each tuple is a pair of indices (begin/end of cluster)
+    clusters: list of slices or list of arrays (boolean masks)
+        We use slices for 1D signals and mask to multidimensional
+        arrays.
 
-    Example
-    -------
-    >>> _find_clusters([1, 2, 3, 1], 1.9, tail=1)
-    [(1, 3)]
-    >>> _find_clusters([2, 2, 3, 1], 1.9, tail=1)
-    [(0, 3)]
-    >>> _find_clusters([1, 2, 3, 2], 1.9, tail=1)
-    [(1, 4)]
-    >>> _find_clusters([1, -2, 3, 1], 1.9, tail=0)
-    [(1, 3)]
-    >>> _find_clusters([1, -2, -3, 1], -1.9, tail=-1)
-    [(1, 3)]
+    sums: array
+        Sum of x values in clusters
     """
     if not tail in [-1, 0, 1]:
         raise ValueError('invalid tail parameter')
@@ -70,10 +48,22 @@ def _find_clusters(x, threshold, tail=0):
         x_in = np.abs(x) > threshold
 
     labels, n_labels = ndimage.label(x_in)
-    return ndimage.find_objects(labels, n_labels)
 
+    if x.ndim == 1:
+        clusters = ndimage.find_objects(labels, n_labels)
+        sums = ndimage.measurements.sum(x, labels, index=range(1, n_labels+1))
+    else:
+        clusters = list()
+        sums = np.empty(n_labels)
+        for l in range(1, n_labels+1):
+            c = labels == l
+            clusters.append(c)
+            sums[l-1] = np.sum(x[c])
+
+    return clusters, sums
 
-def permutation_1d_cluster_test(X, stat_fun=f_oneway, threshold=1.67,
+
+def permutation_cluster_test(X, stat_fun=f_oneway, threshold=1.67,
                              n_permutations=1000, tail=0):
     """Cluster-level statistical permutation test
 
@@ -116,14 +106,13 @@ def permutation_1d_cluster_test(X, stat_fun=f_oneway, threshold=1.67,
     doi:10.1016/j.jneumeth.2007.03.024
     """
     X_full = np.concatenate(X, axis=0)
-    n_samples_total = X_full.shape[0]
     n_samples_per_condition = [x.shape[0] for x in X]
 
     # Step 1: Calculate Anova (or other stat_fun) for original data
     # -------------------------------------------------------------
     T_obs = stat_fun(*X)
 
-    clusters = _find_clusters(T_obs, threshold, tail)
+    clusters, cluster_stats = _find_clusters(T_obs, threshold, tail)
 
     # make list of indices for random data split
     splits_idx = np.append([0], np.cumsum(n_samples_per_condition))
@@ -133,19 +122,16 @@ def permutation_1d_cluster_test(X, stat_fun=f_oneway, threshold=1.67,
     # Step 2: If we have some clusters, repeat process on permutated data
     # -------------------------------------------------------------------
     if len(clusters) > 0:
-        cluster_stats = [np.sum(T_obs[c]) for c in clusters]
         cluster_pv = np.ones(len(clusters), dtype=np.float)
         H0 = np.zeros(n_permutations) # histogram
         for i_s in range(n_permutations):
             np.random.shuffle(X_full)
             X_shuffle_list = [X_full[s] for s in slices]
             T_obs_surr = stat_fun(*X_shuffle_list)
-            clusters_perm = _find_clusters(T_obs_surr, threshold, tail)
+            _, perm_clusters_sums = _find_clusters(T_obs_surr, threshold, tail)
 
-            if len(clusters_perm) > 0:
-                cluster_stats_perm = [np.sum(T_obs_surr[c])
-                                              for c in clusters_perm]
-                H0[i_s] = max(cluster_stats_perm)
+            if len(perm_clusters_sums) > 0:
+                H0[i_s] = np.max(perm_clusters_sums)
             else:
                 H0[i_s] = 0
 
@@ -160,47 +146,71 @@ def permutation_1d_cluster_test(X, stat_fun=f_oneway, threshold=1.67,
         return T_obs, np.array([]), np.array([]), np.array([])
 
 
-permutation_1d_cluster_test.__test__ = False
-
-if __name__ == "__main__":
-    noiselevel = 30
-
-    normfactor = np.hanning(20).sum()
-
-    condition1 = np.random.randn(50, 500) * noiselevel
-    for i in range(50):
-        condition1[i] = np.convolve(condition1[i], np.hanning(20),
-                                      mode="same") / normfactor
-
-    condition2 = np.random.randn(43, 500) * noiselevel
-    for i in range(43):
-        condition2[i] = np.convolve(condition2[i], np.hanning(20),
-                                      mode="same") / normfactor
-
-    pseudoekp = 5 * np.hanning(150)[None,:]
-    condition1[:, 100:250] += pseudoekp
-    condition2[:, 100:250] -= pseudoekp
-
-    fs, cluster_times, cluster_p_values, histogram = permutation_1d_cluster_test(
-                                [condition1, condition2], n_permutations=1000)
-
-    # # Plotting for a better understanding
-    # import pylab as pl
-    # pl.close('all')
-    # pl.subplot(211)
-    # pl.plot(condition1.mean(axis=0), label="Condition 1")
-    # pl.plot(condition2.mean(axis=0), label="Condition 2")
-    # pl.ylabel("signal [a.u.]")
-    # pl.subplot(212)
-    # for i_c, c in enumerate(cluster_times):
-    #     c = c[0]
-    #     if cluster_p_values[i_c] <= 0.05:
-    #         h = pl.axvspan(c.start, c.stop-1, color='r', alpha=0.3)
-    #     else:
-    #         pl.axvspan(c.start, c.stop-1, color=(0.3, 0.3, 0.3), alpha=0.3)
-    # hf = pl.plot(fs, 'g')
-    # pl.legend((h, ), ('cluster p-value < 0.05', ))
-    # pl.xlabel("time (ms)")
-    # pl.ylabel("f-values")
-    # pl.show()
-
+permutation_cluster_test.__test__ = False
+
+# if __name__ == "__main__":
+#     noiselevel = 30
+#     np.random.seed(0)
+# 
+#     # 1D
+#     normfactor = np.hanning(20).sum()
+#     condition1 = np.random.randn(50, 300) * noiselevel
+#     for i in range(50):
+#         condition1[i] = np.convolve(condition1[i], np.hanning(20),
+#                                       mode="same") / normfactor
+#     condition2 = np.random.randn(43, 300) * noiselevel
+#     for i in range(43):
+#         condition2[i] = np.convolve(condition2[i], np.hanning(20),
+#                                       mode="same") / normfactor
+#     pseudoekp = 5 * np.hanning(150)[None,:]
+#     condition1[:, 100:250] += pseudoekp
+#     condition2[:, 100:250] -= pseudoekp
+# 
+#     # Make it 2D
+#     condition1 = np.tile(condition1[:,100:275,None], (1, 1, 15))
+#     condition2 = np.tile(condition2[:,100:275,None], (1, 1, 15))
+#     shape1 = condition1[..., :3].shape
+#     shape2 = condition2[..., :3].shape
+#     condition1[..., :3] = np.random.randn(*shape1) * noiselevel
+#     condition2[..., :3] = np.random.randn(*shape2) * noiselevel
+#     condition1[..., -3:] = np.random.randn(*shape1) * noiselevel
+#     condition2[..., -3:] = np.random.randn(*shape2) * noiselevel
+# 
+#     fs, clusters, cluster_p_values, histogram = permutation_cluster_test(
+#                                 [condition1, condition2], n_permutations=1000)
+# 
+#     # # Plotting for a better understanding
+#     # import pylab as pl
+#     # pl.close('all')
+#     #
+#     # if condition1.ndim == 2:
+#     #     pl.subplot(211)
+#     #     pl.plot(condition1.mean(axis=0), label="Condition 1")
+#     #     pl.plot(condition2.mean(axis=0), label="Condition 2")
+#     #     pl.ylabel("signal [a.u.]")
+#     #     pl.subplot(212)
+#     #     for i_c, c in enumerate(clusters):
+#     #         c = c[0]
+#     #         if cluster_p_values[i_c] <= 0.05:
+#     #             h = pl.axvspan(c.start, c.stop-1, color='r', alpha=0.3)
+#     #         else:
+#     #             pl.axvspan(c.start, c.stop-1, color=(0.3, 0.3, 0.3), alpha=0.3)
+#     #     hf = pl.plot(fs, 'g')
+#     #     pl.legend((h, ), ('cluster p-value < 0.05', ))
+#     #     pl.xlabel("time (ms)")
+#     #     pl.ylabel("f-values")
+#     # else:
+#     #     fs_plot = np.nan * np.ones_like(fs)
+#     #     for c, p_val in zip(clusters, cluster_p_values):
+#     #         if p_val <= 0.05:
+#     #             fs_plot[c] = fs[c]
+#     #
+#     #     pl.imshow(fs.T, cmap=pl.cm.gray)
+#     #     pl.imshow(fs_plot.T, cmap=pl.cm.jet)
+#     #     # pl.imshow(fs.T, cmap=pl.cm.gray, alpha=0.6)
+#     #     # pl.imshow(fs_plot.T, cmap=pl.cm.jet, alpha=0.6)
+#     #     pl.xlabel('time')
+#     #     pl.ylabel('Freq')
+#     #     pl.colorbar()
+#     #
+#     # pl.show()
diff --git a/mne/stats/parametric.py b/mne/stats/parametric.py
new file mode 100644
index 0000000..a489a4e
--- /dev/null
+++ b/mne/stats/parametric.py
@@ -0,0 +1,79 @@
+import numpy as np
+from scipy import stats
+
+# The following function is a rewriting of scipy.stats.f_oneway
+# Contrary to the scipy.stats.f_oneway implementation it does not
+# copy the data while keeping the inputs unchanged.
+def _f_oneway(*args):
+    """
+    Performs a 1-way ANOVA.
+
+    The on-way ANOVA tests the null hypothesis that 2 or more groups have
+    the same population mean.  The test is applied to samples from two or
+    more groups, possibly with differing sizes.
+
+    Parameters
+    ----------
+    sample1, sample2, ... : array_like
+        The sample measurements should be given as arguments.
+
+    Returns
+    -------
+    F-value : float
+        The computed F-value of the test
+    p-value : float
+        The associated p-value from the F-distribution
+
+    Notes
+    -----
+    The ANOVA test has important assumptions that must be satisfied in order
+    for the associated p-value to be valid.
+
+    1. The samples are independent
+    2. Each sample is from a normally distributed population
+    3. The population standard deviations of the groups are all equal.  This
+       property is known as homocedasticity.
+
+    If these assumptions are not true for a given set of data, it may still be
+    possible to use the Kruskal-Wallis H-test (`stats.kruskal`_) although with
+    some loss of power
+
+    The algorithm is from Heiman[2], pp.394-7.
+
+    See scipy.stats.f_oneway that should give the same results while
+    being less efficient
+
+    References
+    ----------
+    .. [1] Lowry, Richard.  "Concepts and Applications of Inferential
+           Statistics". Chapter 14. http://faculty.vassar.edu/lowry/ch14pt1.html
+
+    .. [2] Heiman, G.W.  Research Methods in Statistics. 2002.
+
+    """
+    n_classes = len(args)
+    n_samples_per_class = np.array([len(a) for a in args])
+    n_samples = np.sum(n_samples_per_class)
+    ss_alldata = reduce(lambda x, y: x+y, [np.sum(a**2, axis=0) for a in args])
+    sums_args = [np.sum(a, axis=0) for a in args]
+    square_of_sums_alldata = reduce(lambda x, y: x+y, sums_args)**2
+    square_of_sums_args = [s**2 for s in sums_args]
+    sstot = ss_alldata - square_of_sums_alldata / float(n_samples)
+    ssbn = 0
+    for k, _ in enumerate(args):
+        ssbn += square_of_sums_args[k] / n_samples_per_class[k]
+    ssbn -= square_of_sums_alldata / float(n_samples)
+    sswn = sstot - ssbn
+    dfbn = n_classes - 1
+    dfwn = n_samples - n_classes
+    msb = ssbn / float(dfbn)
+    msw = sswn / float(dfwn)
+    f = msb / msw
+    prob = stats.fprob(dfbn, dfwn, f)
+    return f, prob
+
+
+def f_oneway(*args):
+    """Call scipy.stats.f_oneway, but return only f-value"""
+    return _f_oneway(*args)[0]
+
diff --git a/mne/stats/tests/test_cluster_level.py b/mne/stats/tests/test_cluster_level.py
index 9b6b999..3b9e3bf 100644
--- a/mne/stats/tests/test_cluster_level.py
+++ b/mne/stats/tests/test_cluster_level.py
@@ -1,7 +1,7 @@
 import numpy as np
 from numpy.testing import assert_equal
 
-from ..cluster_level import permutation_1d_cluster_test
+from ..cluster_level import permutation_cluster_test
 
 
 def test_permutation_t_test():
@@ -23,17 +23,17 @@ def test_permutation_t_test():
     condition1[:, 100:250] += pseudoekp
     condition2[:, 100:250] -= pseudoekp
 
-    T_obs, clusters, cluster_p_values, hist = permutation_1d_cluster_test(
+    T_obs, clusters, cluster_p_values, hist = permutation_cluster_test(
                                 [condition1, condition2], n_permutations=500,
                                 tail=1)
     assert_equal(np.sum(cluster_p_values < 0.05), 1)
 
-    T_obs, clusters, cluster_p_values, hist = permutation_1d_cluster_test(
+    T_obs, clusters, cluster_p_values, hist = permutation_cluster_test(
                                 [condition1, condition2], n_permutations=500,
                                 tail=0)
     assert_equal(np.sum(cluster_p_values < 0.05), 1)
 
-    T_obs, clusters, cluster_p_values, hist = permutation_1d_cluster_test(
+    T_obs, clusters, cluster_p_values, hist = permutation_cluster_test(
                                 [condition1, condition2], n_permutations=500,
                                 tail=-1)
     assert_equal(np.sum(cluster_p_values < 0.05), 0)
diff --git a/mne/tfr.py b/mne/tfr.py
index 6a61cb4..65d42bd 100644
--- a/mne/tfr.py
+++ b/mne/tfr.py
@@ -189,6 +189,30 @@ def _time_frequency(X, Ws, use_fft):
 
     return psd, plf
 
+def single_trial_power(epochs, Fs, frequencies, use_fft=True, n_cycles=25,
+                           n_jobs=1):
+    """Compute time-frequency power on single epochs
+    """
+    n_frequencies = len(frequencies)
+    n_epochs, n_channels, n_times = epochs.shape
+
+    # Precompute wavelets for given frequency range to save time
+    Ws = morlet(Fs, frequencies, n_cycles=n_cycles)
+
+    power = np.empty((n_epochs, n_channels, n_frequencies, n_times),
+                     dtype=np.float)
+
+    mode = 'same'
+    if use_fft:
+        _cwt = _cwt_fft
+    else:
+        _cwt = _cwt_convolve
+
+    for k, e in enumerate(epochs):
+        mode = 'same'
+        power[k] = np.abs(_cwt(e, Ws, mode))**2
+
+    return power
 
 def time_frequency(epochs, Fs, frequencies, use_fft=True, n_cycles=25,
                    n_jobs=1):

-- 
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-med/python-mne.git



More information about the debian-med-commit mailing list