[med-svn] [python-mne] 106/376: ENH : exposing epochs as an iterable object with lazy loading from disk

Yaroslav Halchenko debian at onerussian.com
Fri Nov 27 17:22:17 UTC 2015


This is an automated email from the git hooks/post-receive script.

yoh pushed a commit to annotated tag v0.1
in repository python-mne.

commit 948cfadfd3472daef96d8a5c8aa3d5fe901bb283
Author: Alexandre Gramfort <alexandre.gramfort at inria.fr>
Date:   Wed Mar 2 22:28:23 2011 -0500

    ENH : exposing epochs as an iterable object with lazy loading from disk
---
 examples/plot_read_epochs.py                   |  30 ++--
 examples/time_frequency/plot_time_frequency.py |  17 +-
 mne/__init__.py                                |   2 +-
 mne/epochs.py                                  | 214 +++++++++++++++----------
 mne/tests/test_epochs.py                       |   5 +-
 mne/tests/test_tfr.py                          |  11 +-
 mne/tfr.py                                     |  10 +-
 7 files changed, 166 insertions(+), 123 deletions(-)

diff --git a/examples/plot_read_epochs.py b/examples/plot_read_epochs.py
index 48bfe93..66f7dea 100644
--- a/examples/plot_read_epochs.py
+++ b/examples/plot_read_epochs.py
@@ -15,9 +15,6 @@ for both MEG and EEG data by averaging all the epochs.
 
 print __doc__
 
-import os
-import numpy as np
-
 import mne
 from mne import fiff
 from mne.datasets import sample
@@ -39,32 +36,31 @@ events = mne.read_events(event_fname)
 include = [] # or stim channels ['STI 014']
 exclude = raw['info']['bads'] + ['MEG 2443', 'EEG 053'] # bads + 2 more
 
+# EEG
+eeg_picks = fiff.pick_types(raw['info'], meg=False, eeg=True, stim=False,
+                                            include=include, exclude=exclude)
+eeg_epochs = mne.Epochs(raw, events, event_id,
+                            tmin, tmax, picks=eeg_picks, baseline=(None, 0))
+eeg_evoked_data = eeg_epochs.get_data().mean(axis=0) # as 3D matrix and average
+
+
 # MEG Magnetometers
 meg_mag_picks = fiff.pick_types(raw['info'], meg='mag', eeg=False, stim=False,
                                             include=include, exclude=exclude)
