[med-svn] [python-mne] 307/376: fix preload inconsistency due to bad epochs

Yaroslav Halchenko debian at onerussian.com
Fri Nov 27 17:23:10 UTC 2015


This is an automated email from the git hooks/post-receive script.

yoh pushed a commit to annotated tag v0.1
in repository python-mne.

commit 0bf4b4af08cd2ac6aa67f889639222cd76c7c222
Author: Alexandre Gramfort <alexandre.gramfort at inria.fr>
Date:   Mon Jun 27 11:55:25 2011 -0400

    fix preload inconsistency due to bad epochs
---
 examples/time_frequency/plot_time_frequency.py |  8 ++--
 mne/epochs.py                                  | 55 +++++++++++++-------------
 mne/tests/test_epochs.py                       | 47 +++++++++++-----------
 3 files changed, 57 insertions(+), 53 deletions(-)

diff --git a/examples/time_frequency/plot_time_frequency.py b/examples/time_frequency/plot_time_frequency.py
index 21dcbff..601b16e 100644
--- a/examples/time_frequency/plot_time_frequency.py
+++ b/examples/time_frequency/plot_time_frequency.py
@@ -56,7 +56,7 @@ evoked_data = evoked_data[97:98, :]
 frequencies = np.arange(7, 30, 3)  # define frequencies of interest
 Fs = raw.info['sfreq']  # sampling in Hz
 power, phase_lock = induced_power(data, Fs=Fs, frequencies=frequencies,
-                                   n_cycles=2, n_jobs=1, use_fft=False)
+                                  n_cycles=2, n_jobs=1, use_fft=False)
 
 power /= np.mean(power[:, :, times < 0], axis=2)[:, :, None]  # baseline ratio
 
@@ -67,7 +67,7 @@ pl.clf()
 pl.subplots_adjust(0.1, 0.08, 0.96, 0.94, 0.2, 0.63)
 pl.subplot(3, 1, 1)
 pl.plot(times, evoked_data.T)
-pl.title('Evoked response (%s)' % raw.info['ch_names'][picks[0]])
+pl.title('Evoked response (%s)' % evoked.ch_names[97])
 pl.xlabel('time (ms)')
 pl.ylabel('Magnetic Field (fT/cm)')
 pl.xlim(times[0], times[-1])
