[med-svn] [python-mne] 296/353: added atleast_1d support

Yaroslav Halchenko debian at onerussian.com
Fri Nov 27 17:25:19 UTC 2015


This is an automated email from the git hooks/post-receive script.

yoh pushed a commit to tag 0.4
in repository python-mne.

commit 5f0e95b33b09b442b9282d0b76d92ab37122b676
Author: Daniel Strohmeier <daniel.strohmeier at googlemail.com>
Date:   Thu Jul 19 10:00:51 2012 +0200

    added atleast_1d support
---
 mne/epochs.py            | 19 ++++++++-----------
 mne/tests/test_epochs.py | 19 +++++++++----------
 2 files changed, 17 insertions(+), 21 deletions(-)

diff --git a/mne/epochs.py b/mne/epochs.py
index f7199ef..7972576 100644
--- a/mne/epochs.py
+++ b/mne/epochs.py
@@ -387,14 +387,8 @@ class Epochs(object):
             if isinstance(key, slice):
                 epochs._data = self._data[key]
             else:
-                if isinstance(key, list):
-                    key = np.array(key)
-                    print key
-                    print np.ndim(key)
-                if np.ndim(key) == 0:
-                    epochs._data = self._data[key][np.newaxis, :, :]
-                else:
-                    epochs._data = self._data[key]
+                key = np.atleast_1d(key)
+                epochs._data = self._data[key]
         return epochs
 
     def average(self, keep_only_data_channels=True):
@@ -470,24 +464,27 @@ class Epochs(object):
             raise RuntimeError('Modifying data of epochs is only supported '
                                 'when preloading is used. Use preload=True '
                                 'in the constructor.')
+
         if tmin is None:
             tmin = self.tmin
         elif tmin < self.tmin:
             warnings.warn("tmin is not in epochs' time interval."
                           "tmin is set to epochs.tmin")
             tmin = self.tmin
+
         if tmax is None:
             tmax = self.tmax
         elif tmax > self.tmax:
             warnings.warn("tmax is not in epochs' time interval."
                           "tmax is set to epochs.tmax")
             tmax = self.tmax
-            
+
         tmask = (self.times >= tmin) & (self.times <= tmax)
+        tidx = np.where(tmask)[0]
 
         this_epochs = self if not copy else cp.deepcopy(self)
-        this_epochs.tmin = tmin
-        this_epochs.tmax = tmax
+        this_epochs.tmin = this_epochs.times[tidx[0]]
+        this_epochs.tmax = this_epochs.times[tidx[-1]]
         this_epochs.times = this_epochs.times[tmask]
         this_epochs._data = this_epochs._data[:, :, tmask]
         return this_epochs
diff --git a/mne/tests/test_epochs.py b/mne/tests/test_epochs.py
index a7f745b..9bc4eaa 100644
--- a/mne/tests/test_epochs.py
+++ b/mne/tests/test_epochs.py
@@ -138,9 +138,8 @@ def test_indexing_slicing():
             pos += 1
             
         # using indexing with an int
-        idx = np.random.randint(0, data_epochs2_sliced.shape[0], 1)
-        data = epochs2[idx].get_data()
-        assert_array_equal(data, data_normal[idx])
+        data = epochs2[data_epochs2_sliced.shape[0]].get_data()
+        assert_array_equal(data, data_normal[[idx]])
         
         # using indexing with an array
         idx = np.random.randint(0, data_epochs2_sliced.shape[0], 10)
@@ -186,14 +185,14 @@ def test_crop():
                     reject=reject, flat=flat)
 
     # indices for slicing
-    start_tsamp = tmin + 60 * epochs.info['sfreq']
-    end_tsamp = tmax - 60 * epochs.info['sfreq']
-    tmask = (epochs.times >= start_tsamp) & (epochs.times <= end_tsamp)
-    assert_true(start_tsamp > tmin)
-    assert_true(end_tsamp < tmax)
-    epochs3 = epochs2.crop(start_tsamp, end_tsamp, copy=True)
+    tmin_window = tmin + 0.1
+    tmax_window = tmax - 0.1
+    tmask = (epochs.times >= tmin_window) & (epochs.times <= tmax_window)
+    assert_true(tmin_window > tmin)
+    assert_true(tmax_window < tmax)
+    epochs3 = epochs2.crop(tmin_window, tmax_window, copy=True)
     data3 = epochs3.get_data()
-    epochs2.crop(start_tsamp, end_tsamp)
+    epochs2.crop(tmin_window, tmax_window)
     data2 = epochs2.get_data()
     assert_array_equal(data2, data_normal[:, :, tmask])
     assert_array_equal(data3, data_normal[:, :, tmask])

-- 
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-med/python-mne.git



More information about the debian-med-commit mailing list