[med-svn] [python-mne] 273/353: ENH : big cleanup of simulation code + new plot function for sparse stc

Yaroslav Halchenko debian at onerussian.com
Fri Nov 27 17:25:14 UTC 2015


This is an automated email from the git hooks/post-receive script.

yoh pushed a commit to tag 0.4
in repository python-mne.

commit eea3d6290276dd95eec9f3b812c260da98e4aa45
Author: Alexandre Gramfort <alexandre.gramfort at inria.fr>
Date:   Mon Jul 16 15:24:40 2012 +0200

    ENH : big cleanup of simulation code + new plot function for sparse stc
---
 examples/plot_simulate_evoked_data.py |  85 ++++++++++++++++
 mne/simulation/__init__.py            |   4 +
 mne/simulation/sim_evoked.py          | 178 ++++++++++------------------------
 mne/time_frequency/__init__.py        |   2 +-
 mne/time_frequency/ar.py              |  30 ++++++
 mne/utils.py                          |  18 ++++
 mne/viz.py                            | 155 +++++++++++++++++++++++++++++
 7 files changed, 342 insertions(+), 130 deletions(-)

diff --git a/examples/plot_simulate_evoked_data.py b/examples/plot_simulate_evoked_data.py
new file mode 100644
index 0000000..cf0f3ee
--- /dev/null
+++ b/examples/plot_simulate_evoked_data.py
@@ -0,0 +1,85 @@
+"""
+==============================
+Generate simulated evoked data
+==============================
+
+"""
+# Author: Daniel Strohmeier <daniel.strohmeier at tu-ilmenau.de>
+#         Alexandre Gramfort <gramfort at nmr.mgh.harvard.edu>
+#
+# License: BSD (3-clause)
+
+import numpy as np
+import pylab as pl
+
+import mne
+from mne.fiff.pick import pick_types_evoked, pick_types_forward
+from mne.forward import apply_forward
+from mne.datasets import sample
+from mne.time_frequency import fir_filter_raw
+from mne.viz import plot_evoked, plot_sparse_source_estimates
+from mne.simulation.sim_evoked import source_signal, generate_stc, generate_noise_evoked, add_noise
+
+###############################################################################
+# Load real data as templates
+data_path = sample.data_path('.')
+
+raw = mne.fiff.Raw(data_path + '/MEG/sample/sample_audvis_raw.fif')
+proj = mne.read_proj(data_path + '/MEG/sample/ecg_proj.fif')
+raw.info['projs'] += proj
+raw.info['bads'] = ['MEG 2443', 'EEG 053']  # mark bad channels
+
+fwd_fname = data_path + '/MEG/sample/sample_audvis-meg-eeg-oct-6-fwd.fif'
+ave_fname = data_path + '/MEG/sample/sample_audvis-no-filter-ave.fif'
+cov_fname = data_path + '/MEG/sample/sample_audvis-cov.fif'
+
+fwd = mne.read_forward_solution(fwd_fname, force_fixed=True, surf_ori=True)
+fwd = pick_types_forward(fwd, meg=True, eeg=True, exclude=raw.info['bads'])
+
+noise_cov = mne.read_cov(cov_fname)
+
+evoked_template = mne.fiff.read_evoked(ave_fname, setno=0, baseline=None)
+evoked_template = pick_types_evoked(evoked_template, meg=True, eeg=True,
+                                    exclude=raw.info['bads'])
+
+tmin = -0.1
+sfreq = 1000  # Hz
+tstep = 1. / sfreq
+n_samples = 300
+timesamples = np.linspace(tmin, tmin + n_samples * tstep, n_samples)
+
+label_names = ['Aud-lh', 'Aud-rh']
+labels = [mne.read_label(data_path + '/MEG/sample/labels/%s.label' % ln)
+          for ln in label_names]
+
+mus = [[0.030, 0.060, 0.120], [0.040, 0.060, 0.140]]
+sigmas = [[0.01, 0.02, 0.03], [0.01, 0.02, 0.03]]
+amps = [[40 * 1e-9, 40 * 1e-9, 30 * 1e-9], [30 * 1e-9, 40 * 1e-9, 40 * 1e-9]]
+freqs = [[0, 0, 0], [0, 0, 0]]
+phis = [[0, 0, 0], [0, 0, 0]]
+
+SNR = 6
+dB = True
+
+stc_data = source_signal(mus, sigmas, amps, freqs, phis, timesamples)
+stc = generate_stc(fwd, labels, stc_data, tmin, tstep, random_state=0)
+evoked = apply_forward(fwd, stc, evoked_template)
+
+###############################################################################
+# Add noise
+picks = mne.fiff.pick_types(raw.info, meg=True)
+fir_filter = fir_filter_raw(raw, order=5, picks=picks, tmin=60, tmax=180)
+noise = generate_noise_evoked(evoked, noise_cov, n_samples, fir_filter)
+
+evoked_noise = add_noise(evoked, noise, SNR, timesamples, tmin=0.0, tmax=0.2, dB=dB)
+
+###############################################################################
+# Plot
+plot_sparse_source_estimates(fwd['src'], stc, bgcolor=(1, 1, 1),
+                                opacity=0.5, high_resolution=True)
+
+pl.figure()
+pl.psd(evoked_noise.data[0])
+
+pl.figure()
+plot_evoked(evoked)
diff --git a/mne/simulation/__init__.py b/mne/simulation/__init__.py
new file mode 100644
index 0000000..918184b
--- /dev/null
+++ b/mne/simulation/__init__.py
@@ -0,0 +1,4 @@
+"""Data simulation code
+"""
+
+from .sim_evoked import select_source_in_label, generate_stc
\ No newline at end of file
diff --git a/mne/simulation/sim_evoked.py b/mne/simulation/sim_evoked.py
index b01a1a9..97e4b72 100644
--- a/mne/simulation/sim_evoked.py
+++ b/mne/simulation/sim_evoked.py
@@ -1,19 +1,16 @@
-import pdb
-import copy
+# Authors: Alexandre Gramfort <gramfort at nmr.mgh.harvard.edu>
+#          Matti Hamalainen <msh at nmr.mgh.harvard.edu>
+#
+# License: BSD (3-clause)
 
 import numpy as np
