[med-svn] [python-mne] 273/353: ENH : big cleanup of simulation code + new plot function for sparse stc
Yaroslav Halchenko
debian at onerussian.com
Fri Nov 27 17:25:14 UTC 2015
This is an automated email from the git hooks/post-receive script.
yoh pushed a commit to tag 0.4
in repository python-mne.
commit eea3d6290276dd95eec9f3b812c260da98e4aa45
Author: Alexandre Gramfort <alexandre.gramfort at inria.fr>
Date: Mon Jul 16 15:24:40 2012 +0200
ENH : big cleanup of simulation code + new plot function for sparse stc
---
examples/plot_simulate_evoked_data.py | 85 ++++++++++++++++
mne/simulation/__init__.py | 4 +
mne/simulation/sim_evoked.py | 178 ++++++++++------------------------
mne/time_frequency/__init__.py | 2 +-
mne/time_frequency/ar.py | 30 ++++++
mne/utils.py | 18 ++++
mne/viz.py | 155 +++++++++++++++++++++++++++++
7 files changed, 342 insertions(+), 130 deletions(-)
diff --git a/examples/plot_simulate_evoked_data.py b/examples/plot_simulate_evoked_data.py
new file mode 100644
index 0000000..cf0f3ee
--- /dev/null
+++ b/examples/plot_simulate_evoked_data.py
@@ -0,0 +1,85 @@
+"""
+==============================
+Generate simulated evoked data
+==============================
+
+"""
+# Author: Daniel Strohmeier <daniel.strohmeier at tu-ilmenau.de>
+# Alexandre Gramfort <gramfort at nmr.mgh.harvard.edu>
+#
+# License: BSD (3-clause)
+
+import numpy as np
+import pylab as pl
+
+import mne
+from mne.fiff.pick import pick_types_evoked, pick_types_forward
+from mne.forward import apply_forward
+from mne.datasets import sample
+from mne.time_frequency import fir_filter_raw
+from mne.viz import plot_evoked, plot_sparse_source_estimates
+from mne.simulation.sim_evoked import source_signal, generate_stc, generate_noise_evoked, add_noise
+
+###############################################################################
+# Load real data as templates
+data_path = sample.data_path('.')
+
+raw = mne.fiff.Raw(data_path + '/MEG/sample/sample_audvis_raw.fif')
+proj = mne.read_proj(data_path + '/MEG/sample/ecg_proj.fif')
+raw.info['projs'] += proj
+raw.info['bads'] = ['MEG 2443', 'EEG 053'] # mark bad channels
+
+fwd_fname = data_path + '/MEG/sample/sample_audvis-meg-eeg-oct-6-fwd.fif'
+ave_fname = data_path + '/MEG/sample/sample_audvis-no-filter-ave.fif'
+cov_fname = data_path + '/MEG/sample/sample_audvis-cov.fif'
+
+fwd = mne.read_forward_solution(fwd_fname, force_fixed=True, surf_ori=True)
+fwd = pick_types_forward(fwd, meg=True, eeg=True, exclude=raw.info['bads'])
+
+noise_cov = mne.read_cov(cov_fname)
+
+evoked_template = mne.fiff.read_evoked(ave_fname, setno=0, baseline=None)
+evoked_template = pick_types_evoked(evoked_template, meg=True, eeg=True,
+ exclude=raw.info['bads'])
+
+tmin = -0.1
+sfreq = 1000 # Hz
+tstep = 1. / sfreq
+n_samples = 300
+timesamples = np.linspace(tmin, tmin + n_samples * tstep, n_samples)
+
+label_names = ['Aud-lh', 'Aud-rh']
+labels = [mne.read_label(data_path + '/MEG/sample/labels/%s.label' % ln)
+ for ln in label_names]
+
+mus = [[0.030, 0.060, 0.120], [0.040, 0.060, 0.140]]
+sigmas = [[0.01, 0.02, 0.03], [0.01, 0.02, 0.03]]
+amps = [[40 * 1e-9, 40 * 1e-9, 30 * 1e-9], [30 * 1e-9, 40 * 1e-9, 40 * 1e-9]]
+freqs = [[0, 0, 0], [0, 0, 0]]
+phis = [[0, 0, 0], [0, 0, 0]]
+
+SNR = 6
+dB = True
+
+stc_data = source_signal(mus, sigmas, amps, freqs, phis, timesamples)
+stc = generate_stc(fwd, labels, stc_data, tmin, tstep, random_state=0)
+evoked = apply_forward(fwd, stc, evoked_template)
+
+###############################################################################
+# Add noise
+picks = mne.fiff.pick_types(raw.info, meg=True)
+fir_filter = fir_filter_raw(raw, order=5, picks=picks, tmin=60, tmax=180)
+noise = generate_noise_evoked(evoked, noise_cov, n_samples, fir_filter)
+
+evoked_noise = add_noise(evoked, noise, SNR, timesamples, tmin=0.0, tmax=0.2, dB=dB)
+
+###############################################################################
+# Plot
+plot_sparse_source_estimates(fwd['src'], stc, bgcolor=(1, 1, 1),
+ opacity=0.5, high_resolution=True)
+
+pl.figure()
+pl.psd(evoked_noise.data[0])
+
+pl.figure()
+plot_evoked(evoked)
diff --git a/mne/simulation/__init__.py b/mne/simulation/__init__.py
new file mode 100644
index 0000000..918184b
--- /dev/null
+++ b/mne/simulation/__init__.py
@@ -0,0 +1,4 @@
+"""Data simulation code
+"""
+
+from .sim_evoked import select_source_in_label, generate_stc
\ No newline at end of file
diff --git a/mne/simulation/sim_evoked.py b/mne/simulation/sim_evoked.py
index b01a1a9..97e4b72 100644
--- a/mne/simulation/sim_evoked.py
+++ b/mne/simulation/sim_evoked.py
@@ -1,19 +1,16 @@
-import pdb
-import copy
+# Authors: Alexandre Gramfort <gramfort at nmr.mgh.harvard.edu>
+# Matti Hamalainen <msh at nmr.mgh.harvard.edu>
+#
+# License: BSD (3-clause)
import numpy as np
-import pylab as pl
-
from scipy import signal
-import mne
-from mne.fiff.pick import pick_types_evoked, pick_types_forward, pick_channels_cov
-from mne.forward import apply_forward
-from mne.label import read_label
-from mne.datasets import sample
-from mne.minimum_norm.inverse import _make_stc
-from mne.viz import plot_evoked, plot_sparse_source_estimates
-from mne.time_frequency import ar_raw
+import copy
+
+from ..fiff.pick import pick_channels_cov
+from ..minimum_norm.inverse import _make_stc
+from ..utils import check_random_state
def gaboratomr(timesamples, sigma, mu, k, phase):
@@ -67,69 +64,41 @@ def source_signal(mus, sigmas, amps, freqs, phis, timesamples):
signal : array
simulated source signal
"""
- signal = np.zeros(len(timesamples))
- for m, s, a, f, p in zip(mus, sigmas, amps, freqs, phis):
- signal += gaboratomr(timesamples, s, m, f, p) * a
- return signal
-
-
-def generate_fir_from_raw(raw, picks, order, tmin, tmax, proj=None):
- """Fits an AR model to raw data and creates FIR filter
-
- Parameters
- ----------
- raw : Raw object
- an instance of Raw
- picks : array of int
- indices of selected channels
- order : int
- order of the FIR filter
- tmin : float
- start time before event
- tmax : float
- end time after event
- projs : None | list
- The list of projection vectors
-
- Returns
- -------
- FIR : array
- filter coefficients
- """
- if proj is not None:
- raw.info['projs'] += proj
- picks = picks[:5]
- coefs = ar_raw(raw, order=order, picks=picks, tmin=tmin, tmax=tmax)
- mean_coefs = np.mean(coefs, axis=0) # mean model accross channels
- FIR = np.r_[1, -mean_coefs] # filter coefficient
- return FIR
+ data = np.zeros((len(mus), len(timesamples)))
+ for k in range(len(mus)):
+ for m, s, a, f, p in zip(mus[k], sigmas[k], amps[k], freqs[k], phis[k]):
+ data[k] += gaboratomr(timesamples, s, m, f, p) * a
+ return data
-def generate_noise(noise, noise_cov, nsamp, FIR=None):
- """Creates noise as a multivariate random process
- with specified cov matrix. No deepcopy of noise applied
+def generate_noise_evoked(evoked, noise_cov, n_samples, fir_filter=None, random_state=None):
+ """Creates noise as a multivariate random process with specified cov matrix.
Parameters
----------
- noise : evoked object
- an instance of evoked
+ evoked : evoked object
+ an instance of evoked used as template
noise_cov : cov object
an instance of cov
- nsamp : int
- number of samples to generate
- FIR : None | array
+ n_samples : int
+ number of time samples to generate
+ fir_filter : None | array
FIR filter coefficients
+ random_state : None | int | np.random.RandomState
+ To specify the random generator state.
Returns
-------
noise : evoked object
an instance of evoked
"""
- noise_cov = pick_channels_cov(noise_cov, include=noise_template.info['ch_names'])
- rng = np.random.RandomState(0)
- noise.data = rng.multivariate_normal(np.zeros(noise.info['nchan']), noise_cov.data, nsamp).T
- if FIR is not None:
- noise.data = signal.lfilter([1], FIR, noise.data, axis=-1)
+ noise = copy.deepcopy(evoked)
+ noise_cov = pick_channels_cov(noise_cov, include=noise.info['ch_names'])
+ rng = check_random_state(random_state)
+ n_channels = np.zeros(noise.info['nchan'])
+ noise.data = rng.multivariate_normal(n_channels, noise_cov.data, n_samples).T
+ if fir_filter is not None:
+ noise.data = signal.lfilter([1], fir_filter, noise.data, axis=-1)
return noise
@@ -165,26 +134,28 @@ def add_noise(evoked, noise, SNR, timesamples, tmin=None, tmax=None, dB=False):
tmax = np.max(timesamples)
tmask = (timesamples >= tmin) & (timesamples <= tmax)
if dB:
- SNRtemp = 20 * np.log10(np.sqrt(np.mean((evoked.data[:,tmask] ** 2).ravel()) / \
+ SNRtemp = 20 * np.log10(np.sqrt(np.mean((evoked.data[:, tmask] ** 2).ravel()) / \
np.mean((noise.data ** 2).ravel())))
noise.data = 10 ** ((SNRtemp - float(SNR)) / 20) * noise.data
else:
- SNRtemp = np.sqrt(np.mean((evoked.data[:,tmask] ** 2).ravel()) / \
+ SNRtemp = np.sqrt(np.mean((evoked.data[:, tmask] ** 2).ravel()) / \
np.mean((noise.data ** 2).ravel()))
noise.data = SNRtemp / SNR * noise.data
evoked.data += noise.data
return evoked
-def select_source_idxs(fwd, label_fname):
+def select_source_in_label(fwd, label, random_state=None):
"""Select source positions using a label
Parameters
----------
fwd : dict
a forward solution
- label_fname : str
- filename of the freesurfer label to read
+ label : dict
+ the label (read with mne.read_label)
+ random_state : None | int | np.random.RandomState
+ To specify the random generator state.
Returns
-------
@@ -196,10 +167,9 @@ def select_source_idxs(fwd, label_fname):
lh_vertno = list()
rh_vertno = list()
- label = read_label(label_fname)
- rng = np.random.RandomState(0)
+ rng = check_random_state(random_state)
- if label['hemi']=='lh':
+ if label['hemi'] == 'lh':
src_sel_lh = np.intersect1d(fwd['src'][0]['vertno'], label['vertices'])
idx_select = rng.randint(0, len(src_sel_lh), 1)
lh_vertno.append(src_sel_lh[idx_select][0])
@@ -211,63 +181,13 @@ def select_source_idxs(fwd, label_fname):
return lh_vertno, rh_vertno
-## load data_sets from mne-sample-data ##
-data_path = sample.data_path('.')
-
-fwd_fname = data_path + '/MEG/sample/sample_audvis-meg-eeg-oct-6-fwd.fif'
-fwd = mne.read_forward_solution(fwd_fname, force_fixed=True, surf_ori=True)
-exclude = ['MEG 2443', 'EEG 053']
-meg_include = True
-eeg_include = True
-fwd = pick_types_forward(fwd, meg=meg_include, eeg=eeg_include, exclude=exclude)
-
-cov_fname = data_path + '/MEG/sample/sample_audvis-cov.fif'
-noise_cov = mne.read_cov(cov_fname)
-
-tmin = -0.1
-#sfreq
-tstep = 0.001
-n_samples = 300
-timesamples = np.linspace(tmin, tmin + n_samples * tstep, n_samples)
-
-label = ['Aud-lh', 'Aud-rh']
-amps = [[40 * 1e-9, 40 * 1e-9, 30 * 1e-9], [30 * 1e-9, 40 * 1e-9, 40 * 1e-9]]
-mus = [[0.030, 0.060, 0.120], [0.040, 0.060, 0.140]]
-sigmas = [[0.01, 0.02, 0.03], [0.01, 0.02, 0.03]]
-freqs = [[0, 0, 0], [0, 0, 0]]
-phis = [[0, 0, 0], [0, 0, 0]]
-
-SNR = 6
-dB = True
-
-signals = list()
-vertno = [[], []]
-for k in range(len(label)):
- label_fname = data_path + '/MEG/sample/labels/%s.label' % label[k]
- lh_vertno, rh_vertno = select_source_idxs(fwd, label_fname)
- vertno[0] += lh_vertno
- vertno[1] += rh_vertno
- signals.append(source_signal(mus[k], sigmas[k], amps[k], freqs[k], phis[k], timesamples))
-signals = np.vstack(signals)
-stc = _make_stc(signals, tmin, tstep, vertno)
-plot_sparse_source_estimates(fwd['src'], stc, bgcolor=(1, 1, 1),
- opacity=0.5, high_resolution=True)
-
-ave_fname = data_path + '/MEG/sample/sample_audvis-no-filter-ave.fif'
-evoked_template = mne.fiff.read_evoked(ave_fname, setno=0, baseline=None)
-evoked_template = pick_types_evoked(evoked_template, meg=meg_include, eeg=eeg_include, exclude=exclude)
-evoked = apply_forward(fwd, stc, evoked_template, start=None, stop=None)
-
-noise_template = copy.deepcopy(evoked_template)
-raw = mne.fiff.Raw(data_path + '/MEG/sample/sample_audvis_raw.fif')
-proj = mne.read_proj(data_path + '/MEG/sample/ecg_proj.fif')
-raw.info['projs'] += proj
-raw.info['bads'] = ['MEG 2443', 'EEG 053'] # mark bad channels
-picks = mne.fiff.pick_types(raw.info, meg=True)
-FIR = generate_fir_from_raw(raw, picks, 5, tmin=60, tmax=180, proj=proj)
-noise = generate_noise(noise_template, noise_cov, n_samples, FIR=FIR)
-pl.figure()
-pl.psd(noise.data[0])
-evoked = add_noise(evoked, noise, SNR, timesamples, tmin=0.0, tmax=0.2, dB=dB)
-pl.figure()
-plot_evoked(evoked)
+def generate_stc(fwd, labels, stc_data, tmin, tstep, random_state=0):
+ rng = check_random_state(random_state)
+ vertno = [[], []]
+ for label in labels:
+ lh_vertno, rh_vertno = select_source_in_label(fwd, label, rng)
+ vertno[0] += lh_vertno
+ vertno[1] += rh_vertno
+ vertno = map(np.array, vertno)
+ stc = _make_stc(stc_data, tmin, tstep, vertno)
+ return stc
diff --git a/mne/time_frequency/__init__.py b/mne/time_frequency/__init__.py
index 4826123..be88845 100644
--- a/mne/time_frequency/__init__.py
+++ b/mne/time_frequency/__init__.py
@@ -3,4 +3,4 @@
from .tfr import induced_power, single_trial_power
from .psd import compute_raw_psd
-from .ar import yule_walker, ar_raw
+from .ar import yule_walker, ar_raw, fir_filter_raw
diff --git a/mne/time_frequency/ar.py b/mne/time_frequency/ar.py
index 3daafc1..17e7a5a 100644
--- a/mne/time_frequency/ar.py
+++ b/mne/time_frequency/ar.py
@@ -109,3 +109,33 @@ def ar_raw(raw, order, picks, tmin=None, tmax=None):
this_coefs, _ = yule_walker(d, order=order)
coefs[k, :] = this_coefs
return coefs
+
+
+def fir_filter_raw(raw, order, picks, tmin=None, tmax=None):
+ """Fits an AR model to raw data and creates corresponding FIR filter
+
+ The returned filter is the average filter for all the picked channels.
+
+ Parameters
+ ----------
+ raw : Raw object
+ an instance of Raw
+ order : int
+ order of the FIR filter
+ picks : array of int
+ indices of selected channels
+ tmin : float
+ The beginning of time interval in seconds.
+ tmax : float
+ The end of time interval in seconds.
+
+ Returns
+ -------
+ fir : array
+ filter coefficients
+ """
+ picks = picks[:5]
+ coefs = ar_raw(raw, order=order, picks=picks, tmin=tmin, tmax=tmax)
+ mean_coefs = np.mean(coefs, axis=0) # mean model accross channels
+ fir = np.r_[1, -mean_coefs] # filter coefficient
+ return fir
diff --git a/mne/utils.py b/mne/utils.py
index 20f3832..187de8e 100644
--- a/mne/utils.py
+++ b/mne/utils.py
@@ -247,3 +247,21 @@ try:
from scipy.signal import firwin2
except ImportError:
firwin2 = _firwin2
+
+
+def check_random_state(seed):
+ """Turn seed into a np.random.RandomState instance
+
+ If seed is None, return the RandomState singleton used by np.random.
+ If seed is an int, return a new RandomState instance seeded with seed.
+ If seed is already a RandomState instance, return it.
+ Otherwise raise ValueError.
+ """
+ if seed is None or seed is np.random:
+ return np.random.mtrand._rand
+ if isinstance(seed, (int, np.integer)):
+ return np.random.RandomState(seed)
+ if isinstance(seed, np.random.RandomState):
+ return seed
+ raise ValueError('%r cannot be used to seed a numpy.random.RandomState'
+ ' instance' % seed)
diff --git a/mne/viz.py b/mne/viz.py
index 2933a87..5273620 100644
--- a/mne/viz.py
+++ b/mne/viz.py
@@ -5,6 +5,7 @@
#
# License: Simplified BSD
+from itertools import cycle
import copy
import numpy as np
from scipy import linalg
@@ -85,6 +86,160 @@ def plot_evoked(evoked, picks=None, unit=True, show=True):
pl.show()
+COLORS = ['b', 'g', 'r', 'c', 'm', 'y', 'k', '#473C8B', '#458B74',
+ '#CD7F32', '#FF4040', '#ADFF2F', '#8E2323', '#FF1493']
+
+
+def plot_sparse_source_estimates(src, stcs, colors=None, linewidth=2,
+ fontsize=18, bgcolor=(.05, 0, .1), opacity=0.2,
+ brain_color=(0.7, ) * 3, show=True,
+ high_resolution=False, fig_name=None,
+ fig_number=None, labels=None,
+ modes=['cone', 'sphere'],
+ scale_factors=[1, 0.6],
+ **kwargs):
+ """Plot source estimates obtained with sparse solver
+
+ Active dipoles are represented in a "Glass" brain.
+ If the same source is active in multiple source estimates it is
+ displayed with a sphere otherwise with a cone in 3D.
+
+ Parameters
+ ----------
+ src: dict
+ The source space
+ stcs: instance of SourceEstimate or list of instances of SourceEstimate
+ The source estimates (up to 3)
+ colors: list
+ List of colors
+ linewidth: int
+ Line width in 2D plot
+ fontsize: int
+ Font size
+ bgcolor: tuple of length 3
+ Back ground color in 3D
+ opacity: float in [0, 1]
+ Opacity of brain mesh
+ brain_color: tuple of length 3
+ Brain color
+ show: bool
+ Show figures if True
+ fig_name:
+ Mayavi figure name
+ fig_number:
+ Pylab figure number
+ labels: ndarray or list of ndarrays
+ Labels to show sources in clusters. Sources with the same
+ label and the waveforms within each cluster are presented in
+ the same color. labels should be a list of ndarrays when
+ stcs is a list ie. one label for each stc.
+ kwargs: kwargs
+ kwargs pass to mlab.triangular_mesh
+ """
+ if not isinstance(stcs, list):
+ stcs = [stcs]
+ if labels is not None and not isinstance(labels, list):
+ labels = [labels]
+
+ if colors is None:
+ colors = COLORS
+
+ linestyles = ['-', '--', ':']
+
+ # Show 3D
+ lh_points = src[0]['rr']
+ rh_points = src[1]['rr']
+ points = np.r_[lh_points, rh_points]
+
+ lh_normals = src[0]['nn']
+ rh_normals = src[1]['nn']
+ normals = np.r_[lh_normals, rh_normals]
+
+ if high_resolution:
+ use_lh_faces = src[0]['tris']
+ use_rh_faces = src[1]['tris']
+ else:
+ use_lh_faces = src[0]['use_tris']
+ use_rh_faces = src[1]['use_tris']
+
+ use_faces = np.r_[use_lh_faces, lh_points.shape[0] + use_rh_faces]
+
+ points *= 170
+
+ vertnos = [np.r_[stc.lh_vertno, lh_points.shape[0] + stc.rh_vertno]
+ for stc in stcs]
+ unique_vertnos = np.unique(np.concatenate(vertnos).ravel())
+
+ try:
+ from mayavi import mlab
+ except ImportError:
+ from enthought.mayavi import mlab
+
+ from matplotlib.colors import ColorConverter
+ color_converter = ColorConverter()
+
+ f = mlab.figure(figure=fig_name, bgcolor=bgcolor, size=(800, 800))
+ mlab.clf()
+ f.scene.disable_render = True
+ surface = mlab.triangular_mesh(points[:, 0], points[:, 1], points[:, 2],
+ use_faces, color=brain_color, opacity=opacity,
+ **kwargs)
+
+ import pylab as pl
+ # Show time courses
+ pl.figure(fig_number)
+ pl.clf()
+
+ colors = cycle(colors)
+
+ print "Total number of active sources: %d" % len(unique_vertnos)
+
+ if labels is not None:
+ colors = [colors.next() for _ in
+ range(np.unique(np.concatenate(labels).ravel()).size)]
+
+ for v in unique_vertnos:
+ # get indices of stcs it belongs to
+ ind = [k for k, vertno in enumerate(vertnos) if v in vertno]
+ is_common = len(ind) > 1
+
+ if labels is None:
+ c = colors.next()
+ else:
+ # if vertex is in different stcs than take label from first one
+ c = colors[labels[ind[0]][vertnos[ind[0]] == v]]
+
+ mode = modes[1] if is_common else modes[0]
+ scale_factor = scale_factors[1] if is_common else scale_factors[0]
+ x, y, z = points[v]
+ nx, ny, nz = normals[v]
+ mlab.quiver3d(x, y, z, nx, ny, nz, color=color_converter.to_rgb(c),
+ mode=mode, scale_factor=scale_factor)
+
+ for k in ind:
+ vertno = vertnos[k]
+ mask = (vertno == v)
+ assert np.sum(mask) == 1
+ linestyle = linestyles[k]
+ pl.plot(1e3 * stc.times, 1e9 * stcs[k].data[mask].ravel(), c=c,
+ linewidth=linewidth, linestyle=linestyle)
+
+ pl.xlabel('Time (ms)', fontsize=18)
+ pl.ylabel('Source amplitude (nAm)', fontsize=18)
+
+ if fig_name is not None:
+ pl.title(fig_name)
+
+ if show:
+ pl.show()
+ mlab.show()
+
+ surface.actor.property.backface_culling = True
+ surface.actor.property.shading = True
+
+ return surface
+
+
def plot_cov(cov, info, exclude=[], colorbar=True, proj=False, show_svd=True,
show=True):
"""Plot Covariance data
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-med/python-mne.git
More information about the debian-med-commit
mailing list