[med-svn] [python-mne] 104/376: ENH : refactoring time frequency for speed up in parallel settings
Yaroslav Halchenko
debian at onerussian.com
Fri Nov 27 17:22:17 UTC 2015
This is an automated email from the git hooks/post-receive script.
yoh pushed a commit to annotated tag v0.1
in repository python-mne.
commit 47e17da9884d03ed1a66ed4c2e262f010a34280d
Author: Alexandre Gramfort <alexandre.gramfort at inria.fr>
Date: Tue Mar 1 16:03:24 2011 -0500
ENH : refactoring time frequency for speed up in parallel settings
---
examples/time_frequency/plot_time_frequency.py | 2 +-
mne/tests/test_tfr.py | 12 +-
mne/tfr.py | 152 ++++++++++++++-----------
3 files changed, 93 insertions(+), 73 deletions(-)
diff --git a/examples/time_frequency/plot_time_frequency.py b/examples/time_frequency/plot_time_frequency.py
index 8498588..d113dc3 100644
--- a/examples/time_frequency/plot_time_frequency.py
+++ b/examples/time_frequency/plot_time_frequency.py
@@ -50,7 +50,7 @@ evoked_data = np.mean(epochs, axis=0) # compute evoked fields
frequencies = np.arange(4, 30, 3) # define frequencies of interest
Fs = raw['info']['sfreq'] # sampling in Hz
power, phase_lock = time_frequency(epochs, Fs=Fs, frequencies=frequencies,
- n_cycles=2)
+ n_cycles=2, n_jobs=1, use_fft=False)
###############################################################################
# View time-frequency plots
diff --git a/mne/tests/test_tfr.py b/mne/tests/test_tfr.py
index e3f59fe..9d923ac 100644
--- a/mne/tests/test_tfr.py
+++ b/mne/tests/test_tfr.py
@@ -1,11 +1,10 @@
import numpy as np
import os.path as op
-from numpy.testing import assert_allclose
-
import mne
from mne import fiff
from mne import time_frequency
+from mne.tfr import cwt_morlet
raw_fname = op.join(op.dirname(__file__), '..', 'fiff', 'tests', 'data',
'test_raw.fif')
@@ -13,7 +12,7 @@ event_fname = op.join(op.dirname(__file__), '..', 'fiff', 'tests', 'data',
'test-eve.fif')
def test_time_frequency():
- """Test IO for STC files
+ """Test time frequency transform (PSD and phase lock)
"""
# Set parameters
event_id = 1
@@ -35,9 +34,8 @@ def test_time_frequency():
data, times, channel_names = mne.read_epochs(raw, events, event_id,
tmin, tmax, picks=picks, baseline=(None, 0))
epochs = np.array([d['epoch'] for d in data]) # as 3D matrix
- evoked_data = np.mean(epochs, axis=0) # compute evoked fields
- frequencies = np.arange(4, 20, 5) # define frequencies of interest
+ frequencies = np.arange(6, 20, 5) # define frequencies of interest
Fs = raw['info']['sfreq'] # sampling in Hz
power, phase_lock = time_frequency(epochs, Fs=Fs, frequencies=frequencies,
n_cycles=2, use_fft=True)
@@ -54,4 +52,6 @@ def test_time_frequency():
assert power.shape == phase_lock.shape
assert np.sum(phase_lock >= 1) == 0
assert np.sum(phase_lock <= 0) == 0
-
\ No newline at end of file
+
+ tfr = cwt_morlet(epochs[0], Fs, frequencies, use_fft=True, n_cycles=2)
+ assert tfr.shape == (len(picks), len(frequencies), len(times))
diff --git a/mne/tfr.py b/mne/tfr.py
index 05645a9..6a61cb4 100644
--- a/mne/tfr.py
+++ b/mne/tfr.py
@@ -69,71 +69,74 @@ def _centered(arr, newsize):
return arr[tuple(myslice)]
-def _cwt_morlet_fft(x, Fs, freqs, mode="same", Ws=None):
+def _cwt_fft(X, Ws, mode="same"):
"""Compute cwt with fft based convolutions
+ Return a generator over signals.
"""
- x = np.asarray(x)
- freqs = np.asarray(freqs)
+ X = np.asarray(X)
# Precompute wavelets for given frequency range to save time
- n_samples = x.size
- n_freqs = freqs.size
-
- if Ws is None:
- Ws = morlet(Fs, freqs)
+ n_signals, n_times = X.shape
+ n_freqs = len(Ws)
Ws_max_size = max(W.size for W in Ws)
- size = n_samples + Ws_max_size - 1
+ size = n_times + Ws_max_size - 1
# Always use 2**n-sized FFT
fsize = 2**np.ceil(np.log2(size))
- fft_x = fftn(x, [fsize])
-
- if mode == "full":
- tfr = np.zeros((n_freqs, fsize), dtype=np.complex128)
- elif mode == "same" or mode == "valid":
- tfr = np.zeros((n_freqs, n_samples), dtype=np.complex128)
+ # precompute FFTs of Ws
+ fft_Ws = np.empty((n_freqs, fsize), dtype=np.complex128)
for i, W in enumerate(Ws):
- ret = ifftn(fft_x * fftn(W, [fsize]))[:n_samples + W.size - 1]
- if mode == "valid":
- sz = abs(W.size - n_samples) + 1
- offset = (n_samples - sz) / 2
- tfr[i, offset:(offset + sz)] = _centered(ret, sz)
- else:
- tfr[i] = _centered(ret, n_samples)
- return tfr
-
-
-def _cwt_morlet_convolve(x, Fs, freqs, mode='same', Ws=None):
+ fft_Ws[i] = fftn(W, [fsize])
+
+ for k, x in enumerate(X):
+ if mode == "full":
+ tfr = np.zeros((n_freqs, fsize), dtype=np.complex128)
+ elif mode == "same" or mode == "valid":
+ tfr = np.zeros((n_freqs, n_times), dtype=np.complex128)
+
+ fft_x = fftn(x, [fsize])
+ for i, W in enumerate(Ws):
+ ret = ifftn(fft_x * fft_Ws[i])[:n_times + W.size - 1]
+ if mode == "valid":
+ sz = abs(W.size - n_times) + 1
+ offset = (n_times - sz) / 2
+ tfr[i, offset:(offset + sz)] = _centered(ret, sz)
+ else:
+ tfr[i, :] = _centered(ret, n_times)
+ yield tfr
+
+
+def _cwt_convolve(X, Ws, mode='same'):
"""Compute time freq decomposition with temporal convolutions
+ Return a generator over signals.
"""
- x = np.asarray(x)
- freqs = np.asarray(freqs)
+ X = np.asarray(X)
- if Ws is None:
- Ws = morlet(Fs, freqs)
+ n_signals, n_times = X.shape
+ n_freqs = len(Ws)
- n_samples = x.size
# Compute convolutions
- tfr = np.zeros((freqs.size, len(x)), dtype=np.complex128)
- for i, W in enumerate(Ws):
- ret = np.convolve(x, W, mode=mode)
- if mode == "valid":
- sz = abs(W.size - n_samples) + 1
- offset = (n_samples - sz) / 2
- tfr[i, offset:(offset + sz)] = ret
- else:
- tfr[i] = ret
- return tfr
-
-
-def cwt_morlet(x, Fs, freqs, use_fft=True, n_cycles=7.0):
+ for x in X:
+ tfr = np.zeros((n_freqs, n_times), dtype=np.complex128)
+ for i, W in enumerate(Ws):
+ ret = np.convolve(x, W, mode=mode)
+ if mode == "valid":
+ sz = abs(W.size - n_times) + 1
+ offset = (n_times - sz) / 2
+ tfr[i, offset:(offset + sz)] = ret
+ else:
+ tfr[i] = ret
+ yield tfr
+
+
+def cwt_morlet(X, Fs, freqs, use_fft=True, n_cycles=7.0):
"""Compute time freq decomposition with Morlet wavelets
Parameters
----------
- x : array
- signal
+ X : array of shape [n_signals, n_times]
+ signals (one per line)
Fs : float
sampling Frequency
@@ -143,35 +146,48 @@ def cwt_morlet(x, Fs, freqs, use_fft=True, n_cycles=7.0):
Returns
-------
- tfr : 2D array
- Time Frequency Decomposition (Frequencies x Timepoints)
+ tfr : 3D array
+ Time Frequency Decompositions (n_signals x n_frequencies x n_times)
"""
mode = 'same'
# mode = "valid"
+ n_signals, n_times = X.shape
+ n_frequencies = len(freqs)
# Precompute wavelets for given frequency range to save time
Ws = morlet(Fs, freqs, n_cycles=n_cycles)
if use_fft:
- return _cwt_morlet_fft(x, Fs, freqs, mode, Ws)
+ coefs = _cwt_fft(X, Ws, mode)
else:
- return _cwt_morlet_convolve(x, Fs, freqs, mode, Ws)
+ coefs = _cwt_convolve(X, Ws, mode)
+ tfrs = np.empty((n_signals, n_frequencies, n_times))
+ for k, tfr in enumerate(coefs):
+ tfrs[k] = tfr
-def _time_frequency_one_channel(epochs, c, Fs, frequencies, use_fft, n_cycles):
- """Aux of time_frequency for parallel computing"""
- n_epochs, _, n_times = epochs.shape
- n_frequencies = len(frequencies)
- psd_c = np.zeros((n_frequencies, n_times)) # PSD
- plf_c = np.zeros((n_frequencies, n_times), dtype=np.complex) # phase lock
+ return tfrs
- for e in range(n_epochs):
- tfr = cwt_morlet(epochs[e, c, :].ravel(), Fs, frequencies,
- use_fft=use_fft, n_cycles=n_cycles)
+def _time_frequency(X, Ws, use_fft):
+ """Aux of time_frequency for parallel computing over channels
+ """
+ n_epochs, n_times = X.shape
+ n_frequencies = len(Ws)
+ psd = np.zeros((n_frequencies, n_times)) # PSD
+ plf = np.zeros((n_frequencies, n_times), dtype=np.complex) # phase lock
+
+ mode = 'same'
+ if use_fft:
+ tfrs = _cwt_fft(X, Ws, mode)
+ else:
+ tfrs = _cwt_convolve(X, Ws, mode)
+
+ for tfr in tfrs:
tfr_abs = np.abs(tfr)
- psd_c += tfr_abs**2
- plf_c += tfr / tfr_abs
- return psd_c, plf_c
+ psd += tfr_abs**2
+ plf += tfr / tfr_abs
+
+ return psd, plf
def time_frequency(epochs, Fs, frequencies, use_fft=True, n_cycles=25,
@@ -213,6 +229,9 @@ def time_frequency(epochs, Fs, frequencies, use_fft=True, n_cycles=25,
n_frequencies = len(frequencies)
n_epochs, n_channels, n_times = epochs.shape
+ # Precompute wavelets for given frequency range to save time
+ Ws = morlet(Fs, frequencies, n_cycles=n_cycles)
+
try:
import joblib
except ImportError:
@@ -224,13 +243,14 @@ def time_frequency(epochs, Fs, frequencies, use_fft=True, n_cycles=25,
plf = np.empty((n_channels, n_frequencies, n_times), dtype=np.complex)
for c in range(n_channels):
- psd[c,:,:], plf[c,:,:] = _time_frequency_one_channel(epochs, c, Fs,
- frequencies, use_fft, n_cycles)
+ X = np.squeeze(epochs[:,c,:])
+ psd[c], plf[c] = _time_frequency(X, Ws, use_fft)
+
else:
from joblib import Parallel, delayed
psd_plf = Parallel(n_jobs=n_jobs)(
- delayed(_time_frequency_one_channel)(
- epochs, c, Fs, frequencies, use_fft, n_cycles)
+ delayed(_time_frequency)(
+ np.squeeze(epochs[:,c,:]), Ws, use_fft)
for c in range(n_channels))
psd = np.zeros((n_channels, n_frequencies, n_times))
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-med/python-mne.git
More information about the debian-med-commit
mailing list