-import pylab as pl
-
 from scipy import signal
 
-import mne
-from mne.fiff.pick import pick_types_evoked, pick_types_forward, pick_channels_cov
-from mne.forward import apply_forward
-from mne.label import read_label
-from mne.datasets import sample
-from mne.minimum_norm.inverse import _make_stc
-from mne.viz import plot_evoked, plot_sparse_source_estimates
-from mne.time_frequency import ar_raw
+import copy
+
+from ..fiff.pick import pick_channels_cov
+from ..minimum_norm.inverse import _make_stc
+from ..utils import check_random_state
 
 
 def gaboratomr(timesamples, sigma, mu, k, phase):
@@ -67,69 +64,41 @@ def source_signal(mus, sigmas, amps, freqs, phis, timesamples):
     signal : array
         simulated source signal
     """
-    signal = np.zeros(len(timesamples))
-    for m, s, a, f, p in zip(mus, sigmas, amps, freqs, phis):
-        signal += gaboratomr(timesamples, s, m, f, p) * a
-    return signal
-
-
-def generate_fir_from_raw(raw, picks, order, tmin, tmax, proj=None):
-    """Fits an AR model to raw data and creates FIR filter
-
-    Parameters
-    ----------
-    raw : Raw object
-        an instance of Raw
-    picks : array of int
-        indices of selected channels
-    order : int
-        order of the FIR filter
-    tmin : float
-        start time before event
-    tmax : float
-        end time after event
-    projs : None | list
-        The list of projection vectors
-
-    Returns
-    -------
-    FIR : array
-        filter coefficients
-    """
-    if proj is not None:
-        raw.info['projs'] += proj
-    picks = picks[:5]
-    coefs = ar_raw(raw, order=order, picks=picks, tmin=tmin, tmax=tmax)
-    mean_coefs = np.mean(coefs, axis=0)  # mean model accross channels
-    FIR = np.r_[1, -mean_coefs]  # filter coefficient
-    return FIR
+    data = np.zeros((len(mus), len(timesamples)))
+    for k in range(len(mus)):
+        for m, s, a, f, p in zip(mus[k], sigmas[k], amps[k], freqs[k], phis[k]):
+            data[k] += gaboratomr(timesamples, s, m, f, p) * a
+    return data
 
 
-def generate_noise(noise, noise_cov, nsamp, FIR=None):
-    """Creates noise as a multivariate random process
-    with specified cov matrix. No deepcopy of noise applied
+def generate_noise_evoked(evoked, noise_cov, n_samples, fir_filter=None, random_state=None):
+    """Creates noise as a multivariate random process with specified cov matrix.
 
     Parameters
     ----------
-    noise : evoked object
-        an instance of evoked
+    evoked : evoked object
+        an instance of evoked used as template
     noise_cov : cov object
         an instance of cov
-    nsamp : int
-        number of samples to generate
-    FIR : None | array
+    n_samples : int
+        number of time samples to generate
+    fir_filter : None | array
         FIR filter coefficients
+    random_state : None | int | np.random.RandomState
+        To specify the random generator state.
 
     Returns
     -------
     noise : evoked object
         an instance of evoked
     """
