[med-svn] [python-mne] 104/376: ENH : refactoring time frequency for speed up in parallel settings

Yaroslav Halchenko debian at onerussian.com
Fri Nov 27 17:22:17 UTC 2015


This is an automated email from the git hooks/post-receive script.

yoh pushed a commit to annotated tag v0.1
in repository python-mne.

commit 47e17da9884d03ed1a66ed4c2e262f010a34280d
Author: Alexandre Gramfort <alexandre.gramfort at inria.fr>
Date:   Tue Mar 1 16:03:24 2011 -0500

    ENH : refactoring time frequency for speed up in parallel settings
---
 examples/time_frequency/plot_time_frequency.py |   2 +-
 mne/tests/test_tfr.py                          |  12 +-
 mne/tfr.py                                     | 152 ++++++++++++++-----------
 3 files changed, 93 insertions(+), 73 deletions(-)

diff --git a/examples/time_frequency/plot_time_frequency.py b/examples/time_frequency/plot_time_frequency.py
index 8498588..d113dc3 100644
--- a/examples/time_frequency/plot_time_frequency.py
+++ b/examples/time_frequency/plot_time_frequency.py
@@ -50,7 +50,7 @@ evoked_data = np.mean(epochs, axis=0) # compute evoked fields
 frequencies = np.arange(4, 30, 3) # define frequencies of interest
 Fs = raw['info']['sfreq'] # sampling in Hz
 power, phase_lock = time_frequency(epochs, Fs=Fs, frequencies=frequencies,
-                                   n_cycles=2)
+                                   n_cycles=2, n_jobs=1, use_fft=False)
 
 ###############################################################################
 # View time-frequency plots
diff --git a/mne/tests/test_tfr.py b/mne/tests/test_tfr.py
index e3f59fe..9d923ac 100644
--- a/mne/tests/test_tfr.py
+++ b/mne/tests/test_tfr.py
@@ -1,11 +1,10 @@
 import numpy as np
 import os.path as op
 
-from numpy.testing import assert_allclose
-
 import mne
 from mne import fiff
 from mne import time_frequency
+from mne.tfr import cwt_morlet
 
 raw_fname = op.join(op.dirname(__file__), '..', 'fiff', 'tests', 'data',
                 'test_raw.fif')
@@ -13,7 +12,7 @@ event_fname = op.join(op.dirname(__file__), '..', 'fiff', 'tests', 'data',
                 'test-eve.fif')
 
 def test_time_frequency():
-    """Test IO for STC files
+    """Test time frequency transform (PSD and phase lock)
     """
     # Set parameters
     event_id = 1
@@ -35,9 +34,8 @@ def test_time_frequency():
     data, times, channel_names = mne.read_epochs(raw, events, event_id,
                                     tmin, tmax, picks=picks, baseline=(None, 0))
     epochs = np.array([d['epoch'] for d in data]) # as 3D matrix
-    evoked_data = np.mean(epochs, axis=0) # compute evoked fields
 
-    frequencies = np.arange(4, 20, 5) # define frequencies of interest
+    frequencies = np.arange(6, 20, 5) # define frequencies of interest
     Fs = raw['info']['sfreq'] # sampling in Hz
     power, phase_lock = time_frequency(epochs, Fs=Fs, frequencies=frequencies,
                                        n_cycles=2, use_fft=True)
@@ -54,4 +52,6 @@ def test_time_frequency():
     assert power.shape == phase_lock.shape
     assert np.sum(phase_lock >= 1) == 0
     assert np.sum(phase_lock <= 0) == 0
-    
\ No newline at end of file
+
+    tfr = cwt_morlet(epochs[0], Fs, frequencies, use_fft=True, n_cycles=2)
+    assert tfr.shape == (len(picks), len(frequencies), len(times))
diff --git a/mne/tfr.py b/mne/tfr.py
index 05645a9..6a61cb4 100644
--- a/mne/tfr.py
+++ b/mne/tfr.py
@@ -69,71 +69,74 @@ def _centered(arr, newsize):
     return arr[tuple(myslice)]
 
 
-def _cwt_morlet_fft(x, Fs, freqs, mode="same", Ws=None):
+def _cwt_fft(X, Ws, mode="same"):
     """Compute cwt with fft based convolutions
