[med-svn] [python-mne] 76/353: rem. include, exclude from apply_forward functions, added pick_types_forward

Yaroslav Halchenko debian at onerussian.com
Fri Nov 27 17:24:34 UTC 2015


This is an automated email from the git hooks/post-receive script.

yoh pushed a commit to tag 0.4
in repository python-mne.

commit ffda0023578227c041c4173497e873d506e667fe
Author: Martin Luessi <mluessi at nmr.mgh.harvard.edu>
Date:   Tue Feb 7 13:45:29 2012 -0500

    rem. include, exclude from apply_forward functions, added pick_types_forward
---
 mne/fiff/__init__.py      |  3 +-
 mne/fiff/pick.py          | 35 +++++++++++++++++++++-
 mne/forward.py            | 74 +++++++++++++++++++++++++----------------------
 mne/tests/test_forward.py | 15 ++++------
 4 files changed, 82 insertions(+), 45 deletions(-)

diff --git a/mne/fiff/__init__.py b/mne/fiff/__init__.py
index cd881ad..7d5b666 100644
--- a/mne/fiff/__init__.py
+++ b/mne/fiff/__init__.py
@@ -11,7 +11,8 @@ from .evoked import Evoked, read_evoked, write_evoked
 from .raw import Raw, read_raw_segment, read_raw_segment_times, \
                  start_writing_raw, write_raw_buffer, finish_writing_raw
 from .pick import pick_types, pick_channels, pick_types_evoked, \
-                  pick_channels_regexp
+                  pick_channels_regexp, pick_channels_forward, \
+                  pick_types_forward
 
 from .compensator import get_current_comp
 from .proj import compute_spatial_vectors
diff --git a/mne/fiff/pick.py b/mne/fiff/pick.py
index 6d33b48..2352470 100644
--- a/mne/fiff/pick.py
+++ b/mne/fiff/pick.py
@@ -315,7 +315,7 @@ def pick_channels_forward(orig, include=[], exclude=[]):
     Returns
     -------
     res : dict
-        Evoked data restricted to selected channels. If include and
+        Forward solution restricted to selected channels. If include and
         exclude are None it returns orig without copy.
     """
 
@@ -351,6 +351,39 @@ def pick_channels_forward(orig, include=[], exclude=[]):
     return fwd
 
 
+def pick_types_forward(orig, meg=True, eeg=False, include=[], exclude=[]):
+    """Pick by channel type and names from a forward operator
+
+    Parameters
+    ----------
+    orig : dict
+        A forward solution
+    meg : bool or string
+        If True include all MEG channels. If False include None
+        If string it can be 'mag' or 'grad' to select only gradiometers
+        or magnetometers. It can also be 'ref_meg' to get CTF
+        reference channels.
+    eeg : bool
+        If True include EEG channels
+    include : list of string
+        List of additional channels to include. If empty do not include any.
+
+    exclude : list of string
+        List of channels to exclude. If empty do not include any.
+
+    Returns
+    -------
+    res : dict
+        Forward solution restricted to selected channel types.
+    """
+    info = {'ch_names': orig['sol']['row_names'], 'chs': orig['chs'],
+            'nchan': orig['nchan']}
+    sel = pick_types(info, meg, eeg, include=include, exclude=exclude)
+    include_ch_names = [info['ch_names'][k] for k in sel]
+
+    return pick_channels_forward(orig, include_ch_names)
+
+
 def channel_indices_by_type(info):
     """Get indices of channels by type
     """
diff --git a/mne/forward.py b/mne/forward.py
index 6094255..e037ba4 100644
--- a/mne/forward.py
+++ b/mne/forward.py
@@ -16,7 +16,7 @@ from .fiff.open import fiff_open
 from .fiff.tree import dir_tree_find
 from .fiff.tag import find_tag, read_tag
 from .fiff.matrix import _read_named_matrix, _transpose_named_matrix
-from .fiff.pick import pick_channels_forward
+from .fiff.pick import pick_channels_forward, pick_info, pick_channels
 
 from .source_space import read_source_spaces_from_tree, find_source_space_hemi
 from .transforms import transform_source_space_to, invert_transform
@@ -465,14 +465,12 @@ def _stc_src_sel(src, stc):
     return src_sel
 
 
-def _fill_measurement_info(info, fwd, ch_names, sfreq):
+def _fill_measurement_info(info, fwd, sfreq):
     """ Fill the measurement info of a Raw or Evoked object
     """
+    sel = pick_channels(info['ch_names'], fwd['sol']['row_names'])
+    info = pick_info(info, sel)
     info['bads'] = []
-    info['ch_names'] = ch_names
-    info['chs'] = [deepcopy(ch) for ch in fwd['chs'] if ch['ch_name'] in
-                   ch_names]
-    info['nchan'] = len(info['chs'])
 
     info['filename'] = None
     info['meas_id'] = None  #XXX is this the right thing to do?
@@ -488,21 +486,21 @@ def _fill_measurement_info(info, fwd, ch_names, sfreq):
     info['sfreq'] = np.array(sfreq, dtype=np.float32)
     info['projs'] = []
 
+    return info
 
-def _apply_forward(fwd, stc, start=None, stop=None, include=[], exclude=[]):
+
+def _apply_forward(fwd, stc, start=None, stop=None):
     """ Apply forward model and return data, times, ch_names
     """
     if fwd['source_ori'] != FIFF.FIFFV_MNE_FIXED_ORI:
         raise ValueError('Only fixed-orientation forward operators are '
-                         'supported')
+                         'supported.')
 
     if np.all(stc.data > 0):
         warnings.warn('Source estimate only contains currents with positive '
                       'values. Use pick_normal=True when computing the '
                       'inverse to compute currents not current magnitudes.')
 
-    fwd = pick_channels_forward(fwd, include=include, exclude=exclude)
-
     src_sel = _stc_src_sel(fwd['src'], stc)
 
     gain = fwd['sol']['data'][:, src_sel]
@@ -512,19 +510,23 @@ def _apply_forward(fwd, stc, start=None, stop=None, include=[], exclude=[]):
     print '[done]'
 
     times = deepcopy(stc.times[start:stop])
-    ch_names = deepcopy(fwd['sol']['row_names'])
 
-    return data, times, ch_names
+    return data, times
 
 
-def apply_forward(fwd, stc, evoked_template, start=None, stop=None,
-                  include=[], exclude=[]):
+def apply_forward(fwd, stc, evoked_template, start=None, stop=None):
     """
     Project source space currents to sensor space using a forward operator.
 