-    noise_cov = pick_channels_cov(noise_cov, include=noise_template.info['ch_names'])
-    rng = np.random.RandomState(0)
-    noise.data = rng.multivariate_normal(np.zeros(noise.info['nchan']), noise_cov.data, nsamp).T
-    if FIR is not None:
-        noise.data = signal.lfilter([1], FIR, noise.data, axis=-1)
+    noise = copy.deepcopy(evoked)
+    noise_cov = pick_channels_cov(noise_cov, include=noise.info['ch_names'])
+    rng = check_random_state(random_state)
+    n_channels = np.zeros(noise.info['nchan'])
+    noise.data = rng.multivariate_normal(n_channels, noise_cov.data, n_samples).T
+    if fir_filter is not None:
+        noise.data = signal.lfilter([1], fir_filter, noise.data, axis=-1)
     return noise
 
 
@@ -165,26 +134,28 @@ def add_noise(evoked, noise, SNR, timesamples, tmin=None, tmax=None, dB=False):
         tmax = np.max(timesamples)
     tmask = (timesamples >= tmin) & (timesamples <= tmax)
     if dB:
-        SNRtemp = 20 * np.log10(np.sqrt(np.mean((evoked.data[:,tmask] ** 2).ravel()) / \
+        SNRtemp = 20 * np.log10(np.sqrt(np.mean((evoked.data[:, tmask] ** 2).ravel()) / \
                                          np.mean((noise.data ** 2).ravel())))
         noise.data = 10 ** ((SNRtemp - float(SNR)) / 20) * noise.data
     else:
-        SNRtemp = np.sqrt(np.mean((evoked.data[:,tmask] ** 2).ravel()) / \
+        SNRtemp = np.sqrt(np.mean((evoked.data[:, tmask] ** 2).ravel()) / \
                                          np.mean((noise.data ** 2).ravel()))
         noise.data = SNRtemp / SNR * noise.data
     evoked.data += noise.data
     return evoked
 
 
-def select_source_idxs(fwd, label_fname):
+def select_source_in_label(fwd, label, random_state=None):
     """Select source positions using a label
 
     Parameters
     ----------
     fwd : dict
         a forward solution
-    label_fname : str
-        filename of the freesurfer label to read
+    label : dict
+        the label (read with mne.read_label)
+    random_state : None | int | np.random.RandomState
+        To specify the random generator state.
 
     Returns
     -------
@@ -196,10 +167,9 @@ def select_source_idxs(fwd, label_fname):
     lh_vertno = list()
     rh_vertno = list()
 
-    label = read_label(label_fname)
-    rng = np.random.RandomState(0)
+    rng = check_random_state(random_state)
 
-    if label['hemi']=='lh':
+    if label['hemi'] == 'lh':
         src_sel_lh = np.intersect1d(fwd['src'][0]['vertno'], label['vertices'])
         idx_select = rng.randint(0, len(src_sel_lh), 1)
         lh_vertno.append(src_sel_lh[idx_select][0])
@@ -211,63 +181,13 @@ def select_source_idxs(fwd, label_fname):
     return lh_vertno, rh_vertno
 
 
-## load data_sets from mne-sample-data ##
-data_path = sample.data_path('.')
-
-fwd_fname = data_path + '/MEG/sample/sample_audvis-meg-eeg-oct-6-fwd.fif'
-fwd = mne.read_forward_solution(fwd_fname, force_fixed=True, surf_ori=True)
-exclude = ['MEG 2443', 'EEG 053']
-meg_include = True
-eeg_include = True
-fwd = pick_types_forward(fwd, meg=meg_include, eeg=eeg_include, exclude=exclude)
-
-cov_fname = data_path + '/MEG/sample/sample_audvis-cov.fif'
-noise_cov = mne.read_cov(cov_fname)
-
-tmin = -0.1
-#sfreq
-tstep = 0.001
-n_samples = 300
-timesamples = np.linspace(tmin, tmin + n_samples * tstep, n_samples)
-
-label = ['Aud-lh', 'Aud-rh']
-amps = [[40 * 1e-9, 40 * 1e-9, 30 * 1e-9], [30 * 1e-9, 40 * 1e-9, 40 * 1e-9]]
-mus = [[0.030, 0.060, 0.120], [0.040, 0.060, 0.140]]
-sigmas = [[0.01, 0.02, 0.03], [0.01, 0.02, 0.03]]
-freqs = [[0, 0, 0], [0, 0, 0]]
-phis = [[0, 0, 0], [0, 0, 0]]
-
-SNR = 6
-dB = True
-
-signals = list()
-vertno = [[], []]
-for k in range(len(label)):
-    label_fname = data_path + '/MEG/sample/labels/%s.label' % label[k]
-    lh_vertno, rh_vertno = select_source_idxs(fwd, label_fname)
-    vertno[0] += lh_vertno
-    vertno[1] += rh_vertno
-    signals.append(source_signal(mus[k], sigmas[k], amps[k], freqs[k], phis[k], timesamples))
-signals = np.vstack(signals)
-stc = _make_stc(signals, tmin, tstep, vertno)
-plot_sparse_source_estimates(fwd['src'], stc, bgcolor=(1, 1, 1),
-                                opacity=0.5, high_resolution=True)
-
-ave_fname = data_path + '/MEG/sample/sample_audvis-no-filter-ave.fif'
-evoked_template = mne.fiff.read_evoked(ave_fname, setno=0, baseline=None)
-evoked_template = pick_types_evoked(evoked_template, meg=meg_include, eeg=eeg_include, exclude=exclude)
-evoked = apply_forward(fwd, stc, evoked_template, start=None, stop=None)
-
-noise_template = copy.deepcopy(evoked_template)
-raw = mne.fiff.Raw(data_path + '/MEG/sample/sample_audvis_raw.fif')
-proj = mne.read_proj(data_path + '/MEG/sample/ecg_proj.fif')
-raw.info['projs'] += proj
-raw.info['bads'] = ['MEG 2443', 'EEG 053']  # mark bad channels
-picks = mne.fiff.pick_types(raw.info, meg=True)
-FIR = generate_fir_from_raw(raw, picks, 5, tmin=60, tmax=180, proj=proj)
-noise = generate_noise(noise_template, noise_cov, n_samples, FIR=FIR)
-pl.figure()
-pl.psd(noise.data[0])
-evoked = add_noise(evoked, noise, SNR, timesamples, tmin=0.0, tmax=0.2, dB=dB)
-pl.figure()
-plot_evoked(evoked)
+def generate_stc(fwd, labels, stc_data, tmin, tstep, random_state=0):
+    rng = check_random_state(random_state)
+    vertno = [[], []]
+    for label in labels:
+        lh_vertno, rh_vertno = select_source_in_label(fwd, label, rng)
+        vertno[0] += lh_vertno
+        vertno[1] += rh_vertno
+    vertno = map(np.array, vertno)
+    stc = _make_stc(stc_data, tmin, tstep, vertno)
+    return stc
diff --git a/mne/time_frequency/__init__.py b/mne/time_frequency/__init__.py
index 4826123..be88845 100644
--- a/mne/time_frequency/__init__.py
+++ b/mne/time_frequency/__init__.py
@@ -3,4 +3,4 @@
 
 from .tfr import induced_power, single_trial_power
 from .psd import compute_raw_psd
-from .ar import yule_walker, ar_raw
+from .ar import yule_walker, ar_raw, fir_filter_raw
diff --git a/mne/time_frequency/ar.py b/mne/time_frequency/ar.py
index 3daafc1..17e7a5a 100644
--- a/mne/time_frequency/ar.py
+++ b/mne/time_frequency/ar.py
@@ -109,3 +109,33 @@ def ar_raw(raw, order, picks, tmin=None, tmax=None):
         this_coefs, _ = yule_walker(d, order=order)
         coefs[k, :] = this_coefs
     return coefs
+
+
+def fir_filter_raw(raw, order, picks, tmin=None, tmax=None):
+    """Fits an AR model to raw data and creates corresponding FIR filter
+
+    The returned filter is the average filter for all the picked channels.
+
+    Parameters
+    ----------
+    raw : Raw object
+        an instance of Raw
+    order : int
+        order of the FIR filter
+    picks : array of int
+        indices of selected channels
+    tmin : float
+        The beginning of time interval in seconds.
+    tmax : float
+        The end of time interval in seconds.
+
+    Returns
+    -------
+    fir : array
+        filter coefficients
+    """
+    picks = picks[:5]
+    coefs = ar_raw(raw, order=order, picks=picks, tmin=tmin, tmax=tmax)
+    mean_coefs = np.mean(coefs, axis=0)  # mean model accross channels
+    fir = np.r_[1, -mean_coefs]  # filter coefficient
+    return fir
diff --git a/mne/utils.py b/mne/utils.py
index 20f3832..187de8e 100644
--- a/mne/utils.py
+++ b/mne/utils.py
@@ -247,3 +247,21 @@ try:
     from scipy.signal import firwin2
 except ImportError:
     firwin2 = _firwin2
+
+
+def check_random_state(seed):
+    """Turn seed into a np.random.RandomState instance
+
+    If seed is None, return the RandomState singleton used by np.random.
+    If seed is an int, return a new RandomState instance seeded with seed.
+    If seed is already a RandomState instance, return it.
+    Otherwise raise ValueError.
+    """
+    if seed is None or seed is np.random:
+        return np.random.mtrand._rand
+    if isinstance(seed, (int, np.integer)):
+        return np.random.RandomState(seed)
+    if isinstance(seed, np.random.RandomState):
+        return seed
+    raise ValueError('%r cannot be used to seed a numpy.random.RandomState'
+                     ' instance' % seed)
diff --git a/mne/viz.py b/mne/viz.py
index 2933a87..5273620 100644
--- a/mne/viz.py
+++ b/mne/viz.py
@@ -5,6 +5,7 @@
 #
 # License: Simplified BSD
 
+from itertools import cycle
 import copy
 import numpy as np
 from scipy import linalg
@@ -85,6 +86,160 @@ def plot_evoked(evoked, picks=None, unit=True, show=True):
         pl.show()
 
 
+COLORS = ['b', 'g', 'r', 'c', 'm', 'y', 'k', '#473C8B', '#458B74',
+          '#CD7F32', '#FF4040', '#ADFF2F', '#8E2323', '#FF1493']
+
+
+def plot_sparse_source_estimates(src, stcs, colors=None, linewidth=2,
+                                fontsize=18, bgcolor=(.05, 0, .1), opacity=0.2,
+                                brain_color=(0.7, ) * 3, show=True,
+                                high_resolution=False, fig_name=None,
+                                fig_number=None, labels=None,
+                                modes=['cone', 'sphere'],
+                                scale_factors=[1, 0.6],
+                                **kwargs):
+    """Plot source estimates obtained with sparse solver
+
+    Active dipoles are represented in a "Glass" brain.
+    If the same source is active in multiple source estimates it is
+    displayed with a sphere otherwise with a cone in 3D.
+
+    Parameters
+    ----------
+    src: dict
+        The source space
+    stcs: instance of SourceEstimate or list of instances of SourceEstimate
+        The source estimates (up to 3)
+    colors: list
+        List of colors
+    linewidth: int
+        Line width in 2D plot
+    fontsize: int
+        Font size
+    bgcolor: tuple of length 3
+        Back ground color in 3D
+    opacity: float in [0, 1]
+        Opacity of brain mesh
+    brain_color: tuple of length 3
+        Brain color
+    show: bool
+        Show figures if True
+    fig_name:
+        Mayavi figure name
+    fig_number:
+        Pylab figure number
+    labels: ndarray or list of ndarrays
+        Labels to show sources in clusters. Sources with the same
+        label and the waveforms within each cluster are presented in
+        the same color. labels should be a list of ndarrays when
+        stcs is a list ie. one label for each stc.
+    kwargs: kwargs
+        kwargs pass to mlab.triangular_mesh
+    """
+    if not isinstance(stcs, list):
+        stcs = [stcs]
+    if labels is not None and not isinstance(labels, list):
+        labels = [labels]
+
+    if colors is None:
+        colors = COLORS
+
+    linestyles = ['-', '--', ':']
+
+    # Show 3D
+    lh_points = src[0]['rr']
+    rh_points = src[1]['rr']
+    points = np.r_[lh_points, rh_points]
+
+    lh_normals = src[0]['nn']
+    rh_normals = src[1]['nn']
+    normals = np.r_[lh_normals, rh_normals]
+
+    if high_resolution:
+        use_lh_faces = src[0]['tris']
+        use_rh_faces = src[1]['tris']
+    else:
+        use_lh_faces = src[0]['use_tris']
+        use_rh_faces = src[1]['use_tris']
+
+    use_faces = np.r_[use_lh_faces, lh_points.shape[0] + use_rh_faces]
+
+    points *= 170
+
+    vertnos = [np.r_[stc.lh_vertno, lh_points.shape[0] + stc.rh_vertno]
+               for stc in stcs]
+    unique_vertnos = np.unique(np.concatenate(vertnos).ravel())
+
+    try:
+        from mayavi import mlab
+    except ImportError:
+        from enthought.mayavi import mlab
+
+    from matplotlib.colors import ColorConverter
+    color_converter = ColorConverter()
+
+    f = mlab.figure(figure=fig_name, bgcolor=bgcolor, size=(800, 800))
+    mlab.clf()
+    f.scene.disable_render = True
+    surface = mlab.triangular_mesh(points[:, 0], points[:, 1], points[:, 2],
+                            use_faces, color=brain_color, opacity=opacity,
+                            **kwargs)
+
+    import pylab as pl
+    # Show time courses
+    pl.figure(fig_number)
+    pl.clf()
+
+    colors = cycle(colors)
+
+    print "Total number of active sources: %d" % len(unique_vertnos)
+
+    if labels is not None:
+        colors = [colors.next() for _ in
+                        range(np.unique(np.concatenate(labels).ravel()).size)]
+
+    for v in unique_vertnos:
+        # get indices of stcs it belongs to
+        ind = [k for k, vertno in enumerate(vertnos) if v in vertno]
+        is_common = len(ind) > 1
+
+        if labels is None:
+            c = colors.next()
+        else:
+            # if vertex is in different stcs than take label from first one
+            c = colors[labels[ind[0]][vertnos[ind[0]] == v]]
+
+        mode = modes[1] if is_common else modes[0]
+        scale_factor = scale_factors[1] if is_common else scale_factors[0]
+        x, y, z = points[v]
+        nx, ny, nz = normals[v]
+        mlab.quiver3d(x, y, z, nx, ny, nz, color=color_converter.to_rgb(c),
+                      mode=mode, scale_factor=scale_factor)
+
+        for k in ind:
+            vertno = vertnos[k]
+            mask = (vertno == v)
+            assert np.sum(mask) == 1
+            linestyle = linestyles[k]
+            pl.plot(1e3 * stc.times, 1e9 * stcs[k].data[mask].ravel(), c=c,
+                    linewidth=linewidth, linestyle=linestyle)
+
+    pl.xlabel('Time (ms)', fontsize=18)
+    pl.ylabel('Source amplitude (nAm)', fontsize=18)
+
+    if fig_name is not None:
+        pl.title(fig_name)
+
+    if show:
+        pl.show()
+        mlab.show()
+
+    surface.actor.property.backface_culling = True
+    surface.actor.property.shading = True
+
+    return surface
+
+
 def plot_cov(cov, info, exclude=[], colorbar=True, proj=False, show_svd=True,
              show=True):
     """Plot Covariance data

-- 
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-med/python-mne.git



More information about the debian-med-commit mailing list