[med-svn] [python-mne] 07/52: added indexing and slicing operations for epoch

Yaroslav Halchenko debian at onerussian.com
Fri Nov 27 17:23:44 UTC 2015


This is an automated email from the git hooks/post-receive script.

yoh pushed a commit to annotated tag v0.2
in repository python-mne.

commit 866040876254090dae33d2abd2d8545f8ba7c335
Author: Martin Luessi <mluessi at nmr.mgh.harvard.edu>
Date:   Tue Sep 27 17:26:02 2011 -0400

    added indexing and slicing operations for epoch
---
 mne/epochs.py            | 85 +++++++++++++++++++++++++++++++++++++++++++++---
 mne/tests/test_epochs.py | 40 +++++++++++++++++++++++
 2 files changed, 120 insertions(+), 5 deletions(-)

diff --git a/mne/epochs.py b/mne/epochs.py
index c1a52da..a3832cc 100644
--- a/mne/epochs.py
+++ b/mne/epochs.py
@@ -79,6 +79,17 @@ class Epochs(object):
         Return Evoked object containing averaged epochs as a
         2D array [n_channels x n_times].
 
+    drop_bad_epochs() : None
+        Drop all epochs marked as bad. Should be used before indexing and
+        slicing operations.
+
+    Indexing and Slicing:
+    -------
+    epochs = Epochs(...)
+
+    epochs[idx] : Return epoch with index idx (2D array, [n_channels, n_times])
+
+    epochs[start:stop] : Return Epochs object with a subset of epochs
     """
 
     def __init__(self, raw, events, event_id, tmin, tmax, baseline=(None, 0),
@@ -96,6 +107,7 @@ class Epochs(object):
         self.preload = preload
         self.reject = reject
         self.flat = flat
+        self.bad_dropped = False
 
         # Handle measurement info
         self.info = copy.deepcopy(raw.info)
@@ -183,7 +195,9 @@ class Epochs(object):
         self._reject_setup()
 
         if self.preload:
-            self._data = self._get_data_from_disk()
+            self._data, good_events = self._get_data_from_disk()
+            self.events = self.events[good_events,:]
+            self.bad_dropped = True
 
     def drop_picks(self, bad_picks):
         """Drop some picks
@@ -206,6 +220,28 @@ class Epochs(object):
         if self.preload:
             self._data = self._data[:, idx, :]
 
+    def drop_bad_epochs(self):
+        """Drop bad epochs.
+
+        Should be used before slicing operations.
+
+        Warning: Operation is slow since all epochs have to be read from disk
+        """
+        if self.bad_dropped:
+            return
+
+        good = []
+        n_events = len(self.events)
+        for idx in range(n_events):
+            epoch = self._get_epoch_from_disk(idx)
+            if self._is_good_epoch(epoch):
+                good.append(idx)
+
+        self.events = self.events[good,:]
+        self.bad_dropped = True
+
+        print "%d bad epochs dropped" % (n_events - len(good))
+
     def _get_epoch_from_disk(self, idx):
         """Load one epoch from disk"""
         sfreq = self.raw.info['sfreq']
@@ -235,18 +271,20 @@ class Epochs(object):
         data = np.empty((n_events, n_channels, n_times))
         cnt = 0
         n_reject = 0
+        event_idx = []
         for k in range(n_events):
             e = self._get_epoch_from_disk(k)
             if self._is_good_epoch(e):
                 data[cnt] = self._get_epoch_from_disk(k)
+                event_idx.append(k)
                 cnt += 1
             else:
                 n_reject += 1
         print "Rejecting %d epochs." % n_reject
-        return data[:cnt]
+        return data[:cnt], event_idx
 
     def _is_good_epoch(self, data):
-        """Determine is epoch is good
+        """Determine if epoch is good
         """
         n_times = len(self.times)
         if self.reject is None and self.flat is None:
@@ -268,7 +306,8 @@ class Epochs(object):
         if self.preload:
             return self._data
         else:
-            return self._get_data_from_disk()
+            data, _ = self._get_data_from_disk()
+            return data
 
     def _reject_setup(self):
         """Setup reject process
@@ -312,12 +351,48 @@ class Epochs(object):
         return epoch
 
     def __repr__(self):
-        s = "n_events : %s" % len(self.events)
+        if not self.bad_dropped:
+            s = "n_events : %s (good & bad)" % len(self.events)
+        else:
+            s = "n_events : %s (all good)" % len(self.events)
         s += ", tmin : %s (s)" % self.tmin
         s += ", tmax : %s (s)" % self.tmax
         s += ", baseline : %s" % str(self.baseline)
         return "Epochs (%s)" % s
 
+    def __getslice__(self, start, end):
+        """Return an Epoch object with a subset of epochs.
+        """
+        if not self.bad_dropped:
+            print "Warning: bad epochs have not been dropped, indexing will " \
+                  "be inccurate. Use drop_bad_epochs() or preload=True"
+
+        epoch_slice = copy.copy(self)
+        epoch_slice.events = self.events[start:end]
+
+        if self.preload:
+            epoch_slice._data = self._data[start:end]
+
+        return epoch_slice
+
+    def __getitem__(self, index):
+        """Return epoch at index
+        """
+        if index < 0 or index >= len(self.events):
+            raise IndexError("Epoch index out of bounds")
+
+        if self.preload:
+            epoch = epoch = self._data[index]
+        else:
+            epoch = self._get_epoch_from_disk(index)
+
+            if not self._is_good_epoch(epoch):
+                print "Warning: Bad epoch with index %d returned. Use " \
+                      "drop_bad_epochs() or preload=True to prevent this." \
+                      % (index)
+
+        return epoch
+
     def average(self):
         """Compute average of epochs
 
diff --git a/mne/tests/test_epochs.py b/mne/tests/test_epochs.py
index 5d97441..94abf4f 100644
--- a/mne/tests/test_epochs.py
+++ b/mne/tests/test_epochs.py
@@ -70,6 +70,46 @@ def test_preload_epochs():
     data_no_preload = epochs.get_data()
     assert_array_equal(data_preload, data_no_preload)
 
+def test_indexing_slicing():
+    """Test of indexing and slicing operations
+    """
+    epochs = Epochs(raw, events[:20], event_id, tmin, tmax, picks=picks,
+                    baseline=(None, 0), preload=False,
+                    reject=reject, flat=flat)
+
+    data_normal = epochs.get_data()
+
+    n_good_events = data_normal.shape[0]
+
+    # indices for slicing
+    start_index = 1
+    end_index   = n_good_events - 1
+
+    assert((end_index - start_index) > 0)
+
+    for preload in [True, False]:
+        epochs2 = Epochs(raw, events[:20], event_id, tmin, tmax,
+                         picks=picks, baseline=(None, 0), preload=preload,
+                         reject=reject, flat=flat)
+
+        if not preload:
+            epochs2.drop_bad_epochs()
+
+        # get slice
+        epochs2_sliced = epochs2[start_index:end_index]
+
+        # using get_data()
+        data_epochs2_sliced = epochs2_sliced.get_data()
+        assert_array_equal(data_epochs2_sliced, \
+                           data_normal[start_index:end_index])
+
+        # using indexing
+        pos = 0
+        for idx in range(start_index, end_index):
+            assert_array_equal(epochs2_sliced[pos], data_normal[idx])
+            pos += 1
+
+
 
 def test_comparision_with_c():
     """Test of average obtained vs C code

-- 
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-med/python-mne.git



More information about the debian-med-commit mailing list