@@ -79,7 +79,7 @@ pl.imshow(20 * np.log10(power[0]), extent=[times[0], times[-1],
           aspect='auto', origin='lower')
 pl.xlabel('Time (s)')
 pl.ylabel('Frequency (Hz)')
-pl.title('Induced power (%s)' % raw.info['ch_names'][picks[0]])
+pl.title('Induced power (%s)' % evoked.ch_names[97])
 pl.colorbar()
 
 pl.subplot(3, 1, 3)
@@ -88,6 +88,6 @@ pl.imshow(phase_lock[0], extent=[times[0], times[-1],
           aspect='auto', origin='lower')
 pl.xlabel('Time (s)')
 pl.ylabel('Frequency (Hz)')
-pl.title('Phase-lock (%s)' % raw.info['ch_names'][picks[0]])
+pl.title('Phase-lock (%s)' % evoked.ch_names[97])
 pl.colorbar()
 pl.show()
diff --git a/mne/epochs.py b/mne/epochs.py
index 924deca..e62a7ec 100644
--- a/mne/epochs.py
+++ b/mne/epochs.py
@@ -72,9 +72,6 @@ class Epochs(object):
 
     Methods
     -------
-    get_epoch(i) : self
-        Return the ith epoch as a 2D array [n_channels x n_times].
-
     get_data() : self
         Return all epochs as a 3D array [n_epochs x n_channels x n_times].
 
@@ -194,19 +191,6 @@ class Epochs(object):
         if self.preload:
             self._data = self._data[:, idx, :]
 
-    def get_epoch(self, idx):
-        """Load one epoch
-
-        Returns
-        -------
-        data : array of shape [n_channels, n_times]
-            One epoch data
-        """
-        if self.preload:
-            return self._data[idx]
-        else:
-            return self._get_epoch_from_disk(idx)
-
     def _get_epoch_from_disk(self, idx):
         """Load one epoch from disk"""
         sfreq = self.raw.info['sfreq']
@@ -238,16 +222,26 @@ class Epochs(object):
         n_reject = 0
         for k in range(n_events):
             e = self._get_epoch_from_disk(k)
-            if ((self.reject is not None or self.flat is not None) and
-                not _is_good(e, self.ch_names, self._channel_type_idx,
-                         self.reject, self.flat)) or e.shape[1] < n_times:
-                n_reject += 1
-            else:
+            if self._is_good_epoch(e):
                 data[cnt] = self._get_epoch_from_disk(k)
                 cnt += 1
+            else:
+                n_reject += 1
         print "Rejecting %d epochs." % n_reject
         return data[:cnt]
 
+    def _is_good_epoch(self, data):
+        """Determine is epoch is good
+        """
+        n_times = len(self.times)
+        if self.reject is None and self.flat is None:
+            return True
+        elif data.shape[1] < n_times:
+            return False  # epoch is too short ie at the end of the data
+        else:
+            return _is_good(data, self.ch_names, self._channel_type_idx,
+                            self.reject, self.flat)
+
     def get_data(self):
         """Get all epochs as a 3D array
 
@@ -285,14 +279,21 @@ class Epochs(object):
         return self
 
     def next(self):
-        """To iteration over epochs easy.
+        """To make iteration over epochs easy.
         """
-        if self._current >= len(self.events):
-            raise StopIteration
-
-        epoch = self.get_epoch(self._current)
+        if self.preload:
+            if self._current >= len(self._data):
+                raise StopIteration
+            epoch = self._data[self._current]
+            self._current += 1
+        else:
+            if self._current >= len(self.events):
+                raise StopIteration
+            epoch = self._get_epoch_from_disk(self._current)
+            self._current += 1
+            if not self._is_good_epoch(epoch):
+                return self.next()
 
-        self._current += 1
         return epoch
 
     def __repr__(self):
diff --git a/mne/tests/test_epochs.py b/mne/tests/test_epochs.py
index 325c65c..eda6db6 100644
--- a/mne/tests/test_epochs.py
+++ b/mne/tests/test_epochs.py
@@ -3,6 +3,7 @@
 # License: BSD (3-clause)
 
 import os.path as op
+from numpy.testing import assert_array_equal
 
 import mne
 from mne import fiff
@@ -12,19 +13,18 @@ raw_fname = op.join(op.dirname(__file__), '..', 'fiff', 'tests', 'data',
 event_name = op.join(op.dirname(__file__), '..', 'fiff', 'tests', 'data',
                      'test-eve.fif')
 
+event_id, tmin, tmax = 1, -0.2, 0.5
+raw = fiff.Raw(raw_fname)
+events = mne.read_events(event_name)
+picks = fiff.pick_types(raw.info, meg=True, eeg=True, stim=False,
+                        eog=True, include=['STI 014'])
+
+reject = dict(grad=1000e-12, mag=4e-12, eeg=80e-6, eog=150e-6)
+flat = dict(grad=1e-15, mag=1e-15)
 
 def test_read_epochs():
     """Reading epochs from raw files
     """
-    event_id = 1
-    tmin = -0.2
-    tmax = 0.5
-
-    # Setup for reading the raw data
-    raw = fiff.Raw(raw_fname)
-    events = mne.read_events(event_name)
-    picks = fiff.pick_types(raw.info, meg=True, eeg=False, stim=False,
-                            eog=True, include=['STI 014'])
     epochs = mne.Epochs(raw, events, event_id, tmin, tmax, picks=picks,
                         baseline=(None, 0))
     epochs.average()
@@ -40,21 +40,9 @@ def test_read_epochs():
 def test_reject_epochs():
     """Test of epochs rejection
     """
-    event_id = 1
-    tmin = -0.2
-    tmax = 0.5
-
-    # Setup for reading the raw data
-    raw = fiff.Raw(raw_fname)
-    events = mne.read_events(event_name)
-
-    picks = fiff.pick_types(raw.info, meg=True, eeg=True, stim=True,
-                            eog=True, include=['STI 014'])
     epochs = mne.Epochs(raw, events, event_id, tmin, tmax, picks=picks,
                         baseline=(None, 0),
-                        reject=dict(grad=1000e-12, mag=4e-12, eeg=80e-6,
-                                    eog=150e-6),
-                        flat=dict(grad=1e-15, mag=1e-15))
+                        reject=reject, flat=flat)
     data = epochs.get_data()
     n_events = len(epochs.events)
     n_clean_epochs = len(data)
@@ -63,3 +51,18 @@ def test_reject_epochs():
     #   --saveavetag -ave --ave test.ave --filteroff
     assert n_events > n_clean_epochs
     assert n_clean_epochs == 3
+
+
+def test_preload_epochs():
+    """Test of epochs rejection
+    """
+    epochs = mne.Epochs(raw, events[:12], event_id, tmin, tmax, picks=picks,
+                        baseline=(None, 0), preload=True,
+                        reject=reject, flat=flat)
+    data_preload = epochs.get_data()
+
+    epochs = mne.Epochs(raw, events[:12], event_id, tmin, tmax, picks=picks,
+                        baseline=(None, 0), preload=False,
+                        reject=reject, flat=flat)
+    data_no_preload = epochs.get_data()
+    assert_array_equal(data_preload, data_no_preload)

-- 
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-med/python-mne.git



More information about the debian-med-commit mailing list