[med-svn] [python-mne] 295/353: added comments on the pull request

Yaroslav Halchenko debian at onerussian.com
Fri Nov 27 17:25:19 UTC 2015


This is an automated email from the git hooks/post-receive script.

yoh pushed a commit to tag 0.4
in repository python-mne.

commit a7331a3e2677e4579582020898c29394d87ac089
Author: Daniel Strohmeier <daniel.strohmeier at googlemail.com>
Date:   Wed Jul 18 17:59:33 2012 +0200

    added comments on the pull request
---
 mne/epochs.py            | 67 +++++++++++++++++++++++++++++-------------------
 mne/tests/test_epochs.py | 47 ++++++++++++++++++---------------
 2 files changed, 66 insertions(+), 48 deletions(-)

diff --git a/mne/epochs.py b/mne/epochs.py
index fdc723d..f7199ef 100644
--- a/mne/epochs.py
+++ b/mne/epochs.py
@@ -1,17 +1,20 @@
 # Authors: Alexandre Gramfort <gramfort at nmr.mgh.harvard.edu>
 #          Matti Hamalainen <msh at nmr.mgh.harvard.edu>
+#          Daniel Strohmeier <daniel.strohmeier at tu-ilmenau.de>
 #
 # License: BSD (3-clause)
 
-import copy
+import copy as cp
+import warnings
+
 import numpy as np
+
 import fiff
-import warnings
 from .fiff import Evoked
 from .fiff.pick import pick_types, channel_indices_by_type
 from .fiff.proj import activate_proj, make_eeg_average_ref_proj
 from .baseline import rescale
-
+from .utils import check_random_state
 
 class Epochs(object):
     """List of Epochs
@@ -122,7 +125,7 @@ class Epochs(object):
         self._bad_dropped = False
 
         # Handle measurement info
-        self.info = copy.deepcopy(raw.info)
+        self.info = cp.deepcopy(raw.info)
         if picks is not None:
             self.info['chs'] = [self.info['chs'][k] for k in picks]
             self.info['ch_names'] = [self.info['ch_names'][k] for k in picks]
@@ -377,7 +380,7 @@ class Epochs(object):
             warnings.warn("Bad epochs have not been dropped, indexing will be "
                           "inaccurate. Use drop_bad_epochs() or preload=True")
 
-        epochs = copy.copy(self)  # XXX : should use deepcopy but breaks ...
+        epochs = cp.copy(self)  # XXX : should use deepcopy but breaks ...
         epochs.events = np.atleast_2d(self.events[key])
 
         if self.preload:
@@ -386,6 +389,8 @@ class Epochs(object):
             else:
                 if isinstance(key, list):
                     key = np.array(key)
+                    print key
+                    print np.ndim(key)
                 if np.ndim(key) == 0:
                     epochs._data = self._data[key][np.newaxis, :, :]
                 else:
@@ -407,7 +412,7 @@ class Epochs(object):
             The averaged epochs
         """
         evoked = Evoked(None)
-        evoked.info = copy.deepcopy(self.info)
+        evoked.info = cp.deepcopy(self.info)
         n_channels = len(self.ch_names)
         n_times = len(self.times)
         if self.preload:
@@ -444,7 +449,7 @@ class Epochs(object):
             evoked.data = evoked.data[data_picks]
         return evoked
     
