[med-svn] [mne-python] 02/03: Imported Upstream version 0.7rc2
Andreas Tille
tille at debian.org
Fri Nov 22 21:15:27 UTC 2013
This is an automated email from the git hooks/post-receive script.
tille pushed a commit to branch master
in repository mne-python.
commit 1ed39d77cb86e5ca11062cd197eeccdec0e9a096
Author: Andreas Tille <tille at debian.org>
Date: Fri Nov 22 22:10:13 2013 +0100
Imported Upstream version 0.7rc2
---
Makefile | 1 +
examples/plot_from_raw_to_epochs_to_evoked.py | 1 -
mne/beamformer/_lcmv.py | 4 +-
mne/epochs.py | 24 ++----
mne/fiff/compensator.py | 13 +++-
mne/fiff/ctf.py | 20 +++--
mne/fiff/evoked.py | 7 +-
mne/fiff/raw.py | 23 +++---
mne/fiff/tests/test_compensator.py | 53 ++++++++++++-
mne/fiff/tests/test_raw.py | 20 ++++-
mne/fixes.py | 105 +++++++++++---------------
mne/forward/forward.py | 33 ++++----
mne/forward/tests/test_make_forward.py | 16 +++-
mne/realtime/epochs.py | 17 ++---
mne/source_estimate.py | 66 +++-------------
mne/tests/test_fixes.py | 70 +++++++++++++++--
mne/tests/test_source_estimate.py | 25 ------
mne/tests/test_utils.py | 58 +++++++++++++-
mne/tests/test_viz.py | 43 ++++++++++-
mne/utils.py | 22 ++----
mne/viz.py | 45 +++++------
21 files changed, 395 insertions(+), 271 deletions(-)
diff --git a/Makefile b/Makefile
index a90fe4a..530400a 100755
--- a/Makefile
+++ b/Makefile
@@ -37,6 +37,7 @@ $(CURDIR)/examples/MNE-sample-data/MEG/sample/sample_audvis_raw.fif:
ln -s ${PWD}/examples/MNE-sample-data ${PWD}/MNE-sample-data -f
test: in sample_data
+ rm -f .coverage
$(NOSETESTS) mne
test-no-sample: in
diff --git a/examples/plot_from_raw_to_epochs_to_evoked.py b/examples/plot_from_raw_to_epochs_to_evoked.py
index 4646a73..599813e 100644
--- a/examples/plot_from_raw_to_epochs_to_evoked.py
+++ b/examples/plot_from_raw_to_epochs_to_evoked.py
@@ -56,7 +56,6 @@ evoked.save('sample_audvis_eeg-ave.fif') # save evoked data to disk
# View evoked response
times = 1e3 * epochs.times # time in miliseconds
import matplotlib.pyplot as plt
-plt.figure()
evoked.plot()
plt.xlim([times[0], times[-1]])
plt.xlabel('time (ms)')
diff --git a/mne/beamformer/_lcmv.py b/mne/beamformer/_lcmv.py
index b599ff7..4a909b4 100644
--- a/mne/beamformer/_lcmv.py
+++ b/mne/beamformer/_lcmv.py
@@ -647,9 +647,7 @@ def tf_lcmv(epochs, forward, noise_covs, tmin, tmax, tstep, win_lengths,
n_jobs=n_jobs)
epochs_band = Epochs(raw_band, epochs.events, epochs.event_id,
tmin=epochs.tmin, tmax=epochs.tmax,
- picks=raw_picks, keep_comp=epochs.keep_comp,
- dest_comp=epochs.dest_comp,
- proj=epochs.proj, preload=True)
+ picks=raw_picks, proj=epochs.proj, preload=True)
del raw_band
if subtract_evoked:
diff --git a/mne/epochs.py b/mne/epochs.py
index 7b7df23..d82f20a 100644
--- a/mne/epochs.py
+++ b/mne/epochs.py
@@ -44,10 +44,9 @@ class _BaseEpochs(ProjMixin):
directly. See Epochs below for an explanation of the parameters.
"""
def __init__(self, info, event_id, tmin, tmax, baseline=(None, 0),
- picks=None, name='Unknown', keep_comp=False, dest_comp=0,
- reject=None, flat=None, decim=1, reject_tmin=None,
- reject_tmax=None, detrend=None, add_eeg_ref=True,
- verbose=None):
+ picks=None, name='Unknown', reject=None, flat=None,
+ decim=1, reject_tmin=None, reject_tmax=None, detrend=None,
+ add_eeg_ref=True, verbose=None):
self.verbose = verbose
self.name = name
@@ -80,8 +79,6 @@ class _BaseEpochs(ProjMixin):
self.tmin = tmin
self.tmax = tmax
- self.keep_comp = keep_comp
- self.dest_comp = dest_comp
self.baseline = baseline
self.reject = reject
self.reject_tmin = reject_tmin
@@ -105,11 +102,6 @@ class _BaseEpochs(ProjMixin):
if len(picks) == 0:
raise ValueError("Picks cannot be empty.")
- # XXX : deprecate CTF compensator
- if dest_comp is not None or keep_comp is not None:
- raise ValueError('current_comp and keep_comp are deprecated.'
- ' Use the compensation parameter in Raw.')
-
# Handle times
if tmin >= tmax:
raise ValueError('tmin has to be smaller than tmax')
@@ -590,10 +582,9 @@ class Epochs(_BaseEpochs):
"""
@verbose
def __init__(self, raw, events, event_id, tmin, tmax, baseline=(None, 0),
- picks=None, name='Unknown', keep_comp=None, dest_comp=None,
- preload=False, reject=None, flat=None, proj=True,
- decim=1, reject_tmin=None, reject_tmax=None, detrend=None,
- add_eeg_ref=True, verbose=None):
+ picks=None, name='Unknown', preload=False, reject=None,
+ flat=None, proj=True, decim=1, reject_tmin=None,
+ reject_tmax=None, detrend=None, add_eeg_ref=True, verbose=None):
if raw is None:
return
@@ -612,7 +603,6 @@ class Epochs(_BaseEpochs):
# call _BaseEpochs constructor
super(Epochs, self).__init__(info, event_id, tmin, tmax,
baseline=baseline, picks=picks, name=name,
- keep_comp=keep_comp, dest_comp=dest_comp,
reject=reject, flat=flat, decim=decim,
reject_tmin=reject_tmin,
reject_tmax=reject_tmax, detrend=detrend,
@@ -1669,7 +1659,7 @@ def read_epochs(fname, proj=True, add_eeg_ref=True, verbose=None):
comment = tag.data
elif kind == FIFF.FIFF_EPOCH:
tag = read_tag(fid, pos)
- data = tag.data
+ data = tag.data.astype(np.float)
elif kind == FIFF.FIFF_MNE_BASELINE_MIN:
tag = read_tag(fid, pos)
bmin = float(tag.data)
diff --git a/mne/fiff/compensator.py b/mne/fiff/compensator.py
index 40f9c04..91b38cc 100644
--- a/mne/fiff/compensator.py
+++ b/mne/fiff/compensator.py
@@ -19,6 +19,16 @@ def get_current_comp(info):
return comp
+def set_current_comp(info, comp):
+ """Set the current compensation in effect in the data
+ """
+ comp_now = get_current_comp(info)
+ for k, chan in enumerate(info['chs']):
+ if chan['kind'] == FIFF.FIFFV_MEG_CH:
+ rem = chan['coil_type'] - (comp_now << 16)
+ chan['coil_type'] = int(rem + (comp << 16))
+
+
def _make_compensator(info, kind):
"""Auxiliary function for make_compensator
"""
@@ -36,7 +46,7 @@ def _make_compensator(info, kind):
'data' % col_name)
elif len(ind) > 1:
raise ValueError('Ambiguous channel %s' % col_name)
- presel[col, ind] = 1.0
+ presel[col, ind[0]] = 1.0
# Create the postselector
postsel = np.zeros((info['nchan'], this_data['nrow']))
@@ -47,7 +57,6 @@ def _make_compensator(info, kind):
raise ValueError('Ambiguous channel %s' % ch_name)
elif len(ind) == 1:
postsel[c, ind[0]] = 1.0
-
this_comp = np.dot(postsel, np.dot(this_data['data'], presel))
return this_comp
diff --git a/mne/fiff/ctf.py b/mne/fiff/ctf.py
index d24d297..abe13b2 100644
--- a/mne/fiff/ctf.py
+++ b/mne/fiff/ctf.py
@@ -4,6 +4,8 @@
#
# License: BSD (3-clause)
+from copy import deepcopy
+
import numpy as np
from .constants import FIFF
@@ -95,7 +97,7 @@ def _read_named_matrix(fid, node, matkind):
else:
mat['col_names'] = None
- mat['data'] = data
+ mat['data'] = data.astype(np.float)
return mat
@@ -196,8 +198,7 @@ def read_ctf_comp(fid, node, chs, verbose=None):
idx = ch_names.index(mat['row_names'][row])
row_cals[row] = chs[idx]['range'] * chs[idx]['cal']
- mat['data'] = np.dot(np.diag(row_cals), np.dot(mat['data'],
- np.diag(col_cals)))
+ mat['data'] = row_cals[:, None] * mat['data'] * col_cals[None, :]
one['rowcals'] = row_cals
one['colcals'] = col_cals
@@ -242,11 +243,14 @@ def write_ctf_comp(fid, comps):
# Write the compensation kind
write_int(fid, FIFF.FIFF_MNE_CTF_COMP_KIND, comp['ctfkind'])
write_int(fid, FIFF.FIFF_MNE_CTF_COMP_CALIBRATED,
- comp['save_calibrated'])
-
- # Write an uncalibrated or calibrated matrix
- comp['data']['data'] = (comp['rowcals'][:, None] * comp['data']['data']
- * comp['colcals'][None, :])
+ comp['save_calibrated'])
+
+ if not comp['save_calibrated']:
+ # Undo calibration
+ comp = deepcopy(comp)
+ data = ((1. / comp['rowcals'][:, None]) * comp['data']['data']
+ * (1. / comp['colcals'][None, :]))
+ comp['data']['data'] = data
write_named_matrix(fid, FIFF.FIFF_MNE_CTF_COMP_DATA, comp['data'])
end_block(fid, FIFF.FIFFB_MNE_CTF_COMP_DATA)
diff --git a/mne/fiff/evoked.py b/mne/fiff/evoked.py
index d56a394..e8fa6e4 100644
--- a/mne/fiff/evoked.py
+++ b/mne/fiff/evoked.py
@@ -241,13 +241,14 @@ class Evoked(ProjMixin):
if nepoch == 1:
# Only one epoch
- all_data = epoch[0].data
+ all_data = epoch[0].data.astype(np.float)
# May need a transpose if the number of channels is one
if all_data.shape[1] == 1 and info['nchan'] == 1:
- all_data = all_data.T
+ all_data = all_data.T.astype(np.float)
else:
# Put the old style epochs together
- all_data = np.concatenate([e.data[None, :] for e in epoch], axis=0)
+ all_data = np.concatenate([e.data[None, :] for e in epoch],
+ axis=0).astype(np.float)
if all_data.shape[1] != nsamp:
fid.close()
diff --git a/mne/fiff/raw.py b/mne/fiff/raw.py
index c0a7957..5c0032a 100644
--- a/mne/fiff/raw.py
+++ b/mne/fiff/raw.py
@@ -24,7 +24,7 @@ from .tag import read_tag
from .pick import pick_types, channel_type
from .proj import (setup_proj, activate_proj, proj_equal, ProjMixin,
_has_eeg_average_ref_proj, make_eeg_average_ref_proj)
-from .compensator import get_current_comp, make_compensator
+from .compensator import get_current_comp, set_current_comp, make_compensator
from ..filter import (low_pass_filter, high_pass_filter, band_pass_filter,
notch_filter, band_stop_filter, resample)
@@ -103,6 +103,7 @@ class Raw(ProjMixin):
self.cals = raws[0].cals
self.rawdirs = [r.rawdir for r in raws]
self.comp = copy.deepcopy(raws[0].comp)
+ self._orig_comp_grade = raws[0]._orig_comp_grade
self.fids = [r.fid for r in raws]
self.info = copy.deepcopy(raws[0].info)
self.verbose = verbose
@@ -302,6 +303,7 @@ class Raw(ProjMixin):
raw.cals = cals
raw.rawdir = rawdir
raw.comp = None
+ raw._orig_comp_grade = None
# Set up the CTF compensator
current_comp = get_current_comp(info)
@@ -313,6 +315,8 @@ class Raw(ProjMixin):
if raw.comp is not None:
logger.info('Appropriate compensator added to change to '
'grade %d.' % (compensation))
+ raw._orig_comp_grade = current_comp
+ set_current_comp(info, compensation)
logger.info(' Range : %d ... %d = %9.3f ... %9.3f secs' % (
raw.first_samp, raw.last_samp,
@@ -966,6 +970,12 @@ class Raw(ProjMixin):
info = self.info
projector = None
+ # set the correct compensation grade and make inverse compensator
+ inv_comp = None
+ if self.comp is not None:
+ inv_comp = linalg.inv(self.comp)
+ set_current_comp(info, self._orig_comp_grade)
+
outfid, cals = start_writing_raw(fname, info, picks, type_dict[format],
reset_range=reset_dict[format])
#
@@ -990,12 +1000,6 @@ class Raw(ProjMixin):
#
# Read and write all the data
#
-
- # Take care of CTF compensation
- inv_comp = None
- if self.comp is not None:
- inv_comp = linalg.inv(self.comp)
-
if first_samp != 0:
write_int(outfid, FIFF.FIFF_FIRST_SAMPLE, first_samp)
for first in range(start, stop, buffer_size):
@@ -1563,9 +1567,10 @@ class Raw(ProjMixin):
for ri in range(len(self._raw_lengths)):
mult.append(np.diag(self.cals.ravel()))
if self.comp is not None:
- mult[ri] = np.dot(self.comp[idx, :], mult[ri])
+ mult[ri] = np.dot(self.comp, mult[ri])
if projector is not None:
mult[ri] = np.dot(projector, mult[ri])
+ mult[ri] = mult[ri][idx]
# deal with having multiple files accessed by the raw object
cumul_lens = np.concatenate(([0], np.array(self._raw_lengths,
@@ -1642,7 +1647,7 @@ class Raw(ProjMixin):
one = one.T.astype(dtype)
# use proj + cal factors in mult
if mult is not None:
- one = np.dot(mult[fi], one)
+ one[idx] = np.dot(mult[fi], one)
else: # apply just the calibration factors
# this logic is designed to limit memory copies
if isinstance(idx, slice):
diff --git a/mne/fiff/tests/test_compensator.py b/mne/fiff/tests/test_compensator.py
index 5a94431..baca7f4 100644
--- a/mne/fiff/tests/test_compensator.py
+++ b/mne/fiff/tests/test_compensator.py
@@ -4,13 +4,19 @@
import os.path as op
from nose.tools import assert_true
+import numpy as np
+from numpy.testing import assert_allclose
-from mne.fiff.compensator import make_compensator
-from mne.fiff import Raw
+from mne import Epochs
+from mne.fiff.compensator import make_compensator, get_current_comp
+from mne.fiff import Raw, pick_types, read_evoked
+from mne.utils import _TempDir, requires_mne, run_subprocess
base_dir = op.join(op.dirname(__file__), 'data')
ctf_comp_fname = op.join(base_dir, 'test_ctf_comp_raw.fif')
+tempdir = _TempDir()
+
def test_compensation():
"""Test compensation
@@ -20,3 +26,46 @@ def test_compensation():
assert_true(comp1.shape == (340, 340))
comp2 = make_compensator(raw.info, 3, 1, exclude_comp_chs=True)
assert_true(comp2.shape == (311, 340))
+
+ # make sure that changing the comp doesn't modify the original data
+ raw2 = Raw(ctf_comp_fname, compensation=2)
+ assert_true(get_current_comp(raw2.info) == 2)
+ fname = op.join(tempdir, 'ctf-raw.fif')
+ raw2.save(fname)
+ raw2 = Raw(fname, compensation=None)
+ data, _ = raw[:, :]
+ data2, _ = raw2[:, :]
+ assert_allclose(data, data2, rtol=1e-9, atol=1e-20)
+ for ch1, ch2 in zip(raw.info['chs'], raw2.info['chs']):
+ assert_true(ch1['coil_type'] == ch2['coil_type'])
+
+
+ at requires_mne
+def test_compensation_mne():
+ """Test comensation by comparing with MNE
+ """
+ def make_evoked(fname, comp):
+ raw = Raw(fname, compensation=comp)
+ picks = pick_types(raw.info, meg=True, ref_meg=True)
+ events = np.array([[0, 0, 1]], dtype=np.int)
+ evoked = Epochs(raw, events, 1, 0, 20e-3, picks=picks).average()
+ return evoked
+
+ def compensate_mne(fname, comp):
+ tmp_fname = '%s-%d.fif' % (fname[:-4], comp)
+ cmd = ['mne_compensate_data', '--in', fname,
+ '--out', tmp_fname, '--grad', str(comp)]
+ run_subprocess(cmd)
+ return read_evoked(tmp_fname)
+
+ # save evoked response with default compensation
+ fname_default = op.join(tempdir, 'ctf_default-ave.fif')
+ make_evoked(ctf_comp_fname, None).save(fname_default)
+
+ for comp in [0, 1, 2, 3]:
+ evoked_py = make_evoked(ctf_comp_fname, comp)
+ evoked_c = compensate_mne(fname_default, comp)
+ picks_py = pick_types(evoked_py.info, meg=True, ref_meg=True)
+ picks_c = pick_types(evoked_c.info, meg=True, ref_meg=True)
+ assert_allclose(evoked_py.data[picks_py], evoked_c.data[picks_c],
+ rtol=1e-3, atol=1e-17)
diff --git a/mne/fiff/tests/test_raw.py b/mne/fiff/tests/test_raw.py
index ad59f44..2cf4d18 100644
--- a/mne/fiff/tests/test_raw.py
+++ b/mne/fiff/tests/test_raw.py
@@ -16,7 +16,8 @@ from nose.tools import assert_true, assert_raises, assert_equal
from mne.fiff import (Raw, pick_types, pick_channels, concatenate_raws, FIFF,
get_chpi_positions, set_eeg_reference)
from mne import concatenate_events, find_events
-from mne.utils import _TempDir, requires_nitime, requires_pandas
+from mne.utils import (_TempDir, requires_nitime, requires_pandas, requires_mne,
+ run_subprocess)
warnings.simplefilter('always') # enable b/c these tests throw warnings
@@ -818,6 +819,23 @@ def test_compensation_raw():
assert_allclose(data1, data5, rtol=1e-12, atol=1e-22)
+ at requires_mne
+def test_compensation_raw_mne():
+ """Test Raw compensation by comparing with MNE
+ """
+ def compensate_mne(fname, grad):
+ tmp_fname = op.join(tempdir, 'mne_ctf_test_raw.fif')
+ cmd = ['mne_process_raw', '--raw', fname, '--save', tmp_fname,
+ '--grad', str(grad), '--projoff', '--filteroff']
+ run_subprocess(cmd)
+ return Raw(tmp_fname, preload=True)
+
+ for grad in [0, 2, 3]:
+ raw_py = Raw(ctf_comp_fname, preload=True, compensation=grad)
+ raw_c = compensate_mne(ctf_comp_fname, grad)
+ assert_allclose(raw_py._data, raw_c._data, rtol=1e-6, atol=1e-17)
+
+
def test_set_eeg_reference():
""" Test rereference eeg data"""
raw = Raw(fif_fname, preload=True)
diff --git a/mne/fixes.py b/mne/fixes.py
index 1ab35d0..e739b58 100644
--- a/mne/fixes.py
+++ b/mne/fixes.py
@@ -16,8 +16,10 @@ import collections
from operator import itemgetter
import inspect
+import warnings
import numpy as np
import scipy
+from scipy import linalg
from math import ceil, log
from numpy.fft import irfft
from scipy.signal import filtfilt as sp_filtfilt
@@ -25,47 +27,29 @@ from distutils.version import LooseVersion
from functools import partial
import copy_reg
-try:
- Counter = collections.Counter
-except AttributeError:
- class Counter(collections.defaultdict):
- """Partial replacement for Python 2.7 collections.Counter."""
- def __init__(self, iterable=(), **kwargs):
- super(Counter, self).__init__(int, **kwargs)
- self.update(iterable)
-
- def most_common(self):
- return sorted(self.iteritems(), key=itemgetter(1), reverse=True)
-
- def update(self, other):
- """Adds counts for elements in other"""
- if isinstance(other, self.__class__):
- for x, n in other.iteritems():
- self[x] += n
- else:
- for x in other:
- self[x] += 1
+class _Counter(collections.defaultdict):
+ """Partial replacement for Python 2.7 collections.Counter."""
+ def __init__(self, iterable=(), **kwargs):
+ super(_Counter, self).__init__(int, **kwargs)
+ self.update(iterable)
-def lsqr(X, y, tol=1e-3):
- import scipy.sparse.linalg as sp_linalg
- from ..utils.extmath import safe_sparse_dot
+ def most_common(self):
+ return sorted(self.iteritems(), key=itemgetter(1), reverse=True)
- if hasattr(sp_linalg, 'lsqr'):
- # scipy 0.8 or greater
- return sp_linalg.lsqr(X, y)
- else:
- n_samples, n_features = X.shape
- if n_samples > n_features:
- coef, _ = sp_linalg.cg(safe_sparse_dot(X.T, X),
- safe_sparse_dot(X.T, y),
- tol=tol)
+ def update(self, other):
+ """Adds counts for elements in other"""
+ if isinstance(other, self.__class__):
+ for x, n in other.iteritems():
+ self[x] += n
else:
- coef, _ = sp_linalg.cg(safe_sparse_dot(X, X.T), y, tol=tol)
- coef = safe_sparse_dot(X.T, coef)
+ for x in other:
+ self[x] += 1
- residues = y - safe_sparse_dot(X, coef)
- return coef, None, None, residues
+try:
+ Counter = collections.Counter
+except AttributeError:
+ Counter = _Counter
def _unique(ar, return_index=False, return_inverse=False):
@@ -110,15 +94,7 @@ def _unique(ar, return_index=False, return_inverse=False):
flag = np.concatenate(([True], ar[1:] != ar[:-1]))
return ar[flag]
-np_version = []
-for x in np.__version__.split('.'):
- try:
- np_version.append(int(x))
- except ValueError:
- # x may be of the form dev-1ea1592
- np_version.append(x)
-
-if np_version[:2] < (1, 5):
+if LooseVersion(np.__version__) < LooseVersion('1.5'):
unique = _unique
else:
unique = np.unique
@@ -133,7 +109,7 @@ def _bincount(X, weights=None, minlength=None):
out[:len(result)] = result
return out
-if np_version[:2] < (1, 6):
+if LooseVersion(np.__version__) < LooseVersion('1.6'):
bincount = _bincount
else:
bincount = np.bincount
@@ -205,26 +181,29 @@ def _unravel_index(indices, dims):
return tuple(unraveled_coords.T)
-if np_version[:2] < (1, 4):
+if LooseVersion(np.__version__) < LooseVersion('1.4'):
unravel_index = _unravel_index
else:
unravel_index = np.unravel_index
-def qr_economic(A, **kwargs):
- """Compat function for the QR-decomposition in economic mode
-
+def _qr_economic_old(A, **kwargs):
+ """
+ Compat function for the QR-decomposition in economic mode
Scipy 0.9 changed the keyword econ=True to mode='economic'
"""
- import scipy.linalg
- # trick: triangular solve has introduced in 0.9
- if hasattr(scipy.linalg, 'solve_triangular'):
- return scipy.linalg.qr(A, mode='economic', **kwargs)
- else:
- import warnings
- with warnings.catch_warnings():
- warnings.simplefilter("ignore", DeprecationWarning)
- return scipy.linalg.qr(A, econ=True, **kwargs)
+ with warnings.catch_warnings(True):
+ return linalg.qr(A, econ=True, **kwargs)
+
+
+def _qr_economic_new(A, **kwargs):
+ return linalg.qr(A, mode='economic', **kwargs)
+
+
+if LooseVersion(scipy.__version__) < LooseVersion('0.9'):
+ qr_economic = _qr_economic_old
+else:
+ qr_economic = _qr_economic_new
def savemat(file_name, mdict, oned_as="column", **kwargs):
@@ -372,7 +351,8 @@ def _firwin2(numtaps, freq, gain, nfreqs=None, window='hamming', nyq=1.0):
if nfreqs is not None and numtaps >= nfreqs:
raise ValueError('ntaps must be less than nfreqs, but firwin2 was '
- 'called with ntaps=%d and nfreqs=%s' % (numtaps, nfreqs))
+ 'called with ntaps=%d and nfreqs=%s'
+ % (numtaps, nfreqs))
if freq[0] != 0 or freq[-1] != nyq:
raise ValueError('freq must start with 0 and end with `nyq`.')
@@ -385,7 +365,7 @@ def _firwin2(numtaps, freq, gain, nfreqs=None, window='hamming', nyq=1.0):
if numtaps % 2 == 0 and gain[-1] != 0.0:
raise ValueError("A filter with an even number of coefficients must "
- "have zero gain at the Nyquist rate.")
+ "have zero gain at the Nyquist rate.")
if nfreqs is None:
nfreqs = 1 + 2 ** int(ceil(log(numtaps, 2)))
@@ -539,9 +519,8 @@ copy_reg.pickle(partial, _reduce_partial)
def normalize_colors(vmin, vmax, clip=False):
"""Helper to handle matplotlib API"""
- import matplotlib.pyplot as plt
+ import matplotlib.pyplot as plt
if 'Normalize' in vars(plt):
return plt.Normalize(vmin, vmax, clip=clip)
else:
return plt.normalize(vmin, vmax, clip=clip)
-
diff --git a/mne/forward/forward.py b/mne/forward/forward.py
index df480df..0ceb791 100644
--- a/mne/forward/forward.py
+++ b/mne/forward/forward.py
@@ -1373,21 +1373,24 @@ def do_forward_solution(subject, meas, fname=None, src=None, spacing=None,
raise ValueError('subject must be a string')
# check for meas to exist as string, or try to make evoked
- if not isinstance(meas, basestring):
- # See if we need to make a meas file
- if isinstance(meas, Raw):
- events = make_fixed_length_events(meas, 1)[0][np.newaxis, :]
- meas = Epochs(meas, events, 1, 0, 1, proj=False)
- if isinstance(meas, Epochs):
- meas = meas.average()
- if isinstance(meas, Evoked):
- meas_data = meas
- meas = op.join(temp_dir, 'evoked.fif')
- write_evoked(meas, meas_data)
- if not isinstance(meas, basestring):
- raise ValueError('meas must be string, Raw, Epochs, or Evoked')
- if not op.isfile(meas):
- raise IOError('measurement file "%s" could not be found' % meas)
+ meas_data = None
+ if isinstance(meas, basestring):
+ if not op.isfile(meas):
+ raise IOError('measurement file "%s" could not be found' % meas)
+ elif isinstance(meas, Raw):
+ events = np.array([[0, 0, 1]], dtype=np.int)
+ end = 1. / meas.info['sfreq']
+ meas_data = Epochs(meas, events, 1, 0, end, proj=False).average()
+ elif isinstance(meas, Epochs):
+ meas_data = meas.average()
+ elif isinstance(meas, Evoked):
+ meas_data = meas
+ else:
+ raise ValueError('meas must be string, Raw, Epochs, or Evoked')
+
+ if meas_data is not None:
+ meas = op.join(temp_dir, 'evoked.fif')
+ write_evoked(meas, meas_data)
# deal with trans/mri
if mri is not None and trans is not None:
diff --git a/mne/forward/tests/test_make_forward.py b/mne/forward/tests/test_make_forward.py
index 116a141..7095532 100644
--- a/mne/forward/tests/test_make_forward.py
+++ b/mne/forward/tests/test_make_forward.py
@@ -3,12 +3,14 @@ import os.path as op
from subprocess import CalledProcessError
from nose.tools import assert_raises
+import numpy as np
from numpy.testing import (assert_equal, assert_allclose)
from mne.datasets import sample
+from mne.fiff import Raw
from mne.fiff.kit import read_raw_kit
from mne.fiff.bti import read_raw_bti
-from mne import (read_forward_solution, make_forward_solution,
+from mne import (Epochs, read_forward_solution, make_forward_solution,
do_forward_solution, setup_source_space, read_trans,
convert_forward_solution)
from mne.utils import requires_mne, _TempDir
@@ -145,6 +147,18 @@ def test_make_forward_solution_kit():
eeg=False, meg=True, subjects_dir=subjects_dir)
_compare_forwards(fwd, fwd_py, 274, 108)
+ # CTF with compensation changed in python
+ ctf_raw = Raw(fname_ctf_raw, compensation=2)
+
+ fwd_py = make_forward_solution(ctf_raw.info, mindist=0.0,
+ src=src, eeg=False, meg=True,
+ bem=fname_bem, mri=fname_mri)
+
+ fwd = do_forward_solution('sample', ctf_raw, src=fname_src,
+ mindist=0.0, bem=fname_bem, mri=fname_mri,
+ eeg=False, meg=True, subjects_dir=subjects_dir)
+ _compare_forwards(fwd, fwd_py, 274, 108)
+
@sample.requires_sample_data
def test_make_forward_solution():
diff --git a/mne/realtime/epochs.py b/mne/realtime/epochs.py
index c24725c..b356c06 100644
--- a/mne/realtime/epochs.py
+++ b/mne/realtime/epochs.py
@@ -56,8 +56,6 @@ class RtEpochs(_BaseEpochs):
are requested and the receive queue is empty.
name : string
Comment that describes the Evoked data created.
- keep_comp : boolean
- Apply CTF gradient compensation.
baseline : None (default) or tuple of length 2
The time interval to apply baseline correction.
If None do not apply it. If baseline is (a, b)
@@ -126,10 +124,9 @@ class RtEpochs(_BaseEpochs):
@verbose
def __init__(self, client, event_id, tmin, tmax, stim_channel='STI 014',
sleep_time=0.1, baseline=(None, 0), picks=None,
- name='Unknown', keep_comp=None, dest_comp=None, reject=None,
- flat=None, proj=True, decim=1, reject_tmin=None,
- reject_tmax=None, detrend=None, add_eeg_ref=True,
- isi_max=2., verbose=None):
+ name='Unknown', reject=None, flat=None, proj=True,
+ decim=1, reject_tmin=None, reject_tmax=None, detrend=None,
+ add_eeg_ref=True, isi_max=2., verbose=None):
info = client.get_measurement_info()
@@ -140,10 +137,10 @@ class RtEpochs(_BaseEpochs):
# call _BaseEpochs constructor
super(RtEpochs, self).__init__(info, event_id, tmin, tmax,
- baseline=baseline, picks=picks, name=name, keep_comp=keep_comp,
- dest_comp=dest_comp, reject=reject, flat=flat,
- decim=decim, reject_tmin=reject_tmin, reject_tmax=reject_tmax,
- detrend=detrend, add_eeg_ref=add_eeg_ref, verbose=verbose)
+ baseline=baseline, picks=picks, name=name, reject=reject,
+ flat=flat, decim=decim, reject_tmin=reject_tmin,
+ reject_tmax=reject_tmax, detrend=detrend,
+ add_eeg_ref=add_eeg_ref, verbose=verbose)
self.proj = proj
self._projector, self.info = setup_proj(self.info, add_eeg_ref,
diff --git a/mne/source_estimate.py b/mne/source_estimate.py
index 39c9ec7..3e2a631 100644
--- a/mne/source_estimate.py
+++ b/mne/source_estimate.py
@@ -345,48 +345,6 @@ def _make_stc(data, vertices, tmin=None, tstep=None, subject=None):
return stc
-class _NotifyArray(np.ndarray):
- """Array class that executes a callback when it is modified
- """
- def __new__(cls, input_array, modify_callback=None):
- obj = np.asarray(input_array).view(cls)
- obj.modify_callback = modify_callback
- return obj
-
- def __array_finalize__(self, obj):
- if obj is None:
- # an empty constructor was used
- return
-
- # try to copy the callback
- self.modify_callback = getattr(obj, 'modify_callback', None)
-
- def _modified_(self):
- """Execute the callback if it is set"""
- if self.modify_callback is not None:
- self.modify_callback()
-
- def __getattribute__(self, name):
- # catch ndarray methods that modify the array inplace
- if name in ['fill', 'itemset', 'resize', 'sort']:
- self._modified_()
-
- return object.__getattribute__(self, name)
-
- def __setitem__(self, item, value):
- self._modified_()
- np.ndarray.__setitem__(self, item, value)
-
- def __array_wrap__(self, out_arr, context=None):
- # this method is called whenever a numpy ufunc (+, +=..) is called
- # the last entry in context is the array that receives the result
- if (context is not None and len(context[1]) == 3
- and context[1][2] is self):
- self._modified_()
-
- return np.ndarray.__array_wrap__(self, out_arr, context)
-
-
def _verify_source_estimate_compat(a, b):
"""Make sure two SourceEstimates are compatible for arith. operations"""
compat = False
@@ -481,20 +439,17 @@ class _BaseSourceEstimate(object):
self.verbose = verbose
self._kernel = kernel
self._sens_data = sens_data
+ self._kernel_removed = False
self.times = None
self._update_times()
self.subject = _check_subject(None, subject, False)
def _remove_kernel_sens_data_(self):
- """Remove kernel and sensor space data
-
- Note: self._data is also computed if it is None
+ """Remove kernel and sensor space data and compute self._data
"""
if self._kernel is not None or self._sens_data is not None:
- # we can no longer use the kernel and sens_data
- logger.info('STC data modified: removing kernel and sensor data')
- if self._data is None:
- self._data = np.dot(self._kernel, self._sens_data)
+ self._kernel_removed = True
+ self._data = np.dot(self._kernel, self._sens_data)
self._kernel = None
self._sens_data = None
@@ -517,7 +472,6 @@ class _BaseSourceEstimate(object):
if self._kernel is not None and self._sens_data is not None:
self._sens_data = self._sens_data[:, mask]
- self._data = None # will be recomputed when data is accessed
else:
self._data = self._data[:, mask]
@@ -564,11 +518,9 @@ class _BaseSourceEstimate(object):
@property
def data(self):
if self._data is None:
- # compute the solution the first time the data is accessed
- # return a "notify array", so we can later remove the kernel
- # and sensor data if the user modifies self._data
- self._data = _NotifyArray(np.dot(self._kernel, self._sens_data),
- modify_callback=self._remove_kernel_sens_data_)
+ # compute the solution the first time the data is accessed and
+ # remove the kernel and sensor data
+ self._remove_kernel_sens_data_()
return self._data
@property
@@ -788,6 +740,10 @@ class _BaseSourceEstimate(object):
fun_args = tuple()
if self._kernel is None and self._sens_data is None:
+ if self._kernel_removed:
+ warnings.warn('Performance can be improved by not accessing '
+ 'the data attribute before calling this method.')
+
# transform source space data directly
data_t = transform_fun(self.data[idx, tmin_idx:tmax_idx],
*fun_args, **kwargs)
diff --git a/mne/tests/test_fixes.py b/mne/tests/test_fixes.py
index 8ef53ce..b39f2ed 100644
--- a/mne/tests/test_fixes.py
+++ b/mne/tests/test_fixes.py
@@ -7,13 +7,68 @@ import numpy as np
from nose.tools import assert_equal
from numpy.testing import assert_array_equal
+from distutils.version import LooseVersion
from scipy import signal
-from ..fixes import _in1d, _tril_indices, _copysign, _unravel_index
+from ..fixes import (_in1d, _tril_indices, _copysign, _unravel_index,
+ _Counter, _unique, _bincount)
from ..fixes import _firwin2 as mne_firwin2
from ..fixes import _filtfilt as mne_filtfilt
+def test_counter():
+ """Test Counter replacement"""
+ import collections
+ try:
+ Counter = collections.Counter
+ except:
+ pass
+ else:
+ a = Counter([1, 2, 1, 3])
+ b = _Counter([1, 2, 1, 3])
+ for key, count in zip([1, 2, 3], [2, 1, 1]):
+ assert_equal(a[key], b[key])
+
+
+def test_unique():
+ """Test unique() replacement
+ """
+ # skip test for np version < 1.5
+ if LooseVersion(np.__version__) < LooseVersion('1.5'):
+ return
+ for arr in [np.array([]), np.random.rand(10), np.ones(10)]:
+ # basic
+ assert_array_equal(np.unique(arr), _unique(arr))
+ # with return_index=True
+ x1, x2 = np.unique(arr, return_index=True, return_inverse=False)
+ y1, y2 = _unique(arr, return_index=True, return_inverse=False)
+ assert_array_equal(x1, y1)
+ assert_array_equal(x2, y2)
+ # with return_inverse=True
+ x1, x2 = np.unique(arr, return_index=False, return_inverse=True)
+ y1, y2 = _unique(arr, return_index=False, return_inverse=True)
+ assert_array_equal(x1, y1)
+ assert_array_equal(x2, y2)
+ # with both:
+ x1, x2, x3 = np.unique(arr, return_index=True, return_inverse=True)
+ y1, y2, y3 = _unique(arr, return_index=True, return_inverse=True)
+ assert_array_equal(x1, y1)
+ assert_array_equal(x2, y2)
+ assert_array_equal(x3, y3)
+
+
+def test_bincount():
+ """Test bincount() replacement
+ """
+ # skip test for np version < 1.6
+ if LooseVersion(np.__version__) < LooseVersion('1.6'):
+ return
+ for minlength in [None, 100]:
+ x = _bincount(np.ones(10, int), None, minlength)
+ y = np.bincount(np.ones(10, int), None, minlength)
+ assert_array_equal(x, y)
+
+
def test_in1d():
"""Test numpy.in1d() replacement"""
a = np.arange(10)
@@ -40,12 +95,12 @@ def test_tril_indices():
def test_unravel_index():
"""Test numpy.unravel_index() replacement"""
assert_equal(_unravel_index(2, (2, 3)), (0, 2))
- assert_equal(_unravel_index(2,(2,2)), (1,0))
- assert_equal(_unravel_index(254,(17,94)), (2,66))
- assert_equal(_unravel_index((2*3 + 1)*6 + 4, (4,3,6)), (2,1,4))
- assert_array_equal(_unravel_index(np.array([22, 41, 37]), (7,6)),
- [[3, 6, 6],[4, 5, 1]])
- assert_array_equal(_unravel_index(1621, (6,7,8,9)), (3,1,4,1))
+ assert_equal(_unravel_index(2, (2, 2)), (1, 0))
+ assert_equal(_unravel_index(254, (17, 94)), (2, 66))
+ assert_equal(_unravel_index((2 * 3 + 1) * 6 + 4, (4, 3, 6)), (2, 1, 4))
+ assert_array_equal(_unravel_index(np.array([22, 41, 37]), (7, 6)),
+ [[3, 6, 6], [4, 5, 1]])
+ assert_array_equal(_unravel_index(1621, (6, 7, 8, 9)), (3, 1, 4, 1))
def test_copysign():
@@ -64,6 +119,7 @@ def test_firwin2():
taps2 = signal.firwin2(150, [0.0, 0.5, 1.0], [1.0, 1.0, 0.0])
assert_array_equal(taps1, taps2)
+
def test_filtfilt():
"""Test IIR filtfilt replacement
"""
diff --git a/mne/tests/test_source_estimate.py b/mne/tests/test_source_estimate.py
index 18f866a..3aac3ae 100644
--- a/mne/tests/test_source_estimate.py
+++ b/mne/tests/test_source_estimate.py
@@ -453,31 +453,6 @@ def test_transform():
assert_array_equal(stc.data, data_t)
-def test_notify_array_source_estimate():
- """Test that modifying the stc data removes the kernel and sensor data"""
- # make up some data
- n_sensors, n_vertices, n_times = 10, 20, 4
- kernel = np.random.randn(n_vertices, n_sensors)
- sens_data = np.random.randn(n_sensors, n_times)
- vertices = np.arange(n_vertices)
-
- stc = VolSourceEstimate((kernel, sens_data), vertices=vertices,
- tmin=0., tstep=1.)
-
- assert_true(stc._data is None)
- assert_true(stc._kernel is not None)
- assert_true(stc._sens_data is not None)
-
- # now modify the data in some way
- data_half = stc.data[:, n_times / 2:]
- data_half[0] = 1.0
- data_half.fill(1.0)
-
- # the kernel and sensor data can no longer be used: they have been removed
- assert_true(stc._kernel is None)
- assert_true(stc._sens_data is None)
-
-
@requires_sklearn
def test_spatio_temporal_tris_connectivity():
"""Test spatio-temporal connectivity from triangles"""
diff --git a/mne/tests/test_utils.py b/mne/tests/test_utils.py
index 8e646bb..cad5074 100644
--- a/mne/tests/test_utils.py
+++ b/mne/tests/test_utils.py
@@ -8,7 +8,8 @@ import urllib2
from ..utils import (set_log_level, set_log_file, _TempDir,
get_config, set_config, deprecated, _fetch_file,
- sum_squared, requires_mem_gb)
+ sum_squared, requires_mem_gb, estimate_rank,
+ _url_to_local_path, sizeof_fmt)
from ..fiff import Evoked, show_fiff
warnings.simplefilter('always') # enable b/c these tests throw warnings
@@ -27,6 +28,21 @@ def clean_lines(lines):
return [l if 'Reading ' not in l else 'Reading test file' for l in lines]
+def test_tempdir():
+ """Test TempDir
+ """
+ tempdir2 = _TempDir()
+ assert_true(op.isdir(tempdir2))
+ tempdir2.cleanup()
+ assert_true(not op.isdir(tempdir2))
+
+
+def test_estimate_rank():
+ data = np.eye(10)
+ data[0, 0] = 0
+ assert_equal(estimate_rank(data), 9)
+
+
def test_logging():
"""Test logging (to file)
"""
@@ -135,6 +151,12 @@ def deprecated_func():
pass
+ at deprecated('message')
+class deprecated_class(object):
+ def __init__(self):
+ pass
+
+
@requires_mem_gb(10000)
def big_mem_func():
pass
@@ -151,6 +173,9 @@ def test_deprecated():
with warnings.catch_warnings(True) as w:
deprecated_func()
assert_true(len(w) == 1)
+ with warnings.catch_warnings(True) as w:
+ deprecated_class()
+ assert_true(len(w) == 1)
def test_requires_mem_gb():
@@ -183,9 +208,19 @@ def test_fetch_file():
except urllib2.URLError:
from nose.plugins.skip import SkipTest
raise SkipTest('No internet connection, skipping download test.')
- url = "http://github.com/mne-tools/mne-python/blob/master/README.rst"
- archive_name = op.join(tempdir, "download_test")
- _fetch_file(url, archive_name, print_destination=False)
+
+ urls = ['http://github.com/mne-tools/mne-python/blob/master/README.rst',
+ 'ftp://surfer.nmr.mgh.harvard.edu/pub/data/bert.recon.md5sum.txt']
+ for url in urls:
+ archive_name = op.join(tempdir, "download_test")
+ _fetch_file(url, archive_name, print_destination=False)
+ assert_raises(Exception, _fetch_file, 'http://0.0',
+ op.join(tempdir, 'test'))
+ resume_name = op.join(tempdir, "download_resume")
+ # touch file
+ with file(resume_name + '.part', 'w'):
+ os.utime(resume_name + '.part', None)
+ _fetch_file(url, resume_name, print_destination=False, resume=True)
def test_sum_squared():
@@ -193,3 +228,18 @@ def test_sum_squared():
"""
X = np.random.randint(0, 50, (3, 3))
assert_equal(np.sum(X ** 2), sum_squared(X))
+
+
+def test_sizeof_fmt():
+ """Test sizeof_fmt
+ """
+ assert_equal(sizeof_fmt(0), '0 bytes')
+ assert_equal(sizeof_fmt(1), '1 byte')
+ assert_equal(sizeof_fmt(1000), '1000 bytes')
+
+
+def test_url_to_local_path():
+ """Test URL to local path
+ """
+ assert_equal(_url_to_local_path('http://google.com/home/why.html', '.'),
+ op.join('.', 'home', 'why.html'))
diff --git a/mne/tests/test_viz.py b/mne/tests/test_viz.py
index 99d7450..583c398 100644
--- a/mne/tests/test_viz.py
+++ b/mne/tests/test_viz.py
@@ -13,7 +13,7 @@ from mne.viz import (plot_topo, plot_topo_tfr, plot_topo_power,
plot_sparse_source_estimates, plot_source_estimates,
plot_cov, mne_analyze_colormap, plot_image_epochs,
plot_connectivity_circle, circular_layout, plot_drop_log,
- compare_fiff)
+ compare_fiff, plot_source_spectrogram)
from mne.datasets import sample
from mne.source_space import read_source_spaces
from mne.preprocessing import ICA
@@ -193,6 +193,14 @@ def test_plot_epochs():
epochs = _get_epochs()
epochs.plot([0, 1], picks=[0, 2, 3], scalings=None, title_str='%s')
epochs[0].plot(picks=[0, 2, 3], scalings=None, title_str='%s')
+ # test clicking: should increase coverage on
+ # 3200-3226, 3235, 3237, 3239-3242, 3245-3255, 3260-3280
+ fig = plt.gcf()
+ fig.canvas.button_press_event(10, 10, 'left')
+ # now let's add a bad channel
+ epochs.info['bads'] = [epochs.ch_names[0]] # include a bad one
+ epochs.plot([0, 1], picks=[0, 2, 3], scalings=None, title_str='%s')
+ epochs[0].plot(picks=[0, 2, 3], scalings=None, title_str='%s')
plt.close('all')
@@ -352,7 +360,17 @@ def test_plot_raw():
"""
raw = _get_raw()
events = _get_events()
- raw.plot(events=events, show_options=True)
+ fig = raw.plot(events=events, show_options=True)
+ # test mouse clicks (XXX not complete yet)
+ fig.canvas.button_press_event(0.5, 0.5, 1)
+ # test keypresses
+ fig.canvas.key_press_event('escape')
+ fig.canvas.key_press_event('down')
+ fig.canvas.key_press_event('up')
+ fig.canvas.key_press_event('right')
+ fig.canvas.key_press_event('left')
+ fig.canvas.key_press_event('o')
+ fig.canvas.key_press_event('escape')
plt.close('all')
@@ -396,8 +414,8 @@ def test_plot_topomap():
assert_raises(RuntimeError, plot_evoked_topomap, evoked, np.repeat(.1, 50))
assert_raises(ValueError, plot_evoked_topomap, evoked, [-3e12, 15e6])
- # projs
- projs = read_proj(ecg_fname)[:7]
+ projs = read_proj(ecg_fname)
+ projs = [p for p in projs if p['desc'].lower().find('eeg') < 0]
plot_projs_topomap(projs)
plt.close('all')
@@ -424,3 +442,20 @@ def test_plot_ica_topomap():
ica.info = None
assert_raises(RuntimeError, ica.plot_topomap, 1)
plt.close('all')
+
+
+ at sample.requires_sample_data
+def test_plot_source_spectrogram():
+ """Test plotting of source spectrogram
+ """
+ sample_src = read_source_spaces(op.join(data_dir, 'subjects', 'sample',
+ 'bem', 'sample-oct-6-src.fif'))
+
+ # dense version
+ vertices = [s['vertno'] for s in sample_src]
+ n_time = 5
+ n_verts = sum(len(v) for v in vertices)
+ stc_data = np.ones((n_verts, n_time))
+ stc = SourceEstimate(stc_data, vertices, 1, 1)
+ plot_source_spectrogram([stc, stc], [[1, 2], [3, 4]])
+ assert_raises(ValueError, plot_source_spectrogram, [], [])
diff --git a/mne/utils.py b/mne/utils.py
index 69583ee..04e053c 100644
--- a/mne/utils.py
+++ b/mne/utils.py
@@ -108,13 +108,10 @@ class _TempDir(str):
We cannot simply use __del__() method for cleanup here because the rmtree
function may be cleaned up before this object, so we use the atexit module
- instead. Passing del_after and print_del kwargs to the constructor are
- helpful primarily for debugging purposes.
+ instead.
"""
- def __new__(self, del_after=True, print_del=False):
+ def __new__(self):
new = str.__new__(self, tempfile.mkdtemp())
- self._del_after = del_after
- self._print_del = print_del
return new
def __init__(self):
@@ -122,10 +119,7 @@ class _TempDir(str):
atexit.register(self.cleanup)
def cleanup(self):
- if self._del_after is True:
- if self._print_del is True:
- print 'Deleting %s ...' % self._path
- rmtree(self._path, ignore_errors=True)
+ rmtree(self._path, ignore_errors=True)
def estimate_rank(data, tol=1e-4, return_singular=False,
@@ -212,14 +206,14 @@ def run_subprocess(command, *args, **kwargs):
logger.info("Running subprocess: %s" % str(command))
p = subprocess.Popen(command, *args, **kwargs)
- stdout, stderr = p.communicate()
+ stdout_, stderr = p.communicate()
- if stdout.strip():
- logger.info("stdout:\n%s" % stdout)
+ if stdout_.strip():
+ logger.info("stdout:\n%s" % stdout_)
if stderr.strip():
logger.info("stderr:\n%s" % stderr)
- output = (stdout, stderr)
+ output = (stdout_, stderr)
if p.returncode:
print output
raise subprocess.CalledProcessError(p.returncode, command, output)
@@ -498,7 +492,7 @@ def requires_tvtk(function):
def dec(*args, **kwargs):
skip = False
try:
- from tvtk.api import tvtk
+ from tvtk.api import tvtk # analysis:ignore
except ImportError:
skip = True
diff --git a/mne/viz.py b/mne/viz.py
index 2e4e339..b290ff4 100644
--- a/mne/viz.py
+++ b/mne/viz.py
@@ -1167,7 +1167,7 @@ def plot_evoked(evoked, picks=None, exclude='bads', unit=True, show=True,
ch_types_used.append(t)
axes_init = axes # remember if axes where given as input
-
+
fig = None
if axes is None:
fig, axes = plt.subplots(n_channel_types, 1)
@@ -1179,7 +1179,7 @@ def plot_evoked(evoked, picks=None, exclude='bads', unit=True, show=True,
if axes_init is not None:
fig = axes[0].get_figure()
-
+
if not len(axes) == n_channel_types:
raise ValueError('Number of axes (%g) must match number of channel '
'types (%g)' % (len(axes), n_channel_types))
@@ -2697,7 +2697,7 @@ def plot_raw(raw, events=None, duration=10.0, start=0.0, n_channels=None,
if show:
plt.show(block=block)
-
+
return fig
@@ -3183,17 +3183,6 @@ def _prepare_trellis(n_cells, max_col):
return fig, axes
-def _plot_epochs_get_data(epochs, epoch_idx, n_channels, times, picks,
- scalings, types):
- """Aux function
- """
- data = np.zeros((len(epoch_idx), n_channels, len(times)))
- for ii, epoch in enumerate(epochs.get_data()[epoch_idx][:, picks]):
- for jj, (this_type, this_channel) in enumerate(zip(types, epoch)):
- data[ii, jj] = this_channel / scalings[this_type]
- return data
-
-
def _draw_epochs_axes(epoch_idx, good_ch_idx, bad_ch_idx, data, times, axes,
title_str, axes_handler):
"""Aux functioin"""
@@ -3245,10 +3234,8 @@ def _epochs_navigation_onclick(event, params):
p['idx_handler'].rotate(here)
p['axes_handler'].rotate(here)
this_idx = p['idx_handler'][0]
- data = _plot_epochs_get_data(p['epochs'], this_idx, p['n_channels'],
- p['times'], p['picks'], p['scalings'],
- p['types'])
- _draw_epochs_axes(this_idx, p['good_ch_idx'], p['bad_ch_idx'], data,
+ _draw_epochs_axes(this_idx, p['good_ch_idx'], p['bad_ch_idx'],
+ p['data'][this_idx],
p['times'], p['axes'], p['title_str'],
p['axes_handler'])
# XXX don't ask me why
@@ -3338,13 +3325,16 @@ def plot_epochs(epochs, epoch_idx=None, picks=None, scalings=None,
raise RuntimeError('No appropriate channels found. Please'
' check your picks')
times = epochs.times * 1e3
- n_channels = len(picks)
+ n_channels = epochs.info['nchan']
types = [channel_type(epochs.info, idx) for idx in
picks]
# preallocation needed for min / max scaling
- data = _plot_epochs_get_data(epochs, idx_handler[0], n_channels,
- times, picks, scalings, types)
+ data = np.zeros((len(epochs.events), n_channels, len(times)))
+ for ii, epoch in enumerate(epochs.get_data()):
+ for jj, (this_type, this_channel) in enumerate(zip(types, epoch)):
+ data[ii, jj] = this_channel / scalings[this_type]
+
n_events = len(epochs.events)
epoch_idx = epoch_idx[:n_events]
idx_handler = deque(create_chunks(epoch_idx, 20))
@@ -3359,9 +3349,9 @@ def plot_epochs(epochs, epoch_idx=None, picks=None, scalings=None,
else:
good_ch_idx = np.arange(n_channels)
- fig, axes = _prepare_trellis(len(data), max_col=5)
+ fig, axes = _prepare_trellis(len(data[idx_handler[0]]), max_col=5)
axes_handler = deque(range(len(idx_handler)))
- for ii, data_, ax in zip(idx_handler[0], data, axes):
+ for ii, data_, ax in zip(idx_handler[0], data[idx_handler[0]], axes):
ax.plot(times, data_[good_ch_idx].T, color='k')
if bad_ch_idx is not None:
ax.plot(times, data_[bad_ch_idx].T, color='r')
@@ -3389,11 +3379,9 @@ def plot_epochs(epochs, epoch_idx=None, picks=None, scalings=None,
'fig': fig,
'idx_handler': idx_handler,
'epochs': epochs,
- 'n_channels': n_channels,
'picks': picks,
'times': times,
'scalings': scalings,
- 'types': types,
'good_ch_idx': good_ch_idx,
'bad_ch_idx': bad_ch_idx,
'axes': axes,
@@ -3402,7 +3390,8 @@ def plot_epochs(epochs, epoch_idx=None, picks=None, scalings=None,
'reject-quit': mpl.widgets.Button(ax3, 'reject-quit'),
'title_str': title_str,
'reject_idx': [],
- 'axes_handler': axes_handler
+ 'axes_handler': axes_handler,
+ 'data': data
}
fig.canvas.mpl_connect('button_press_event',
partial(_epochs_axes_onclick, params=params))
@@ -3436,6 +3425,8 @@ def plot_source_spectrogram(stcs, freq_bins, source_index=None, colorbar=False,
import matplotlib.pyplot as plt
# Gathering results for each time window
+ if len(stcs) == 0:
+ raise ValueError('cannot plot spectrogram if len(stcs) == 0')
source_power = np.array([stc.data for stc in stcs])
# Finding the source with maximum source power
@@ -3454,7 +3445,7 @@ def plot_source_spectrogram(stcs, freq_bins, source_index=None, colorbar=False,
gap_bounds = []
for i in range(len(freq_bins) - 1):
lower_bound = freq_bins[i][1]
- upper_bound = freq_bins[i+1][0]
+ upper_bound = freq_bins[i + 1][0]
if lower_bound != upper_bound:
freq_bounds.remove(lower_bound)
gap_bounds.append((lower_bound, upper_bound))
--
Alioth's /git/debian-med/git-commit-notice on /srv/git.debian.org/git/debian-med/mne-python.git
More information about the debian-med-commit
mailing list