[med-svn] [python-mne] 12/52: Multiple fixes: - indexing and slicing now always returns Epochs object - fixed bugs that occurs when Epochs only has one event (len(self.events) is 3 for single event) - still using shallow copy to avoid copying raw

Yaroslav Halchenko debian at onerussian.com
Fri Nov 27 17:23:44 UTC 2015


This is an automated email from the git hooks/post-receive script.

yoh pushed a commit to annotated tag v0.2
in repository python-mne.

commit 6799065d5f906de43c2ffa2bf297a26b7aa419d5
Author: Martin Luessi <mluessi at nmr.mgh.harvard.edu>
Date:   Wed Sep 28 15:26:21 2011 -0400

     Multiple fixes:
     - indexing and slicing now always returns Epochs object
     - fixed bugs that occurs when Epochs only has one event (len(self.events) is 3 for single event)
     - still using shallow copy to avoid copying raw
---
 mne/epochs.py            | 82 +++++++++++++++++++++++++-----------------------
 mne/tests/test_epochs.py |  6 ++--
 2 files changed, 45 insertions(+), 43 deletions(-)

diff --git a/mne/epochs.py b/mne/epochs.py
index f4c44e4..f61f393 100644
--- a/mne/epochs.py
+++ b/mne/epochs.py
@@ -88,9 +88,9 @@ class Epochs(object):
     -------
     epochs = Epochs(...)
 
-    epochs[idx] : Return epoch with index idx (2D array, [n_channels, n_times])
-
-    epochs[start:stop] : Return Epochs object with a subset of epochs
+    epochs[idx] : Epochs
+        Return Epochs object with a subset of epochs (supports single
+        index and python style slicing)
     """
 
     def __init__(self, raw, events, event_id, tmin, tmax, baseline=(None, 0),
@@ -177,7 +177,7 @@ class Epochs(object):
         #    Select the desired events
         selected = np.logical_and(events[:, 1] == 0, events[:, 2] == event_id)
         self.events = events[selected]
-        n_events = len(self.events)
+        n_events = len(self)
 
         if n_events > 0:
             print '%d matching events found' % n_events
@@ -232,7 +232,7 @@ class Epochs(object):
             return
 
         good_events = []
-        n_events = len(self.events)
+        n_events = len(self)
         for idx in range(n_events):
             epoch = self._get_epoch_from_disk(idx)
             if self._is_good_epoch(epoch):
@@ -246,7 +246,12 @@ class Epochs(object):
     def _get_epoch_from_disk(self, idx):
         """Load one epoch from disk"""
         sfreq = self.raw.info['sfreq']
-        event_samp = self.events[idx, 0]
+
+        if self.events.ndim == 1:
+            #single event
+            event_samp = self.events[0]
+        else:
+            event_samp = self.events[idx, 0]
 
         # Read a data segment
         first_samp = self.raw.first_samp
@@ -268,15 +273,15 @@ class Epochs(object):
         """
         n_channels = len(self.ch_names)
         n_times = len(self.times)
-        n_events = len(self.events)
+        n_events = len(self)
         data = np.empty((n_events, n_channels, n_times))
         cnt = 0
         n_reject = 0
         event_idx = []
         for k in range(n_events):
-            e = self._get_epoch_from_disk(k)
-            if self._is_good_epoch(e):
-                data[cnt] = self._get_epoch_from_disk(k)
+            epoch = self._get_epoch_from_disk(k)
+            if self._is_good_epoch(epoch):
+                data[cnt] = epoch
                 event_idx.append(k)
                 cnt += 1
             else:
@@ -342,7 +347,7 @@ class Epochs(object):
             epoch = self._data[self._current]
             self._current += 1
         else:
-            if self._current >= len(self.events):
+            if self._current >= len(self):
                 raise StopIteration
             epoch = self._get_epoch_from_disk(self._current)
             self._current += 1