-meg_mag_data, times, channel_names = mne.read_epochs(raw, events, event_id,
+meg_mag_epochs = mne.Epochs(raw, events, event_id,
                            tmin, tmax, picks=meg_mag_picks, baseline=(None, 0))
-meg_mag_epochs = np.array([d['epoch'] for d in meg_mag_data]) # as 3D matrix
-meg_mag_evoked_data = np.mean(meg_mag_epochs, axis=0) # compute evoked fields
+meg_mag_evoked_data = meg_mag_epochs.get_data().mean(axis=0)
 
 # MEG
 meg_grad_picks = fiff.pick_types(raw['info'], meg='grad', eeg=False,
                                 stim=False, include=include, exclude=exclude)
-meg_grad_data, times, channel_names = mne.read_epochs(raw, events, event_id,
+meg_grad_epochs = mne.Epochs(raw, events, event_id,
                         tmin, tmax, picks=meg_grad_picks, baseline=(None, 0))
-meg_grad_epochs = np.array([d['epoch'] for d in meg_grad_data]) # as 3D matrix
-meg_grad_evoked_data = np.mean(meg_grad_epochs, axis=0) # compute evoked fields
-
-# EEG
-eeg_picks = fiff.pick_types(raw['info'], meg=False, eeg=True, stim=False,
-                                            include=include, exclude=exclude)
-eeg_data, times, channel_names = mne.read_epochs(raw, events, event_id,
-                            tmin, tmax, picks=eeg_picks, baseline=(None, 0))
-eeg_epochs = np.array([d['epoch'] for d in eeg_data]) # as 3D matrix
-eeg_evoked_data = np.mean(eeg_epochs, axis=0) # compute evoked potentials
+meg_grad_evoked_data = meg_grad_epochs.get_data().mean(axis=0)
 
 ###############################################################################
 # View evoked response
+times = eeg_epochs.times
 import pylab as pl
 pl.clf()
 pl.subplot(3, 1, 1)
diff --git a/examples/time_frequency/plot_time_frequency.py b/examples/time_frequency/plot_time_frequency.py
index 9c4558b..0ba40a8 100644
--- a/examples/time_frequency/plot_time_frequency.py
+++ b/examples/time_frequency/plot_time_frequency.py
@@ -42,14 +42,17 @@ picks = fiff.pick_types(raw['info'], meg='grad', eeg=False,
                                 stim=False, include=include, exclude=exclude)
 
 picks = [picks[97]]
-data, times, channel_names = mne.read_epochs(raw, events, event_id,
-                                tmin, tmax, picks=picks, baseline=(None, 0))
-epochs = np.array([d['epoch'] for d in data]) # as 3D matrix
-evoked_data = np.mean(epochs, axis=0) # compute evoked fields
+epochs = mne.Epochs(raw, events, event_id,
+                    tmin, tmax, picks=picks, baseline=(None, 0))
+data = epochs.get_data() # as 3D matrix
+evoked_data = np.mean(data, axis=0) # compute evoked fields
+
+times = 1e3 * epochs.times # change unit to ms
+evoked_data *= 1e13 # change unit to fT / cm
 
 frequencies = np.arange(7, 30, 3) # define frequencies of interest
 Fs = raw['info']['sfreq'] # sampling in Hz
-power, phase_lock = time_frequency(epochs, Fs=Fs, frequencies=frequencies,
+power, phase_lock = time_frequency(data, Fs=Fs, frequencies=frequencies,
                                    n_cycles=2, n_jobs=1, use_fft=False)
 
 ###############################################################################
@@ -58,11 +61,11 @@ import pylab as pl
 pl.clf()
 pl.subplots_adjust(0.1, 0.08, 0.96, 0.94, 0.2, 0.63)
 pl.subplot(3, 1, 1)
-pl.plot(1e3 * times, 1e13 * evoked_data.T)
+pl.plot(times, evoked_data.T)
 pl.title('Evoked response (%s)' % raw['info']['ch_names'][picks[0]])
 pl.xlabel('time (ms)')
 pl.ylabel('Magnetic Field (fT/cm)')
-pl.xlim(1e3 * times[0], 1e3 * times[-1])
+pl.xlim(times[0], times[-1])
 pl.ylim(-150, 300)
 
 pl.subplot(3, 1, 2)
diff --git a/mne/__init__.py b/mne/__init__.py
index 9f2fef4..64e7010 100644
--- a/mne/__init__.py
+++ b/mne/__init__.py
@@ -6,7 +6,7 @@ from .forward import read_forward_solution
 from .stc import read_stc, write_stc
 from .bem_surfaces import read_bem_surfaces
 from .inverse import read_inverse_operator, compute_inverse
-from .epochs import read_epochs
+from .epochs import Epochs
 from .tfr import time_frequency
 from .label import label_time_courses, read_label
 import fiff
diff --git a/mne/epochs.py b/mne/epochs.py
index bf7c931..9e67db2 100644
--- a/mne/epochs.py
+++ b/mne/epochs.py
@@ -7,9 +7,8 @@ import numpy as np
 import fiff
 
 
-def read_epochs(raw, events, event_id, tmin, tmax, picks=None,
-                 keep_comp=False, dest_comp=0, baseline=None):
-    """Read epochs from a raw dataset
+class Epochs(object):
+    """List of Epochs
 
     Parameters
     ----------
@@ -40,91 +39,110 @@ def read_epochs(raw, events, event_id, tmin, tmax, picks=None,
         If baseline is equal ot (None, None) all the time
         interval is used.
 
-    Returns
-    -------
-    data : list of epochs
-        An epoch is a dict with key:
-            epoch    the epoch, channel by channel
-            event    event #
-            tmin     starting time in the raw data file (initial skip omitted)
-            tmax     ending stime in the raw data file (initial skip omitted)
-
-    times : array
-        The time points of the samples, in seconds
-
-    ch_names : list of strings
-        Names of the channels included
+    preload : boolean
+        Load all epochs from disk when creating the object
+        or wait before accessing each epoch (more memory
+        efficient but can be slower).
 
-    Notes
-    -----
-    NOTE 1: The purpose of this function is to demonstrate the raw data reading
-    routines. You may need to modify this for your purposes
-
-    NOTE 2: You need to run mne_process_raw once as
+    Methods
+    -------
+    get_epoch(i) : self
+        Return the ith epoch as a 2D array [n_channels x n_times].
 
-    mne_process_raw --raw <fname> --projoff
+    get_data() : self
+        Return all epochs as a 3D array [n_epochs x n_channels x n_times].
 
-    to create the fif-format event file (or open the file in mne_browse_raw).
     """
 
-    ch_names = [raw['info']['ch_names'][k] for k in picks]
-    sfreq = raw['info']['sfreq']
-
-    #   Set up projection
-    if raw['info']['projs'] is None:
-        print 'No projector specified for these data'
-        raw['proj'] = []
-    else:
-        #   Activate the projection items
-        for proj in raw['info']['projs']:
-            proj['active'] = True
-
-        print '%d projection items activated' % len(raw['info']['projs'])
-
-        #   Create the projector
-        proj, nproj = fiff.proj.make_projector_info(raw['info'])
-        if nproj == 0:
-            print 'The projection vectors do not apply to these channels'
-            raw['proj'] = None
+    def __init__(self, raw, events, event_id, tmin, tmax,
+                picks=None, keep_comp=False,
+                dest_comp=0, baseline=(None, 0),
+                preload=True):
+        self.raw = raw
+        self.event_id = event_id
+        self.tmin = tmin
+        self.tmax = tmax
+        self.picks = picks
+        self.keep_comp = keep_comp
+        self.dest_comp = dest_comp
+        self.baseline = baseline
+        self.preload = preload
+
+        if picks is None:
+            picks = range(len(raw['info']['ch_names']))
+            self.ch_names = raw['info']['ch_names']
         else:
-            print 'Created an SSP operator (subspace dimension = %d)' % nproj
-            raw['proj'] = proj
+            self.ch_names = [raw['info']['ch_names'][k] for k in picks]
 
-    #   Set up the CTF compensator
-    current_comp = fiff.get_current_comp(raw['info'])
-    if current_comp > 0:
-        print 'Current compensation grade : %d' % current_comp
+        #   Set up projection
+        if raw['info']['projs'] is None:
+            print 'No projector specified for these data'
+            raw['proj'] = []
+        else:
+            #   Activate the projection items
+            for proj in raw['info']['projs']:
+                proj['active'] = True
 
-    if keep_comp:
-        dest_comp = current_comp
+            print '%d projection items activated' % len(raw['info']['projs'])
 
-    if current_comp != dest_comp:
-        raw.comp = fiff.raw.make_compensator(raw['info'], current_comp,
-                                             dest_comp)
-        print 'Appropriate compensator added to change to grade %d.' % (
+            #   Create the projector
+            proj, nproj = fiff.proj.make_projector_info(raw['info'])
+            if nproj == 0:
+                print 'The projection vectors do not apply to these channels'
+                raw['proj'] = None
+            else:
+                print ('Created an SSP operator (subspace dimension = %d)'
+                                                                    % nproj)
+                raw['proj'] = proj
+
+        #   Set up the CTF compensator
+        current_comp = fiff.get_current_comp(raw['info'])
+        if current_comp > 0:
+            print 'Current compensation grade : %d' % current_comp
+
+        if keep_comp:
+            dest_comp = current_comp
+
+        if current_comp != dest_comp:
+            raw.comp = fiff.raw.make_compensator(raw['info'], current_comp,
+                                                 dest_comp)
+            print 'Appropriate compensator added to change to grade %d.' % (
                                                                     dest_comp)
 
-    #    Select the desired events
-    selected = np.logical_and(events[:, 1] == 0, events[:, 2] == event_id)
-    n_events = np.sum(selected)
-    if n_events > 0:
-        print '%d matching events found' % n_events
-    else:
-        raise ValueError, 'No desired events found.'
+        #    Select the desired events
+        selected = np.logical_and(events[:, 1] == 0, events[:, 2] == event_id)
+        self.events = events[selected]
+        n_events = len(self.events)
+
+        if n_events > 0:
+            print '%d matching events found' % n_events
+        else:
+            raise ValueError, 'No desired events found.'
+
+        # Handle times
+        sfreq = raw['info']['sfreq']
+        self.times = np.arange(int(tmin*sfreq), int(tmax*sfreq),
+                          dtype=np.float) / sfreq
+
+        if self.preload:
+            self._data = self._get_data()
 
-    data = list()
+    def __len__(self):
+        return len(self.events)
 
-    for p, event_samp in enumerate(events[selected, 0]):
-        #       Read a data segment
-        start = int(event_samp + tmin*sfreq)
-        stop = int(event_samp + tmax*sfreq)
-        epoch, _ = raw[picks, start:stop]
+    def get_epoch(self, idx):
+        """Load one epoch from disk"""
+        sfreq = self.raw['info']['sfreq']
+        event_samp = self.events[idx, 0]
 
-        if p == 0:
-            times = np.arange(start - event_samp, stop - event_samp,
-                              dtype=np.float) / sfreq
+        # Read a data segment
+        start = int(event_samp + self.tmin*sfreq)
+        stop = start + len(self.times)
+        epoch, _ = self.raw[self.picks, start:stop]
 
         # Run baseline correction
+        times = self.times
+        baseline = self.baseline
         if baseline is not None:
             print "Applying baseline correction ..."
             bmin = baseline[0]
@@ -137,23 +155,49 @@ def read_epochs(raw, events, event_id, tmin, tmax, picks=None,
                 imax = len(times)
             else:
                 imax = int(np.where(times <= bmax)[0][-1]) + 1
-            epoch -= np.mean(epoch[:, imin:imax], axis=1)[:,None]
+            epoch -= np.mean(epoch[:, imin:imax], axis=1)[:, None]
         else:
             print "No baseline correction applied..."
 
-        d = dict()
-        d['epoch'] = epoch
-        d['event'] = event_id
-        d['tmin'] = (float(start) - float(raw['first_samp'])) / sfreq
-        d['tmax'] = (float(stop) - float(raw['first_samp'])) / sfreq
-        data.append(d)
+        return epoch
+
+    def _get_data(self):
+        """Load all data from disk
+        """
+        n_channels = len(self.ch_names)
+        n_times = len(self.times)
+        n_events = len(self.events)
+        data = np.empty((n_events, n_channels, n_times))
+        for k, e in enumerate(self):
+            data[k] = e
+        return data
+
+    def get_data(self):
+        """Get all epochs as a 3D array
+
+        Returns
+        -------
+        data : array of shape [n_epochs, n_channels, n_times]
+            The epochs data
+        """
+        if self.preload:
+            return self._data
+        else:
+            return self._get_data()
 
+    def __iter__(self):
+        """To iteration over epochs easy.
+        """
+        self._current = 0
+        return self
 
-    print 'Read %d epochs, %d samples each.' % (len(data),
-                                                data[0]['epoch'].shape[1])
+    def next(self):
+        """To iteration over epochs easy.
+        """
+        if self._current >= len(self.events):
+            raise StopIteration
 
-    #   Remember to close the file descriptor
-    # raw['fid'].close()
-    # print 'File closed.'
+        epoch = self.get_epoch(self._current)
 
-    return data, times, ch_names
+        self._current += 1
+        return epoch
diff --git a/mne/tests/test_epochs.py b/mne/tests/test_epochs.py
index af5ee6b..2be1da0 100644
--- a/mne/tests/test_epochs.py
+++ b/mne/tests/test_epochs.py
@@ -31,6 +31,5 @@ def test_read_epochs():
     picks = fiff.pick_types(raw['info'], want_meg, want_eeg, want_stim,
                             include, raw['info']['bads'])
 
-    data, times, channel_names = mne.read_epochs(raw, events, event_id,
-                                                    tmin, tmax, picks=picks,
-                                                    baseline=(None, 0))
+    epochs = mne.Epochs(raw, events, event_id, tmin, tmax, picks=picks,
+                        baseline=(None, 0))
diff --git a/mne/tests/test_tfr.py b/mne/tests/test_tfr.py
index 9d923ac..d9ca5e4 100644
--- a/mne/tests/test_tfr.py
+++ b/mne/tests/test_tfr.py
@@ -31,13 +31,14 @@ def test_time_frequency():
                                     stim=False, include=include, exclude=exclude)
 
     picks = picks[:2]
-    data, times, channel_names = mne.read_epochs(raw, events, event_id,
+    epochs = mne.read_epochs(raw, events, event_id,
                                     tmin, tmax, picks=picks, baseline=(None, 0))
-    epochs = np.array([d['epoch'] for d in data]) # as 3D matrix
+    data = epochs.get_data()
+    times = epochs.times
 
     frequencies = np.arange(6, 20, 5) # define frequencies of interest
     Fs = raw['info']['sfreq'] # sampling in Hz
-    power, phase_lock = time_frequency(epochs, Fs=Fs, frequencies=frequencies,
+    power, phase_lock = time_frequency(data, Fs=Fs, frequencies=frequencies,
                                        n_cycles=2, use_fft=True)
 
     assert power.shape == (len(picks), len(frequencies), len(times))
@@ -45,7 +46,7 @@ def test_time_frequency():
     assert np.sum(phase_lock >= 1) == 0
     assert np.sum(phase_lock <= 0) == 0
 
-    power, phase_lock = time_frequency(epochs, Fs=Fs, frequencies=frequencies,
+    power, phase_lock = time_frequency(data, Fs=Fs, frequencies=frequencies,
                                        n_cycles=2, use_fft=False)
 
     assert power.shape == (len(picks), len(frequencies), len(times))
@@ -53,5 +54,5 @@ def test_time_frequency():
     assert np.sum(phase_lock >= 1) == 0
     assert np.sum(phase_lock <= 0) == 0
 
-    tfr = cwt_morlet(epochs[0], Fs, frequencies, use_fft=True, n_cycles=2)
+    tfr = cwt_morlet(data[0], Fs, frequencies, use_fft=True, n_cycles=2)
     assert tfr.shape == (len(picks), len(frequencies), len(times))
diff --git a/mne/tfr.py b/mne/tfr.py
index 6a61cb4..7e32542 100644
--- a/mne/tfr.py
+++ b/mne/tfr.py
@@ -190,7 +190,7 @@ def _time_frequency(X, Ws, use_fft):
     return psd, plf
 
 
-def time_frequency(epochs, Fs, frequencies, use_fft=True, n_cycles=25,
+def time_frequency(data, Fs, frequencies, use_fft=True, n_cycles=25,
                    n_jobs=1):
     """Compute time induced power and inter-trial phase-locking factor
 
@@ -198,7 +198,7 @@ def time_frequency(epochs, Fs, frequencies, use_fft=True, n_cycles=25,
 
     Parameters
     ----------
-    epochs : array
+    data : array
         3D array of shape [n_epochs, n_channels, n_times]
 
     Fs : float
@@ -227,7 +227,7 @@ def time_frequency(epochs, Fs, frequencies, use_fft=True, n_cycles=25,
         Phase locking factor in [0, 1] (Channels x Frequencies x Timepoints)
     """
     n_frequencies = len(frequencies)
-    n_epochs, n_channels, n_times = epochs.shape
+    n_epochs, n_channels, n_times = data.shape
 
     # Precompute wavelets for given frequency range to save time
     Ws = morlet(Fs, frequencies, n_cycles=n_cycles)
@@ -243,14 +243,14 @@ def time_frequency(epochs, Fs, frequencies, use_fft=True, n_cycles=25,
         plf = np.empty((n_channels, n_frequencies, n_times), dtype=np.complex)
 
         for c in range(n_channels):
-            X = np.squeeze(epochs[:,c,:])
+            X = np.squeeze(data[:,c,:])
             psd[c], plf[c] = _time_frequency(X, Ws, use_fft)
 
     else:
         from joblib import Parallel, delayed
         psd_plf = Parallel(n_jobs=n_jobs)(
                     delayed(_time_frequency)(
-                            np.squeeze(epochs[:,c,:]), Ws, use_fft)
+                            np.squeeze(data[:,c,:]), Ws, use_fft)
                     for c in range(n_channels))
 
         psd = np.zeros((n_channels, n_frequencies, n_times))

-- 
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-med/python-mne.git



More information about the debian-med-commit mailing list