-    def crop(self, tmin, tmax):
+    def crop(self, tmin=None, tmax=None, copy=False):
         """Crops a time interval from epochs object.
     
         Parameters
@@ -453,7 +458,9 @@ class Epochs(object):
             Start time of selection in seconds
         tmax : float
             End time of selection in seconds
-    
+        copy : bool
+            If False epochs is cropped in place
+
         Returns
         -------
         epochs : Epochs instance
@@ -463,19 +470,27 @@ class Epochs(object):
             raise RuntimeError('Modifying data of epochs is only supported '
                                 'when preloading is used. Use preload=True '
                                 'in the constructor.')
-        if tmin < self.tmin:
+        if tmin is None:
             tmin = self.tmin
-        if tmax > self.tmax:
+        elif tmin < self.tmin:
+            warnings.warn("tmin is not in epochs' time interval."
+                          "tmin is set to epochs.tmin")
+            tmin = self.tmin
+        if tmax is None:
+            tmax = self.tmax
+        elif tmax > self.tmax:
+            warnings.warn("tmax is not in epochs' time interval."
+                          "tmax is set to epochs.tmax")
             tmax = self.tmax
             
-        sfreq = self.info['sfreq']
-        first_samp = int((tmin - self.tmin) * sfreq)
-        last_samp = int((tmax - self.tmax) * sfreq) - 1
-        
-        self.tmin = tmin
-        self.tmax = tmax
-        self._data = self._data[:, :, first_samp:last_samp]
-        return self
+        tmask = (self.times >= tmin) & (self.times <= tmax)
+
+        this_epochs = self if not copy else cp.deepcopy(self)
+        this_epochs.tmin = tmin
+        this_epochs.tmax = tmax
+        this_epochs.times = this_epochs.times[tmask]
+        this_epochs._data = this_epochs._data[:, :, tmask]
+        return this_epochs
 
 
 def _is_good(e, ch_names, channel_type_idx, reject, flat):
@@ -514,15 +529,15 @@ def _is_good(e, ch_names, channel_type_idx, reject, flat):
     return True
 
 
-def bootstrap(epochs, rng, return_idx=False):
-    """Compute average of epochs selected by bootstrapping
+def bootstrap(epochs, random_state=None):
+    """Compute epochs selected by bootstrapping
 
     Parameters
     ----------
     epochs : Epochs instance
         epochs data to be bootstrapped
-    rng : 
-        random number generator.
+    random_state : None | int | np.random.RandomState
+        To specify the random generator state
     return_idx : bool
         If True the selected indices are provided as an output
 
@@ -536,11 +551,9 @@ def bootstrap(epochs, rng, return_idx=False):
                             'when preloading is used. Use preload=True '
                             'in the constructor.')
 
-    epochs_bootstrap = copy.deepcopy(epochs)
+    rng = check_random_state(random_state)
+    epochs_bootstrap = cp.deepcopy(epochs)
     n_events = len(epochs_bootstrap.events)
     idx = rng.randint(0, n_events, n_events)
     epochs_bootstrap = epochs_bootstrap[idx]
-    if return_idx:
-        return epochs_bootstrap, idx
-    else:
-        return epochs_bootstrap
+    return epochs_bootstrap
diff --git a/mne/tests/test_epochs.py b/mne/tests/test_epochs.py
index c751db3..a7f745b 100644
--- a/mne/tests/test_epochs.py
+++ b/mne/tests/test_epochs.py
@@ -137,22 +137,23 @@ def test_indexing_slicing():
             assert_array_equal(data[0], data_normal[idx])
             pos += 1
             
-        # using indexing with int
+        # using indexing with an int
         idx = np.random.randint(0, data_epochs2_sliced.shape[0], 1)
         data = epochs2[idx].get_data()
         assert_array_equal(data, data_normal[idx])
         
-        # using indexing with array
+        # using indexing with an array
         idx = np.random.randint(0, data_epochs2_sliced.shape[0], 10)
         data = epochs2[idx].get_data()
         assert_array_equal(data, data_normal[idx])
         
-        # using indexing with list of indices
-        #idx = list()
-        #for k in range(3):
-        #    idx.append(np.random.randint(0, data_epochs2_sliced.shape[0], 1))
-        #    data = epochs2[idx].get_data()
-        #    assert_array_equal(data, data_normal[idx])
+        # using indexing with a list of indices
+        idx = [0]
+        data = epochs2[idx].get_data()
+        assert_array_equal(data, data_normal[idx])
+        idx = [0, 1]
+        data = epochs2[idx].get_data()
+        assert_array_equal(data, data_normal[idx])
 
 
 def test_comparision_with_c():
@@ -175,33 +176,37 @@ def test_comparision_with_c():
 def test_crop():
     """Test of crop of epochs
     """
-    epochs = Epochs(raw, events[:20], event_id, tmin, tmax, picks=picks,
+    epochs = Epochs(raw, events[:5], event_id, tmin, tmax, picks=picks,
                     baseline=(None, 0), preload=False,
                     reject=reject, flat=flat)
-    epochs2 = Epochs(raw, events[:20], event_id, tmin, tmax,
+    data_normal = epochs.get_data()
+
+    epochs2 = Epochs(raw, events[:5], event_id, tmin, tmax,
                     picks=picks, baseline=(None, 0), preload=True,
                     reject=reject, flat=flat)
-    data_normal = epochs.get_data()
 
     # indices for slicing
     start_tsamp = tmin + 60 * epochs.info['sfreq']
     end_tsamp = tmax - 60 * epochs.info['sfreq']
     tmask = (epochs.times >= start_tsamp) & (epochs.times <= end_tsamp)
-    assert((start_tsamp) > tmin)
-    assert((end_tsamp) < tmax)
+    assert_true(start_tsamp > tmin)
+    assert_true(end_tsamp < tmax)
+    epochs3 = epochs2.crop(start_tsamp, end_tsamp, copy=True)
+    data3 = epochs3.get_data()
     epochs2.crop(start_tsamp, end_tsamp)
-    data = epochs2.get_data()
-    assert_array_equal(data, data_normal[:, :, tmask])
-    
+    data2 = epochs2.get_data()
+    assert_array_equal(data2, data_normal[:, :, tmask])
+    assert_array_equal(data3, data_normal[:, :, tmask])
+
 
 def test_bootstrap():
-    """Test of crop of epochs
+    """Test of bootstrapping of epochs
     """
-    epochs = Epochs(raw, events[:20], event_id, tmin, tmax, picks=picks,
+    epochs = Epochs(raw, events[:5], event_id, tmin, tmax, picks=picks,
                     baseline=(None, 0), preload=True,
                     reject=reject, flat=flat)
     data_normal = epochs._data
-    rng = np.random.RandomState(0)
-    epochs2, idx = bootstrap(epochs, rng, return_idx=True)
+    epochs2 = bootstrap(epochs, random_state=0)
     n_events = len(epochs.events)
-    assert_array_equal(epochs2._data, data_normal[idx])
+    assert_true(len(epochs2.events) == len(epochs.events))
+    assert_true(epochs._data.shape == epochs2._data.shape)

-- 
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-med/python-mne.git



More information about the debian-med-commit mailing list