@@ -353,9 +358,9 @@ class Epochs(object):
 
     def __repr__(self):
         if not self.bad_dropped:
-            s = "n_events : %s (good & bad)" % len(self.events)
+            s = "n_events : %s (good & bad)" % len(self)
         else:
-            s = "n_events : %s (all good)" % len(self.events)
+            s = "n_events : %s (all good)" % len(self)
         s += ", tmin : %s (s)" % self.tmin
         s += ", tmax : %s (s)" % self.tmax
         s += ", baseline : %s" % str(self.baseline)
@@ -364,38 +369,35 @@ class Epochs(object):
     def __len__(self):
         """Return length (number of events)
         """
-        return len(self.events)
+        if self.events.ndim == 1:
+            return 1
+        else:
+            return len(self.events)
 
-    def __getitem__(self, index):
-        """Return epoch at index or an Epochs object with a slice of epochs
+    def __getitem__(self, key):
+        """Return an Epochs object with a subset of epochs
         """
-        if isinstance(index, slice):
-            # return Epochs object with slice of epochs
-            if not self.bad_dropped:
-                    warnings.warn("Bad epochs have not been dropped, indexing "
-                                  "will be inccurate. Use drop_bad_epochs() "
-                                  "or preload=True")
-
-            epoch_slice = copy.copy(self)
-            epoch_slice.events = self.events[index]
-
-            if self.preload:
-                epoch_slice._data = self._data[index]
+        print key
+        if not self.bad_dropped:
+                warnings.warn("Bad epochs have not been dropped, indexing "
+                              "will be inccurate. Use drop_bad_epochs() "
+                              "or preload=True")
 
-            return epoch_slice
+        epochs = copy.copy(self)
+        epochs.events = self.events[key]
 
-        # return single epoch as 2D array
         if self.preload:
-            epoch = epoch = self._data[index]
-        else:
-            epoch = self._get_epoch_from_disk(index)
-
-            if not self._is_good_epoch(epoch):
-                warnings.warn("Bad epoch with index %d returned. "
-                              "Use drop_bad_epochs() or preload=True "
-                              "to prevent this." % (index))
+            if isinstance(key, slice):
+                epochs._data = self._data[key]
+            else:
+                #make sure data remains a 3D array
+                n_channels = len(self.ch_names)
+                n_times = len(self.times)
+                data = np.empty((1, n_channels, n_times))
+                data[0, :, :] = self._data[key]
+                epochs._data = data
 
-        return epoch
+        return epochs
 
     def average(self):
         """Compute average of epochs
@@ -409,7 +411,7 @@ class Epochs(object):
         evoked.info = copy.deepcopy(self.info)
         n_channels = len(self.ch_names)
         n_times = len(self.times)
-        n_events = len(self.events)
+        n_events = len(self)
         if self.preload:
             data = np.mean(self._data, axis=0)
         else:
diff --git a/mne/tests/test_epochs.py b/mne/tests/test_epochs.py
index 545d91a..120cf87 100644
--- a/mne/tests/test_epochs.py
+++ b/mne/tests/test_epochs.py
@@ -96,10 +96,9 @@ def test_indexing_slicing():
         if not preload:
             epochs2.drop_bad_epochs()
 
-        # get slice
+        # using slicing
         epochs2_sliced = epochs2[start_index:end_index]
 
-        # using get_data()
         data_epochs2_sliced = epochs2_sliced.get_data()
         assert_array_equal(data_epochs2_sliced, \
                            data_normal[start_index:end_index])
@@ -107,7 +106,8 @@ def test_indexing_slicing():
         # using indexing
         pos = 0
         for idx in range(start_index, end_index):
-            assert_array_equal(epochs2_sliced[pos], data_normal[idx])
+            data = epochs2_sliced[pos].get_data()
+            assert_array_equal(data[0], data_normal[idx])
             pos += 1
 
 

-- 
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-med/python-mne.git



More information about the debian-med-commit mailing list