+    Return a generator over signals.
     """
-    x = np.asarray(x)
-    freqs = np.asarray(freqs)
+    X = np.asarray(X)
 
     # Precompute wavelets for given frequency range to save time
-    n_samples = x.size
-    n_freqs = freqs.size
-
-    if Ws is None:
-        Ws = morlet(Fs, freqs)
+    n_signals, n_times = X.shape
+    n_freqs = len(Ws)
 
     Ws_max_size = max(W.size for W in Ws)
-    size = n_samples + Ws_max_size - 1
+    size = n_times + Ws_max_size - 1
     # Always use 2**n-sized FFT
     fsize = 2**np.ceil(np.log2(size))
-    fft_x = fftn(x, [fsize])
-
-    if mode == "full":
-        tfr = np.zeros((n_freqs, fsize), dtype=np.complex128)
-    elif mode == "same" or mode == "valid":
-        tfr = np.zeros((n_freqs, n_samples), dtype=np.complex128)
 
+    # precompute FFTs of Ws
+    fft_Ws = np.empty((n_freqs, fsize), dtype=np.complex128)
     for i, W in enumerate(Ws):
-        ret = ifftn(fft_x * fftn(W, [fsize]))[:n_samples + W.size - 1]
-        if mode == "valid":
-            sz = abs(W.size - n_samples) + 1
-            offset = (n_samples - sz) / 2
-            tfr[i, offset:(offset + sz)] = _centered(ret, sz)
-        else:
-            tfr[i] = _centered(ret, n_samples)
-    return tfr
-
-
-def _cwt_morlet_convolve(x, Fs, freqs, mode='same', Ws=None):
+        fft_Ws[i] = fftn(W, [fsize])
+
+    for k, x in enumerate(X):
+        if mode == "full":
+            tfr = np.zeros((n_freqs, fsize), dtype=np.complex128)
+        elif mode == "same" or mode == "valid":
+            tfr = np.zeros((n_freqs, n_times), dtype=np.complex128)
+
+        fft_x = fftn(x, [fsize])
+        for i, W in enumerate(Ws):
+            ret = ifftn(fft_x * fft_Ws[i])[:n_times + W.size - 1]
+            if mode == "valid":
+                sz = abs(W.size - n_times) + 1
+                offset = (n_times - sz) / 2
+                tfr[i, offset:(offset + sz)] = _centered(ret, sz)
+            else:
+                tfr[i, :] = _centered(ret, n_times)
+        yield tfr
+
+
+def _cwt_convolve(X, Ws, mode='same'):
     """Compute time freq decomposition with temporal convolutions
+    Return a generator over signals.
     """
-    x = np.asarray(x)
-    freqs = np.asarray(freqs)
+    X = np.asarray(X)
 
-    if Ws is None:
-        Ws = morlet(Fs, freqs)
+    n_signals, n_times = X.shape
+    n_freqs = len(Ws)
 
-    n_samples = x.size
     # Compute convolutions
-    tfr = np.zeros((freqs.size, len(x)), dtype=np.complex128)
-    for i, W in enumerate(Ws):
-        ret = np.convolve(x, W, mode=mode)
-        if mode == "valid":
-            sz = abs(W.size - n_samples) + 1
-            offset = (n_samples - sz) / 2
-            tfr[i, offset:(offset + sz)] = ret
-        else:
-            tfr[i] = ret
-    return tfr
-
-
-def cwt_morlet(x, Fs, freqs, use_fft=True, n_cycles=7.0):
+    for x in X:
+        tfr = np.zeros((n_freqs, n_times), dtype=np.complex128)
+        for i, W in enumerate(Ws):
+            ret = np.convolve(x, W, mode=mode)
+            if mode == "valid":
+                sz = abs(W.size - n_times) + 1
+                offset = (n_times - sz) / 2
+                tfr[i, offset:(offset + sz)] = ret
+            else:
+                tfr[i] = ret
+        yield tfr
+
+
+def cwt_morlet(X, Fs, freqs, use_fft=True, n_cycles=7.0):
     """Compute time freq decomposition with Morlet wavelets
 
     Parameters
     ----------
-    x : array
-        signal
+    X : array of shape [n_signals, n_times]
+        signals (one per line)
 
     Fs : float
         sampling Frequency
