[med-svn] [python-mne] 292/353: added bootstrap and crop function to epochs

Yaroslav Halchenko debian at onerussian.com
Fri Nov 27 17:25:18 UTC 2015


This is an automated email from the git hooks/post-receive script.

yoh pushed a commit to tag 0.4
in repository python-mne.

commit bf7fa3447926b9962bb3cc00b5aec54be7d6d7d7
Author: Daniel Strohmeier <daniel.strohmeier at googlemail.com>
Date:   Wed Jul 18 15:44:30 2012 +0200

    added bootstrap and crop function to epochs
---
 mne/epochs.py            | 70 +++++++++++++++++++++++++++++++++++++++++++++---
 mne/tests/test_epochs.py | 53 ++++++++++++++++++++++++++++++++++++
 2 files changed, 119 insertions(+), 4 deletions(-)

diff --git a/mne/epochs.py b/mne/epochs.py
index a49f60d..ceb9618 100644
--- a/mne/epochs.py
+++ b/mne/epochs.py
@@ -384,10 +384,12 @@ class Epochs(object):
             if isinstance(key, slice):
                 epochs._data = self._data[key]
             else:
-                #make sure data remains a 3D array
-                #Note: np.atleast_3d() doesn't do what we want
-                epochs._data = np.array([self._data[key]])
-
+                if isinstance(key, list):
+                    key = np.array(key)
+                if np.ndim(key) == 0:
+                    epochs._data = self._data[key][np.newaxis, :, :]
+                else:
+                    epochs._data = self._data[key]
         return epochs
 
     def average(self, keep_only_data_channels=True):
@@ -441,6 +443,39 @@ class Epochs(object):
             evoked.info['nchan'] = len(data_picks)
             evoked.data = evoked.data[data_picks]
         return evoked
+    
+    def crop(self, tmin, tmax):
+        """Crops a time interval from epochs object.
+    
+        Parameters
+        ----------
+        tmin : float
+            Start time of selection in seconds
+        tmax : float
+            End time of selection in seconds
+    
+        Returns
+        -------
+        epochs : Epochs instance
+            The bootstrap samples
+        """
+        if not self.preload:
+            raise RuntimeError('Modifying data of epochs is only supported '
+                                'when preloading is used. Use preload=True '
+                                'in the constructor.')
+        if tmin < self.tmin:
+            tmin = self.tmin
+        if tmax > self.tmax:
+            tmax = self.tmax
+            
+        sfreq = self.info['sfreq']
+        first_samp = int((tmin - self.tmin) * sfreq)
+        last_samp = int((tmax - self.tmax) * sfreq) - 1
+        
+        self.tmin = tmin
+        self.tmax = tmax
+        self._data = self._data[:, :, first_samp:last_samp]
+        return self
 
 
 def _is_good(e, ch_names, channel_type_idx, reject, flat):
@@ -477,3 +512,30 @@ def _is_good(e, ch_names, channel_type_idx, reject, flat):
                     return False
 
     return True
+
+
+def bootstrap(epochs, rng):
+    """Compute average of epochs selected by bootstrapping
+
+    Parameters
+    ----------
+    epochs : Epochs instance
+        epochs data to be bootstrapped
+    rng: 
+        random number generator.
+
+    Returns
+    -------
+    epochs : Epochs instance
+        The bootstrap samples
+    """
+    if not epochs.preload:
+        raise RuntimeError('Modifying data of epochs is only supported '
+                            'when preloading is used. Use preload=True '
+                            'in the constructor.')
+
+    epochs_bootstrap = copy.deepcopy(epochs)
+    n_events = len(epochs_bootstrap.events)
+    idx = rng.randint(0, n_events, n_events)
+    epochs_bootstrap = epochs_bootstrap[idx]
+    return epochs_bootstrap, idx
diff --git a/mne/tests/test_epochs.py b/mne/tests/test_epochs.py
index 29bfc1c..d2eb23c 100644
--- a/mne/tests/test_epochs.py
+++ b/mne/tests/test_epochs.py
@@ -8,6 +8,7 @@ from numpy.testing import assert_array_equal, assert_array_almost_equal
 import numpy as np
 
 from .. import fiff, Epochs, read_events, pick_events
+from ..epochs import bootstrap
 
 raw_fname = op.join(op.dirname(__file__), '..', 'fiff', 'tests', 'data',
                      'test_raw.fif')
@@ -135,6 +136,23 @@ def test_indexing_slicing():
             data = epochs2_sliced[pos].get_data()
             assert_array_equal(data[0], data_normal[idx])
             pos += 1
+            
+        # using indexing with int
+        idx = np.random.randint(0, data_epochs2_sliced.shape[0], 1)
+        data = epochs2[idx].get_data()
+        assert_array_equal(data, data_normal[idx])
+        
+        # using indexing with array
+        idx = np.random.randint(0, data_epochs2_sliced.shape[0], 10)
+        data = epochs2[idx].get_data()
+        assert_array_equal(data, data_normal[idx])
+        
+        # using indexing with list of indices
+        #idx = list()
+        #for k in range(3):
+        #    idx.append(np.random.randint(0, data_epochs2_sliced.shape[0], 1))
+        #    data = epochs2[idx].get_data()
+        #    assert_array_equal(data, data_normal[idx])
 
 
 def test_comparision_with_c():
@@ -152,3 +170,38 @@ def test_comparision_with_c():
     assert_true(evoked.nave == c_evoked.nave)
     assert_array_almost_equal(evoked_data, c_evoked_data, 10)
     assert_array_almost_equal(evoked.times, c_evoked.times, 12)
+
+
+def test_crop():
+    """Test of crop of epochs
+    """
+    epochs = Epochs(raw, events[:20], event_id, tmin, tmax, picks=picks,
+                    baseline=(None, 0), preload=False,
+                    reject=reject, flat=flat)
+    epochs2 = Epochs(raw, events[:20], event_id, tmin, tmax,
+                    picks=picks, baseline=(None, 0), preload=True,
+                    reject=reject, flat=flat)
+    data_normal = epochs.get_data()
+
+    # indices for slicing
+    start_tsamp = tmin + 60 * epochs.info['sfreq']
+    end_tsamp = tmax - 60 * epochs.info['sfreq']
+    tmask = (epochs.times >= start_tsamp) & (epochs.times <= end_tsamp)
+    assert((start_tsamp) > tmin)
+    assert((end_tsamp) < tmax)
+    epochs2.crop(start_tsamp, end_tsamp)
+    data = epochs2.get_data()
+    assert_array_equal(data, data_normal[:, :, tmask])
+    
+
+def test_bootstrap():
+    """Test of crop of epochs
+    """
+    epochs = Epochs(raw, events[:20], event_id, tmin, tmax, picks=picks,
+                    baseline=(None, 0), preload=True,
+                    reject=reject, flat=flat)
+    data_normal = epochs._data
+    rng = np.random.RandomState(0)
+    epochs2, idx = bootstrap(epochs, rng)
+    n_events = len(epochs.events)
+    assert_array_equal(epochs2._data, data_normal[idx])

-- 
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-med/python-mne.git



More information about the debian-med-commit mailing list