+    The sensor space data is computed for all channels present in fwd. Use
+    pick_channels_forward or pick_types_forward to restrict the solution to a
+    subset of channels.
+
     The function returns an Evoked object, which is constructed from
     evoked_template. The evoked_template should be from the same MEG system on
-    which the original data was acquired.
+    which the original data was acquired. An exception will be raised if the
+    forward operator contains channels that are not present in the template.
+
 
     Parameters
     ----------
@@ -538,11 +540,6 @@ def apply_forward(fwd, stc, evoked_template, start=None, stop=None,
         Index of first time sample (index not time is seconds).
     stop: int, optional
         Index of first time sample not to include (index not time is seconds).
-    include: list, optional
-        List of names of channels to include in output. If empty all channels
-        are included.
-    exclude: list, optional
-        List of names of channels to exclude. If empty include all channels.
 
     Returns
     -------
@@ -554,9 +551,14 @@ def apply_forward(fwd, stc, evoked_template, start=None, stop=None,
     apply_forward_raw: Compute sensor space data and return a Raw object.
     """
 
+    # make sure evoked_template contains all channels in fwd
+    for ch_name in fwd['sol']['row_names']:
+        if ch_name not in evoked_template.ch_names:
+            raise ValueError('Channel %s of forward operator not present in '
+                             'evoked_template.' % ch_name)
+
     # project the source estimate to the sensor space
-    data, times, ch_names = _apply_forward(fwd, stc, start, stop,
-                                           include, exclude)
+    data, times = _apply_forward(fwd, stc, start, stop)
 
     # store sensor data in an Evoked object using the template
     evoked = deepcopy(evoked_template)
@@ -570,19 +572,23 @@ def apply_forward(fwd, stc, evoked_template, start=None, stop=None,
     evoked.last = evoked.first + evoked.data.shape[1] - 1
 
     # fill the measurement info
-    _fill_measurement_info(evoked.info, fwd, ch_names, sfreq)
+    evoked.info = _fill_measurement_info(evoked.info, fwd, sfreq)
 
     return evoked
 
 
-def apply_forward_raw(fwd, stc, raw_template, start=None, stop=None,
-                      include=[], exclude=[]):
+def apply_forward_raw(fwd, stc, raw_template, start=None, stop=None):
     """
     Project source space currents to sensor space using a forward operator.
 
+    The sensor space data is computed for all channels present in fwd. Use
+    pick_channels_forward or pick_types_forward to restrict the solution to a
+    subset of channels.
+
     The function returns a Raw object, which is constructed from raw_template.
     The raw_template should be from the same MEG system on which the original
-    data was acquired.
+    data was acquired. An exception will be raised if the forward operator
+    contains channels that are not present in the template.
 
     Parameters
     ----------
@@ -596,11 +602,6 @@ def apply_forward_raw(fwd, stc, raw_template, start=None, stop=None,
         Index of first time sample (index not time is seconds).
     stop: int, optional
         Index of first time sample not to include (index not time is seconds).
-    include: list, optional
-        List of names of channels to include in output. If empty all channels
-        are included.
-    exclude: list, optional
-        List of names of channels to exclude. If empty include all channels.
 
     Returns
     -------
@@ -612,9 +613,14 @@ def apply_forward_raw(fwd, stc, raw_template, start=None, stop=None,
     apply_forward: Compute sensor space data and return an Evoked object.
     """
 
+    # make sure raw_template contains all channels in fwd
+    for ch_name in fwd['sol']['row_names']:
+        if ch_name not in raw_template.ch_names:
+            raise ValueError('Channel %s of forward operator not present in '
+                             'raw_template.' % ch_name)
+
     # project the source estimate to the sensor space
-    data, times, ch_names = _apply_forward(fwd, stc, start, stop,
-                                           include, exclude)
+    data, times = _apply_forward(fwd, stc, start, stop)
 
     # store sensor data in Raw object using the template
     raw = deepcopy(raw_template)
@@ -628,6 +634,6 @@ def apply_forward_raw(fwd, stc, raw_template, start=None, stop=None,
     raw.last_samp = raw.first_samp + raw._data.shape[1] - 1
 
     # fill the measurement info
-    _fill_measurement_info(raw.info, fwd, ch_names, sfreq)
+    raw.info = _fill_measurement_info(raw.info, fwd, sfreq)
 
     return raw
diff --git a/mne/tests/test_forward.py b/mne/tests/test_forward.py
index 0151aac..e28d66d 100644
--- a/mne/tests/test_forward.py
+++ b/mne/tests/test_forward.py
@@ -4,7 +4,8 @@ import numpy as np
 from numpy.testing import assert_array_almost_equal, assert_equal
 
 from ..datasets import sample
-from ..fiff import Raw, Evoked, pick_channels
+from ..fiff import Raw, Evoked, pick_types, pick_types_forward, \
+                   pick_channels_forward
 from ..minimum_norm.inverse import _make_stc
 from .. import read_forward_solution, apply_forward, apply_forward_raw,\
                SourceEstimate
@@ -41,6 +42,7 @@ def test_apply_forward():
     t_start = 0.123
 
     fwd = read_forward_solution(fname, force_fixed=True)
+    fwd = pick_types_forward(fwd, meg=True)
 
     vertno = [fwd['src'][0]['vertno'], fwd['src'][1]['vertno']]
     stc_data = np.ones((len(vertno[0]) + len(vertno[1]), n_times))
@@ -53,10 +55,7 @@ def test_apply_forward():
     data = evoked.data
     times = evoked.times
 
-    sel = pick_channels(fwd['sol']['row_names'],
-                        include=evoked.info['ch_names'])
-
-    gain_sum = np.sum(fwd['sol']['data'][sel, :], axis=1)
+    gain_sum = np.sum(fwd['sol']['data'], axis=1)
 
     # do some tests
     assert_array_almost_equal(evoked.info['sfreq'], sfreq)
@@ -75,6 +74,7 @@ def test_apply_forward_raw():
     t_start = 0.123
 
     fwd = read_forward_solution(fname, force_fixed=True)
+    fwd = pick_types_forward(fwd, meg=True)
 
     vertno = [fwd['src'][0]['vertno'], fwd['src'][1]['vertno']]
     stc_data = np.ones((len(vertno[0]) + len(vertno[1]), n_times))
@@ -86,10 +86,7 @@ def test_apply_forward_raw():
 
     data, times = raw_proj[:, :]
 
-    sel = pick_channels(fwd['sol']['row_names'],
-                        include=raw_proj.info['ch_names'])
-
-    gain_sum = np.sum(fwd['sol']['data'][sel, :], axis=1)
+    gain_sum = np.sum(fwd['sol']['data'], axis=1)
 
     # do some tests
     assert_array_almost_equal(raw_proj.info['sfreq'], sfreq)

-- 
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-med/python-mne.git



More information about the debian-med-commit mailing list