@@ -143,35 +146,48 @@ def cwt_morlet(x, Fs, freqs, use_fft=True, n_cycles=7.0):
 
     Returns
     -------
-    tfr : 2D array
-        Time Frequency Decomposition (Frequencies x Timepoints)
+    tfr : 3D array
+        Time Frequency Decompositions (n_signals x n_frequencies x n_times)
     """
     mode = 'same'
     # mode = "valid"
+    n_signals, n_times = X.shape
+    n_frequencies = len(freqs)
 
     # Precompute wavelets for given frequency range to save time
     Ws = morlet(Fs, freqs, n_cycles=n_cycles)
 
     if use_fft:
-        return _cwt_morlet_fft(x, Fs, freqs, mode, Ws)
+        coefs = _cwt_fft(X, Ws, mode)
     else:
-        return _cwt_morlet_convolve(x, Fs, freqs, mode, Ws)
+        coefs = _cwt_convolve(X, Ws, mode)
 
+    tfrs = np.empty((n_signals, n_frequencies, n_times))
+    for k, tfr in enumerate(coefs):
+        tfrs[k] = tfr
 
-def _time_frequency_one_channel(epochs, c, Fs, frequencies, use_fft, n_cycles):
-    """Aux of time_frequency for parallel computing"""
-    n_epochs, _, n_times = epochs.shape
-    n_frequencies = len(frequencies)
-    psd_c = np.zeros((n_frequencies, n_times)) # PSD
-    plf_c = np.zeros((n_frequencies, n_times), dtype=np.complex) # phase lock
+    return tfrs
 
-    for e in range(n_epochs):
-        tfr = cwt_morlet(epochs[e, c, :].ravel(), Fs, frequencies,
-                                  use_fft=use_fft, n_cycles=n_cycles)
+def _time_frequency(X, Ws, use_fft):
+    """Aux of time_frequency for parallel computing over channels
+    """
+    n_epochs, n_times = X.shape
+    n_frequencies = len(Ws)
+    psd = np.zeros((n_frequencies, n_times)) # PSD
+    plf = np.zeros((n_frequencies, n_times), dtype=np.complex) # phase lock
+
+    mode = 'same'
+    if use_fft:
+        tfrs = _cwt_fft(X, Ws, mode)
+    else:
+        tfrs = _cwt_convolve(X, Ws, mode)
+
+    for tfr in tfrs:
         tfr_abs = np.abs(tfr)
-        psd_c += tfr_abs**2
-        plf_c += tfr / tfr_abs
-    return psd_c, plf_c
+        psd += tfr_abs**2
+        plf += tfr / tfr_abs
+
+    return psd, plf
 
 
 def time_frequency(epochs, Fs, frequencies, use_fft=True, n_cycles=25,
@@ -213,6 +229,9 @@ def time_frequency(epochs, Fs, frequencies, use_fft=True, n_cycles=25,
     n_frequencies = len(frequencies)
     n_epochs, n_channels, n_times = epochs.shape
 
+    # Precompute wavelets for given frequency range to save time
+    Ws = morlet(Fs, frequencies, n_cycles=n_cycles)
+
     try:
         import joblib
     except ImportError:
@@ -224,13 +243,14 @@ def time_frequency(epochs, Fs, frequencies, use_fft=True, n_cycles=25,
         plf = np.empty((n_channels, n_frequencies, n_times), dtype=np.complex)
 
         for c in range(n_channels):
-            psd[c,:,:], plf[c,:,:] = _time_frequency_one_channel(epochs, c, Fs,
-                                                frequencies, use_fft, n_cycles)
+            X = np.squeeze(epochs[:,c,:])
+            psd[c], plf[c] = _time_frequency(X, Ws, use_fft)
+
     else:
         from joblib import Parallel, delayed
         psd_plf = Parallel(n_jobs=n_jobs)(
-                    delayed(_time_frequency_one_channel)(
-                            epochs, c, Fs, frequencies, use_fft, n_cycles)
+                    delayed(_time_frequency)(
+                            np.squeeze(epochs[:,c,:]), Ws, use_fft)
                     for c in range(n_channels))
 
         psd = np.zeros((n_channels, n_frequencies, n_times))

-- 
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-med/python-mne.git



More information about the debian-med-commit mailing list