[med-svn] [python-mne] 02/15: New upstream version 0.13+dfsg
Yaroslav Halchenko
debian at onerussian.com
Fri Nov 18 22:17:29 UTC 2016
This is an automated email from the git hooks/post-receive script.
yoh pushed a commit to branch master
in repository python-mne.
commit 8b4f854e092b32acb230b80bdb5ea7ba2a080df2
Author: Yaroslav Halchenko <debian at onerussian.com>
Date: Wed Sep 28 20:36:58 2016 -0400
New upstream version 0.13+dfsg
---
.coveragerc | 1 +
.mailmap | 9 +
.travis.yml | 39 +-
Makefile | 5 +-
README.rst | 112 +-
appveyor.yml | 30 +-
circle.yml | 88 +-
codecov.yml | 14 +
doc/Makefile | 7 +-
doc/_static/style.css | 3 +
doc/_templates/class.rst | 2 +-
doc/cite.rst | 2 +-
doc/conf.py | 17 +-
doc/contributing.rst | 370 +++--
doc/faq.rst | 3 +-
doc/getting_started.rst | 2 +-
doc/index.rst | 15 +-
doc/manual/appendix/bem_model.rst | 10 +-
doc/manual/cookbook.rst | 2 -
doc/manual/decoding.rst | 4 +
doc/manual/io.rst | 24 +-
doc/manual/pitfalls.rst | 13 -
doc/manual/source_localization/inverse.rst | 2 +-
doc/python_reference.rst | 34 +-
doc/sphinxext/gen_commands.py | 13 +-
doc/sphinxext/numpy_ext/__init__.py | 0
doc/sphinxext/numpy_ext/docscrape.py | 512 ------
doc/sphinxext/numpy_ext/docscrape_sphinx.py | 240 ---
doc/sphinxext/numpy_ext/numpydoc.py | 192 ---
doc/tutorials.rst | 18 +-
doc/whats_new.rst | 271 +++-
.../plot_mne_inverse_coherence_epochs.py | 6 +-
.../plot_decoding_spatio_temporal_source.py | 5 +-
...plot_decoding_time_generalization_conditions.py | 8 +-
.../plot_decoding_unsupervised_spatial_filter.py | 67 +
examples/decoding/plot_decoding_xdawn_eeg.py | 4 +-
examples/decoding/plot_ems_filtering.py | 104 +-
examples/inverse/plot_dics_beamformer.py | 14 +-
examples/inverse/plot_dics_source_power.py | 12 +-
examples/inverse/plot_lcmv_beamformer.py | 4 +-
examples/inverse/plot_morph_data.py | 1 -
examples/inverse/plot_tf_dics.py | 10 +-
.../plot_time_frequency_mixed_norm_inverse.py | 7 +-
examples/io/plot_elekta_epochs.py | 68 +
examples/preprocessing/plot_run_ica.py | 20 +-
examples/preprocessing/plot_xdawn_denoising.py | 5 +-
examples/realtime/ftclient_rt_average.py | 1 +
examples/realtime/plot_compute_rt_decoder.py | 4 +-
examples/realtime/rt_feedback_server.py | 6 +-
examples/time_frequency/README.txt | 3 +-
.../plot_compute_raw_data_spectrum.py | 2 +-
.../plot_source_label_time_frequency.py | 10 +-
make/install_python.ps1 | 93 --
mne/__init__.py | 11 +-
mne/annotations.py | 55 +-
mne/baseline.py | 38 +-
mne/beamformer/_dics.py | 16 +-
mne/beamformer/_lcmv.py | 2 +-
mne/beamformer/tests/test_dics.py | 48 +-
mne/beamformer/tests/test_lcmv.py | 18 +-
mne/bem.py | 146 +-
mne/channels/channels.py | 257 ++-
mne/channels/data/layouts/KIT-UMD-3.lout | 158 ++
mne/channels/data/neighbors/KIT-UMD-1_neighb.mat | Bin 0 -> 4750 bytes
mne/channels/data/neighbors/KIT-UMD-2_neighb.mat | Bin 0 -> 4832 bytes
mne/channels/data/neighbors/KIT-UMD-3_neighb.mat | Bin 0 -> 4794 bytes
mne/channels/data/neighbors/__init__.py | 3 +
mne/channels/interpolation.py | 7 +-
mne/channels/layout.py | 78 +-
mne/channels/montage.py | 188 ++-
mne/channels/tests/test_channels.py | 34 +-
mne/channels/tests/test_interpolation.py | 10 +-
mne/channels/tests/test_layout.py | 81 +-
mne/channels/tests/test_montage.py | 103 +-
mne/chpi.py | 149 +-
mne/commands/mne_browse_raw.py | 10 +
mne/commands/mne_coreg.py | 20 +-
mne/commands/mne_flash_bem.py | 6 +-
mne/commands/mne_kit2fiff.py | 6 +-
mne/commands/mne_make_scalp_surfaces.py | 20 +-
mne/commands/mne_show_fiff.py | 6 +
mne/commands/tests/test_commands.py | 44 +-
mne/connectivity/spectral.py | 13 +-
mne/connectivity/tests/test_spectral.py | 51 +-
mne/coreg.py | 168 +-
mne/cov.py | 221 ++-
mne/cuda.py | 9 +-
mne/data/coil_def.dat | 17 +-
mne/data/mne_analyze.sel | 8 +-
mne/datasets/__init__.py | 1 +
mne/datasets/brainstorm/__init__.py | 3 +-
mne/datasets/brainstorm/bst_auditory.py | 3 +-
.../{bst_resting.py => bst_phantom_ctf.py} | 19 +-
.../{bst_resting.py => bst_phantom_elekta.py} | 20 +-
mne/datasets/brainstorm/bst_raw.py | 3 +-
mne/datasets/brainstorm/bst_resting.py | 3 +-
mne/datasets/megsim/urls.py | 2 +-
mne/datasets/multimodal/__init__.py | 4 +
.../{somato/somato.py => multimodal/multimodal.py} | 15 +-
mne/datasets/sample/sample.py | 3 +-
mne/datasets/somato/somato.py | 3 +-
mne/datasets/spm_face/spm_data.py | 3 +-
mne/datasets/testing/_testing.py | 3 +-
mne/datasets/tests/test_datasets.py | 10 +
mne/datasets/utils.py | 96 +-
mne/decoding/__init__.py | 6 +-
mne/decoding/base.py | 114 +-
mne/decoding/csp.py | 385 +++--
mne/decoding/ems.py | 149 +-
mne/decoding/search_light.py | 629 ++++++++
mne/decoding/tests/test_csp.py | 134 +-
mne/decoding/tests/test_ems.py | 46 +-
mne/decoding/tests/test_search_light.py | 170 ++
mne/decoding/tests/test_time_frequency.py | 41 +
mne/decoding/tests/test_time_gen.py | 25 +-
mne/decoding/tests/test_transformer.py | 141 +-
mne/decoding/time_frequency.py | 152 ++
mne/decoding/time_gen.py | 97 +-
mne/decoding/transformer.py | 363 ++++-
mne/defaults.py | 29 +-
mne/dipole.py | 220 ++-
mne/epochs.py | 1010 +++++-------
mne/event.py | 562 ++++++-
mne/evoked.py | 677 +++-----
mne/externals/h5io/_h5io.py | 186 ++-
mne/externals/tempita/_looper.py | 4 +-
mne/filter.py | 1664 ++++++++++++++------
mne/fixes.py | 1107 ++++---------
mne/forward/__init__.py | 2 +-
mne/forward/_compute_forward.py | 2 +-
mne/forward/_field_interpolation.py | 5 +-
mne/forward/_make_forward.py | 3 +-
mne/forward/forward.py | 116 +-
mne/forward/tests/test_field_interpolation.py | 14 +-
mne/forward/tests/test_forward.py | 5 +-
mne/forward/tests/test_make_forward.py | 5 +-
mne/gui/__init__.py | 2 +-
mne/gui/_coreg_gui.py | 214 ++-
mne/gui/_fiducials_gui.py | 53 +-
mne/gui/_file_traits.py | 97 +-
mne/gui/_kit2fiff_gui.py | 277 +++-
mne/gui/help/kit2fiff.json | 5 +-
mne/gui/tests/test_coreg_gui.py | 73 +-
mne/gui/tests/test_fiducials_gui.py | 2 +-
mne/gui/tests/test_file_traits.py | 6 +-
mne/gui/tests/test_kit2fiff_gui.py | 18 +-
mne/inverse_sparse/mxne_inverse.py | 2 +-
mne/io/__init__.py | 3 +-
mne/io/array/array.py | 2 +-
mne/io/array/tests/test_array.py | 42 +-
mne/io/base.py | 784 +++++----
mne/io/brainvision/brainvision.py | 315 +++-
mne/io/brainvision/tests/data/test.vhdr | 24 +-
.../brainvision/tests/data/test_highpass_hz.vhdr | 103 ++
mne/io/brainvision/tests/data/test_lowpass_s.vhdr | 103 ++
.../data/{test.vhdr => test_mixed_highpass.vhdr} | 103 +-
.../tests/data/test_mixed_highpass_hz.vhdr | 103 ++
.../data/{test.vhdr => test_mixed_lowpass.vhdr} | 103 +-
.../tests/data/test_mixed_lowpass_s.vhdr | 103 ++
.../test_old_layout_latin1_software_filter.eeg | Bin 0 -> 29116 bytes
.../test_old_layout_latin1_software_filter.vhdr | 156 ++
.../test_old_layout_latin1_software_filter.vmrk | 14 +
...vhdr => test_partially_disabled_hw_filter.vhdr} | 63 +-
mne/io/brainvision/tests/test_brainvision.py | 198 ++-
mne/io/bti/bti.py | 11 +-
mne/io/bti/tests/test_bti.py | 98 +-
mne/io/cnt/cnt.py | 88 +-
mne/io/cnt/tests/test_cnt.py | 7 +-
mne/io/compensator.py | 37 +-
mne/io/constants.py | 109 +-
mne/io/ctf/info.py | 20 +
mne/io/ctf/tests/test_ctf.py | 1 +
mne/io/edf/edf.py | 17 +-
mne/io/edf/tests/test_edf.py | 48 +-
mne/io/eeglab/eeglab.py | 96 +-
mne/io/eeglab/tests/test_eeglab.py | 45 +-
mne/io/fiff/raw.py | 162 +-
mne/io/fiff/tests/test_raw_fiff.py | 625 +++++---
mne/io/kit/constants.py | 61 +-
mne/io/kit/coreg.py | 7 +-
mne/io/kit/kit.py | 65 +-
mne/io/kit/tests/data/test.elp | 37 +
mne/io/kit/tests/data/test.hsp | 514 ++++++
mne/io/kit/tests/data/test_umd-raw.sqd | Bin 0 -> 99692 bytes
mne/io/kit/tests/test_coreg.py | 15 +
mne/io/kit/tests/test_kit.py | 83 +-
mne/io/meas_info.py | 131 +-
mne/io/nicolet/nicolet.py | 2 +-
mne/io/open.py | 20 +-
mne/io/pick.py | 45 +-
mne/io/proj.py | 23 +-
mne/io/reference.py | 53 +-
mne/io/tag.py | 2 +-
mne/io/tests/data/test-ave-2.log | 6 +-
mne/io/tests/data/test-ave.log | 3 +-
mne/io/tests/test_compensator.py | 56 +-
mne/io/tests/test_meas_info.py | 83 +-
mne/io/tests/test_pick.py | 29 +-
mne/io/tests/test_proc_history.py | 15 +-
mne/io/tests/test_raw.py | 12 +-
mne/io/tests/test_reference.py | 112 +-
mne/io/write.py | 13 +-
mne/label.py | 79 +-
mne/minimum_norm/tests/test_inverse.py | 34 +-
mne/minimum_norm/tests/test_time_frequency.py | 23 +-
mne/minimum_norm/time_frequency.py | 8 +-
mne/preprocessing/__init__.py | 2 +-
mne/preprocessing/_fine_cal.py | 84 +
mne/preprocessing/ecg.py | 60 +-
mne/preprocessing/eog.py | 15 +-
mne/preprocessing/ica.py | 423 +++--
mne/preprocessing/infomax_.py | 127 +-
mne/preprocessing/maxfilter.py | 4 +-
mne/preprocessing/maxwell.py | 335 ++--
mne/preprocessing/ssp.py | 7 +-
mne/preprocessing/tests/test_ecg.py | 24 +-
mne/preprocessing/tests/test_eeglab_infomax.py | 283 ++--
mne/preprocessing/tests/test_eog.py | 6 +-
mne/preprocessing/tests/test_fine_cal.py | 40 +
mne/preprocessing/tests/test_ica.py | 164 +-
mne/preprocessing/tests/test_infomax.py | 40 +-
mne/preprocessing/tests/test_maxwell.py | 292 ++--
mne/preprocessing/tests/test_ssp.py | 28 +-
mne/preprocessing/tests/test_stim.py | 13 +-
mne/preprocessing/tests/test_xdawn.py | 204 ++-
mne/preprocessing/xdawn.py | 641 +++++---
mne/proj.py | 4 +-
mne/realtime/epochs.py | 15 +-
mne/realtime/fieldtrip_client.py | 2 +
mne/realtime/stim_server_client.py | 8 +-
mne/realtime/tests/test_fieldtrip_client.py | 3 +
mne/realtime/tests/test_mockclient.py | 24 +-
mne/realtime/tests/test_stim_client_server.py | 4 +-
mne/report.py | 38 +-
mne/selection.py | 76 +-
mne/simulation/evoked.py | 25 +-
mne/simulation/raw.py | 8 +-
mne/simulation/source.py | 145 +-
mne/simulation/tests/test_evoked.py | 28 +-
mne/simulation/tests/test_raw.py | 70 +-
mne/simulation/tests/test_source.py | 59 +-
mne/source_estimate.py | 254 ++-
mne/source_space.py | 69 +-
mne/stats/cluster_level.py | 17 +-
mne/stats/parametric.py | 15 +-
mne/stats/regression.py | 3 +-
mne/stats/tests/test_cluster_level.py | 8 +-
mne/stats/tests/test_regression.py | 15 +-
mne/surface.py | 159 +-
mne/tests/test_annotations.py | 43 +-
mne/tests/test_bem.py | 58 +-
mne/tests/test_chpi.py | 109 +-
mne/tests/test_cov.py | 114 +-
mne/tests/test_dipole.py | 64 +-
mne/tests/test_docstring_parameters.py | 26 +-
mne/tests/test_epochs.py | 631 ++++----
mne/tests/test_event.py | 149 +-
mne/tests/test_evoked.py | 171 +-
mne/tests/test_filter.py | 376 +++--
mne/tests/test_fixes.py | 184 +--
mne/tests/test_import_nesting.py | 5 +-
mne/tests/test_label.py | 73 +-
mne/tests/test_line_endings.py | 8 +-
mne/tests/test_proj.py | 41 +-
mne/tests/test_report.py | 39 +-
mne/tests/test_selection.py | 4 +-
mne/tests/test_source_estimate.py | 18 +-
mne/tests/test_source_space.py | 15 +-
mne/tests/test_surface.py | 39 +-
mne/tests/test_transforms.py | 2 +-
mne/tests/test_utils.py | 223 ++-
mne/time_frequency/__init__.py | 10 +-
mne/time_frequency/_stockwell.py | 10 +-
mne/time_frequency/csd.py | 292 +++-
mne/time_frequency/multitaper.py | 60 +-
mne/time_frequency/psd.py | 179 +--
mne/time_frequency/tests/test_ar.py | 8 +-
mne/time_frequency/tests/test_csd.py | 301 +++-
mne/time_frequency/tests/test_psd.py | 171 +-
mne/time_frequency/tests/test_stockwell.py | 51 +-
mne/time_frequency/tests/test_tfr.py | 219 ++-
mne/time_frequency/tfr.py | 1557 ++++++++++--------
mne/transforms.py | 12 +-
mne/utils.py | 479 +++++-
mne/viz/_3d.py | 130 +-
mne/viz/__init__.py | 6 +-
mne/viz/circle.py | 7 +-
mne/viz/epochs.py | 102 +-
mne/viz/evoked.py | 583 ++++++-
mne/viz/ica.py | 367 +++--
mne/viz/misc.py | 116 +-
mne/viz/raw.py | 253 ++-
mne/viz/tests/test_3d.py | 18 +-
mne/viz/tests/test_decoding.py | 8 +-
mne/viz/tests/test_epochs.py | 26 +-
mne/viz/tests/test_evoked.py | 73 +-
mne/viz/tests/test_ica.py | 117 +-
mne/viz/tests/test_misc.py | 31 +-
mne/viz/tests/test_raw.py | 78 +-
mne/viz/tests/test_topo.py | 42 +-
mne/viz/tests/test_topomap.py | 63 +-
mne/viz/tests/test_utils.py | 35 +-
mne/viz/topo.py | 47 +-
mne/viz/topomap.py | 300 +++-
mne/viz/utils.py | 594 ++++++-
setup.cfg | 1 +
setup.py | 1 +
tutorials/plot_artifacts_correction_filtering.py | 69 +-
tutorials/plot_artifacts_correction_ica.py | 136 +-
.../plot_artifacts_correction_maxwell_filtering.py | 2 +-
tutorials/plot_artifacts_correction_rejection.py | 8 +-
tutorials/plot_artifacts_correction_ssp.py | 3 +-
tutorials/plot_background_filtering.py | 951 +++++++++++
tutorials/plot_brainstorm_auditory.py | 37 +-
tutorials/plot_brainstorm_phantom_ctf.py | 112 ++
tutorials/plot_brainstorm_phantom_elekta.py | 106 ++
tutorials/plot_compute_covariance.py | 7 +-
tutorials/plot_dipole_fit.py | 2 +-
tutorials/plot_eeg_erp.py | 30 +-
tutorials/plot_epoching_and_averaging.py | 7 +-
tutorials/plot_epochs_to_data_frame.py | 8 +-
tutorials/plot_forward.py | 18 +-
tutorials/plot_ica_from_raw.py | 5 +-
tutorials/plot_info.py | 2 +-
tutorials/plot_introduction.py | 51 +-
tutorials/plot_mne_dspm_source_localization.py | 29 +-
tutorials/plot_modifying_data_inplace.py | 28 +-
tutorials/plot_object_epochs.py | 7 +-
tutorials/plot_object_raw.py | 44 +-
tutorials/plot_point_spread.py | 171 ++
tutorials/plot_sensors_decoding.py | 20 +-
tutorials/plot_sensors_time_frequency.py | 6 +-
...plot_stats_cluster_1samp_test_time_frequency.py | 73 +-
tutorials/plot_stats_cluster_methods.py | 3 +-
tutorials/plot_stats_cluster_spatio_temporal.py | 6 +-
.../plot_stats_cluster_spatio_temporal_2samp.py | 5 +-
...ster_spatio_temporal_repeated_measures_anova.py | 6 +-
tutorials/plot_stats_cluster_time_frequency.py | 89 +-
...uster_time_frequency_repeated_measures_anova.py | 86 +-
.../plot_stats_spatio_temporal_cluster_sensors.py | 3 +-
tutorials/plot_visualize_epochs.py | 14 +-
tutorials/plot_visualize_evoked.py | 33 +-
tutorials/plot_visualize_raw.py | 54 +-
343 files changed, 22134 insertions(+), 11303 deletions(-)
diff --git a/.coveragerc b/.coveragerc
index 6b9b8a5..5df0314 100644
--- a/.coveragerc
+++ b/.coveragerc
@@ -6,3 +6,4 @@ omit =
*/mne/externals/*
*/bin/*
*/setup.py
+ */mne/fixes*
diff --git a/.mailmap b/.mailmap
index c951ce4..cb4759f 100644
--- a/.mailmap
+++ b/.mailmap
@@ -8,6 +8,7 @@ Martin Luessi <mluessi at nmr.mgh.harvard.edu> martin <martin at think.(none)>
Matti Hamalainen <msh at nmr.mgh.harvard.edu> Matti Hamalainen <msh at parsley.nmr.mgh.harvard.edu>
Matti Hamalainen <msh at nmr.mgh.harvard.edu> mshamalainen <msh at nmr.mgh.harvard.edu>
Christian Brodbeck <christianmbrodbeck at gmail.com> christianmbrodbeck <christianmbrodbeck at gmail.com>
+Christian Brodbeck <christianmbrodbeck at gmail.com> Christian Brodbeck <christianbrodbeck at nyu.edu>
Louis Thibault <louist87 at gmail.com> = <louist87 at gmail.com>
Louis Thibault <louist87 at gmail.com> Louis Thibault <louist at ltpc.(none)>
Eric Larson <larson.eric.d at gmail.com> Eric Larson <larson.eric.d at gmail.com>
@@ -53,6 +54,8 @@ Jean-Remi King <jeanremi.kibng+github at gmail.com> kingjr <jeanremi.kibng+github at g
Jean-Remi King <jeanremi.kibng+github at gmail.com> UMR9752 <jeanremi.king+github at gmail.com>
Jean-Remi King <jeanremi.kibng+github at gmail.com> UMR9752 <umr9752 at umr9752-desktop.(none)>
Jean-Remi King <jeanremi.kibng+github at gmail.com> kingjr <jeanremi.king+github at gmail.com>
+Jean-Remi King <jeanremi.kibng+github at gmail.com> Jean-Rémi KING <jeanremi.king at gmail.com>
+Jean-Remi King <jeanremi.kibng+github at gmail.com> kingjr <jeanremi.king at gmail.com>
Roan LaPlante <aestrivex at gmail.com> aestrivex <aestrivex at gmail.com>
Mark Wronkiewicz <wronk.mark at gmail.com> wronk <wronk.mark at gmail.com>
Basile Pinsard <basile.pinsard at umontreal.ca>
@@ -73,3 +76,9 @@ Daniel McCloy <dan.mccloy at gmail.com> drammock <dan.mccloy at gmail.com>
Fede Raimondo <slashack at gmail.com> Fede <slashack at gmail.com>
Emily Stephen <emilyps14 at gmail.com> emilyps14 <emilyps14 at gmail.com>
Marian Dovgialo <mdovgialo at fabrizzio.zfb.fuw.edu.pl>
+Guillaume Dumas <deep at introspection.eu> deep-introspection <deep at introspection.eu>
+Guillaume Dumas <deep at introspection.eu> Guillaume Dumas <deep-introspection at users.noreply.github.com>
+Félix Raimundo <gamaz3ps at gmail.com> Felix Raimundo <gamaz3ps at gmail.com>
+Asish Panda <asishrocks95 at gmail.com> kaichogami <asishrocks95 at gmail.com>
+Mikolaj Magnuski <mmagnuski at swps.edu.pl> mmagnuski <mmagnuski at swps.edu.pl>
+Alexandre Barachant <alexandre.barachant at gmail.com> alexandre barachant <alexandre.barachant at gmail.com>
\ No newline at end of file
diff --git a/.travis.yml b/.travis.yml
index b942197..90a13cb 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -14,9 +14,6 @@ env:
# Note that we don't run coverage on Py3k anyway because it slows our tests
# by a factor of 2 (!), so we make this our "from install dir" run.
#
- # If we change the old-version run to be a different Python version
- # from 2.6, then we need to update mne.utils.clean_warning_registry.
- #
# Run one test (3.5) with a non-default stim channel to make sure our
# tests are explicit about channels.
#
@@ -24,10 +21,11 @@ env:
#
# Conda currently has packaging bug with mayavi/traits/numpy where 1.10 can't be used
# but breaks sklearn on install; hopefully eventually the NUMPY=1.9 on 2.7 full can be removed
+ # Mayavi=4.3 on old 2.7 installs, but doesn't work properly due to a traits bug
- PYTHON=2.7 DEPS=full TEST_LOCATION=src NUMPY="=1.9" SCIPY="=0.17"
- PYTHON=2.7 DEPS=nodata TEST_LOCATION=src MNE_DONTWRITE_HOME=true MNE_FORCE_SERIAL=true MNE_SKIP_NETWORK_TEST=1 # also runs flake8
- PYTHON=3.5 DEPS=full TEST_LOCATION=install MNE_STIM_CHANNEL=STI101
- - PYTHON=2.6 DEPS=full TEST_LOCATION=src NUMPY="=1.7" SCIPY="=0.11" MPL="=1.1" LIBPNG="=1.5" SKLEARN="=0.11" PANDAS="=0.8"
+ - PYTHON=2.7 DEPS=full TEST_LOCATION=src NUMPY="=1.8" SCIPY="=0.12" MPL="=1.3" SKLEARN="=0.14" PANDAS="=0.12"
- PYTHON=2.7 DEPS=minimal TEST_LOCATION=src
# Setup anaconda
@@ -49,7 +47,7 @@ install:
# We have to replicate e.g. numpy$NUMPY to ensure the recommended (higher) versions
# are not automatically installed below with multiple "conda install" calls!
- if [ "${DEPS}" == "full" ]; then
- curl http://lester.ilabs.uw.edu/files/minimal_cmds.tar.gz | tar xz;
+ curl https://staff.washington.edu/larsoner/minimal_cmds.tar.gz | tar xz;
export MNE_ROOT="${PWD}/minimal_cmds";
export NEUROMAG2FT_ROOT="${PWD}/minimal_cmds/bin";
source ${MNE_ROOT}/bin/mne_setup_sh;
@@ -59,22 +57,18 @@ install:
conda install --yes --quiet $ENSURE_PACKAGES ipython;
else
conda install --yes --quiet $ENSURE_PACKAGES ipython==1.1.0 statsmodels pandas$PANDAS;
- pip install -q nitime;
- if [ "${PYTHON}" == "2.7" ]; then
- conda install --yes --quiet $ENSURE_PACKAGES mayavi traits;
- pip install -q pysurfer faulthandler;
+ pip install nitime faulthandler;
+ if [ "${NUMPY}" != "=1.8" ]; then
+ conda install --yes --quiet $ENSURE_PACKAGES mayavi$MAYAVI;
+ pip install pysurfer;
fi;
fi;
fi;
- if [ "${DEPS}" == "nodata" ]; then
- pip install -q flake8;
- wget -q https://github.com/lucasdemarchi/codespell/archive/v1.8.tar.gz;
- tar xzf v1.8.tar.gz;
- cp codespell-1.8/codespell.py ~/miniconda/envs/testenv/bin;
- rm v1.8.tar.gz;
- rm -r codespell-1.8;
+ conda install --yes $ENSURE_PACKAGES sphinx;
+ pip install flake8 codespell numpydoc;
fi;
- - pip install -q coveralls nose-timer
+ - pip install -q codecov nose-timer
# check our versions for the major packages
- NP_VERSION=`python -c 'import numpy; print(numpy.__version__)'`
- if [ -n "$NUMPY" ] && [ "${NUMPY:(-3)}" != "${NP_VERSION::3}" ]; then
@@ -106,6 +100,7 @@ install:
python -c 'import mne; mne.datasets.testing.data_path(verbose=True)';
if [ "${DEPS}" == "full" ]; then
export FREESURFER_HOME=$(python -c 'import mne; print(mne.datasets.testing.data_path())');
+ export MNE_SKIP_FS_FLASH_CALL=1;
fi;
else
export MNE_SKIP_TESTING_DATASET_TESTS=true;
@@ -126,13 +121,10 @@ install:
ln -s ${SRC_DIR}/setup.cfg ${MNE_DIR}/../setup.cfg;
ln -s ${SRC_DIR}/.coveragerc ${MNE_DIR}/../.coveragerc;
cd ${MNE_DIR}/../;
+ COVERAGE=;
else
cd ${SRC_DIR};
- fi;
- - if [ "${PYTHON}" != "3.5" ]; then
COVERAGE=--with-coverage;
- else
- COVERAGE=;
fi;
script:
@@ -146,9 +138,8 @@ script:
after_success:
# Need to run from source dir to exectue "git" commands
- # Coverage not collected for 3.5, so don't report it
- - if [ "${TEST_LOCATION}" == "src" ] && [ "${PYTHON}" != "3.5" ]; then
- echo "Running coveralls";
+ - if [ "${TEST_LOCATION}" == "src" ]; then
+ echo "Running codecov";
cd ${SRC_DIR};
- coveralls;
+ codecov;
fi;
diff --git a/Makefile b/Makefile
index a01dc85..9353cf1 100755
--- a/Makefile
+++ b/Makefile
@@ -5,8 +5,7 @@
PYTHON ?= python
NOSETESTS ?= nosetests
CTAGS ?= ctags
-# The *.fif had to be there twice to be properly ignored (!)
-CODESPELL_SKIPS ?= "*.fif,*.fif,*.eve,*.gz,*.tgz,*.zip,*.mat,*.stc,*.label,*.w,*.bz2,*.annot,*.sulc,*.log,*.local-copy,*.orig_avg,*.inflated_avg,*.gii,*.pyc,*.doctree,*.pickle,*.inv,*.png,*.edf,*.touch,*.thickness,*.nofix,*.volume,*.defect_borders,*.mgh,lh.*,rh.*,COR-*,FreeSurferColorLUT.txt,*.examples,.xdebug_mris_calc,bad.segments,BadChannels,*.hist,empty_file,*.orig"
+CODESPELL_SKIPS ?= "*.fif,*.eve,*.gz,*.tgz,*.zip,*.mat,*.stc,*.label,*.w,*.bz2,*.annot,*.sulc,*.log,*.local-copy,*.orig_avg,*.inflated_avg,*.gii,*.pyc,*.doctree,*.pickle,*.inv,*.png,*.edf,*.touch,*.thickness,*.nofix,*.volume,*.defect_borders,*.mgh,lh.*,rh.*,COR-*,FreeSurferColorLUT.txt,*.examples,.xdebug_mris_calc,bad.segments,BadChannels,*.hist,empty_file,*.orig,*.js,*.map,*.ipynb"
CODESPELL_DIRS ?= mne/ doc/ tutorials/ examples/
all: clean inplace test test-doc
@@ -104,7 +103,7 @@ codespell: # running manually
@codespell.py -w -i 3 -q 3 -S $(CODESPELL_SKIPS) -D ./dictionary.txt $(CODESPELL_DIRS)
codespell-error: # running on travis
- @codespell.py -i 0 -q 7 -S $(CODESPELL_SKIPS) -D ./dictionary.txt $(CODESPELL_DIRS) | tee /dev/tty | wc -l | xargs test 0 -eq
+ @codespell.py -i 0 -q 7 -S $(CODESPELL_SKIPS) -D ./dictionary.txt $(CODESPELL_DIRS)
manpages:
@echo "I: generating manpages"
diff --git a/README.rst b/README.rst
index 543268a..f07d1ae 100644
--- a/README.rst
+++ b/README.rst
@@ -1,7 +1,7 @@
.. -*- mode: rst -*-
-|Travis|_ |Appveyor|_ |Coveralls|_ |Zenodo|_
+|Travis|_ |Appveyor|_ |Codecov|_ |Zenodo|_
.. |Travis| image:: https://api.travis-ci.org/mne-tools/mne-python.png?branch=master
.. _Travis: https://travis-ci.org/mne-tools/mne-python
@@ -9,95 +9,74 @@
.. |Appveyor| image:: https://ci.appveyor.com/api/projects/status/reccwk3filrasumg/branch/master?svg=true
.. _Appveyor: https://ci.appveyor.com/project/Eric89GXL/mne-python/branch/master
-.. |Coveralls| image:: https://coveralls.io/repos/mne-tools/mne-python/badge.png?branch=master
-.. _Coveralls: https://coveralls.io/r/mne-tools/mne-python?branch=master
+.. |Codecov| image:: https://codecov.io/gh/mne-tools/mne-python/branch/master/graph/badge.svg
+.. _Codecov: https://codecov.io/gh/mne-tools/mne-python
.. |Zenodo| image:: https://zenodo.org/badge/5822/mne-tools/mne-python.svg
.. _Zenodo: https://zenodo.org/badge/latestdoi/5822/mne-tools/mne-python
-`mne-python <http://mne-tools.github.io/>`_
+`MNE-Python <http://mne-tools.github.io/>`_
=======================================================
This package is designed for sensor- and source-space analysis of [M/E]EG
data, including frequency-domain and time-frequency analyses, MVPA/decoding
-and non-parametric statistics. This package is presently evolving quickly and
-thanks to the adopted open development environment user contributions can
-be easily incorporated.
+and non-parametric statistics. This package generally evolves quickly and
+user contributions can easily be incorporated thanks to the open
+development environment .
Get more information
^^^^^^^^^^^^^^^^^^^^
-This page only contains bare-bones instructions for installing mne-python.
-
-If you're familiar with MNE and you're looking for information on using
-mne-python specifically, jump right to the `mne-python homepage
-<http://mne-tools.github.io/stable/python_reference.html>`_. This website includes
-`tutorials <http://mne-tools.github.io/stable/tutorials.html>`_,
-helpful `examples <http://mne-tools.github.io/stable/auto_examples/index.html>`_, and
-a handy `function reference <http://mne-tools.github.io/stable/python_reference.html>`_,
-among other things.
-
-If you're unfamiliar with MNE, you can visit the
-`MNE homepage <http://martinos.org/mne>`_ for full user documentation.
+If you're unfamiliar with MNE or MNE-Python, you can visit the
+`MNE homepage <http://mne-tools.github.io/>`_ for full user documentation.
Get the latest code
^^^^^^^^^^^^^^^^^^^
-To get the latest code using git, simply type::
+To get the latest code using `git <https://git-scm.com/>`_, simply type:
+
+.. code-block:: bash
- git clone git://github.com/mne-tools/mne-python.git
+ $ git clone git://github.com/mne-tools/mne-python.git
-If you don't have git installed, you can download a zip
-of the latest code: https://github.com/mne-tools/mne-python/archive/master.zip
+If you don't have git installed, you can download a
+`zip of the latest code <https://github.com/mne-tools/mne-python/archive/master.zip>`_.
Install mne-python
^^^^^^^^^^^^^^^^^^
-As any Python packages, to install MNE-Python, after obtaining the source code
-(e.g. from git), go in the mne-python source code directory and do::
-
- python setup.py install
-
-or if you don't have admin access to your python setup (permission denied
-when install) use::
-
- python setup.py install --user
-
-You can also install the latest release version with easy_install::
-
- easy_install -U mne
-
-or with pip::
-
- pip install mne
-
-for an update of an already installed version use::
+As with most Python packages, to install the latest stable version of
+MNE-Python, you can do:
- pip install mne --upgrade
+.. code-block:: bash
-or for the latest development version (the most up to date)::
+ $ pip install mne
- pip install -e git+https://github.com/mne-tools/mne-python#egg=mne-dev --user
+For more complete instructions and more advanced install methods (e.g. for
+the latest development version), see the
+`getting started page <http://mne-tools.github.io/stable/getting_started.html>`_
+page.
Dependencies
^^^^^^^^^^^^
-The required dependencies to build the software are python >= 2.6,
-NumPy >= 1.6, SciPy >= 0.7.2 and matplotlib >= 0.98.4.
+The minimum required dependencies to run the software are:
-Some isolated functions require pandas >= 0.7.3.
-Decoding relies on scikit-learn >= 0.15.
+ - Python >= 2.7
+ - NumPy >= 1.8
+ - SciPy >= 0.12
+ - matplotlib >= 1.3
-To run the tests you will also need nose >= 0.10.
-and the MNE sample dataset (will be downloaded automatically
-when you run an example ... but be patient).
+For full functionality, some functions require:
-To use NVIDIA CUDA for resampling and FFT FIR filtering, you will also need
-to install the NVIDIA CUDA SDK, pycuda, and scikits.cuda. The difficulty of this
-varies by platform; consider reading the following site for help getting pycuda
-to work (typically the most difficult to configure):
+ - scikit-learn >= 0.18
+ - nibabel >= 2.1.0
+ - pandas >= 0.12
-http://wiki.tiker.net/PyCuda/Installation/
+To use NVIDIA CUDA for resampling and FFT FIR filtering, you will also need
+to install the NVIDIA CUDA SDK, pycuda, and scikits.cuda. See the
+`getting started page <http://mne-tools.github.io/stable/getting_started.html>`_
+for more information.
Contribute to mne-python
^^^^^^^^^^^^^^^^^^^^^^^^
@@ -111,25 +90,6 @@ Mailing list
http://mail.nmr.mgh.harvard.edu/mailman/listinfo/mne_analysis
-Running the test suite
-^^^^^^^^^^^^^^^^^^^^^^
-
-To run the test suite, you need nosetests and the coverage modules.
-Run the test suite using::
-
- nosetests
-
-from the root of the project.
-
-Making a release and uploading it to PyPI
-^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
-
-This command is only run by project manager, to make a release, and
-upload in to PyPI::
-
- python setup.py sdist bdist_egg register upload
-
-
Licensing
^^^^^^^^^
diff --git a/appveyor.yml b/appveyor.yml
index 3953842..3486ee1 100644
--- a/appveyor.yml
+++ b/appveyor.yml
@@ -1,29 +1,17 @@
-# CI on Windows via appveyor
-# This file was based on Olivier Grisel's python-appveyor-demo
-
environment:
-
+ global:
+ PYTHON: "C:\\conda"
+ MINICONDA_VERSION: "latest"
+ CONDA_DEPENDENCIES: "setuptools numpy scipy matplotlib scikit-learn nose mayavi pandas h5py PIL patsy pyside"
matrix:
- - PYTHON: "C:\\Python27-conda64"
- PYTHON_VERSION: "2.7"
- PYTHON_ARCH: "64"
+ - PYTHON_VERSION: "2.7"
+ PYTHON_ARCH: "64"
install:
- # Install miniconda Python
- - "powershell ./make/install_python.ps1"
-
- # Prepend newly installed Python to the PATH of this build (this cannot be
- # done from inside the powershell script as it would require to restart
- # the parent CMD process).
+ - "git clone git://github.com/astropy/ci-helpers.git"
+ - "powershell ci-helpers/appveyor/install-miniconda.ps1"
- "SET PATH=%PYTHON%;%PYTHON%\\Scripts;%PATH%"
-
- # Check that we have the expected version and architecture for Python
- - "python --version"
- - "python -c \"import struct; print(struct.calcsize('P') * 8)\""
-
- # Install the dependencies of the project (skip nibabel for speed)
- - "conda create -n testenv --yes --quiet setuptools numpy scipy matplotlib scikit-learn nose mayavi pandas h5py PIL patsy pyside"
- - "activate testenv"
+ - "activate test"
- "pip install nose-timer nibabel nitime"
- "python setup.py develop"
- "SET MNE_SKIP_NETWORK_TESTS=1"
diff --git a/circle.yml b/circle.yml
index 5a933a2..7092492 100644
--- a/circle.yml
+++ b/circle.yml
@@ -2,15 +2,21 @@ machine:
environment:
# We need to set this variable to let Anaconda take precedence
PATH: "/home/ubuntu/miniconda/envs/circleenv/bin:/home/ubuntu/miniconda/bin:$PATH"
- MNE_DATA: "/home/ubuntu/mne_data"
DISPLAY: ":99.0"
dependencies:
cache_directories:
- - "/home/ubuntu/miniconda"
- - "/home/ubuntu/.mne"
- - "/home/ubuntu/mne_data"
- - "/home/ubuntu/mne-tools.github.io"
+ - "~/miniconda"
+ - "~/.mne"
+ - "~/mne_data/MNE-sample-data"
+ - "~/mne_data/MNE-testing-data"
+ - "~/mne_data/MNE-misc-data"
+ - "~/mne_data/MNE-spm-face"
+ - "~/mne_data/MNE-somato-data"
+ - "~/mne_data/MNE-brainstorm-data"
+ - "~/mne_data/MEGSIM"
+ - "~/mne_data/MNE-eegbci-data"
+ - "~/mne-tools.github.io"
# Various dependencies
pre:
# Get a running Python
@@ -26,27 +32,72 @@ dependencies:
chmod +x ~/miniconda.sh;
~/miniconda.sh -b -p /home/ubuntu/miniconda;
conda update --yes --quiet conda;
- conda create -n circleenv --yes pip python=2.7 pip numpy scipy scikit-learn mayavi matplotlib sphinx pillow six IPython pandas;
+ conda create -n circleenv --yes pip python=2.7 pip;
sed -i "s/ENABLE_USER_SITE = .*/ENABLE_USER_SITE = False/g" /home/ubuntu/miniconda/envs/circleenv/lib/python2.7/site.py;
else
echo "Conda already set up.";
fi
+ - conda install -n circleenv --yes numpy scipy scikit-learn mayavi matplotlib sphinx pillow six IPython pandas;
- ls -al /home/ubuntu/miniconda;
- ls -al /home/ubuntu/miniconda/bin;
- echo $PATH;
+ - echo $CIRCLE_BRANCH
- which python;
- which pip;
+ - pip install --upgrade pyface;
- git clone https://github.com/sphinx-gallery/sphinx-gallery.git;
- cd sphinx-gallery && pip install -r requirements.txt && python setup.py develop;
- - cd /home/ubuntu && git clone https://github.com/enthought/pyface.git && cd pyface && python setup.py develop;
- - pip install sphinx_bootstrap_theme PySurfer nilearn neo;
+ - pip install sphinx_bootstrap_theme git+git://github.com/nipy/PySurfer.git nilearn neo numpydoc;
override:
+ # Figure out if we should run a full, pattern, or noplot version
- cd /home/ubuntu/mne-python && python setup.py develop;
- - if [ "$CIRCLE_BRANCH" == "master" ]; then
- mkdir -p ~/mne_data;
+ - git branch -a
+ - PATTERN="";
+ if [ "$CIRCLE_BRANCH" == "master" ] || [[ `git log -1 --pretty=%B` == *"[circle full]"* ]]; then
+ echo html_dev > build.txt;
+ elif [ "$CIRCLE_BRANCH" == "maint/0.13" ]; then
+ echo html_stable > build.txt;
+ else
+ FNAMES=$(git diff --name-only $CIRCLE_BRANCH $(git merge-base $CIRCLE_BRANCH origin/master));
+ echo FNAMES="$FNAMES";
+ for FNAME in $FNAMES; do
+ if [[ `expr match $FNAME "\(tutorials\|examples\)/.*plot_.*\.py"` ]] ; then
+ echo "Checking example $FNAME ...";
+ PATTERN=`basename $FNAME`"\\|"$PATTERN;
+ if [[ $(cat $FNAME | grep -x ".*datasets.*sample.*" | wc -l) -gt 0 ]]; then
+ python -c "import mne; print(mne.datasets.sample.data_path())";
+ fi;
+ if [[ $(cat $FNAME | grep -x ".*brainstorm.*bst_auditory.*" | wc -l) -gt 0 ]]; then
+ python -c "import mne; print(mne.datasets.brainstorm.bst_auditory.data_path())" --accept-brainstorm-license;
+ fi;
+ if [[ $(cat $FNAME | grep -x ".*brainstorm.*bst_raw.*" | wc -l) -gt 0 ]]; then
+ python -c "import mne; print(mne.datasets.brainstorm.bst_raw.data_path())" --accept-brainstorm-license;
+ fi;
+ if [[ $(cat $FNAME | grep -x ".*brainstorm.*bst_phantom_ctf.*" | wc -l) -gt 0 ]]; then
+ python -c "import mne; print(mne.datasets.brainstorm.bst_phantom_ctf.data_path())" --accept-brainstorm-license;
+ fi;
+ if [[ $(cat $FNAME | grep -x ".*brainstorm.*bst_phantom_elekta.*" | wc -l) -gt 0 ]]; then
+ python -c "import mne; print(mne.datasets.brainstorm.bst_phantom_elekta.data_path())" --accept-brainstorm-license;
+ fi;
+ fi;
+ done;
+ echo PATTERN="$PATTERN";
+ echo NEED_SAMPLE="$NEED_SAMPLE";
+ if [[ $PATTERN ]]; then
+ PATTERN="\(${PATTERN::-2}\)";
+ echo html_dev-pattern > build.txt;
+ else
+ echo html_dev-noplot > build.txt;
+ fi;
+ fi;
+ echo "$PATTERN" > pattern.txt;
+ - echo BUILD="$(cat build.txt)"
+ - mkdir -p ~/mne_data;
+ - ls -al ~/mne_data;
+ - if [[ $(cat build.txt) == "html_dev" ]] || [[ $(cat build.txt) == "html_stable" ]]; then
python -c "import mne; mne.datasets._download_all_example_data()";
- fi
+ fi;
- python -c "import mne; mne.sys_info()";
- >
if [ ! -d "/home/ubuntu/mne-tools.github.io" ]; then
@@ -56,12 +107,7 @@ dependencies:
test:
override:
- - if [ "$CIRCLE_BRANCH" == "master" ]; then
- make test-doc;
- else
- cd doc && make html_dev-noplot;
- fi
- - if [ "$CIRCLE_BRANCH" == "master" ]; then cd doc && make html_dev; fi:
+ - if [[ $(cat build.txt) == "html_dev-noplot" ]]; then cd doc && make html_dev-noplot; elif [[ $(cat build.txt) == "html_dev-pattern" ]]; then cd doc && PATTERN=$(cat ../pattern.txt) make html_dev-pattern; else make test-doc; cd doc; make $(cat ../build.txt); fi:
timeout: 1500
general:
@@ -81,3 +127,11 @@ deployment:
- cd ../mne-tools.github.io && git checkout master && git pull origin master
- cd doc/_build/html && cp -rf * ~/mne-tools.github.io/dev
- cd ../mne-tools.github.io && git add -A && git commit -m 'Automated update of dev docs.' && git push origin master
+ stable:
+ branch: maint/0.12
+ commands:
+ - git config --global user.email "circle at mne.com"
+ - git config --global user.name "Circle Ci"
+ - cd ../mne-tools.github.io && git checkout master && git pull origin master
+ - cd doc/_build/html_stable && cp -rf * ~/mne-tools.github.io/stable
+ - cd ../mne-tools.github.io && git add -A && git commit -m 'Automated update of stable docs.' && git push origin master
diff --git a/codecov.yml b/codecov.yml
new file mode 100644
index 0000000..1e7b888
--- /dev/null
+++ b/codecov.yml
@@ -0,0 +1,14 @@
+coverage:
+ precision: 2
+ round: down
+ range: "70...100"
+ status:
+ project:
+ default:
+ target: auto
+ threshold: 0.01
+ patch: false
+ changes: false
+comment:
+ layout: "header, diff, sunburst, uncovered"
+ behavior: default
diff --git a/doc/Makefile b/doc/Makefile
index d236bd1..6bb71a3 100644
--- a/doc/Makefile
+++ b/doc/Makefile
@@ -49,7 +49,7 @@ html_dev:
@echo "Build finished. The HTML pages are in _build/html"
html_dev-pattern:
- BUILD_DEV_HTML=1 $(SPHINXBUILD) -D plot_gallery=1 -D sphinx_gallery_conf.filename_pattern=$(PATTERN) -b html $(ALLSPHINXOPTS) _build/html
+ BUILD_DEV_HTML=1 $(SPHINXBUILD) -D plot_gallery=1 -D raise_gallery=1 -D abort_on_example_error=1 -D sphinx_gallery_conf.filename_pattern=$(PATTERN) -b html $(ALLSPHINXOPTS) _build/html
@echo
@echo "Build finished. The HTML pages are in _build/html"
@@ -113,3 +113,8 @@ doctest:
$(SPHINXBUILD) -b doctest $(ALLSPHINXOPTS) _build/doctest
@echo "Testing of doctests in the sources finished, look at the " \
"results in _build/doctest/output.txt."
+
+view:
+ @python -c "import webbrowser; webbrowser.open_new_tab('file://$(PWD)/_build/html/index.html')"
+
+show: view
diff --git a/doc/_static/style.css b/doc/_static/style.css
index e1e9e25..a1da6d9 100644
--- a/doc/_static/style.css
+++ b/doc/_static/style.css
@@ -133,3 +133,6 @@ dt:target code {
background-color: inherit !important;
}
+.label { /* Necessary for multiple refs, from bootstrap.min.css:7 */
+ color: #2c3e50;
+}
diff --git a/doc/_templates/class.rst b/doc/_templates/class.rst
index 6e17cfa..11b72d5 100644
--- a/doc/_templates/class.rst
+++ b/doc/_templates/class.rst
@@ -4,9 +4,9 @@
.. currentmodule:: {{ module }}
.. autoclass:: {{ objname }}
+ :special-members: __contains__,__getitem__,__iter__,__len__,__add__,__sub__,__mul__,__div__,__neg__,__hash__
{% block methods %}
- .. automethod:: __init__
{% endblock %}
diff --git a/doc/cite.rst b/doc/cite.rst
index 393e7a2..3c4e0df 100644
--- a/doc/cite.rst
+++ b/doc/cite.rst
@@ -3,7 +3,7 @@
How to cite MNE
---------------
-If you use in your research the implementations provided by the MNE software you should cite:
+If you use the implementations provided by the MNE software in your research, you should cite:
[1] A. Gramfort, M. Luessi, E. Larson, D. Engemann, D. Strohmeier, C. Brodbeck, L. Parkkonen, M. Hämäläinen, `MNE software for processing MEG and EEG data <http://www.ncbi.nlm.nih.gov/pubmed/24161808>`_, NeuroImage, Volume 86, 1 February 2014, Pages 446-460, ISSN 1053-8119, `[DOI] <http://dx.doi.org/10.1016/j.neuroimage.2013.10.027>`__
diff --git a/doc/conf.py b/doc/conf.py
index 481f1a0..8f93729 100644
--- a/doc/conf.py
+++ b/doc/conf.py
@@ -26,6 +26,8 @@ sys.path.append(os.path.abspath(os.path.join(curdir, '..', 'mne')))
sys.path.append(os.path.abspath(os.path.join(curdir, 'sphinxext')))
import mne
+if not os.path.isdir('_images'):
+ os.mkdir('_images')
# -- General configuration ------------------------------------------------
@@ -34,7 +36,9 @@ import mne
# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom ones.
-import numpy_ext.numpydoc
+from numpydoc import numpydoc, docscrape
+docscrape.ClassDoc.extra_public_methods = mne.utils._doc_special_members
+
extensions = [
'sphinx.ext.autodoc',
'sphinx.ext.autosummary',
@@ -46,7 +50,7 @@ extensions = [
'sphinx_gallery.gen_gallery',
]
-extensions += ['numpy_ext.numpydoc']
+extensions += ['numpydoc']
extensions += ['gen_commands'] # auto generate the doc for the python commands
# extensions += ['flow_diagram] # generate flow chart in cookbook
@@ -264,7 +268,11 @@ latex_use_parts = True
trim_doctests_flags = True
# Example configuration for intersphinx: refer to the Python standard library.
-intersphinx_mapping = {'http://docs.python.org/': None}
+intersphinx_mapping = {
+ 'python': ('http://docs.python.org/', None),
+ 'numpy': ('http://docs.scipy.org/doc/numpy-dev/', None),
+ 'scipy': ('http://scipy.github.io/devdocs/', None),
+}
examples_dirs = ['../examples', '../tutorials']
gallery_dirs = ['auto_examples', 'auto_tutorials']
@@ -290,4 +298,7 @@ sphinx_gallery_conf = {
'gallery_dirs': gallery_dirs,
'find_mayavi_figures': find_mayavi_figures,
'default_thumb_file': os.path.join('_static', 'mne_helmet.png'),
+ 'mod_example_dir': 'generated',
}
+
+numpydoc_class_members_toctree = False
diff --git a/doc/contributing.rst b/doc/contributing.rst
index 0656389..47383c3 100644
--- a/doc/contributing.rst
+++ b/doc/contributing.rst
@@ -5,7 +5,7 @@ Contribute to MNE
.. contents:: Contents
:local:
- :depth: 2
+ :depth: 1
.. We want to thank all MNE Software users at the Martinos Center and
.. in other institutions for their collaboration during the creation
@@ -26,12 +26,12 @@ What you will need
------------------
#. A good python editor: Atom_ and `Sublime Text`_ are modern general-purpose
-text editors and are available on all three major platforms. Both provide
-plugins that facilitate editing python code and help avoid bugs and style errors.
-See for example linterflake8_ for Atom_.
-The Spyder_ IDE is espectially suitable for those migrating from Matlab.
-EPD_ and Anaconda_ both ship Spyder and all its dependencies.
-As always, Vim or Emacs will suffice as well.
+ text editors and are available on all three major platforms. Both provide
+ plugins that facilitate editing python code and help avoid bugs and style
+ errors. See for example linterflake8_ for Atom_.
+ The Spyder_ IDE is especially suitable for those migrating from Matlab.
+ EPD_ and Anaconda_ both ship Spyder and all its dependencies.
+ As always, Vim or Emacs will suffice as well.
#. Basic scientific tools in python: numpy_, scipy_, matplotlib_
@@ -46,6 +46,9 @@ As always, Vim or Emacs will suffice as well.
system. If you are on Windows, you can install these applications inside a
Unix virtual machine.
+#. Documentation building packages ``numpydoc``, ``sphinx_bootstrap_theme`` and
+ ``sphinx_gallery``.
+
General code guidelines
-----------------------
@@ -53,17 +56,21 @@ General code guidelines
`pyflakes`_, such as `Spyder`_. Standard python style guidelines are
followed, with very few exceptions.
- You can also manually check pyflakes and pep8 warnings as::
+ You can also manually check pyflakes and pep8 warnings as:
+
+ .. code-block:: bash
- pip install pyflakes
- pip install pep8
- pyflakes path/to/module.py
- pep8 path/to/module.py
+ $ pip install pyflakes
+ $ pip install pep8
+ $ pyflakes path/to/module.py
+ $ pep8 path/to/module.py
- AutoPEP8 can then help you fix some of the easy redundant errors::
+ AutoPEP8 can then help you fix some of the easy redundant errors:
- pip install autopep8
- autopep8 path/to/pep8.py
+ .. code-block:: bash
+
+ $ pip install autopep8
+ $ autopep8 path/to/pep8.py
* mne-python adheres to the same docstring formatting as seen on
`numpy style`_.
@@ -79,26 +86,34 @@ General code guidelines
The ambition is to achieve around 85% coverage with tests.
* After changes have been made, **ensure all tests pass**. This can be done
- by running the following from the ``mne-python`` root directory::
+ by running the following from the ``mne-python`` root directory:
+
+ .. code-block:: bash
- make
+ $ make
- To run individual tests, you can also run any of the following::
+ To run individual tests, you can also run any of the following:
- make clean
- make inplace
- make test-doc
- make inplace
- nosetests
+ .. code-block:: bash
+
+ $ make clean
+ $ make inplace
+ $ make test-doc
+ $ make inplace
+ $ nosetests
To explicitly download and extract the mne-python testing dataset (~320 MB)
- run::
+ run:
+
+ .. code-block:: bash
make testing_data
- Alternatively::
+ Alternatively:
- python -c "import mne; mne.datasets.testing.data_path(verbose=True)"
+ .. code-block:: bash
+
+ $ python -c "import mne; mne.datasets.testing.data_path(verbose=True)"
downloads the test data as well. Having a complete testing dataset is
necessary for running the tests. To run the examples you'll need
@@ -113,30 +128,60 @@ General code guidelines
>>> run_tests_if_main()
For more details see troubleshooting_.
-
+
* Update relevant documentation. Update :doc:`whats_new.rst <whats_new>` for new features and :doc:`python_reference.rst <python_reference>` for new classes and standalone functions. :doc:`whats_new.rst <whats_new>` is organized in chronological order with the last feature at the end of the document.
- To ensure that these files were rendered correctly, run the following command::
+Checking and building documentation
+-----------------------------------
- make html-noplot
+All changes to the codebase must be properly documented.
+To ensure that documentation is rendered correctly, the best bet is to
+follow the existing examples for class and function docstrings,
+and examples and tutorials.
- This will build the docs without building all the examples, which can save some time.
+Our documentation (including docstring in code) uses ReStructuredText format,
+see `Sphinx documentation`_ to learn more about editing them. Our code
+follows the `NumPy docstring standard`_.
+To test documentation locally, you will need to install (e.g., via ``pip``):
-More mne-python specific guidelines
------------------------------------
+ * sphinx
+ * sphinx-gallery
+ * sphinx_bootstrap_theme
+ * numpydoc
+
+Then to build the documentation locally, within the ``mne/doc`` directory do:
+
+.. code-block:: bash
+
+ $ make html-noplot
+
+This will build the docs without building all the examples, which can save
+some time. If you are working on examples or tutorials, you can build
+specific examples with e.g.:
+
+.. code-block:: bash
+
+ $ PATTERN=plot_background_filtering.py make html_dev-pattern
+
+Consult the `sphinx gallery documentation`_ for more details.
+
+MNE-Python specific coding guidelines
+-------------------------------------
* Please, ideally address one and only one issue per pull request (PR).
* Avoid unnecessary cosmetic changes if they are not the goal of the PR, this will help keep the diff clean and facilitate reviewing.
* Use underscores to separate words in non class names: n_samples rather than nsamples.
* Use CamelCase for class names.
* Use relative imports for references inside mne-python.
-* Use nested imports for ``matplotlib``, ``sklearn``, and ``pandas``.
+* Use nested imports (i.e., within a function or method instead of at the top of a file) for ``matplotlib``, ``sklearn``, and ``pandas``.
* Use ``RdBu_r`` colormap for signed data and ``Reds`` for unsigned data in visualization functions and examples.
* All visualization functions must accept a ``show`` parameter and return a ``fig`` handle.
-* Efforts to improve test timing without decreasing coverage is well appreciated. To see the top-30 tests in order of decreasing timing, run the following command::
+* Efforts to improve test timing without decreasing coverage is well appreciated. To see the top-30 tests in order of decreasing timing, run the following command:
+
+ .. code-block:: bash
- nosetests --with-timer --timer-top-n 30
+ $ nosetests --with-timer --timer-top-n 30
* Instance methods that update the state of the object should return self.
* Use single quotes whenever possible.
@@ -165,19 +210,16 @@ improvements to the documentation, or new functionality, can be done via
adapted for our use here!]
The only absolutely necessary configuration step is identifying yourself and
-your contact info::
+your contact info:
- git config --global user.name "Your Name"
- git config --global user.email you at yourdomain.example.com
+.. code-block:: bash
+
+ $ git config --global user.name "Your Name"
+ $ git config --global user.email you at yourdomain.example.com
If you are going to :ref:`setup-github` eventually, this email address should
be the same as the one used to sign up for a GitHub account. For more
-information about configuring your git installation, see:
-
-.. toctree::
- :maxdepth: 1
-
- customizing_git
+information about configuring your git installation, see :ref:`customizing-git`.
The following sections cover the installation of the git software, the basic
configuration, and links to resources to learn more about using git.
@@ -185,7 +227,7 @@ However, you can also directly go to the `GitHub help pages
<https://help.github.com/>`_ which offer a great introduction to git and
GitHub.
-In the present document, we refer to the mne-python ``master`` branch, as the
+In the present document, we refer to the MNE-Python ``master`` branch, as the
*trunk*.
.. _forking:
@@ -229,25 +271,33 @@ in principle also fork a different one, such as ``mne-matlab```):
Setting up the fork and the working directory
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
-Briefly, this is done using::
+Briefly, this is done using:
+
+.. code-block:: bash
- git clone git at github.com:your-user-name/mne-python.git
- cd mne-python
- git remote add upstream git://github.com/mne-tools/mne-python.git
+ $ git clone git at github.com:your-user-name/mne-python.git
+ $ cd mne-python
+ $ git remote add upstream git://github.com/mne-tools/mne-python.git
These steps can be broken out to be more explicit as:
-#. Clone your fork to the local computer::
+#. Clone your fork to the local computer:
- git clone git at github.com:your-user-name/mne-python.git
+ .. code-block:: bash
-#. Change directory to your new repo::
+ $ git clone git at github.com:your-user-name/mne-python.git
- cd mne-python
+#. Change directory to your new repo:
- Then type::
+ .. code-block:: bash
- git branch -a
+ $ cd mne-python
+
+ Then type:
+
+ .. code-block:: bash
+
+ $ git branch -a
to show you all branches. You'll get something like::
@@ -260,10 +310,12 @@ These steps can be broken out to be more explicit as:
see the URLs for the remote. They will point to your GitHub fork.
Now you want to connect to the mne-python repository, so you can
- merge in changes from the trunk::
+ merge in changes from the trunk:
+
+ .. code-block:: bash
- cd mne-python
- git remote add upstream git://github.com/mne-tools/mne-python.git
+ $ cd mne-python
+ $ git remote add upstream git://github.com/mne-tools/mne-python.git
``upstream`` here is just the arbitrary name we're using to refer to the
main mne-python_ repository.
@@ -286,9 +338,11 @@ These steps can be broken out to be more explicit as:
#. Install mne with editing permissions to the installed folder:
To be able to conveniently edit your files after installing mne-python,
- install using the following setting::
+ install using the following setting:
- $ python setup.py develop --user
+ .. code-block:: bash
+
+ $ python setup.py develop --user
To make changes in the code, edit the relevant files and restart the
ipython kernel for changes to take effect.
@@ -297,9 +351,11 @@ These steps can be broken out to be more explicit as:
Make sure before starting to code that all unit tests pass and the
html files in the ``doc/`` directory can be built without errors. To build
- the html files, first go the ``doc/`` directory and then type::
+ the html files, first go the ``doc/`` directory and then type:
+
+ .. code-block:: bash
- $ make html
+ $ make html
Once it is compiled for the first time, subsequent compiles will only
recompile what has changed. That's it! You are now ready to hack away.
@@ -349,9 +405,11 @@ details.
Updating the mirror of trunk
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
-From time to time you should fetch the upstream (trunk) changes from GitHub::
+From time to time you should fetch the upstream (trunk) changes from GitHub:
+
+.. code-block:: bash
- git fetch upstream
+ $ git fetch upstream
This will pull down any commits you don't have, and set the remote branches to
point to the right commit. For example, 'trunk' is the branch referred to by
@@ -374,27 +432,31 @@ Choose an informative name for the branch to remind yourself and the rest of
us what the changes in the branch are for. For example ``add-ability-to-fly``,
or ``buxfix-for-issue-42``.
-::
+.. code-block:: bash
- # Update the mirror of trunk
- git fetch upstream
+ # Update the mirror of trunk
+ $ git fetch upstream
- # Make new feature branch starting at current trunk
- git branch my-new-feature upstream/master
- git checkout my-new-feature
+ # Make new feature branch starting at current trunk
+ $ git branch my-new-feature upstream/master
+ $ git checkout my-new-feature
Generally, you will want to keep your feature branches on your public GitHub_
fork. To do this, you `git push`_ this new branch up to your
github repo. Generally (if you followed the instructions in these pages, and
by default), git will have a link to your GitHub repo, called ``origin``. You
-push up to your own repo on GitHub with::
+push up to your own repo on GitHub with:
+
+.. code-block:: bash
- git push origin my-new-feature
+ $ git push origin my-new-feature
In git > 1.7 you can ensure that the link is correctly set by using the
-``--set-upstream`` option::
+``--set-upstream`` option:
- git push --set-upstream origin my-new-feature
+.. code-block:: bash
+
+ $ git push --set-upstream origin my-new-feature
From now on git will know that ``my-new-feature`` is related to the
``my-new-feature`` branch in the GitHub repo.
@@ -407,11 +469,11 @@ The editing workflow
Overview
^^^^^^^^
-::
+.. code-block:: bash
- git add my_new_file
- git commit -am 'FIX: some message'
- git push
+ $ git add my_new_file
+ $ git commit -am 'FIX: some message'
+ $ git push
In more detail
^^^^^^^^^^^^^^
@@ -506,19 +568,19 @@ Some other things you might want to do
Delete a branch on GitHub
^^^^^^^^^^^^^^^^^^^^^^^^^
-::
+.. code-block:: bash
# change to the master branch (if you still have one, otherwise change to another branch)
- git checkout master
+ $ git checkout master
# delete branch locally
- git branch -D my-unwanted-branch
+ $ git branch -D my-unwanted-branch
# delete branch on GitHub
- git push origin :my-unwanted-branch
+ $ git push origin :my-unwanted-branch
(Note the colon ``:`` before ``test-branch``. See also:
-https://help.github.com/remotes
+https://help.github.com/remotes)
Several people sharing a single repository
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
@@ -537,30 +599,38 @@ collaborator:
.. image:: _static/pull_button.png
-Now all those people can do::
+Now all those people can do:
+
+.. code-block:: bash
- git clone git at githhub.com:your-user-name/mne-python.git
+ $ git clone git at githhub.com:your-user-name/mne-python.git
Remember that links starting with ``git@`` use the ssh protocol and are
read-write; links starting with ``git://`` are read-only.
Your collaborators can then commit directly into that repo with the
-usual::
+usual:
- git commit -am 'ENH: much better code'
- git push origin master # pushes directly into your repo
+.. code-block:: bash
+
+ $ git commit -am 'ENH: much better code'
+ $ git push origin master # pushes directly into your repo
Explore your repository
^^^^^^^^^^^^^^^^^^^^^^^
To see a graphical representation of the repository branches and
-commits::
+commits:
+
+.. code-block:: bash
- gitk --all
+ $ gitk --all
-To see a linear list of commits for this branch::
+To see a linear list of commits for this branch:
- git log
+.. code-block:: bash
+
+ $ git log
You can also look at the `network graph visualizer`_ for your GitHub
repo.
@@ -604,28 +674,34 @@ your history will look like this::
See `rebase without tears`_ for more detail.
-To do a rebase on trunk::
+To do a rebase on trunk:
+
+.. code-block:: bash
# Update the mirror of trunk
- git fetch upstream
+ $ git fetch upstream
# Go to the feature branch
- git checkout cool-feature
+ $ git checkout cool-feature
# Make a backup in case you mess up
- git branch tmp cool-feature
+ $ git branch tmp cool-feature
# Rebase cool-feature onto trunk
- git rebase --onto upstream/master upstream/master cool-feature
+ $ git rebase --onto upstream/master upstream/master cool-feature
In this situation, where you are already on branch ``cool-feature``, the last
-command can be written more succinctly as::
+command can be written more succinctly as:
- git rebase upstream/master
+.. code-block:: bash
-When all looks good you can delete your backup branch::
+ $ git rebase upstream/master
- git branch -D tmp
+When all looks good you can delete your backup branch:
+
+.. code-block:: bash
+
+ $ git branch -D tmp
If it doesn't look good you may need to have a look at
:ref:`recovering-from-mess-up`.
@@ -639,9 +715,11 @@ merge`_.
If your feature branch is already on GitHub and you rebase, you will have to force
push the branch; a normal push would give an error. If the branch you rebased is
called ``cool-feature`` and your GitHub fork is available as the remote called ``origin``,
-you use this command to force-push::
+you use this command to force-push:
- git push -f origin cool-feature
+.. code-block:: bash
+
+ $ git push -f origin cool-feature
Note that this will overwrite the branch on GitHub, i.e. this is one of the few ways
you can actually lose commits with git.
@@ -657,19 +735,25 @@ Recovering from mess-ups
Sometimes, you mess up merges or rebases. Luckily, in git it is relatively
straightforward to recover from such mistakes.
-If you mess up during a rebase::
+If you mess up during a rebase:
+
+.. code-block:: bash
- git rebase --abort
+ $ git rebase --abort
-If you notice you messed up after the rebase::
+If you notice you messed up after the rebase:
+
+.. code-block:: bash
# Reset branch back to the saved point
- git reset --hard tmp
+ $ git reset --hard tmp
+
+If you forgot to make a backup branch:
-If you forgot to make a backup branch::
+.. code-block:: bash
# Look at the reflog of the branch
- git reflog show cool-feature
+ $ git reflog show cool-feature
8630830 cool-feature@{0}: commit: BUG: io: close file handles immediately
278dd2a cool-feature@{1}: rebase finished: refs/heads/my-feature-branch onto 11ee694744f2552d
@@ -677,7 +761,7 @@ If you forgot to make a backup branch::
...
# Reset the branch to where it was before the botched rebase
- git reset --hard cool-feature@{2}
+ $ git reset --hard cool-feature@{2}
Otherwise, googling the issue may be helpful (especially links to Stack
Overflow).
@@ -696,9 +780,11 @@ made several false starts you would like the posterity not to see.
This can be done via *interactive rebasing*.
-Suppose that the commit history looks like this::
+Suppose that the commit history looks like this:
+
+.. code-block:: bash
- git log --oneline
+ $ git log --oneline
eadc391 Fix some remaining bugs
a815645 Modify it so that it works
2dec1ac Fix a few bugs + disable
@@ -713,12 +799,14 @@ want to make the following changes:
* Rewrite the commit message for ``13d7934`` to something more sensible.
* Combine the commits ``2dec1ac``, ``a815645``, ``eadc391`` into a single one.
-We do as follows::
+We do as follows:
+
+.. code-block:: bash
# make a backup of the current state
- git branch tmp HEAD
+ $ git branch tmp HEAD
# interactive rebase
- git rebase -i 6ad92e5
+ $ git rebase -i 6ad92e5
This will open an editor with the following text in it::
@@ -773,14 +861,18 @@ Fetching a pull request
^^^^^^^^^^^^^^^^^^^^^^^
To fetch a pull request on the main repository to your local working
-directory as a new branch, just do::
+directory as a new branch, just do:
+
+.. code-block:: bash
- git fetch upstream pull/<pull request number>/head:<local-branch>
+ $ git fetch upstream pull/<pull request number>/head:<local-branch>
As an example, to pull the realtime pull request which has a url
-``https://github.com/mne-tools/mne-python/pull/615/``, do::
+``https://github.com/mne-tools/mne-python/pull/615/``, do:
- git fetch upstream pull/615/head:realtime
+.. code-block:: bash
+
+ $ git fetch upstream pull/615/head:realtime
If you want to fetch a pull request to your own fork, replace
``upstream`` with ``origin``. That's it!
@@ -794,33 +886,6 @@ The builds when the pull request is in `WIP` state can be safely skipped. The im
This will help prevent clogging up Travis and Appveyor and also save the environment.
-Documentation
--------------
-
-Adding an example to example gallery
-^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
-
-Add the example to the correct subfolder in the ``examples/`` directory and
-prefix the file with ``plot_``. To make sure that the example renders correctly,
-run ``make html`` in the ``doc/`` folder
-
-Building a subset of examples
-^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
-
-To build only a subset of examples, it is possible to provide a regular expression
-which searches on the full pathname of the file. For example, you can do::
-
- make html_dev-pattern PATTERN='/decoding/plot_'
-
-It will run only the examples in the ``decoding`` folder. Consult the `sphinx gallery documentation`_
-for more details.
-
-Editing \*.rst files
-^^^^^^^^^^^^^^^^^^^^
-
-These are reStructuredText files. Consult the `Sphinx documentation`_ to learn
-more about editing them.
-
.. _troubleshooting:
Troubleshooting
@@ -845,12 +910,15 @@ restart the ipython kernel.
ICE default IO error handler doing an exit()
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
-If the make test command fails with the error ``ICE default IO error
-handler doing an exit()``, try backing up or removing .ICEauthority::
+If the make test command fails with the error
+``ICE default IO error handler doing an exit()``, try backing up or removing
+.ICEauthority:
- mv ~/.ICEauthority ~/.ICEauthority.bak
+.. code-block:: bash
-.. include:: links.inc
+ $ mv ~/.ICEauthority ~/.ICEauthority.bak
+.. include:: links.inc
.. _Sphinx documentation: http://sphinx-doc.org/rest.html
-.. _sphinx gallery documentation: http://sphinx-gallery.readthedocs.org/en/latest/advanced_configuration.html
\ No newline at end of file
+.. _sphinx gallery documentation: http://sphinx-gallery.readthedocs.org/en/latest/advanced_configuration.html
+.. _NumPy docstring standard: https://github.com/numpy/numpy/blob/master/doc/HOWTO_DOCUMENT.rst.txt
diff --git a/doc/faq.rst b/doc/faq.rst
index 9424b1c..469e669 100644
--- a/doc/faq.rst
+++ b/doc/faq.rst
@@ -84,8 +84,7 @@ If you want to write your own data to disk (e.g., subject behavioral
scores), we strongly recommend using `h5io <https://github.com/h5io/h5io>`_,
which is based on the
`HDF5 format <https://en.wikipedia.org/wiki/Hierarchical_Data_Format>`_ and
-`h5py <http://www.h5py.org/>`_,
-to save data in a fast, future-compatible, standard format.
+h5py_, to save data in a fast, future-compatible, standard format.
Resampling and decimating data
diff --git a/doc/getting_started.rst b/doc/getting_started.rst
index ada2ab3..06d0648 100644
--- a/doc/getting_started.rst
+++ b/doc/getting_started.rst
@@ -24,7 +24,7 @@ visualization, and analysis.
can read and convert CTF, BTI/4D, KIT and various EEG formats to
FIF (see :ref:`IO functions <ch_convert>`).
- If you have being using MNE-C, there is no need to convert your fif
+ If you have been using MNE-C, there is no need to convert your fif
files to a new system or database -- MNE-Python works nicely with
the historical fif files.
diff --git a/doc/index.rst b/doc/index.rst
index 8f7c4a5..88966cf 100644
--- a/doc/index.rst
+++ b/doc/index.rst
@@ -1,5 +1,7 @@
.. title:: MNE
+.. include:: links.inc
+
.. raw:: html
<div class="container"><div class="row">
@@ -32,11 +34,10 @@ providing comprehensive tools and workflows for
6. Applying machine learning algorithms
7. Visualization of sensor- and source-space data
-MNE includes a comprehensive `Python <https://www.python.org/>`_ package
-supplemented by tools compiled from C code for the LINUX and Mac OSX
-operating systems, as well as a MATLAB toolbox.
+MNE includes a comprehensive Python_ package supplemented by tools compiled
+from C code for the LINUX and Mac OSX operating systems, as well as a MATLAB toolbox.
-**From raw data to source estimates in about 30 lines of code** (:ref:`try it yourself! <getting_started>`):
+**From raw data to source estimates in about 30 lines of code** (Try it :ref:`by installing it <getting_started>` or `in an experimental online demo <http://mybinder.org/repo/mne-tools/mne-binder/notebooks/plot_introduction.ipynb>`_!):
.. code:: python
@@ -82,6 +83,10 @@ Direct financial support for the project has been provided by:
Institute.
- (FR) IDEX Paris-Saclay, ANR-11-IDEX-0003-02, via the
`Center for Data Science <http://www.datascience-paris-saclay.fr/>`_.
+- (FR) European Research Council (ERC) Starting Grant (ERC-YStG-263584).
+- (FR) French National Research Agency (ANR-14-NEUC-0002-01).
+- (FR) European Research Council (ERC) Starting Grant (ERC-YStG-676943).
+- Amazon Web Services - Research Grant issued to Denis A. Engemann
.. raw:: html
@@ -123,7 +128,7 @@ Direct financial support for the project has been provided by:
<h2>Community</h2>
-* `Analysis talk: join the MNE mailing list <MNE mailing list>`_
+* Analysis talk: join the `MNE mailing list`_
* `Feature requests and bug reports on GitHub <https://github.com/mne-tools/mne-python/issues/>`_
diff --git a/doc/manual/appendix/bem_model.rst b/doc/manual/appendix/bem_model.rst
index 1e13f48..8fdb0eb 100644
--- a/doc/manual/appendix/bem_model.rst
+++ b/doc/manual/appendix/bem_model.rst
@@ -78,9 +78,9 @@ following steps:
- Inspecting the meshes with tkmedit, see :ref:`BABHJBED`.
.. note:: Different methods can be employed for the creation of the
- individual surfaces. For example, it may turn out that the
+ individual surfaces. For example, it may turn out that the
watershed algorithm produces are better quality skin surface than
- the segmentation approach based on the FLASH images. If this is
+ the segmentation approach based on the FLASH images. If this is
the case, ``outer_skin.surf`` can set to point to the corresponding
watershed output file while the other surfaces can be picked from
the FLASH segmentation data.
@@ -159,6 +159,12 @@ Before running mne_flash_bem do the following:
- ``ln -s`` <*FLASH 30 series dir*> ``flash30``
+- Some partition formats (e.g. FAT32) do not support symbolic links. In this case, copy the file to the appropriate series:
+
+ - ``cp`` <*FLASH 5 series dir*> ``flash05``
+
+ - ``cp`` <*FLASH 30 series dir*> ``flash30``
+
- Set the ``SUBJECTS_DIR`` and ``SUBJECT`` environment
variables
diff --git a/doc/manual/cookbook.rst b/doc/manual/cookbook.rst
index aca9131..fab76e1 100644
--- a/doc/manual/cookbook.rst
+++ b/doc/manual/cookbook.rst
@@ -384,8 +384,6 @@ ways:
stationary with respect to background brain activity. This can also
use :func:`mne.compute_raw_covariance`.
-See :ref:`covariance` for more information.
-
.. _CIHCFJEI:
Calculating the inverse operator
diff --git a/doc/manual/decoding.rst b/doc/manual/decoding.rst
index ddd584c..b1a0baf 100644
--- a/doc/manual/decoding.rst
+++ b/doc/manual/decoding.rst
@@ -27,6 +27,10 @@ Scikit-learn API enforces the requirement that data arrays must be 2D. A common
To recover the original 3D data, an ``inverse_transform`` can be used. The ``epochs_vectorizer`` is particularly useful when constructing a pipeline object (used mainly for parameter search and cross validation). The ``epochs_vectorizer`` is the first estimator in the pipeline enabling estimators downstream to be more advanced estimators implemented in Scikit-learn.
+Vectorizer
+^^^^^^^^^^
+Scikit-learn API provides functionality to chain transformers and estimators by using :class:`sklearn.pipeline.Pipeline`. We can construct decoding pipelines and perform cross-validation and grid-search. However scikit-learn transformers and estimators generally expect 2D data (n_samples * n_features), whereas MNE transformers typically output data with a higher dimensionality (e.g. n_samples * n_channels * n_frequencies * n_times). A Vectorizer therefore needs to be applied between the [...]
+
PSDEstimator
^^^^^^^^^^^^
This estimator computes the power spectral density (PSD) using the multitaper method. It takes a 3D array as input, it into 2D and computes the PSD.
diff --git a/doc/manual/io.rst b/doc/manual/io.rst
index 16d6141..1352a0a 100644
--- a/doc/manual/io.rst
+++ b/doc/manual/io.rst
@@ -46,9 +46,11 @@ Neuromag Raw FIF files can be loaded using :func:`mne.io.read_raw_fif`.
``mne.io.read_raw_fif(..., allow_maxshield=True)``.
.. note::
- This file format also supports EEG data. An average reference will be added
- by default on reading EEG data. To change this behavior call the readers
- like this: ``mne.io.read_raw_fif(..., add_eeg_ref=False)``
+ This file format also supports EEG data. In 0.13, an average reference
+ will be added by default on reading EEG data. To change this behavior,
+ use the argument ``add_eeg_ref=False``, which will become the default
+ in 0.14. The argument will be removed in 0.15 in favor of
+ :func:`mne.set_eeg_reference` and :meth:`mne.io.Raw.set_eeg_reference`.
Importing 4-D Neuroimaging / BTI data
@@ -349,6 +351,20 @@ to the fif format with help of the :ref:`mne_eximia2fiff` script.
It creates a BrainVision ``vhdr`` file and calls :ref:`mne_brain_vision2fiff`.
+Setting EEG references
+######################
+
+The preferred method for applying an EEG reference in MNE is
+:func:`mne.set_eeg_reference`, or equivalent instance methods like
+:meth:`raw.set_eeg_reference() <mne.io.Raw.set_eeg_reference>`. By default,
+an average reference is used. Instead of applying the average reference to
+the data directly, an average EEG reference projector is created that is
+applied like any other SSP projection operator.
+
+There are also other functions that can be useful for other referencing
+operations. See :func:`mne.set_bipolar_reference` and
+:func:`mne.add_reference_channels` for more information.
+
Reading Electrode locations and Headshapes for EEG recordings
#############################################################
@@ -373,7 +389,7 @@ Arbitrary (e.g., simulated or manually read in) raw data can be constructed
from memory by making use of :class:`mne.io.RawArray`, :class:`mne.EpochsArray`
or :class:`mne.EvokedArray` in combination with :func:`mne.create_info`.
-This functionality is illustrated in :ref:`example_io_plot_objects_from_arrays.py` .
+This functionality is illustrated in :ref:`sphx_glr_auto_examples_io_plot_objects_from_arrays.py`.
Using 3rd party libraries such as NEO (https://pythonhosted.org/neo/) in combination
with these functions abundant electrophysiological file formats can be easily loaded
into MNE.
diff --git a/doc/manual/pitfalls.rst b/doc/manual/pitfalls.rst
index b0e43e6..c2c400a 100644
--- a/doc/manual/pitfalls.rst
+++ b/doc/manual/pitfalls.rst
@@ -8,19 +8,6 @@
Pitfalls
########
-Evoked Arithmetic
-=================
-
-Two evoked objects can be contrasted using::
-
- >>> evoked = evoked_cond1 - evoked_cond2
-
-Note, however that the number of trials used to obtain the averages for
-``evoked_cond1`` and ``evoked_cond2`` are taken into account when computing
-``evoked``. That is, what you get is a weighted average, not a simple
-element-by-element subtraction. To do a uniform (not weighted) average, use
-the function :func:`mne.combine_evoked`.
-
Float64 vs float32
==================
diff --git a/doc/manual/source_localization/inverse.rst b/doc/manual/source_localization/inverse.rst
index 6e2d866..6e383e0 100644
--- a/doc/manual/source_localization/inverse.rst
+++ b/doc/manual/source_localization/inverse.rst
@@ -184,7 +184,7 @@ Using the UNIX tools :ref:`mne_inverse_operator`, the values
:math:`\varepsilon_k` can be adjusted with the regularization options
``--magreg`` , ``--gradreg`` , and ``--eegreg`` specified at the time of the
inverse operator decomposition, see :ref:`inverse_operator`. The convenience script
-:ref:`mne_do_inverse_solution` has the ``--magreg`` and ``--gradreg`` combined to
+:ref:`mne_do_inverse_operator` has the ``--magreg`` and ``--gradreg`` combined to
a single option, ``--megreg`` , see :ref:`CIHCFJEI`.
Suggested range of values for :math:`\varepsilon_k` is :math:`0.05 \dotso 0.2`.
diff --git a/doc/python_reference.rst b/doc/python_reference.rst
index 649d69d..2b7292f 100644
--- a/doc/python_reference.rst
+++ b/doc/python_reference.rst
@@ -4,10 +4,12 @@
Python API Reference
====================
-This is the classes and functions reference of mne-python. Functions are
+This is the classes and functions reference of MNE-Python. Functions are
grouped thematically by analysis stage. Functions and classes that are not
below a module heading are found in the :py:mod:`mne` namespace.
+MNE-Python also provides multiple command-line scripts that can be called
+directly from a terminal, see :ref:`python_commands`.
.. contents::
:local:
@@ -27,9 +29,11 @@ Classes
io.RawFIF
io.RawArray
Annotations
+ AcqParserFIF
Epochs
Evoked
SourceSpaces
+ Forward
SourceEstimate
VolSourceEstimate
MixedSourceEstimate
@@ -166,6 +170,7 @@ Functions:
read_source_spaces
read_surface
read_trans
+ read_tri
save_stc_as_volume
write_labels_to_annot
write_bem_solution
@@ -179,7 +184,7 @@ Functions:
write_source_spaces
write_surface
write_trans
-
+ io.read_info
Creating data objects from arrays
=================================
@@ -325,8 +330,10 @@ Functions:
plot_evoked_joint
plot_evoked_field
plot_evoked_white
+ plot_compare_evokeds
plot_ica_sources
plot_ica_components
+ plot_ica_properties
plot_ica_scores
plot_ica_overlay
plot_epochs_image
@@ -434,10 +441,11 @@ Functions:
maxwell_filter
read_ica
run_ica
+ corrmap
EEG referencing:
-.. currentmodule:: mne.io
+.. currentmodule:: mne
.. autosummary::
:toctree: generated/
@@ -461,6 +469,8 @@ EEG referencing:
band_pass_filter
construct_iir_filter
+ estimate_ringing_samples
+ filter_data
high_pass_filter
low_pass_filter
notch_filter
@@ -526,6 +536,7 @@ Events
combine_event_ids
equalize_epoch_counts
+
Sensor Space Data
=================
@@ -744,6 +755,18 @@ Functions:
fit_dipole
+:py:mod:`mne.dipole`:
+
+.. currentmodule:: mne.dipole
+
+Functions:
+
+.. autosummary::
+ :toctree: generated/
+ :template: function.rst
+
+ get_phantom_dipoles
+
Source Space Data
=================
@@ -793,6 +816,7 @@ Classes:
:template: class.rst
AverageTFR
+ EpochsTFR
Functions that operate on mne-python objects:
@@ -800,7 +824,7 @@ Functions that operate on mne-python objects:
:toctree: generated/
:template: function.rst
- compute_epochs_csd
+ csd_epochs
psd_welch
psd_multitaper
fit_iir_model_raw
@@ -816,10 +840,10 @@ Functions that operate on ``np.ndarray`` objects:
:toctree: generated/
:template: function.rst
+ csd_array
cwt_morlet
dpss_windows
morlet
- multitaper_psd
single_trial_power
stft
istft
diff --git a/doc/sphinxext/gen_commands.py b/doc/sphinxext/gen_commands.py
index 24d2f9b..cf26750 100644
--- a/doc/sphinxext/gen_commands.py
+++ b/doc/sphinxext/gen_commands.py
@@ -1,5 +1,7 @@
# -*- coding: utf-8 -*-
+from __future__ import print_function
+
import os
import glob
from os import path as op
@@ -53,8 +55,9 @@ def generate_commands_rst(app):
out_dir = op.abspath(op.join(op.dirname(__file__), '..', 'generated'))
out_fname = op.join(out_dir, 'commands.rst')
- command_path = op.join(os.path.dirname(__file__), '..', '..', 'mne', 'commands')
- print(command_path)
+ command_path = op.join(os.path.dirname(__file__), '..', '..', 'mne',
+ 'commands')
+ print('Generating commands for: %s ... ' % command_path, end='')
fnames = glob.glob(op.join(command_path, 'mne_*.py'))
with open(out_fname, 'w') as f:
@@ -63,9 +66,9 @@ def generate_commands_rst(app):
cmd_name = op.basename(fname)[:-3]
output, _ = run_subprocess(['python', fname, '--help'])
- f.write(command_rst % (cmd_name, cmd_name.replace('mne_', 'mne '), output))
-
- print('Done')
+ f.write(command_rst % (cmd_name, cmd_name.replace('mne_', 'mne '),
+ output))
+ print('[Done]')
# This is useful for testing/iterating to see what the result looks like
diff --git a/doc/sphinxext/numpy_ext/__init__.py b/doc/sphinxext/numpy_ext/__init__.py
deleted file mode 100644
index e69de29..0000000
diff --git a/doc/sphinxext/numpy_ext/docscrape.py b/doc/sphinxext/numpy_ext/docscrape.py
deleted file mode 100644
index 5e01fd8..0000000
--- a/doc/sphinxext/numpy_ext/docscrape.py
+++ /dev/null
@@ -1,512 +0,0 @@
-"""Extract reference documentation from the NumPy source tree.
-
-"""
-
-import inspect
-import textwrap
-import re
-import pydoc
-from warnings import warn, catch_warnings
-# Try Python 2 first, otherwise load from Python 3
-try:
- from StringIO import StringIO
-except:
- from io import StringIO
-
-
-class Reader(object):
- """A line-based string reader.
-
- """
- def __init__(self, data):
- """
- Parameters
- ----------
- data : str
- String with lines separated by '\n'.
-
- """
- if isinstance(data, list):
- self._str = data
- else:
- self._str = data.split('\n') # store string as list of lines
-
- self.reset()
-
- def __getitem__(self, n):
- return self._str[n]
-
- def reset(self):
- self._l = 0 # current line nr
-
- def read(self):
- if not self.eof():
- out = self[self._l]
- self._l += 1
- return out
- else:
- return ''
-
- def seek_next_non_empty_line(self):
- for l in self[self._l:]:
- if l.strip():
- break
- else:
- self._l += 1
-
- def eof(self):
- return self._l >= len(self._str)
-
- def read_to_condition(self, condition_func):
- start = self._l
- for line in self[start:]:
- if condition_func(line):
- return self[start:self._l]
- self._l += 1
- if self.eof():
- return self[start:self._l + 1]
- return []
-
- def read_to_next_empty_line(self):
- self.seek_next_non_empty_line()
-
- def is_empty(line):
- return not line.strip()
- return self.read_to_condition(is_empty)
-
- def read_to_next_unindented_line(self):
- def is_unindented(line):
- return (line.strip() and (len(line.lstrip()) == len(line)))
- return self.read_to_condition(is_unindented)
-
- def peek(self, n=0):
- if self._l + n < len(self._str):
- return self[self._l + n]
- else:
- return ''
-
- def is_empty(self):
- return not ''.join(self._str).strip()
-
-
-class NumpyDocString(object):
- def __init__(self, docstring, config={}):
- docstring = textwrap.dedent(docstring).split('\n')
-
- self._doc = Reader(docstring)
- self._parsed_data = {
- 'Signature': '',
- 'Summary': [''],
- 'Extended Summary': [],
- 'Parameters': [],
- 'Returns': [],
- 'Raises': [],
- 'Warns': [],
- 'Other Parameters': [],
- 'Attributes': [],
- 'Methods': [],
- 'See Also': [],
- 'Notes': [],
- 'Warnings': [],
- 'References': '',
- 'Examples': '',
- 'index': {}
- }
-
- self._parse()
-
- def __getitem__(self, key):
- return self._parsed_data[key]
-
- def __setitem__(self, key, val):
- if key not in self._parsed_data:
- warn("Unknown section %s" % key)
- else:
- self._parsed_data[key] = val
-
- def _is_at_section(self):
- self._doc.seek_next_non_empty_line()
-
- if self._doc.eof():
- return False
-
- l1 = self._doc.peek().strip() # e.g. Parameters
-
- if l1.startswith('.. index::'):
- return True
-
- l2 = self._doc.peek(1).strip() # ---------- or ==========
- return l2.startswith('-' * len(l1)) or l2.startswith('=' * len(l1))
-
- def _strip(self, doc):
- i = 0
- j = 0
- for i, line in enumerate(doc):
- if line.strip():
- break
-
- for j, line in enumerate(doc[::-1]):
- if line.strip():
- break
-
- return doc[i:len(doc) - j]
-
- def _read_to_next_section(self):
- section = self._doc.read_to_next_empty_line()
-
- while not self._is_at_section() and not self._doc.eof():
- if not self._doc.peek(-1).strip(): # previous line was empty
- section += ['']
-
- section += self._doc.read_to_next_empty_line()
-
- return section
-
- def _read_sections(self):
- while not self._doc.eof():
- data = self._read_to_next_section()
- name = data[0].strip()
-
- if name.startswith('..'): # index section
- yield name, data[1:]
- elif len(data) < 2:
- yield StopIteration
- else:
- yield name, self._strip(data[2:])
-
- def _parse_param_list(self, content):
- r = Reader(content)
- params = []
- while not r.eof():
- header = r.read().strip()
- if ' : ' in header:
- arg_name, arg_type = header.split(' : ')[:2]
- else:
- arg_name, arg_type = header, ''
-
- desc = r.read_to_next_unindented_line()
- desc = dedent_lines(desc)
-
- params.append((arg_name, arg_type, desc))
-
- return params
-
- _name_rgx = re.compile(r"^\s*(:(?P<role>\w+):`(?P<name>[a-zA-Z0-9_.-]+)`|"
- r" (?P<name2>[a-zA-Z0-9_.-]+))\s*", re.X)
-
- def _parse_see_also(self, content):
- """
- func_name : Descriptive text
- continued text
- another_func_name : Descriptive text
- func_name1, func_name2, :meth:`func_name`, func_name3
-
- """
- items = []
-
- def parse_item_name(text):
- """Match ':role:`name`' or 'name'"""
- m = self._name_rgx.match(text)
- if m:
- g = m.groups()
- if g[1] is None:
- return g[3], None
- else:
- return g[2], g[1]
- raise ValueError("%s is not a item name" % text)
-
- def push_item(name, rest):
- if not name:
- return
- name, role = parse_item_name(name)
- items.append((name, list(rest), role))
- del rest[:]
-
- current_func = None
- rest = []
-
- for line in content:
- if not line.strip():
- continue
-
- m = self._name_rgx.match(line)
- if m and line[m.end():].strip().startswith(':'):
- push_item(current_func, rest)
- current_func, line = line[:m.end()], line[m.end():]
- rest = [line.split(':', 1)[1].strip()]
- if not rest[0]:
- rest = []
- elif not line.startswith(' '):
- push_item(current_func, rest)
- current_func = None
- if ',' in line:
- for func in line.split(','):
- push_item(func, [])
- elif line.strip():
- current_func = line
- elif current_func is not None:
- rest.append(line.strip())
- push_item(current_func, rest)
- return items
-
- def _parse_index(self, section, content):
- """
- .. index: default
- :refguide: something, else, and more
-
- """
- def strip_each_in(lst):
- return [s.strip() for s in lst]
-
- out = {}
- section = section.split('::')
- if len(section) > 1:
- out['default'] = strip_each_in(section[1].split(','))[0]
- for line in content:
- line = line.split(':')
- if len(line) > 2:
- out[line[1]] = strip_each_in(line[2].split(','))
- return out
-
- def _parse_summary(self):
- """Grab signature (if given) and summary"""
- if self._is_at_section():
- return
-
- summary = self._doc.read_to_next_empty_line()
- summary_str = " ".join([s.strip() for s in summary]).strip()
- if re.compile('^([\w., ]+=)?\s*[\w\.]+\(.*\)$').match(summary_str):
- self['Signature'] = summary_str
- if not self._is_at_section():
- self['Summary'] = self._doc.read_to_next_empty_line()
- else:
- self['Summary'] = summary
-
- if not self._is_at_section():
- self['Extended Summary'] = self._read_to_next_section()
-
- def _parse(self):
- self._doc.reset()
- self._parse_summary()
-
- for (section, content) in self._read_sections():
- if not section.startswith('..'):
- section = ' '.join([s.capitalize()
- for s in section.split(' ')])
- if section in ('Parameters', 'Attributes', 'Methods',
- 'Returns', 'Raises', 'Warns'):
- self[section] = self._parse_param_list(content)
- elif section.startswith('.. index::'):
- self['index'] = self._parse_index(section, content)
- elif section == 'See Also':
- self['See Also'] = self._parse_see_also(content)
- else:
- self[section] = content
-
- # string conversion routines
-
- def _str_header(self, name, symbol='-'):
- return [name, len(name) * symbol]
-
- def _str_indent(self, doc, indent=4):
- out = []
- for line in doc:
- out += [' ' * indent + line]
- return out
-
- def _str_signature(self):
- if self['Signature']:
- return [self['Signature'].replace('*', '\*')] + ['']
- else:
- return ['']
-
- def _str_summary(self):
- if self['Summary']:
- return self['Summary'] + ['']
- else:
- return []
-
- def _str_extended_summary(self):
- if self['Extended Summary']:
- return self['Extended Summary'] + ['']
- else:
- return []
-
- def _str_param_list(self, name):
- out = []
- if self[name]:
- out += self._str_header(name)
- for param, param_type, desc in self[name]:
- out += ['%s : %s' % (param, param_type)]
- out += self._str_indent(desc)
- out += ['']
- return out
-
- def _str_section(self, name):
- out = []
- if self[name]:
- out += self._str_header(name)
- out += self[name]
- out += ['']
- return out
-
- def _str_see_also(self, func_role):
- if not self['See Also']:
- return []
- out = []
- out += self._str_header("See Also")
- last_had_desc = True
- for func, desc, role in self['See Also']:
- if role:
- link = ':%s:`%s`' % (role, func)
- elif func_role:
- link = ':%s:`%s`' % (func_role, func)
- else:
- link = "`%s`_" % func
- if desc or last_had_desc:
- out += ['']
- out += [link]
- else:
- out[-1] += ", %s" % link
- if desc:
- out += self._str_indent([' '.join(desc)])
- last_had_desc = True
- else:
- last_had_desc = False
- out += ['']
- return out
-
- def _str_index(self):
- idx = self['index']
- out = []
- out += ['.. index:: %s' % idx.get('default', '')]
- for section, references in idx.iteritems():
- if section == 'default':
- continue
- out += [' :%s: %s' % (section, ', '.join(references))]
- return out
-
- def __str__(self, func_role=''):
- out = []
- out += self._str_signature()
- out += self._str_summary()
- out += self._str_extended_summary()
- for param_list in ('Parameters', 'Returns', 'Raises'):
- out += self._str_param_list(param_list)
- out += self._str_section('Warnings')
- out += self._str_see_also(func_role)
- for s in ('Notes', 'References', 'Examples'):
- out += self._str_section(s)
- for param_list in ('Attributes', 'Methods'):
- out += self._str_param_list(param_list)
- out += self._str_index()
- return '\n'.join(out)
-
-
-def indent(str, indent=4):
- indent_str = ' ' * indent
- if str is None:
- return indent_str
- lines = str.split('\n')
- return '\n'.join(indent_str + l for l in lines)
-
-
-def dedent_lines(lines):
- """Deindent a list of lines maximally"""
- return textwrap.dedent("\n".join(lines)).split("\n")
-
-
-def header(text, style='-'):
- return text + '\n' + style * len(text) + '\n'
-
-
-class FunctionDoc(NumpyDocString):
- def __init__(self, func, role='func', doc=None, config={}):
- self._f = func
- self._role = role # e.g. "func" or "meth"
-
- if doc is None:
- if func is None:
- raise ValueError("No function or docstring given")
- doc = inspect.getdoc(func) or ''
- NumpyDocString.__init__(self, doc)
-
- if not self['Signature'] and func is not None:
- func, func_name = self.get_func()
- try:
- # try to read signature
- with catch_warnings(record=True):
- argspec = inspect.getargspec(func)
- argspec = inspect.formatargspec(*argspec)
- argspec = argspec.replace('*', '\*')
- signature = '%s%s' % (func_name, argspec)
- except TypeError as e:
- signature = '%s()' % func_name
- self['Signature'] = signature
-
- def get_func(self):
- func_name = getattr(self._f, '__name__', self.__class__.__name__)
- if inspect.isclass(self._f):
- func = getattr(self._f, '__call__', self._f.__init__)
- else:
- func = self._f
- return func, func_name
-
- def __str__(self):
- out = ''
-
- func, func_name = self.get_func()
- signature = self['Signature'].replace('*', '\*')
-
- roles = {'func': 'function',
- 'meth': 'method'}
-
- if self._role:
- if not roles.has_key(self._role):
- print("Warning: invalid role %s" % self._role)
- out += '.. %s:: %s\n \n\n' % (roles.get(self._role, ''),
- func_name)
-
- out += super(FunctionDoc, self).__str__(func_role=self._role)
- return out
-
-
-class ClassDoc(NumpyDocString):
- def __init__(self, cls, doc=None, modulename='', func_doc=FunctionDoc,
- config=None):
- if not inspect.isclass(cls) and cls is not None:
- raise ValueError("Expected a class or None, but got %r" % cls)
- self._cls = cls
-
- if modulename and not modulename.endswith('.'):
- modulename += '.'
- self._mod = modulename
-
- if doc is None:
- if cls is None:
- raise ValueError("No class or documentation string given")
- doc = pydoc.getdoc(cls)
-
- NumpyDocString.__init__(self, doc)
-
- if config is not None and config.get('show_class_members', True):
- if not self['Methods']:
- self['Methods'] = [(name, '', '')
- for name in sorted(self.methods)]
- if not self['Attributes']:
- self['Attributes'] = [(name, '', '')
- for name in sorted(self.properties)]
-
- @property
- def methods(self):
- if self._cls is None:
- return []
- return [name for name, func in inspect.getmembers(self._cls)
- if not name.startswith('_') and callable(func)]
-
- @property
- def properties(self):
- if self._cls is None:
- return []
- return [name for name, func in inspect.getmembers(self._cls)
- if not name.startswith('_') and func is None]
diff --git a/doc/sphinxext/numpy_ext/docscrape_sphinx.py b/doc/sphinxext/numpy_ext/docscrape_sphinx.py
deleted file mode 100644
index ca28300..0000000
--- a/doc/sphinxext/numpy_ext/docscrape_sphinx.py
+++ /dev/null
@@ -1,240 +0,0 @@
-import re
-import inspect
-import textwrap
-import pydoc
-from .docscrape import NumpyDocString
-from .docscrape import FunctionDoc
-from .docscrape import ClassDoc
-
-
-class SphinxDocString(NumpyDocString):
- def __init__(self, docstring, config=None):
- config = {} if config is None else config
- self.use_plots = config.get('use_plots', False)
- NumpyDocString.__init__(self, docstring, config=config)
-
- # string conversion routines
- def _str_header(self, name, symbol='`'):
- return ['.. rubric:: ' + name, '']
-
- def _str_field_list(self, name):
- return [':' + name + ':']
-
- def _str_indent(self, doc, indent=4):
- out = []
- for line in doc:
- out += [' ' * indent + line]
- return out
-
- def _str_signature(self):
- return ['']
- if self['Signature']:
- return ['``%s``' % self['Signature']] + ['']
- else:
- return ['']
-
- def _str_summary(self):
- return self['Summary'] + ['']
-
- def _str_extended_summary(self):
- return self['Extended Summary'] + ['']
-
- def _str_param_list(self, name):
- out = []
- if self[name]:
- out += self._str_field_list(name)
- out += ['']
- for param, param_type, desc in self[name]:
- out += self._str_indent(['**%s** : %s' % (param.strip(),
- param_type)])
- out += ['']
- out += self._str_indent(desc, 8)
- out += ['']
- return out
-
- @property
- def _obj(self):
- if hasattr(self, '_cls'):
- return self._cls
- elif hasattr(self, '_f'):
- return self._f
- return None
-
- def _str_member_list(self, name):
- """
- Generate a member listing, autosummary:: table where possible,
- and a table where not.
-
- """
- out = []
- if self[name]:
- out += ['.. rubric:: %s' % name, '']
- prefix = getattr(self, '_name', '')
-
- if prefix:
- prefix = '~%s.' % prefix
-
- autosum = []
- others = []
- for param, param_type, desc in self[name]:
- param = param.strip()
- if not self._obj or hasattr(self._obj, param):
- autosum += [" %s%s" % (prefix, param)]
- else:
- others.append((param, param_type, desc))
-
- if autosum:
- # GAEL: Toctree commented out below because it creates
- # hundreds of sphinx warnings
- # out += ['.. autosummary::', ' :toctree:', '']
- out += ['.. autosummary::', '']
- out += autosum
-
- if others:
- maxlen_0 = max([len(x[0]) for x in others])
- maxlen_1 = max([len(x[1]) for x in others])
- hdr = "=" * maxlen_0 + " " + "=" * maxlen_1 + " " + "=" * 10
- fmt = '%%%ds %%%ds ' % (maxlen_0, maxlen_1)
- n_indent = maxlen_0 + maxlen_1 + 4
- out += [hdr]
- for param, param_type, desc in others:
- out += [fmt % (param.strip(), param_type)]
- out += self._str_indent(desc, n_indent)
- out += [hdr]
- out += ['']
- return out
-
- def _str_section(self, name):
- out = []
- if self[name]:
- out += self._str_header(name)
- out += ['']
- content = textwrap.dedent("\n".join(self[name])).split("\n")
- out += content
- out += ['']
- return out
-
- def _str_see_also(self, func_role):
- out = []
- if self['See Also']:
- see_also = super(SphinxDocString, self)._str_see_also(func_role)
- out = ['.. seealso::', '']
- out += self._str_indent(see_also[2:])
- return out
-
- def _str_warnings(self):
- out = []
- if self['Warnings']:
- out = ['.. warning::', '']
- out += self._str_indent(self['Warnings'])
- return out
-
- def _str_index(self):
- idx = self['index']
- out = []
- if len(idx) == 0:
- return out
-
- out += ['.. index:: %s' % idx.get('default', '')]
- for section, references in idx.iteritems():
- if section == 'default':
- continue
- elif section == 'refguide':
- out += [' single: %s' % (', '.join(references))]
- else:
- out += [' %s: %s' % (section, ','.join(references))]
- return out
-
- def _str_references(self):
- out = []
- if self['References']:
- out += self._str_header('References')
- if isinstance(self['References'], str):
- self['References'] = [self['References']]
- out.extend(self['References'])
- out += ['']
- # Latex collects all references to a separate bibliography,
- # so we need to insert links to it
- import sphinx # local import to avoid test dependency
- if sphinx.__version__ >= "0.6":
- out += ['.. only:: latex', '']
- else:
- out += ['.. latexonly::', '']
- items = []
- for line in self['References']:
- m = re.match(r'.. \[([a-z0-9._-]+)\]', line, re.I)
- if m:
- items.append(m.group(1))
- out += [' ' + ", ".join(["[%s]_" % item for item in items]), '']
- return out
-
- def _str_examples(self):
- examples_str = "\n".join(self['Examples'])
-
- if (self.use_plots and 'import matplotlib' in examples_str
- and 'plot::' not in examples_str):
- out = []
- out += self._str_header('Examples')
- out += ['.. plot::', '']
- out += self._str_indent(self['Examples'])
- out += ['']
- return out
- else:
- return self._str_section('Examples')
-
- def __str__(self, indent=0, func_role="obj"):
- out = []
- out += self._str_signature()
- out += self._str_index() + ['']
- out += self._str_summary()
- out += self._str_extended_summary()
- for param_list in ('Parameters', 'Returns', 'Raises', 'Attributes'):
- out += self._str_param_list(param_list)
- out += self._str_warnings()
- out += self._str_see_also(func_role)
- out += self._str_section('Notes')
- out += self._str_references()
- out += self._str_examples()
- for param_list in ('Methods',):
- out += self._str_member_list(param_list)
- out = self._str_indent(out, indent)
- return '\n'.join(out)
-
-
-class SphinxFunctionDoc(SphinxDocString, FunctionDoc):
- def __init__(self, obj, doc=None, config={}):
- self.use_plots = config.get('use_plots', False)
- FunctionDoc.__init__(self, obj, doc=doc, config=config)
-
-
-class SphinxClassDoc(SphinxDocString, ClassDoc):
- def __init__(self, obj, doc=None, func_doc=None, config={}):
- self.use_plots = config.get('use_plots', False)
- ClassDoc.__init__(self, obj, doc=doc, func_doc=None, config=config)
-
-
-class SphinxObjDoc(SphinxDocString):
- def __init__(self, obj, doc=None, config=None):
- self._f = obj
- SphinxDocString.__init__(self, doc, config=config)
-
-
-def get_doc_object(obj, what=None, doc=None, config={}):
- if what is None:
- if inspect.isclass(obj):
- what = 'class'
- elif inspect.ismodule(obj):
- what = 'module'
- elif callable(obj):
- what = 'function'
- else:
- what = 'object'
- if what == 'class':
- return SphinxClassDoc(obj, func_doc=SphinxFunctionDoc, doc=doc,
- config=config)
- elif what in ('function', 'method'):
- return SphinxFunctionDoc(obj, doc=doc, config=config)
- else:
- if doc is None:
- doc = pydoc.getdoc(obj)
- return SphinxObjDoc(obj, doc, config=config)
diff --git a/doc/sphinxext/numpy_ext/numpydoc.py b/doc/sphinxext/numpy_ext/numpydoc.py
deleted file mode 100644
index 6ff03e0..0000000
--- a/doc/sphinxext/numpy_ext/numpydoc.py
+++ /dev/null
@@ -1,192 +0,0 @@
-"""
-========
-numpydoc
-========
-
-Sphinx extension that handles docstrings in the Numpy standard format. [1]
-
-It will:
-
-- Convert Parameters etc. sections to field lists.
-- Convert See Also section to a See also entry.
-- Renumber references.
-- Extract the signature from the docstring, if it can't be determined
- otherwise.
-
-.. [1] http://projects.scipy.org/numpy/wiki/CodingStyleGuidelines#docstring-standard
-
-"""
-
-from __future__ import unicode_literals
-
-import sys # Only needed to check Python version
-import os
-import re
-import pydoc
-from .docscrape_sphinx import get_doc_object
-from .docscrape_sphinx import SphinxDocString
-import inspect
-
-
-def mangle_docstrings(app, what, name, obj, options, lines,
- reference_offset=[0]):
-
- cfg = dict(use_plots=app.config.numpydoc_use_plots,
- show_class_members=app.config.numpydoc_show_class_members)
-
- if what == 'module':
- # Strip top title
- title_re = re.compile(r'^\s*[#*=]{4,}\n[a-z0-9 -]+\n[#*=]{4,}\s*',
- re.I | re.S)
- lines[:] = title_re.sub('', "\n".join(lines)).split("\n")
- else:
- doc = get_doc_object(obj, what, "\n".join(lines), config=cfg)
- if sys.version_info[0] < 3:
- lines[:] = unicode(doc).splitlines()
- else:
- lines[:] = str(doc).splitlines()
-
- if app.config.numpydoc_edit_link and hasattr(obj, '__name__') and \
- obj.__name__:
- if hasattr(obj, '__module__'):
- v = dict(full_name="%s.%s" % (obj.__module__, obj.__name__))
- else:
- v = dict(full_name=obj.__name__)
- lines += [u'', u'.. htmlonly::', '']
- lines += [u' %s' % x for x in
- (app.config.numpydoc_edit_link % v).split("\n")]
-
- # replace reference numbers so that there are no duplicates
- references = []
- for line in lines:
- line = line.strip()
- m = re.match(r'^.. \[([a-z0-9_.-])\]', line, re.I)
- if m:
- references.append(m.group(1))
-
- # start renaming from the longest string, to avoid overwriting parts
- references.sort(key=lambda x: -len(x))
- if references:
- for i, line in enumerate(lines):
- for r in references:
- if re.match(r'^\d+$', r):
- new_r = "R%d" % (reference_offset[0] + int(r))
- else:
- new_r = u"%s%d" % (r, reference_offset[0])
- lines[i] = lines[i].replace(u'[%s]_' % r,
- u'[%s]_' % new_r)
- lines[i] = lines[i].replace(u'.. [%s]' % r,
- u'.. [%s]' % new_r)
-
- reference_offset[0] += len(references)
-
-
-def mangle_signature(app, what, name, obj,
- options, sig, retann):
- # Do not try to inspect classes that don't define `__init__`
- if (inspect.isclass(obj) and
- (not hasattr(obj, '__init__') or
- 'initializes x; see ' in pydoc.getdoc(obj.__init__))):
- return '', ''
-
- if not (callable(obj) or hasattr(obj, '__argspec_is_invalid_')):
- return
- if not hasattr(obj, '__doc__'):
- return
-
- doc = SphinxDocString(pydoc.getdoc(obj))
- if doc['Signature']:
- sig = re.sub("^[^(]*", "", doc['Signature'])
- return sig, ''
-
-
-def setup(app, get_doc_object_=get_doc_object):
- global get_doc_object
- get_doc_object = get_doc_object_
-
- if sys.version_info[0] < 3:
- app.connect(b'autodoc-process-docstring', mangle_docstrings)
- app.connect(b'autodoc-process-signature', mangle_signature)
- else:
- app.connect('autodoc-process-docstring', mangle_docstrings)
- app.connect('autodoc-process-signature', mangle_signature)
- app.add_config_value('numpydoc_edit_link', None, False)
- app.add_config_value('numpydoc_use_plots', None, False)
- app.add_config_value('numpydoc_show_class_members', True, True)
-
- # Extra mangling domains
- app.add_domain(NumpyPythonDomain)
- app.add_domain(NumpyCDomain)
-
-#-----------------------------------------------------------------------------
-# Docstring-mangling domains
-#-----------------------------------------------------------------------------
-
-try:
- import sphinx # lazy to avoid test dependency
-except ImportError:
- CDomain = PythonDomain = object
-else:
- from sphinx.domains.c import CDomain
- from sphinx.domains.python import PythonDomain
-
-
-class ManglingDomainBase(object):
- directive_mangling_map = {}
-
- def __init__(self, *a, **kw):
- super(ManglingDomainBase, self).__init__(*a, **kw)
- self.wrap_mangling_directives()
-
- def wrap_mangling_directives(self):
- for name, objtype in self.directive_mangling_map.items():
- self.directives[name] = wrap_mangling_directive(
- self.directives[name], objtype)
-
-
-class NumpyPythonDomain(ManglingDomainBase, PythonDomain):
- name = 'np'
- directive_mangling_map = {
- 'function': 'function',
- 'class': 'class',
- 'exception': 'class',
- 'method': 'function',
- 'classmethod': 'function',
- 'staticmethod': 'function',
- 'attribute': 'attribute',
- }
-
-
-class NumpyCDomain(ManglingDomainBase, CDomain):
- name = 'np-c'
- directive_mangling_map = {
- 'function': 'function',
- 'member': 'attribute',
- 'macro': 'function',
- 'type': 'class',
- 'var': 'object',
- }
-
-
-def wrap_mangling_directive(base_directive, objtype):
- class directive(base_directive):
- def run(self):
- env = self.state.document.settings.env
-
- name = None
- if self.arguments:
- m = re.match(r'^(.*\s+)?(.*?)(\(.*)?', self.arguments[0])
- name = m.group(2).strip()
-
- if not name:
- name = self.arguments[0]
-
- lines = list(self.content)
- mangle_docstrings(env.app, objtype, name, None, None, lines)
- # local import to avoid testing dependency
- from docutils.statemachine import ViewList
- self.content = ViewList(lines, self.content.parent)
-
- return base_directive.run(self)
-
- return directive
diff --git a/doc/tutorials.rst b/doc/tutorials.rst
index e9c1e6d..a67d6d8 100644
--- a/doc/tutorials.rst
+++ b/doc/tutorials.rst
@@ -43,12 +43,24 @@ For further reading:
.. raw:: html
<h2>Introduction to MNE and Python</h2>
-
+
.. toctree::
:maxdepth: 1
auto_tutorials/plot_python_intro.rst
tutorials/seven_stories_about_mne.rst
+ auto_tutorials/plot_introduction.rst
+
+.. container:: span box
+
+ .. raw:: html
+
+ <h2>Background information</h2>
+
+ .. toctree::
+ :maxdepth: 1
+
+ auto_tutorials/plot_background_filtering.rst
.. container:: span box
@@ -104,6 +116,7 @@ For further reading:
:maxdepth: 1
auto_tutorials/plot_object_raw.rst
+ auto_tutorials/plot_modifying_data_inplace.rst
auto_tutorials/plot_object_epochs.rst
auto_tutorials/plot_object_evoked.rst
auto_tutorials/plot_creating_data_structures.rst
@@ -123,6 +136,9 @@ For further reading:
auto_tutorials/plot_mne_dspm_source_localization.rst
auto_tutorials/plot_dipole_fit.rst
auto_tutorials/plot_brainstorm_auditory.rst
+ auto_tutorials/plot_brainstorm_phantom_ctf.rst
+ auto_tutorials/plot_brainstorm_phantom_elekta.rst
+ auto_tutorials/plot_point_spread.rst
.. container:: span box
diff --git a/doc/whats_new.rst b/doc/whats_new.rst
index 8a5ba70..0b30924 100644
--- a/doc/whats_new.rst
+++ b/doc/whats_new.rst
@@ -7,6 +7,251 @@ What's new
Note, we are now using links to highlight new functions and classes.
Please be sure to follow the examples below like :func:`mne.stats.f_mway_rm`, so the whats_new page will have a link to the function/class documentation.
+.. currentmodule:: mne
+
+.. _changes_0_13:
+
+Version 0.13
+------------
+
+Changelog
+~~~~~~~~~
+
+ - Add new class :class:`AcqParserFIF` to parse Elekta/Neuromag MEG acquisition info, allowing e.g. collecting epochs according to acquisition-defined averaging categories by `Jussi Nurminen`_
+
+ - Adds automatic determination of FIR filter parameters ``filter_length``, ``l_trans_bandwidth``, and ``h_trans_bandwidth`` and adds ``phase`` argument in e.g. in :meth:`mne.io.Raw.filter` by `Eric Larson`_
+
+ - Adds faster ``n_fft='auto'`` option to :meth:`mne.io.Raw.apply_hilbert` by `Eric Larson`_
+
+ - Adds new function :func:`mne.time_frequency.csd_array` to compute the cross-spectral density of multivariate signals stored in an array, by `Nick Foti`_
+
+ - Add order params 'selection' and 'position' for :func:`mne.viz.plot_raw` to allow plotting of specific brain regions by `Jaakko Leppakangas`_
+
+ - Added the ability to decimate :class:`mne.Evoked` objects with :func:`mne.Evoked.decimate` by `Eric Larson`_
+
+ - Add generic array-filtering function :func:`mne.filter.filter_data` by `Eric Larson`_
+
+ - :func:`mne.viz.plot_trans` now also shows head position indicators by `Christian Brodbeck`_
+
+ - Add label center of mass function :func:`mne.Label.center_of_mass` by `Eric Larson`_
+
+ - Added :func:`mne.viz.plot_ica_properties` that allows plotting of independent component properties similar to ``pop_prop`` in EEGLAB. Also :class:`mne.preprocessing.ICA` has :func:`mne.preprocessing.ICA.plot_properties` method now. Added by `Mikołaj Magnuski`_
+
+ - Add second-order sections (instead of ``(b, a)`` form) IIR filtering for reduced numerical error by `Eric Larson`_
+
+ - Add interactive colormap option to image plotting functions by `Jaakko Leppakangas`_
+
+ - Add support for the University of Maryland KIT system by `Christian Brodbeck`_
+
+ - Add support for \*.elp and \*.hsp files to the KIT2FIFF converter and :func:`mne.channels.read_dig_montage` by `Teon Brooks`_ and `Christian Brodbeck`_
+
+ - Add option to preview events in the KIT2FIFF GUI by `Christian Brodbeck`_
+
+ - Add approximation of size of :class:`io.Raw`, :class:`Epochs`, and :class:`Evoked` in :func:`repr` by `Eric Larson`_
+
+ - Add possibility to select a subset of sensors by lasso selector to :func:`mne.viz.plot_sensors` and :func:`mne.viz.plot_raw` when using order='selection' or order='position' by `Jaakko Leppakangas`_
+
+ - Add the option to plot brain surfaces and source spaces to :func:`viz.plot_bem` by `Christian Brodbeck`_
+
+ - Add the ``--filterchpi`` option to :ref:`mne browse_raw <gen_mne_browse_raw>`, by `Felix Raimundo`_
+
+ - Add the ``--no-decimate`` option to :ref:`mne make_scalp_surfaces <gen_mne_make_scalp_surfaces>` to skip the high-resolution surface decimation step, by `Eric Larson`_
+
+ - Add new class :class:`mne.decoding.EMS` to transform epochs with the event-matched spatial filters and add 'cv' parameter to :func:`mne.decoding.compute_ems`, by `Jean-Remi King`_
+
+ - Added :class:`mne.time_frequency.EpochsTFR` and average parameter in :func:`mne.time_frequency.tfr_morlet` and :func:`mne.time_frequency.tfr_multitaper` to compute time-frequency transforms on single trial epochs without averaging, by `Jean-Remi King`_ and `Alex Gramfort`_
+
+ - Added :class:`mne.decoding.TimeFrequency` to transform signals in scikit-learn pipelines, by `Jean-Remi King`_
+
+ - Added :class:`mne.decoding.UnsupervisedSpatialFilter` providing interface for scikit-learn decomposition algorithms to be used with MNE data, by `Jean-Remi King`_ and `Asish Panda`_
+
+ - Added support for multiclass decoding in :class:`mne.decoding.CSP`, by `Jean-Remi King`_ and `Alexandre Barachant`_
+
+ - Components obtained from :class:`mne.preprocessing.ICA` are now sorted by explained variance, by `Mikołaj Magnuski`_
+
+ - Adding an EEG reference channel using :func:`mne.io.add_reference_channels` will now use its digitized location from the FIFF file, if present, by `Chris Bailey`_
+
+ - Added interactivity to :func:`mne.preprocessing.ICA.plot_components` - passing an instance of :class:`io.Raw` or :class:`Epochs` in ``inst`` argument allows to open component properties by clicking on component topomaps, by `Mikołaj Magnuski`_
+
+ - Adds new function :func:`mne.viz.plot_compare_evokeds` to show multiple evoked time courses at a single location, or the mean over a ROI, or the GFP, automatically averaging and calculating a CI if multiple subjects are given, by `Jona Sassenhagen`_
+
+ - Added `transform_into` parameter into :class:`mne.decoding.CSP` to retrieve the average power of each source or the time course of each source, by `Jean-Remi King`_
+
+ - Added support for reading MaxShield (IAS) evoked data (e.g., from the acquisition machine) in :func:`mne.read_evokeds` by `Eric Larson`_
+
+ - Added support for functional near-infrared spectroscopy (fNIRS) channels by `Jaakko Leppakangas`_
+
+BUG
+~~~
+
+ - Fixed a bug where selecting epochs using hierarchical event IDs (HIDs) was *and*-like instead of *or*-like. When doing e.g. ``epochs[('Auditory', 'Left')]``, previously all trials that contain ``'Auditory'`` *and* ``'Left'`` (like ``'Auditory/Left'``) would be selected, but now any conditions matching ``'Auditory'`` *or* ``'Left'`` will be selected (like ``'Auditory/Left'``, ``'Auditory/Right'``, and ``'Visual/Left'``). This is now consistent with how epoch selection was done witho [...]
+
+ - Fixed Infomax/Extended Infomax when the user provides an initial weights matrix by `Jair Montoya Martinez`_
+
+ - Fixed the default raw FIF writing buffer size to be 1 second instead of 10 seconds by `Eric Larson`_
+
+ - Fixed channel selection order when MEG channels do not come first in :func:`mne.preprocessing.maxwell_filter` by `Eric Larson`_
+
+ - Fixed color ranges to correspond to the colorbar when plotting several time instances with :func:`mne.viz.plot_evoked_topomap` by `Jaakko Leppakangas`_
+
+ - Added units to :func:`mne.io.read_raw_brainvision` for reading non-data channels and enable default behavior of inferring channel type by unit by `Jaakko Leppakangas`_ and `Pablo-Arias`_
+
+ - Fixed minor bugs with :func:`mne.Epochs.resample` and :func:`mne.Epochs.decimate` by `Eric Larson`_
+
+ - Fixed a bug where duplicate vertices were not strictly checked by :func:`mne.simulation.simulate_stc` by `Eric Larson`_
+
+ - Fixed a bug where some FIF files could not be read with :func:`mne.io.show_fiff` by `Christian Brodbeck`_ and `Eric Larson`_
+
+ - Fixed a bug where ``merge_grads=True`` causes :func:`mne.viz.plot_evoked_topo` to fail when plotting a list of evokeds by `Jaakko Leppakangas`_
+
+ - Fixed a bug when setting multiple bipolar references with :func:`mne.io.set_bipolar_reference` by `Marijn van Vliet`_.
+
+ - Fixed image scaling in :func:`mne.viz.plot_epochs_image` when plotting more than one channel by `Jaakko Leppakangas`_
+
+ - Fixed :class:`mne.preprocessing.Xdawn` to fit shuffled epochs by `Jean-Remi King`_
+
+ - Fixed a bug with channel order determination that could lead to an ``AssertionError`` when using :class:`mne.Covariance` matrices by `Eric Larson`_
+
+ - Fixed the check for CTF gradient compensation in :func:`mne.preprocessing.maxwell_filter` by `Eric Larson`_
+
+ - Fixed the import of EDF files with encoding characters in :func:`mne.io.read_raw_edf` by `Guillaume Dumas`_
+
+ - Fixed :class:`mne.Epochs` to ensure that detrend parameter is not a boolean by `Jean-Remi King`_
+
+ - Fixed bug with :func:`mne.realtime.FieldTripClient.get_data_as_epoch` when ``picks=None`` which crashed the function by `Mainak Jas`_
+
+ - Fixed reading of units in ``.elc`` montage files (from ``UnitsPosition`` field) so that :class:`mne.channels.Montage` objects are now returned with the ``pos`` attribute correctly in meters, by `Chris Mullins`_
+
+ - Fixed reading of BrainVision files by `Phillip Alday`_:
+
+ - Greater support for BVA files, especially older ones: alternate text coding schemes with fallback to Latin-1 as well as units in column headers
+
+ - Use online software filter information when present
+
+ - Fix comparisons of filter settings for determining "strictest"/"weakest" filter
+
+ - Weakest filter is now used for heterogeneous channel filter settings, leading to more consistent behavior with filtering methods applied to a subset of channels (e.g. ``Raw.filter`` with ``picks != None``).
+
+ - Fixed plotting and timing of :class:`Annotations` and restricted addition of annotations outside data range to prevent problems with cropping and concatenating data by `Jaakko Leppakangas`_
+
+ - Fixed ICA plotting functions to refer to IC index instead of component number by `Andreas Hojlund`_ and `Jaakko Leppakangas`_
+
+ - Fixed bug with ``picks`` when interpolating MEG channels by `Mainak Jas`_.
+
+ - Fixed bug in padding of Stockwell transform for signal of length a power of 2 by `Johannes Niediek`_
+
+API
+~~~
+
+ - The ``add_eeg_ref`` argument in core functions like :func:`mne.io.read_raw_fif` and :class:`mne.Epochs` has been deprecated in favor of using :func:`mne.set_eeg_reference` and equivalent instance methods like :meth:`raw.set_eeg_reference() <mne.io.Raw.set_eeg_reference>`. In functions like :func:`mne.io.read_raw_fif` where the default in 0.13 and older versions is ``add_eeg_ref=True``, the default will change to ``add_eeg_ref=False`` in 0.14, and the argument will be removed in 0.15.
+
+ - Multiple aspects of FIR filtering in MNE-Python has been refactored:
+
+ 1. New recommended defaults for ``l_trans_bandwidth='auto'``, ``h_trans_bandwidth='auto'``, and ``filter_length='auto'``. This should generally reduce filter artifacts at the expense of slight decrease in effective filter stop-band attenuation. For details see :ref:`tut_filtering_in_python`. The default values of ``l_trans_bandwidth=h_trans_bandwidth=0.5`` and ``filter_length='10s'`` will change to ``'auto'`` in 0.14.
+
+ 2. The ``filter_length=None`` option (i.e. use ``len(x)``) has been deprecated.
+
+ 3. An improved ``phase='zero'`` zero-phase FIR filtering has been added. Instead of running the designed filter forward and backward, the filter is applied once and we compensate for the linear phase of the filter. The previous ``phase='zero-double'`` default will change to ``phase='zero'`` in 0.14.
+
+ 4. A warning is provided when the filter is longer than the signal of interest, as this is unlikely to produce desired results.
+
+ 5. Previously, if the filter was as long or longer than the signal of interest, direct FFT-based computations were used. Now a single code path (overlap-add filtering) is used for all FIR filters. This could cause minor changes in how short signals are filtered.
+
+ - Support for Python 2.6 has been dropped, and the minimum supported dependencies are NumPy_ 1.8, SciPy_ 0.12, and Matplotlib_ 1.3 by `Eric Larson`_
+
+ - When CTF gradient compensation is applied to raw data, it is no longer reverted on save of :meth:`mne.io.Raw.save` by `Eric Larson`_
+
+ - Adds :func:`mne.time_frequency.csd_epochs` to replace :func:`mne.time_frequency.csd_compute_epochs` for naming consistency. :func:`mne.time_frequency.csd_compute_epochs` is now deprecated and will be removed in mne 0.14, by `Nick Foti`_
+
+ - Weighted addition and subtraction of :class:`Evoked` as ``ev1 + ev2`` and ``ev1 - ev2`` have been deprecated, use explicit :func:`mne.combine_evoked(..., weights='nave') <mne.combine_evoked>` instead by `Eric Larson`_
+
+ - Deprecated support for passing a lits of filenames to :class:`mne.io.Raw` constructor, use :func:`mne.io.read_raw_fif` and :func:`mne.concatenate_raws` instead by `Eric Larson`_
+
+ - Added options for setting data and date formats manually in :func:`mne.io.read_raw_cnt` by `Jaakko Leppakangas`_
+
+ - Now channels with units of 'C', 'µS', 'uS', 'ARU' and 'S' will be turned to misc by default in :func:`mne.io.read_raw_brainvision` by `Jaakko Leppakangas`_
+
+ - Add :func:`mne.io.anonymize_info` function to anonymize measurements and add methods to :class:`mne.io.Raw`, :class:`mne.Epochs` and :class:`mne.Evoked`, by `Jean-Remi King`_
+
+ - Now it is possible to plot only a subselection of channels in :func:`mne.viz.plot_raw` by using an array for order parameter by `Jaakko Leppakangas`_
+
+ - EOG channels can now be incuded when calling :func:`mne.preprocessing.ICA.fit` and a proper error is raised when trying to include unsupported channels by `Alexander Rudiuk`_
+
+ - :func:`mne.concatenate_epochs` and :func:`mne.compute_covariance` now check to see if all :class:`Epochs` instances have the same MEG-to-Head transformation, and errors by default if they do not by `Eric Larson`_
+
+ - Added option to pass a list of axes to :func:`mne.viz.epochs.plot_epochs_image` by `Mikołaj Magnuski`_
+
+ - Constructing IIR filters in :func:`mne.filter.construct_iir_filter` defaults to ``output='ba'`` in 0.13 but this will be changed to ``output='sos'`` by `Eric Larson`_
+
+ - Add ``zorder`` parameter to :func:`mne.Evoked.plot` and derived functions to sort allow sorting channels by e.g. standard deviation, by `Jona Sassenhagen`_
+
+ - The ``baseline`` parameter of :func:`mne.Epochs.apply_baseline` is set by default (None, 0), by `Felix Raimundo`_
+
+ - Adds :func:`mne.Evoked.apply_baseline` to be consistent with :func:`mne.Epochs.apply_baseline`, by `Felix Raimundo`_
+
+ - Deprecated the `baseline` parameter in :class:`mne.Evoked`, by `Felix Raimundo`_
+
+ - The API of :meth:`mne.SourceEstimate.plot` and :func:`mne.viz.plot_source_estimates` has been updated to reflect current PySurfer 0.6 API. The ``config_opts`` parameter is now deprecated and will be removed in mne 0.14, and the default representation for time will change from ``ms`` to ``s`` in mne 0.14. By `Christian Brodbeck`_
+
+ - The default dataset location has been changed from ``examples/`` in the MNE-Python root directory to ``~/mne_data`` in the user's home directory, by `Eric Larson`_
+
+ - A new option ``set_env`` has been added to :func:`mne.set_config` that defaults to ``False`` in 0.13 but will change to ``True`` in 0.14, by `Eric Larson`_
+
+ - The ``compensation`` parameter in :func:`mne.io.read_raw_fif` has been deprecated in favor of the method :meth:`mne.io.Raw.apply_gradient_compensation` by `Eric Larson`_
+
+ - :class:`mne.decoding.EpochsVectorizer` has been deprecated in favor of :class:`mne.decoding.Vectorizer` by `Asish Panda`_
+
+ - The `epochs_data` parameter has been deprecated in :class:`mne.decoding.CSP`, in favour of the ``X`` parameter to comply to scikit-learn API, by `Jean-Remi King`_
+
+ - Deprecated :func:`mne.time_frequency.cwt_morlet` and :func:`mne.time_frequency.single_trial_power` in favour of :func:`mne.time_frequency.tfr_morlet` with parameter average=False, by `Jean-Remi King`_ and `Alex Gramfort`_
+
+ - Add argument ``mask_type`` to func:`mne.read_events` and func:`mne.find_events` to support MNE-C style of trigger masking by `Teon Brooks`_ and `Eric Larson`_
+
+ - Extended Infomax is now the new default in :func:`mne.preprocessing.infomax` (``extended=True``), by `Clemens Brunner`_
+
+ - :func:`mne.io.read_raw_eeglab` and :func:`mne.io.read_epochs_eeglab` now take additional argument ``uint16_codec`` that allows to define the encoding of character arrays in set file. This helps in rare cases when reading a set file fails with ``TypeError: buffer is too small for requested array``. By `Mikołaj Magnuski`_
+
+ - Added :class:`mne.decoding.TemporalFilter` to filter data in scikit-learn pipelines, by `Asish Panda`_
+
+ - :func:`mne.preprocessing.create_ecg_epochs` now includes all the channels when ``picks=None`` by `Jaakko Leppakangas`_
+
+Authors
+~~~~~~~
+
+The committer list for this release is the following (sorted by alphabetical order):
+
+ * Alexander Rudiuk
+ * Alexandre Barachant
+ * Alexandre Gramfort
+ * Asish Panda
+ * Camilo Lamus
+ * Chris Holdgraf
+ * Christian Brodbeck
+ * Christopher J. Bailey
+ * Christopher Mullins
+ * Clemens Brunner
+ * Denis A. Engemann
+ * Eric Larson
+ * Federico Raimondo
+ * Félix Raimundo
+ * Guillaume Dumas
+ * Jaakko Leppakangas
+ * Jair Montoya
+ * Jean-Remi King
+ * Johannes Niediek
+ * Jona Sassenhagen
+ * Jussi Nurminen
+ * Keith Doelling
+ * Mainak Jas
+ * Marijn van Vliet
+ * Michael Krause
+ * Mikolaj Magnuski
+ * Nick Foti
+ * Phillip Alday
+ * Simon-Shlomo Poil
+ * Teon Brooks
+ * Yaroslav Halchenko
+
.. _changes_0_12:
Version 0.12
@@ -127,7 +372,7 @@ BUG
- Fix bug in source normal adjustment that occurred when 1) patch information is available (e.g., when distances have been calculated) and 2) points are excluded from the source space (by inner skull distance) by `Eric Larson`_
- Fix bug when merging info that has a field with list of dicts by `Jaakko Leppakangas`_
-
+
- The BTI/4D reader now considers user defined channel labels instead of the hard-ware names, however only for channels other than MEG. By `Denis Engemann`_ and `Alex Gramfort`_.
- Fix bug in :func:`mne.compute_raw_covariance` where rejection by non-data channels (e.g. EOG) was not done properly by `Eric Larson`_.
@@ -136,6 +381,8 @@ BUG
- Fix bug in :func:`mne.io.Raw.save` where, in rare cases, automatically split files could end up writing an extra empty file that wouldn't be read properly by `Eric Larson`_
+ - Fix :class:`mne.realtime.StimServer` by removing superfluous argument ``ip`` used while initializing the object by `Mainak Jas`_.
+
API
~~~
@@ -848,7 +1095,7 @@ API
- Pick functions (e.g., ``pick_types``) are now in the mne namespace (e.g. use ``mne.pick_types``).
- - Deprecated ICA methods specific to one container type. Use ICA.fit, ICA.get_sources ICA.apply and ICA.plot_XXX for processing Raw, Epochs and Evoked objects.
+ - Deprecated ICA methods specific to one container type. Use ICA.fit, ICA.get_sources ICA.apply and ``ICA.plot_*`` for processing Raw, Epochs and Evoked objects.
- The default smoothing method for ``mne.stc_to_label`` will change in v0.9, and the old method is deprecated.
@@ -1547,3 +1794,23 @@ of commits):
.. _Natalie Klein: http://www.stat.cmu.edu/people/students/neklein
.. _Jon Houck: http://www.unm.edu/~jhouck/
+
+.. _Pablo-Arias: https://github.com/Pablo-Arias
+
+.. _Alexander Rudiuk: https://github.com/ARudiuk
+
+.. _Mikołaj Magnuski: https://github.com/mmagnuski
+
+.. _Felix Raimundo: https://github.com/gamazeps
+
+.. _Nick Foti: http://nfoti.github.io
+
+.. _Guillaume Dumas: http://www.extrospection.eu
+
+.. _Chris Mullins: http://crmullins.com
+
+.. _Phillip Alday: http://palday.bitbucket.org
+
+.. _Andreas Hojlund: https://github.com/ahoejlund
+
+.. _Johannes Niediek: https://github.com/jniediek
diff --git a/examples/connectivity/plot_mne_inverse_coherence_epochs.py b/examples/connectivity/plot_mne_inverse_coherence_epochs.py
index df4b3e1..cbf10a1 100644
--- a/examples/connectivity/plot_mne_inverse_coherence_epochs.py
+++ b/examples/connectivity/plot_mne_inverse_coherence_epochs.py
@@ -36,7 +36,7 @@ method = "dSPM" # use dSPM method (could also be MNE or sLORETA)
# Load data
inverse_operator = read_inverse_operator(fname_inv)
label_lh = mne.read_label(fname_label_lh)
-raw = mne.io.read_raw_fif(fname_raw)
+raw = mne.io.read_raw_fif(fname_raw, add_eeg_ref=False)
events = mne.read_events(fname_event)
# Add a bad channel
@@ -48,8 +48,8 @@ picks = mne.pick_types(raw.info, meg=True, eeg=False, stim=False, eog=True,
# Read epochs
epochs = mne.Epochs(raw, events, event_id, tmin, tmax, picks=picks,
- baseline=(None, 0), reject=dict(mag=4e-12, grad=4000e-13,
- eog=150e-6))
+ add_eeg_ref=False, baseline=(None, 0),
+ reject=dict(mag=4e-12, grad=4000e-13, eog=150e-6))
# First, we find the most active vertex in the left auditory cortex, which
# we will later use as seed for the connectivity computation
diff --git a/examples/decoding/plot_decoding_spatio_temporal_source.py b/examples/decoding/plot_decoding_spatio_temporal_source.py
index 631fe41..bb62fda 100644
--- a/examples/decoding/plot_decoding_spatio_temporal_source.py
+++ b/examples/decoding/plot_decoding_spatio_temporal_source.py
@@ -148,6 +148,5 @@ stc_feat = mne.SourceEstimate(feature_weights, vertices=vertices,
tmin=stc.tmin, tstep=stc.tstep,
subject='sample')
-brain = stc_feat.plot()
-brain.set_time(100)
-brain.show_view('l') # take the medial view to further explore visual areas
+brain = stc_feat.plot(hemi='split', views=['lat', 'med'], transparent=True,
+ initial_time=0.1, time_unit='s')
diff --git a/examples/decoding/plot_decoding_time_generalization_conditions.py b/examples/decoding/plot_decoding_time_generalization_conditions.py
index 7212b7f..77643c1 100644
--- a/examples/decoding/plot_decoding_time_generalization_conditions.py
+++ b/examples/decoding/plot_decoding_time_generalization_conditions.py
@@ -5,11 +5,6 @@ Decoding sensor space data with generalization across time and conditions
This example runs the analysis computed in:
-Jean-Remi King, Alexandre Gramfort, Aaron Schurger, Lionel Naccache
-and Stanislas Dehaene, "Two distinct dynamic modes subtend the detection of
-unexpected sounds", PLOS ONE, 2013,
-http://www.ncbi.nlm.nih.gov/pubmed/24475052
-
King & Dehaene (2014) 'Characterizing the dynamics of mental
representations: the temporal generalization method', Trends In Cognitive
Sciences, 18(4), 203-210.
@@ -71,5 +66,4 @@ gat.fit(epochs[('AudL', 'VisL')], y=viz_vs_auditory_l)
viz_vs_auditory_r = (triggers[np.in1d(triggers, (2, 4))] == 4).astype(int)
gat.score(epochs[('AudR', 'VisR')], y=viz_vs_auditory_r)
-gat.plot(
- title="Generalization Across Time (visual vs auditory): left to right")
+gat.plot(title="Temporal Generalization (visual vs auditory): left to right")
diff --git a/examples/decoding/plot_decoding_unsupervised_spatial_filter.py b/examples/decoding/plot_decoding_unsupervised_spatial_filter.py
new file mode 100644
index 0000000..3c9c71f
--- /dev/null
+++ b/examples/decoding/plot_decoding_unsupervised_spatial_filter.py
@@ -0,0 +1,67 @@
+"""
+==================================================================
+Analysis of evoked response using ICA and PCA reduction techniques
+==================================================================
+
+This example computes PCA and ICA of evoked or epochs data. Then the
+PCA / ICA components, a.k.a. spatial filters, are used to transform
+the channel data to new sources / virtual channels. The output is
+visualized on the average of all the epochs.
+"""
+# Authors: Jean-Remi King <jeanremi.king at gmail.com>
+# Asish Panda <asishrocks95 at gmail.com>
+#
+# License: BSD (3-clause)
+
+import numpy as np
+import matplotlib.pyplot as plt
+
+import mne
+from mne.datasets import sample
+from mne.decoding import UnsupervisedSpatialFilter
+
+from sklearn.decomposition import PCA, FastICA
+
+print(__doc__)
+
+# Preprocess data
+data_path = sample.data_path()
+
+# Load and filter data, set up epochs
+raw_fname = data_path + '/MEG/sample/sample_audvis_filt-0-40_raw.fif'
+event_fname = data_path + '/MEG/sample/sample_audvis_filt-0-40_raw-eve.fif'
+tmin, tmax = -0.1, 0.3
+event_id = dict(aud_l=1, aud_r=2, vis_l=3, vis_r=4)
+
+raw = mne.io.read_raw_fif(raw_fname, preload=True)
+raw.filter(1, 20)
+events = mne.read_events(event_fname)
+
+picks = mne.pick_types(raw.info, meg=False, eeg=True, stim=False, eog=False,
+ exclude='bads')
+
+epochs = mne.Epochs(raw, events, event_id, tmin, tmax, proj=False,
+ picks=picks, baseline=None, preload=True,
+ add_eeg_ref=False, verbose=False)
+
+X = epochs.get_data()
+
+##############################################################################
+# Transform data with PCA computed on the average ie evoked response
+pca = UnsupervisedSpatialFilter(PCA(30), average=False)
+pca_data = pca.fit_transform(X)
+ev = mne.EvokedArray(np.mean(pca_data, axis=0),
+ mne.create_info(30, epochs.info['sfreq'],
+ ch_types='eeg'), tmin=tmin)
+ev.plot(show=False, window_title="PCA")
+
+##############################################################################
+# Transform data with ICA computed on the raw epochs (no averaging)
+ica = UnsupervisedSpatialFilter(FastICA(30), average=False)
+ica_data = ica.fit_transform(X)
+ev1 = mne.EvokedArray(np.mean(ica_data, axis=0),
+ mne.create_info(30, epochs.info['sfreq'],
+ ch_types='eeg'), tmin=tmin)
+ev1.plot(show=False, window_title='ICA')
+
+plt.show()
diff --git a/examples/decoding/plot_decoding_xdawn_eeg.py b/examples/decoding/plot_decoding_xdawn_eeg.py
index 8ac6a6c..2cc864c 100644
--- a/examples/decoding/plot_decoding_xdawn_eeg.py
+++ b/examples/decoding/plot_decoding_xdawn_eeg.py
@@ -34,7 +34,7 @@ from sklearn.preprocessing import MinMaxScaler
from mne import io, pick_types, read_events, Epochs
from mne.datasets import sample
from mne.preprocessing import Xdawn
-from mne.decoding import EpochsVectorizer
+from mne.decoding import Vectorizer
from mne.viz import tight_layout
@@ -63,7 +63,7 @@ epochs = Epochs(raw, events, event_id, tmin, tmax, proj=False,
# Create classification pipeline
clf = make_pipeline(Xdawn(n_components=3),
- EpochsVectorizer(),
+ Vectorizer(),
MinMaxScaler(),
LogisticRegression(penalty='l1'))
diff --git a/examples/decoding/plot_ems_filtering.py b/examples/decoding/plot_ems_filtering.py
index 1022e4b..f449892 100644
--- a/examples/decoding/plot_ems_filtering.py
+++ b/examples/decoding/plot_ems_filtering.py
@@ -23,72 +23,110 @@ condition. Finally a topographic plot is created which exhibits the
temporal evolution of the spatial filters.
"""
# Author: Denis Engemann <denis.engemann at gmail.com>
+# Jean-Remi King <jeanremi.king at gmail.com>
#
# License: BSD (3-clause)
+import numpy as np
import matplotlib.pyplot as plt
import mne
-from mne import io
+from mne import io, EvokedArray
from mne.datasets import sample
-from mne.decoding import compute_ems
+from mne.decoding import EMS, compute_ems
+from sklearn.cross_validation import StratifiedKFold
print(__doc__)
data_path = sample.data_path()
-# Set parameters
+# Preprocess the data
raw_fname = data_path + '/MEG/sample/sample_audvis_filt-0-40_raw.fif'
event_fname = data_path + '/MEG/sample/sample_audvis_filt-0-40_raw-eve.fif'
-event_ids = {'AudL': 1, 'VisL': 3, 'AudR': 2, 'VisR': 4}
-tmin = -0.2
-tmax = 0.5
+event_ids = {'AudL': 1, 'VisL': 3}
# Read data and create epochs
raw = io.read_raw_fif(raw_fname, preload=True)
-raw.filter(1, 45)
+raw.filter(0.5, 45, l_trans_bandwidth='auto', h_trans_bandwidth='auto',
+ filter_length='auto', phase='zero')
events = mne.read_events(event_fname)
-include = [] # or stim channels ['STI 014']
-ch_type = 'grad'
-picks = mne.pick_types(raw.info, meg=ch_type, eeg=False, stim=False, eog=True,
- include=include, exclude='bads')
+picks = mne.pick_types(raw.info, meg='grad', eeg=False, stim=False, eog=True,
+ exclude='bads')
-reject = dict(grad=4000e-13, eog=150e-6)
+epochs = mne.Epochs(raw, events, event_ids, tmin=-0.2, tmax=0.5, picks=picks,
+ baseline=None, reject=dict(grad=4000e-13, eog=150e-6),
+ preload=True)
+epochs.drop_bad()
+epochs.pick_types(meg='grad')
-epochs = mne.Epochs(raw, events, event_ids, tmin, tmax, picks=picks,
- baseline=None, reject=reject)
+# Setup the data to use it a scikit-learn way:
+X = epochs.get_data() # The MEG data
+y = epochs.events[:, 2] # The conditions indices
+n_epochs, n_channels, n_times = X.shape
-# Let's equalize the trial counts in each condition
-epochs.equalize_event_counts(epochs.event_id, copy=False)
+#############################################################################
-# compute surrogate time series
-surrogates, filters, conditions = compute_ems(epochs, ['AudL', 'VisL'])
+# Initialize EMS transformer
+ems = EMS()
-times = epochs.times * 1e3
+# Initialize the variables of interest
+X_transform = np.zeros((n_epochs, n_times)) # Data after EMS transformation
+filters = list() # Spatial filters at each time point
+
+# In the original paper, the cross-validation is a leave-one-out. However,
+# we recommend using a Stratified KFold, because leave-one-out tends
+# to overfit and cannot be used to estimate the variance of the
+# prediction within a given fold.
+
+for train, test in StratifiedKFold(y):
+ # In the original paper, the z-scoring is applied outside the CV.
+ # However, we recommend to apply this preprocessing inside the CV.
+ # Note that such scaling should be done separately for each channels if the
+ # data contains multiple channel types.
+ X_scaled = X / np.std(X[train])
+
+ # Fit and store the spatial filters
+ ems.fit(X_scaled[train], y[train])
+
+ # Store filters for future plotting
+ filters.append(ems.filters_)
+
+ # Generate the transformed data
+ X_transform[test] = ems.transform(X_scaled[test])
+
+# Average the spatial filters across folds
+filters = np.mean(filters, axis=0)
+
+# Plot individual trials
plt.figure()
plt.title('single trial surrogates')
-plt.imshow(surrogates[conditions.argsort()], origin='lower', aspect='auto',
- extent=[times[0], times[-1], 1, len(surrogates)],
+plt.imshow(X_transform[y.argsort()], origin='lower', aspect='auto',
+ extent=[epochs.times[0], epochs.times[-1], 1, len(X_transform)],
cmap='RdBu_r')
plt.xlabel('Time (ms)')
plt.ylabel('Trials (reordered by condition)')
+# Plot average response
plt.figure()
plt.title('Average EMS signal')
-
-mappings = [(k, v) for k, v in event_ids.items() if v in conditions]
+mappings = [(key, value) for key, value in event_ids.items()]
for key, value in mappings:
- ems_ave = surrogates[conditions == value]
- ems_ave *= 1e13
- plt.plot(times, ems_ave.mean(0), label=key)
+ ems_ave = X_transform[y == value]
+ plt.plot(epochs.times, ems_ave.mean(0), label=key)
plt.xlabel('Time (ms)')
-plt.ylabel('fT/cm')
+plt.ylabel('a.u.')
plt.legend(loc='best')
-
-
-# visualize spatial filters across time
plt.show()
-evoked = epochs.average()
-evoked.data = filters
-evoked.plot_topomap(ch_type=ch_type)
+
+# Visualize spatial filters across time
+evoked = EvokedArray(filters, epochs.info, tmin=epochs.tmin)
+evoked.plot_topomap()
+
+#############################################################################
+# Note that a similar transformation can be applied with `compute_ems`
+# However, this function replicates Schurger et al's original paper, and thus
+# applies the normalization outside a leave-one-out cross-validation, which we
+# recommend not to do.
+epochs.equalize_event_counts(event_ids)
+X_transform, filters, classes = compute_ems(epochs)
diff --git a/examples/inverse/plot_dics_beamformer.py b/examples/inverse/plot_dics_beamformer.py
index 81dc049..c2631d7 100644
--- a/examples/inverse/plot_dics_beamformer.py
+++ b/examples/inverse/plot_dics_beamformer.py
@@ -21,7 +21,7 @@ import matplotlib.pyplot as plt
import numpy as np
from mne.datasets import sample
-from mne.time_frequency import compute_epochs_csd
+from mne.time_frequency import csd_epochs
from mne.beamformer import dics
print(__doc__)
@@ -57,10 +57,10 @@ forward = mne.read_forward_solution(fname_fwd, surf_ori=True)
# Computing the data and noise cross-spectral density matrices
# The time-frequency window was chosen on the basis of spectrograms from
# example time_frequency/plot_time_frequency.py
-data_csd = compute_epochs_csd(epochs, mode='multitaper', tmin=0.04, tmax=0.15,
- fmin=6, fmax=10)
-noise_csd = compute_epochs_csd(epochs, mode='multitaper', tmin=-0.11, tmax=0.0,
- fmin=6, fmax=10)
+data_csd = csd_epochs(epochs, mode='multitaper', tmin=0.04, tmax=0.15,
+ fmin=6, fmax=10)
+noise_csd = csd_epochs(epochs, mode='multitaper', tmin=-0.11, tmax=0.0,
+ fmin=6, fmax=10)
evoked = epochs.average()
@@ -77,8 +77,8 @@ plt.title('DICS time course of the 30 largest sources.')
plt.show()
# Plot brain in 3D with PySurfer if available
-brain = stc.plot(hemi='rh', subjects_dir=subjects_dir)
-brain.set_data_time_index(180)
+brain = stc.plot(hemi='rh', subjects_dir=subjects_dir,
+ initial_time=0.1, time_unit='s')
brain.show_view('lateral')
# Uncomment to save image
diff --git a/examples/inverse/plot_dics_source_power.py b/examples/inverse/plot_dics_source_power.py
index 12f4a61..1a9ab0d 100644
--- a/examples/inverse/plot_dics_source_power.py
+++ b/examples/inverse/plot_dics_source_power.py
@@ -17,7 +17,7 @@ in the human brain. PNAS (2001) vol. 98 (2) pp. 694-699
import mne
from mne.datasets import sample
-from mne.time_frequency import compute_epochs_csd
+from mne.time_frequency import csd_epochs
from mne.beamformer import dics_source_power
print(__doc__)
@@ -51,12 +51,12 @@ forward = mne.read_forward_solution(fname_fwd, surf_ori=True)
# Computing the data and noise cross-spectral density matrices
# The time-frequency window was chosen on the basis of spectrograms from
# example time_frequency/plot_time_frequency.py
-# As fsum is False compute_epochs_csd returns a list of CrossSpectralDensity
+# As fsum is False csd_epochs returns a list of CrossSpectralDensity
# instances than can then be passed to dics_source_power
-data_csds = compute_epochs_csd(epochs, mode='multitaper', tmin=0.04, tmax=0.15,
- fmin=15, fmax=30, fsum=False)
-noise_csds = compute_epochs_csd(epochs, mode='multitaper', tmin=-0.11,
- tmax=-0.001, fmin=15, fmax=30, fsum=False)
+data_csds = csd_epochs(epochs, mode='multitaper', tmin=0.04, tmax=0.15,
+ fmin=15, fmax=30, fsum=False)
+noise_csds = csd_epochs(epochs, mode='multitaper', tmin=-0.11,
+ tmax=-0.001, fmin=15, fmax=30, fsum=False)
# Compute DICS spatial filter and estimate source power
stc = dics_source_power(epochs.info, forward, noise_csds, data_csds)
diff --git a/examples/inverse/plot_lcmv_beamformer.py b/examples/inverse/plot_lcmv_beamformer.py
index 5caf21d..42af4e1 100644
--- a/examples/inverse/plot_lcmv_beamformer.py
+++ b/examples/inverse/plot_lcmv_beamformer.py
@@ -87,6 +87,6 @@ plt.legend()
plt.show()
# Plot last stc in the brain in 3D with PySurfer if available
-brain = stc.plot(hemi='lh', subjects_dir=subjects_dir)
-brain.set_data_time_index(180)
+brain = stc.plot(hemi='lh', subjects_dir=subjects_dir,
+ initial_time=0.1, time_unit='s')
brain.show_view('lateral')
diff --git a/examples/inverse/plot_morph_data.py b/examples/inverse/plot_morph_data.py
index 7d9d162..c91b361 100644
--- a/examples/inverse/plot_morph_data.py
+++ b/examples/inverse/plot_morph_data.py
@@ -28,7 +28,6 @@ subject_to = 'fsaverage'
subjects_dir = data_path + '/subjects'
fname = data_path + '/MEG/sample/sample_audvis-meg'
-src_fname = data_path + '/MEG/sample/sample_audvis-meg-oct-6-fwd.fif'
# Read input stc file
stc_from = mne.read_source_estimate(fname)
diff --git a/examples/inverse/plot_tf_dics.py b/examples/inverse/plot_tf_dics.py
index 7d7191b..7300b3d 100644
--- a/examples/inverse/plot_tf_dics.py
+++ b/examples/inverse/plot_tf_dics.py
@@ -17,7 +17,7 @@ dynamics of cortical activity. NeuroImage (2008) vol. 40 (4) pp. 1686-1700
import mne
from mne.event import make_fixed_length_events
from mne.datasets import sample
-from mne.time_frequency import compute_epochs_csd
+from mne.time_frequency import csd_epochs
from mne.beamformer import tf_dics
from mne.viz import plot_source_spectrogram
@@ -104,10 +104,10 @@ subtract_evoked = False
# from the baseline period in the data, change epochs_noise to epochs
noise_csds = []
for freq_bin, win_length, n_fft in zip(freq_bins, win_lengths, n_ffts):
- noise_csd = compute_epochs_csd(epochs_noise, mode='fourier',
- fmin=freq_bin[0], fmax=freq_bin[1],
- fsum=True, tmin=-win_length, tmax=0,
- n_fft=n_fft)
+ noise_csd = csd_epochs(epochs_noise, mode='fourier',
+ fmin=freq_bin[0], fmax=freq_bin[1],
+ fsum=True, tmin=-win_length, tmax=0,
+ n_fft=n_fft)
noise_csds.append(noise_csd)
# Computing DICS solutions for time-frequency windows in a label in source
diff --git a/examples/inverse/plot_time_frequency_mixed_norm_inverse.py b/examples/inverse/plot_time_frequency_mixed_norm_inverse.py
index 960891c..cc32bb3 100644
--- a/examples/inverse/plot_time_frequency_mixed_norm_inverse.py
+++ b/examples/inverse/plot_time_frequency_mixed_norm_inverse.py
@@ -112,9 +112,8 @@ plot_sparse_source_estimates(forward['src'], stc, bgcolor=(1, 1, 1),
time_label = 'TF-MxNE time=%0.2f ms'
clim = dict(kind='value', lims=[10e-9, 15e-9, 20e-9])
-brain = stc.plot('sample', 'inflated', 'rh', clim=clim, time_label=time_label,
- smoothing_steps=5, subjects_dir=subjects_dir)
-brain.show_view('medial')
-brain.set_data_time_index(120)
+brain = stc.plot('sample', 'inflated', 'rh', views='medial',
+ clim=clim, time_label=time_label, smoothing_steps=5,
+ subjects_dir=subjects_dir, initial_time=150, time_unit='ms')
brain.add_label("V1", color="yellow", scalar_thresh=.5, borders=True)
brain.add_label("V2", color="red", scalar_thresh=.5, borders=True)
diff --git a/examples/io/plot_elekta_epochs.py b/examples/io/plot_elekta_epochs.py
new file mode 100644
index 0000000..de0e0f3
--- /dev/null
+++ b/examples/io/plot_elekta_epochs.py
@@ -0,0 +1,68 @@
+"""
+======================================
+Getting averaging info from .fif files
+======================================
+
+Parse averaging information defined in Elekta Vectorview/TRIUX DACQ (data
+acquisition). Extract and average epochs accordingly. Modify some
+averaging parameters and get epochs.
+"""
+# Author: Jussi Nurminen (jnu at iki.fi)
+#
+# License: BSD (3-clause)
+
+
+import mne
+import os
+from mne.datasets import multimodal
+from mne import AcqParserFIF
+
+fname_raw = os.path.join(multimodal.data_path(), 'multimodal_raw.fif')
+
+
+print(__doc__)
+
+###############################################################################
+# Read raw file and create parser instance
+raw = mne.io.read_raw_fif(fname_raw)
+ap = AcqParserFIF(raw.info)
+
+###############################################################################
+# Check DACQ defined averaging categories and other info
+print(ap)
+
+###############################################################################
+# Extract epochs corresponding to a category
+cond = ap.get_condition(raw, 'Auditory right')
+epochs = mne.Epochs(raw, **cond)
+epochs.average().plot_topo()
+
+###############################################################################
+# Get epochs from all conditions, average
+evokeds = []
+for cat in ap.categories:
+ cond = ap.get_condition(raw, cat)
+ # copy (supported) rejection parameters from DACQ settings
+ epochs = mne.Epochs(raw, reject=ap.reject, flat=ap.flat, **cond)
+ evoked = epochs.average()
+ evoked.comment = cat['comment']
+ evokeds.append(evoked)
+# save all averages to an evoked fiff file
+# fname_out = 'multimodal-ave.fif'
+# mne.write_evokeds(fname_out, evokeds)
+
+###############################################################################
+# Make a new averaging category
+newcat = dict()
+newcat['comment'] = 'Visual lower left, longer epochs'
+newcat['event'] = 3 # reference event
+newcat['start'] = -.2 # epoch start rel. to ref. event (in seconds)
+newcat['end'] = .7 # epoch end
+newcat['reqevent'] = 0 # additional required event; 0 if none
+newcat['reqwithin'] = .5 # ...required within .5 sec (before or after)
+newcat['reqwhen'] = 2 # ...required before (1) or after (2) ref. event
+newcat['index'] = 9 # can be set freely
+
+cond = ap.get_condition(raw, newcat)
+epochs = mne.Epochs(raw, reject=ap.reject, flat=ap.flat, **cond)
+epochs.average().plot()
diff --git a/examples/preprocessing/plot_run_ica.py b/examples/preprocessing/plot_run_ica.py
index b9114ca..9b49188 100644
--- a/examples/preprocessing/plot_run_ica.py
+++ b/examples/preprocessing/plot_run_ica.py
@@ -7,8 +7,6 @@ Compute ICA components on epochs
ICA is fit to MEG raw data.
We assume that the non-stationary EOG artifacts have already been removed.
The sources matching the ECG are automatically found and displayed.
-Subsequently, artefact detection and rejection quality are assessed.
-Finally, the impact on the evoked ERF is visualized.
Note that this example does quite a bit of processing, so even on a
fast machine it can take about a minute to complete.
@@ -25,22 +23,36 @@ from mne.datasets import sample
print(__doc__)
###############################################################################
-# Fit ICA model using the FastICA algorithm, detect and inspect components
+# Read and preprocess the data. Preprocessing consists of:
+#
+# - meg channel selection
+#
+# - 1 - 30 Hz band-pass IIR filter
+#
+# - epoching -0.2 to 0.5 seconds with respect to events
data_path = sample.data_path()
raw_fname = data_path + '/MEG/sample/sample_audvis_filt-0-40_raw.fif'
raw = mne.io.read_raw_fif(raw_fname, preload=True)
-raw.filter(1, 30, method='iir')
raw.pick_types(meg=True, eeg=False, exclude='bads', stim=True)
+raw.filter(1, 30, method='iir')
# longer + more epochs for more artifact exposure
events = mne.find_events(raw, stim_channel='STI 014')
epochs = mne.Epochs(raw, events, event_id=None, tmin=-0.2, tmax=0.5)
+###############################################################################
+# Fit ICA model using the FastICA algorithm, detect and plot components
+# explaining ECG artifacts.
+
ica = ICA(n_components=0.95, method='fastica').fit(epochs)
ecg_epochs = create_ecg_epochs(raw, tmin=-.5, tmax=.5)
ecg_inds, scores = ica.find_bads_ecg(ecg_epochs)
ica.plot_components(ecg_inds)
+
+###############################################################################
+# Plot properties of ECG components:
+ica.plot_properties(epochs, picks=ecg_inds)
diff --git a/examples/preprocessing/plot_xdawn_denoising.py b/examples/preprocessing/plot_xdawn_denoising.py
index 03f6f4b..7c15f5d 100644
--- a/examples/preprocessing/plot_xdawn_denoising.py
+++ b/examples/preprocessing/plot_xdawn_denoising.py
@@ -31,8 +31,7 @@ efficient sensor selection in a P300 BCI. In Signal Processing Conference,
# License: BSD (3-clause)
-from mne import (io, compute_raw_covariance, read_events, pick_types,
- Epochs)
+from mne import (io, compute_raw_covariance, read_events, pick_types, Epochs)
from mne.datasets import sample
from mne.preprocessing import Xdawn
from mne.viz import plot_epochs_image
@@ -76,5 +75,5 @@ xd.fit(epochs)
# Denoise epochs
epochs_denoised = xd.apply(epochs)
-# Plot image epoch after xdawn
+# Plot image epoch after Xdawn
plot_epochs_image(epochs_denoised['vis_r'], picks=[230], vmin=-500, vmax=500)
diff --git a/examples/realtime/ftclient_rt_average.py b/examples/realtime/ftclient_rt_average.py
index 8f0985a..486ae12 100644
--- a/examples/realtime/ftclient_rt_average.py
+++ b/examples/realtime/ftclient_rt_average.py
@@ -70,6 +70,7 @@ with FieldTripClient(host='localhost', port=1972,
for ii, ev in enumerate(rt_epochs.iter_evoked()):
print("Just got epoch %d" % (ii + 1))
+ ev.pick_types(meg=True, eog=False)
if ii == 0:
evoked = ev
else:
diff --git a/examples/realtime/plot_compute_rt_decoder.py b/examples/realtime/plot_compute_rt_decoder.py
index 9069891..1c16308 100644
--- a/examples/realtime/plot_compute_rt_decoder.py
+++ b/examples/realtime/plot_compute_rt_decoder.py
@@ -55,14 +55,14 @@ from sklearn import preprocessing # noqa
from sklearn.svm import SVC # noqa
from sklearn.pipeline import Pipeline # noqa
from sklearn.cross_validation import cross_val_score, ShuffleSplit # noqa
-from mne.decoding import EpochsVectorizer, FilterEstimator # noqa
+from mne.decoding import Vectorizer, FilterEstimator # noqa
scores_x, scores, std_scores = [], [], []
filt = FilterEstimator(rt_epochs.info, 1, 40)
scaler = preprocessing.StandardScaler()
-vectorizer = EpochsVectorizer()
+vectorizer = Vectorizer()
clf = SVC(C=1, kernel='linear')
concat_classifier = Pipeline([('filter', filt), ('vector', vectorizer),
diff --git a/examples/realtime/rt_feedback_server.py b/examples/realtime/rt_feedback_server.py
index c5febc2..296a664 100644
--- a/examples/realtime/rt_feedback_server.py
+++ b/examples/realtime/rt_feedback_server.py
@@ -43,7 +43,7 @@ import mne
from mne.datasets import sample
from mne.realtime import StimServer
from mne.realtime import MockRtClient
-from mne.decoding import EpochsVectorizer, FilterEstimator
+from mne.decoding import Vectorizer, FilterEstimator
print(__doc__)
@@ -55,7 +55,7 @@ raw = mne.io.read_raw_fif(raw_fname, preload=True)
# Instantiating stimulation server
# The with statement is necessary to ensure a clean exit
-with StimServer('localhost', port=4218) as stim_server:
+with StimServer(port=4218) as stim_server:
# The channels to be used while decoding
picks = mne.pick_types(raw.info, meg='grad', eeg=False, eog=True,
@@ -66,7 +66,7 @@ with StimServer('localhost', port=4218) as stim_server:
# Constructing the pipeline for classification
filt = FilterEstimator(raw.info, 1, 40)
scaler = preprocessing.StandardScaler()
- vectorizer = EpochsVectorizer()
+ vectorizer = Vectorizer()
clf = SVC(C=1, kernel='linear')
concat_classifier = Pipeline([('filter', filt), ('vector', vectorizer),
diff --git a/examples/time_frequency/README.txt b/examples/time_frequency/README.txt
index 16050ac..c6b0ab6 100644
--- a/examples/time_frequency/README.txt
+++ b/examples/time_frequency/README.txt
@@ -2,5 +2,4 @@
Time-Frequency Examples
-----------------------
-Some examples of how to explore time frequency content of M/EEG data with MNE.
-
+Some examples of how to explore time-frequency content of M/EEG data with MNE.
diff --git a/examples/time_frequency/plot_compute_raw_data_spectrum.py b/examples/time_frequency/plot_compute_raw_data_spectrum.py
index 6a158d9..80f6656 100644
--- a/examples/time_frequency/plot_compute_raw_data_spectrum.py
+++ b/examples/time_frequency/plot_compute_raw_data_spectrum.py
@@ -75,7 +75,7 @@ raw.plot_psd(tmin=tmin, tmax=tmax, fmin=fmin, fmax=fmax, n_fft=n_fft,
ax.set_title('Four left-temporal magnetometers')
plt.legend(['Without SSP', 'With SSP', 'SSP + Notch'])
-# Alternatively, you may also create PSDs from Raw objects with psd_XXX
+# Alternatively, you may also create PSDs from Raw objects with ``psd_*``
f, ax = plt.subplots()
psds, freqs = psd_multitaper(raw, low_bias=True, tmin=tmin, tmax=tmax,
fmin=fmin, fmax=fmax, proj=True, picks=picks,
diff --git a/examples/time_frequency/plot_source_label_time_frequency.py b/examples/time_frequency/plot_source_label_time_frequency.py
index 1b47722..e3b203d 100644
--- a/examples/time_frequency/plot_source_label_time_frequency.py
+++ b/examples/time_frequency/plot_source_label_time_frequency.py
@@ -67,13 +67,13 @@ plt.close('all')
for ii, (this_epochs, title) in enumerate(zip([epochs, epochs_induced],
['evoked + induced',
'induced only'])):
- # compute the source space power and phase lock
- power, phase_lock = source_induced_power(
+ # compute the source space power and the inter-trial coherence
+ power, itc = source_induced_power(
this_epochs, inverse_operator, frequencies, label, baseline=(-0.1, 0),
baseline_mode='percent', n_cycles=n_cycles, n_jobs=1)
power = np.mean(power, axis=0) # average over sources
- phase_lock = np.mean(phase_lock, axis=0) # average over sources
+ itc = np.mean(itc, axis=0) # average over sources
times = epochs.times
##########################################################################
@@ -89,13 +89,13 @@ for ii, (this_epochs, title) in enumerate(zip([epochs, epochs_induced],
plt.colorbar()
plt.subplot(2, 2, 2 * ii + 2)
- plt.imshow(phase_lock,
+ plt.imshow(itc,
extent=[times[0], times[-1], frequencies[0], frequencies[-1]],
aspect='auto', origin='lower', vmin=0, vmax=0.7,
cmap='RdBu_r')
plt.xlabel('Time (s)')
plt.ylabel('Frequency (Hz)')
- plt.title('Phase-lock (%s)' % title)
+ plt.title('ITC (%s)' % title)
plt.colorbar()
plt.show()
diff --git a/make/install_python.ps1 b/make/install_python.ps1
deleted file mode 100644
index 23b996f..0000000
--- a/make/install_python.ps1
+++ /dev/null
@@ -1,93 +0,0 @@
-# Sample script to install Python and pip under Windows
-# Authors: Olivier Grisel, Jonathan Helmus and Kyle Kastner
-# License: CC0 1.0 Universal: http://creativecommons.org/publicdomain/zero/1.0/
-
-$MINICONDA_URL = "http://repo.continuum.io/miniconda/"
-$BASE_URL = "https://www.python.org/ftp/python/"
-
-
-function DownloadMiniconda ($python_version, $platform_suffix) {
- $webclient = New-Object System.Net.WebClient
- if ($python_version -eq "3.4") {
- $filename = "Miniconda3-latest-Windows-" + $platform_suffix + ".exe"
- } else {
- $filename = "Miniconda-latest-Windows-" + $platform_suffix + ".exe"
- }
- $url = $MINICONDA_URL + $filename
-
- $basedir = $pwd.Path + "\"
- $filepath = $basedir + $filename
- if (Test-Path $filename) {
- Write-Host "Reusing" $filepath
- return $filepath
- }
-
- # Download and retry up to 3 times in case of network transient errors.
- Write-Host "Downloading" $filename "from" $url
- $retry_attempts = 2
- for($i=0; $i -lt $retry_attempts; $i++){
- try {
- $webclient.DownloadFile($url, $filepath)
- break
- }
- Catch [Exception]{
- Start-Sleep 1
- }
- }
- if (Test-Path $filepath) {
- Write-Host "File saved at" $filepath
- } else {
- # Retry once to get the error message if any at the last try
- $webclient.DownloadFile($url, $filepath)
- }
- return $filepath
-}
-
-
-function InstallMiniconda ($python_version, $architecture, $python_home) {
- Write-Host "Installing Python" $python_version "for" $architecture "bit architecture to" $python_home
- if (Test-Path $python_home) {
- Write-Host $python_home "already exists, skipping."
- return $false
- }
- if ($architecture -eq "32") {
- $platform_suffix = "x86"
- } else {
- $platform_suffix = "x86_64"
- }
- $filepath = DownloadMiniconda $python_version $platform_suffix
- Write-Host "Installing" $filepath "to" $python_home
- $install_log = $python_home + ".log"
- $args = "/S /D=$python_home"
- Write-Host $filepath $args
- Start-Process -FilePath $filepath -ArgumentList $args -Wait -Passthru
- if (Test-Path $python_home) {
- Write-Host "Python $python_version ($architecture) installation complete"
- } else {
- Write-Host "Failed to install Python in $python_home"
- Get-Content -Path $install_log
- Exit 1
- }
-}
-
-
-function InstallMinicondaPip ($python_home) {
- $pip_path = $python_home + "\Scripts\pip.exe"
- $conda_path = $python_home + "\Scripts\conda.exe"
- if (-not(Test-Path $pip_path)) {
- Write-Host "Installing pip..."
- $args = "install --yes pip"
- Write-Host $conda_path $args
- Start-Process -FilePath "$conda_path" -ArgumentList $args -Wait -Passthru
- } else {
- Write-Host "pip already installed."
- }
-}
-
-
-function main () {
- InstallMiniconda $env:PYTHON_VERSION $env:PYTHON_ARCH $env:PYTHON
- InstallMinicondaPip $env:PYTHON
-}
-
-main
diff --git a/mne/__init__.py b/mne/__init__.py
index aba4d28..e58f017 100644
--- a/mne/__init__.py
+++ b/mne/__init__.py
@@ -17,7 +17,7 @@
# Dev branch marker is: 'X.Y.devN' where N is an integer.
#
-__version__ = '0.12.0'
+__version__ = '0.13'
# have to import verbose first since it's needed by many things
from .utils import (set_log_level, set_log_file, verbose, set_config,
@@ -28,11 +28,12 @@ from .io.pick import (pick_types, pick_channels,
pick_types_forward, pick_channels_cov,
pick_channels_evoked, pick_info)
from .io.base import concatenate_raws
-from .chpi import get_chpi_positions
from .io.meas_info import create_info, Info
from .io.proj import Projection
from .io.kit import read_epochs_kit
from .io.eeglab import read_epochs_eeglab
+from .io.reference import (set_eeg_reference, set_bipolar_reference,
+ add_reference_channels)
from .bem import (make_sphere_model, make_bem_model, make_bem_solution,
read_bem_surfaces, write_bem_surfaces,
read_bem_solution, write_bem_solution)
@@ -40,9 +41,9 @@ from .cov import (read_cov, write_cov, Covariance, compute_raw_covariance,
compute_covariance, whiten_evoked, make_ad_hoc_cov)
from .event import (read_events, write_events, find_events, merge_events,
pick_events, make_fixed_length_events, concatenate_events,
- find_stim_steps)
+ find_stim_steps, AcqParserFIF)
from .forward import (read_forward_solution, apply_forward, apply_forward_raw,
- do_forward_solution, average_forward_solutions,
+ average_forward_solutions, Forward,
write_forward_solution, make_forward_solution,
convert_forward_solution, make_field_map,
make_forward_dipole)
@@ -58,7 +59,7 @@ from .source_estimate import (read_source_estimate, MixedSourceEstimate,
spatio_temporal_tris_connectivity,
spatio_temporal_dist_connectivity,
save_stc_as_volume, extract_label_time_course)
-from .surface import (read_surface, write_surface, decimate_surface,
+from .surface import (read_surface, write_surface, decimate_surface, read_tri,
read_morph_map, get_head_surf, get_meg_helmet_surf)
from .source_space import (read_source_spaces, vertex_to_mni,
write_source_spaces, setup_source_space,
diff --git a/mne/annotations.py b/mne/annotations.py
index 33df574..4f4a94b 100644
--- a/mne/annotations.py
+++ b/mne/annotations.py
@@ -13,15 +13,32 @@ from .externals.six import string_types
class Annotations(object):
"""Annotation object for annotating segments of raw data.
+ Annotations are added to instance of :class:`mne.io.Raw` as an attribute
+ named ``annotations``. See the example below. To reject bad epochs using
+ annotations, use annotation description starting with 'bad' keyword. The
+ epochs with overlapping bad segments are then rejected automatically by
+ default.
+
+ To remove epochs with blinks you can do::
+ >>> eog_events = mne.preprocessing.find_eog_events(raw) # doctest: +SKIP
+ >>> n_blinks = len(eog_events) # doctest: +SKIP
+ >>> onset = eog_events[:, 0] / raw.info['sfreq'] - 0.25 # doctest: +SKIP
+ >>> duration = np.repeat(0.5, n_blinks) # doctest: +SKIP
+ >>> description = ['bad blink'] * n_blinks # doctest: +SKIP
+ >>> annotations = mne.Annotations(onset, duration, description) # doctest: +SKIP
+ >>> raw.annotations = annotations # doctest: +SKIP
+ >>> epochs = mne.Epochs(raw, events, event_id, tmin, tmax) # doctest: +SKIP
+
Parameters
----------
onset : array of float, shape (n_annotations,)
- Annotation time onsets from the beginning of the recording.
+ Annotation time onsets from the beginning of the recording in seconds.
duration : array of float, shape (n_annotations,)
- Durations of the annotations.
+ Durations of the annotations in seconds.
description : array of str, shape (n_annotations,) | str
Array of strings containing description for each annotation. If a
- string, all the annotations are given the same description.
+ string, all the annotations are given the same description. To reject
+ epochs, use description starting with keyword 'bad'. See example above.
orig_time : float | int | instance of datetime | array of int | None
A POSIX Timestamp, datetime or an array containing the timestamp as the
first element and microseconds as the second element. Determines the
@@ -33,10 +50,10 @@ class Annotations(object):
Notes
-----
- Annotations are synced to sample 0. ``raw.first_samp`` is taken
- into account in the same way as with events.
- """
-
+ If ``orig_time`` is None, the annotations are synced to the start of the
+ data (0 seconds). Otherwise the annotations are synced to sample 0 and
+ ``raw.first_samp`` is taken into account the same way as with events.
+ """ # noqa
def __init__(self, onset, duration, description, orig_time=None):
if orig_time is not None:
@@ -48,10 +65,10 @@ class Annotations(object):
orig_time = float(orig_time) # np.int not serializable
self.orig_time = orig_time
- onset = np.array(onset)
+ onset = np.array(onset, dtype=float)
if onset.ndim != 1:
raise ValueError('Onset must be a one dimensional array.')
- duration = np.array(duration)
+ duration = np.array(duration, dtype=float)
if isinstance(description, string_types):
description = np.repeat(description, len(onset))
if duration.ndim != 1:
@@ -84,11 +101,10 @@ def _combine_annotations(annotations, last_samps, first_samps, sfreq):
old_description = annotations[0].description
old_orig_time = annotations[0].orig_time
- if annotations[1].orig_time is None:
- onset = (annotations[1].onset +
- (sum(last_samps[:-1]) - sum(first_samps[:-1])) / sfreq)
- else:
- onset = annotations[1].onset
+ extra_samps = len(first_samps) - 1 # Account for sample 0
+ onset = (annotations[1].onset + (sum(last_samps[:-1]) + extra_samps -
+ sum(first_samps[:-1])) / sfreq)
+
onset = np.concatenate([old_onset, onset])
duration = np.concatenate([old_duration, annotations[1].duration])
description = np.concatenate([old_description, annotations[1].description])
@@ -102,12 +118,15 @@ def _onset_to_seconds(raw, onset):
if meas_date is None:
meas_date = 0
elif not np.isscalar(meas_date):
- meas_date = meas_date[0] + meas_date[1] / 1000000.
+ if len(meas_date) > 1:
+ meas_date = meas_date[0] + meas_date[1] / 1000000.
+ else:
+ meas_date = meas_date[0]
if raw.annotations.orig_time is None:
orig_time = meas_date
else:
- orig_time = raw.annotations.orig_time
+ orig_time = (raw.annotations.orig_time -
+ raw.first_samp / raw.info['sfreq'])
- annot_start = (orig_time - meas_date + onset -
- raw.first_samp / raw.info['sfreq'])
+ annot_start = orig_time - meas_date + onset
return annot_start
diff --git a/mne/baseline.py b/mne/baseline.py
index 320b822..aa2c4f9 100644
--- a/mne/baseline.py
+++ b/mne/baseline.py
@@ -36,19 +36,22 @@ def rescale(data, times, baseline, mode='mean', copy=True, verbose=None):
Time instants is seconds.
baseline : tuple or list of length 2, or None
The time interval to apply rescaling / baseline correction.
- If None do not apply it. If baseline is (a, b)
- the interval is between "a (s)" and "b (s)".
- If a is None the beginning of the data is used
- and if b is None then b is set to the end of the interval.
- If baseline is equal ot (None, None) all the time
- interval is used. If None, no correction is applied.
- mode : 'logratio' | 'ratio' | 'zscore' | 'mean' | 'percent' | 'zlogratio'
+ If None do not apply it. If baseline is ``(bmin, bmax)``
+ the interval is between ``bmin`` (s) and ``bmax`` (s).
+ If ``bmin is None`` the beginning of the data is used
+ and if ``bmax is None`` then ``bmax`` is set to the end of the
+ interval. If baseline is ``(None, None)`` the entire time
+ interval is used. If baseline is None, no correction is applied.
+ mode : None | 'ratio' | 'zscore' | 'mean' | 'percent' | 'logratio' | 'zlogratio' # noqa
Do baseline correction with ratio (power is divided by mean
power during baseline) or zscore (power is divided by standard
deviation of power during baseline after subtracting the mean,
- power = [power - mean(power_baseline)] / std(power_baseline)).
- logratio is the same an mean but in log-scale, zlogratio is the
- same as zscore but data is rendered in log-scale first.
+ power = [power - mean(power_baseline)] / std(power_baseline)), mean
+ simply subtracts the mean power, percent is the same as applying ratio
+ then mean, logratio is the same as mean but then rendered in log-scale,
+ zlogratio is the same as zscore but data is rendered in log-scale
+ first.
+ If None no baseline correction is applied.
copy : bool
Whether to return a new instance or modify in place.
verbose : bool, str, int, or None
@@ -68,11 +71,22 @@ def rescale(data, times, baseline, mode='mean', copy=True, verbose=None):
if bmin is None:
imin = 0
else:
- imin = int(np.where(times >= bmin)[0][0])
+ imin = np.where(times >= bmin)[0]
+ if len(imin) == 0:
+ raise ValueError('bmin is too large (%s), it exceeds the largest '
+ 'time value' % (bmin,))
+ imin = int(imin[0])
if bmax is None:
imax = len(times)
else:
- imax = int(np.where(times <= bmax)[0][-1]) + 1
+ imax = np.where(times <= bmax)[0]
+ if len(imax) == 0:
+ raise ValueError('bmax is too small (%s), it is smaller than the '
+ 'smallest time value' % (bmax,))
+ imax = int(imax[-1]) + 1
+ if imin >= imax:
+ raise ValueError('Bad rescaling slice (%s:%s) from time values %s, %s'
+ % (imin, imax, bmin, bmax))
# avoid potential "empty slice" warning
if data.shape[-1] > 0:
diff --git a/mne/beamformer/_dics.py b/mne/beamformer/_dics.py
index a7299d3..8532fc2 100644
--- a/mne/beamformer/_dics.py
+++ b/mne/beamformer/_dics.py
@@ -14,7 +14,7 @@ from ..utils import logger, verbose, warn
from ..forward import _subject_from_forward
from ..minimum_norm.inverse import combine_xyz, _check_reference
from ..source_estimate import _make_stc
-from ..time_frequency import CrossSpectralDensity, compute_epochs_csd
+from ..time_frequency import CrossSpectralDensity, csd_epochs
from ._lcmv import _prepare_beamformer_input, _setup_picks
from ..externals import six
@@ -564,13 +564,13 @@ def tf_dics(epochs, forward, noise_csds, tmin, tmax, tstep, win_lengths,
win_tmax = win_tmax + 1e-10
# Calculating data CSD in current time window
- data_csd = compute_epochs_csd(epochs, mode=mode,
- fmin=freq_bin[0],
- fmax=freq_bin[1], fsum=True,
- tmin=win_tmin, tmax=win_tmax,
- n_fft=n_fft,
- mt_bandwidth=mt_bandwidth,
- mt_low_bias=mt_low_bias)
+ data_csd = csd_epochs(epochs, mode=mode,
+ fmin=freq_bin[0],
+ fmax=freq_bin[1], fsum=True,
+ tmin=win_tmin, tmax=win_tmax,
+ n_fft=n_fft,
+ mt_bandwidth=mt_bandwidth,
+ mt_low_bias=mt_low_bias)
# Scale data CSD to allow data and noise CSDs to have different
# length
diff --git a/mne/beamformer/_lcmv.py b/mne/beamformer/_lcmv.py
index 75a9b02..9a15809 100644
--- a/mne/beamformer/_lcmv.py
+++ b/mne/beamformer/_lcmv.py
@@ -751,7 +751,7 @@ def tf_lcmv(epochs, forward, noise_covs, tmin, tmax, tstep, win_lengths,
raw_band = raw.copy()
raw_band.filter(l_freq, h_freq, picks=raw_picks, method='iir',
- n_jobs=n_jobs)
+ n_jobs=n_jobs, iir_params=dict(output='ba'))
raw_band.info['highpass'] = l_freq
raw_band.info['lowpass'] = h_freq
epochs_band = Epochs(raw_band, epochs.events, epochs.event_id,
diff --git a/mne/beamformer/tests/test_dics.py b/mne/beamformer/tests/test_dics.py
index 81faec3..ccdd50c 100644
--- a/mne/beamformer/tests/test_dics.py
+++ b/mne/beamformer/tests/test_dics.py
@@ -10,9 +10,9 @@ from numpy.testing import assert_array_equal, assert_array_almost_equal
import mne
from mne.datasets import testing
from mne.beamformer import dics, dics_epochs, dics_source_power, tf_dics
-from mne.time_frequency import compute_epochs_csd
+from mne.time_frequency import csd_epochs
from mne.externals.six import advance_iterator
-from mne.utils import run_tests_if_main, clean_warning_registry
+from mne.utils import run_tests_if_main
# Note that this is the first test file, this will apply to all subsequent
# tests in a full nosetest:
@@ -29,9 +29,6 @@ fname_event = op.join(data_path, 'MEG', 'sample',
label = 'Aud-lh'
fname_label = op.join(data_path, 'MEG', 'sample', 'labels', '%s.label' % label)
-# bit of a hack to deal with old scipy/numpy throwing warnings in tests
-clean_warning_registry()
-
def read_forward_solution_meg(*args, **kwargs):
fwd = mne.read_forward_solution(*args, **kwargs)
@@ -43,7 +40,7 @@ def _get_data(tmin=-0.11, tmax=0.15, read_all_forward=True, compute_csds=True):
"""
label = mne.read_label(fname_label)
events = mne.read_events(fname_event)[:10]
- raw = mne.io.read_raw_fif(fname_raw, preload=False)
+ raw = mne.io.read_raw_fif(fname_raw, preload=False, add_eeg_ref=False)
raw.add_proj([], remove_existing=True) # we'll subselect so remove proj
forward = mne.read_forward_solution(fname_fwd)
if read_all_forward:
@@ -70,18 +67,19 @@ def _get_data(tmin=-0.11, tmax=0.15, read_all_forward=True, compute_csds=True):
# Read epochs
epochs = mne.Epochs(raw, events, event_id, tmin, tmax, proj=True,
picks=picks, baseline=(None, 0), preload=True,
- reject=dict(grad=4000e-13, mag=4e-12, eog=150e-6))
+ reject=dict(grad=4000e-13, mag=4e-12, eog=150e-6),
+ add_eeg_ref=False)
epochs.resample(200, npad=0, n_jobs=2)
evoked = epochs.average()
# Computing the data and noise cross-spectral density matrices
if compute_csds:
- data_csd = compute_epochs_csd(epochs, mode='multitaper', tmin=0.045,
- tmax=None, fmin=8, fmax=12,
- mt_bandwidth=72.72)
- noise_csd = compute_epochs_csd(epochs, mode='multitaper', tmin=None,
- tmax=0.0, fmin=8, fmax=12,
- mt_bandwidth=72.72)
+ data_csd = csd_epochs(epochs, mode='multitaper', tmin=0.045,
+ tmax=None, fmin=8, fmax=12,
+ mt_bandwidth=72.72)
+ noise_csd = csd_epochs(epochs, mode='multitaper', tmin=None,
+ tmax=0.0, fmin=8, fmax=12,
+ mt_bandwidth=72.72)
else:
data_csd, noise_csd = None, None
@@ -240,10 +238,10 @@ def test_tf_dics():
noise_csds = []
for freq_bin, win_length in zip(freq_bins, win_lengths):
- noise_csd = compute_epochs_csd(epochs, mode='fourier',
- fmin=freq_bin[0], fmax=freq_bin[1],
- fsum=True, tmin=tmin,
- tmax=tmin + win_length)
+ noise_csd = csd_epochs(epochs, mode='fourier',
+ fmin=freq_bin[0], fmax=freq_bin[1],
+ fsum=True, tmin=tmin,
+ tmax=tmin + win_length)
noise_csds.append(noise_csd)
stcs = tf_dics(epochs, forward, noise_csds, tmin, tmax, tstep, win_lengths,
@@ -257,14 +255,14 @@ def test_tf_dics():
source_power = []
time_windows = [(-0.1, 0.1), (0.0, 0.2)]
for time_window in time_windows:
- data_csd = compute_epochs_csd(epochs, mode='fourier',
- fmin=freq_bins[0][0],
- fmax=freq_bins[0][1], fsum=True,
- tmin=time_window[0], tmax=time_window[1])
- noise_csd = compute_epochs_csd(epochs, mode='fourier',
- fmin=freq_bins[0][0],
- fmax=freq_bins[0][1], fsum=True,
- tmin=-0.2, tmax=0.0)
+ data_csd = csd_epochs(epochs, mode='fourier',
+ fmin=freq_bins[0][0],
+ fmax=freq_bins[0][1], fsum=True,
+ tmin=time_window[0], tmax=time_window[1])
+ noise_csd = csd_epochs(epochs, mode='fourier',
+ fmin=freq_bins[0][0],
+ fmax=freq_bins[0][1], fsum=True,
+ tmin=-0.2, tmax=0.0)
data_csd.data /= data_csd.n_fft
noise_csd.data /= noise_csd.n_fft
stc_source_power = dics_source_power(epochs.info, forward, noise_csd,
diff --git a/mne/beamformer/tests/test_lcmv.py b/mne/beamformer/tests/test_lcmv.py
index ab05444..ff463b9 100644
--- a/mne/beamformer/tests/test_lcmv.py
+++ b/mne/beamformer/tests/test_lcmv.py
@@ -40,7 +40,7 @@ def _get_data(tmin=-0.1, tmax=0.15, all_forward=True, epochs=True,
"""
label = mne.read_label(fname_label)
events = mne.read_events(fname_event)
- raw = mne.io.read_raw_fif(fname_raw, preload=True)
+ raw = mne.io.read_raw_fif(fname_raw, preload=True, add_eeg_ref=False)
forward = mne.read_forward_solution(fname_fwd)
if all_forward:
forward_surf_ori = read_forward_solution_meg(fname_fwd, surf_ori=True)
@@ -69,7 +69,8 @@ def _get_data(tmin=-0.1, tmax=0.15, all_forward=True, epochs=True,
epochs = mne.Epochs(
raw, events, event_id, tmin, tmax, proj=True,
baseline=(None, 0), preload=epochs_preload,
- reject=dict(grad=4000e-13, mag=4e-12, eog=150e-6))
+ reject=dict(grad=4000e-13, mag=4e-12, eog=150e-6),
+ add_eeg_ref=False)
if epochs_preload:
epochs.resample(200, npad=0, n_jobs=2)
evoked = epochs.average()
@@ -267,7 +268,7 @@ def test_tf_lcmv():
"""
label = mne.read_label(fname_label)
events = mne.read_events(fname_event)
- raw = mne.io.read_raw_fif(fname_raw, preload=True)
+ raw = mne.io.read_raw_fif(fname_raw, preload=True, add_eeg_ref=False)
forward = mne.read_forward_solution(fname_fwd)
event_id, tmin, tmax = 1, -0.2, 0.2
@@ -287,7 +288,8 @@ def test_tf_lcmv():
# Read epochs
epochs = mne.Epochs(raw, events, event_id, tmin, tmax, proj=True,
baseline=None, preload=False,
- reject=dict(grad=4000e-13, mag=4e-12, eog=150e-6))
+ reject=dict(grad=4000e-13, mag=4e-12, eog=150e-6),
+ add_eeg_ref=False)
epochs.drop_bad()
freq_bins = [(4, 12), (15, 40)]
@@ -300,10 +302,11 @@ def test_tf_lcmv():
noise_covs = []
for (l_freq, h_freq), win_length in zip(freq_bins, win_lengths):
raw_band = raw.copy()
- raw_band.filter(l_freq, h_freq, method='iir', n_jobs=1)
+ raw_band.filter(l_freq, h_freq, method='iir', n_jobs=1,
+ iir_params=dict(output='ba'))
epochs_band = mne.Epochs(
raw_band, epochs.events, epochs.event_id, tmin=tmin, tmax=tmax,
- baseline=None, proj=True)
+ baseline=None, proj=True, add_eeg_ref=False)
with warnings.catch_warnings(record=True): # not enough samples
noise_cov = compute_covariance(epochs_band, tmin=tmin, tmax=tmin +
win_length)
@@ -364,7 +367,8 @@ def test_tf_lcmv():
# Test correct detection of preloaded epochs objects that do not contain
# the underlying raw object
epochs_preloaded = mne.Epochs(raw, events, event_id, tmin, tmax, proj=True,
- baseline=(None, 0), preload=True)
+ baseline=(None, 0), preload=True,
+ add_eeg_ref=False)
epochs_preloaded._raw = None
with warnings.catch_warnings(record=True): # not enough samples
assert_raises(ValueError, tf_lcmv, epochs_preloaded, forward,
diff --git a/mne/bem.py b/mne/bem.py
index 5a7a760..625bd48 100644
--- a/mne/bem.py
+++ b/mne/bem.py
@@ -5,16 +5,16 @@
#
# License: BSD (3-clause)
-import sys
+from functools import partial
+import glob
import os
import os.path as op
import shutil
-import glob
+import sys
import numpy as np
from scipy import linalg
-from .fixes import partial
from .utils import verbose, logger, run_subprocess, get_subjects_dir, warn
from .transforms import _ensure_trans, apply_trans
from .io import Info
@@ -318,6 +318,7 @@ def make_bem_solution(surfs, verbose=None):
logger.info('Homogeneous model surface loaded.')
else:
raise RuntimeError('Only 1- or 3-layer BEM computations supported')
+ _check_bem_size(bem['surfs'])
_fwd_bem_linear_collocation_solution(bem)
logger.info('BEM geometry computations complete.')
return bem
@@ -556,6 +557,7 @@ def make_bem_model(subject, ico=4, conductivity=(0.3, 0.006, 0.3),
surfaces = surfaces[:1]
ids = ids[:1]
surfaces = _surfaces_to_bem(surfaces, ids, conductivity, ico)
+ _check_bem_size(surfaces)
logger.info('Complete.\n')
return surfaces
@@ -812,7 +814,7 @@ _dig_kind_ints = tuple(_dig_kind_dict.values())
@verbose
-def fit_sphere_to_headshape(info, dig_kinds='auto', units=None, verbose=None):
+def fit_sphere_to_headshape(info, dig_kinds='auto', units='m', verbose=None):
"""Fit a sphere to the headshape points to determine head center
Parameters
@@ -825,8 +827,7 @@ def fit_sphere_to_headshape(info, dig_kinds='auto', units=None, verbose=None):
be 'auto' (default), which will use only the 'extra' points if
enough are available, and if not, uses 'extra' and 'eeg' points.
units : str
- Can be "m" or "mm". The default in 0.12 is "mm" but will be changed
- to "m" in 0.13.
+ Can be "m" (default) or "mm".
.. versionadded:: 0.12
@@ -847,11 +848,6 @@ def fit_sphere_to_headshape(info, dig_kinds='auto', units=None, verbose=None):
This function excludes any points that are low and frontal
(``z < 0 and y > 0``) to improve the fit.
"""
- if units is None:
- warn('Please explicitly set the units. In 0.12 units="mm" will '
- 'be used, but this will change to units="m" in 0.13.',
- DeprecationWarning)
- units = 'mm'
if not isinstance(units, string_types) or units not in ('m', 'mm'):
raise ValueError('units must be a "m" or "mm"')
if not isinstance(info, Info):
@@ -1019,7 +1015,7 @@ def make_watershed_bem(subject, subjects_dir=None, overwrite=False,
-----
.. versionadded:: 0.10
"""
- from .surface import read_surface
+ from .surface import read_surface, write_surface, _read_surface_geom
from .viz.misc import plot_bem
env, mri_dir = _prepare_env(subject, subjects_dir,
requires_freesurfer=True,
@@ -1069,14 +1065,20 @@ def make_watershed_bem(subject, subjects_dir=None, overwrite=False,
run_subprocess(cmd, env=env, stdout=sys.stdout, stderr=sys.stderr)
if op.isfile(T1_mgz):
- # XXX : do this with python code
+ new_info = _extract_volume_info(T1_mgz)
+ if new_info is None:
+ warn('nibabel is required to replace the volume info. Volume info'
+ 'not updated in the written surface.')
+ new_info = dict()
surfs = ['brain', 'inner_skull', 'outer_skull', 'outer_skin']
for s in surfs:
surf_ws_out = op.join(ws_dir, '%s_%s_surface' % (subject, s))
- cmd = ['mne_convert_surface', '--surf', surf_ws_out, '--mghmri',
- T1_mgz, '--surfout', s, "--replacegeom"]
- run_subprocess(cmd, env=env, stdout=sys.stdout, stderr=sys.stderr)
+ surf, volume_info = _read_surface_geom(surf_ws_out,
+ read_metadata=True)
+ volume_info.update(new_info) # replace volume info, 'head' stays
+
+ write_surface(s, surf['rr'], surf['tris'], volume_info=volume_info)
# Create symbolic links
surf_out = op.join(bem_dir, '%s.surf' % s)
if not overwrite and op.exists(surf_out):
@@ -1084,7 +1086,7 @@ def make_watershed_bem(subject, subjects_dir=None, overwrite=False,
else:
if op.exists(surf_out):
os.remove(surf_out)
- os.symlink(surf_ws_out, surf_out)
+ _symlink(surf_ws_out, surf_out)
skip_symlink = False
if skip_symlink:
@@ -1120,6 +1122,28 @@ def make_watershed_bem(subject, subjects_dir=None, overwrite=False,
logger.info('Created %s\n\nComplete.' % (fname_head,))
+def _extract_volume_info(mgz, raise_error=True):
+ """Helper for extracting volume info from a mgz file."""
+ try:
+ import nibabel as nib
+ except ImportError:
+ return # warning raised elsewhere
+ header = nib.load(mgz).header
+ new_info = dict()
+ version = header['version']
+ if version == 1:
+ version = '%s # volume info valid' % version
+ else:
+ raise ValueError('Volume info invalid.')
+ new_info['valid'] = version
+ new_info['filename'] = mgz
+ new_info['volume'] = header['dims'][:3]
+ new_info['voxelsize'] = header['delta']
+ new_info['xras'], new_info['yras'], new_info['zras'] = header['Mdc'].T
+ new_info['cras'] = header['Pxyz_c']
+ return new_info
+
+
# ############################################################################
# Read
@@ -1244,7 +1268,7 @@ def _read_bem_surface(fid, this, def_coord_frame, s_id=None):
tag = find_tag(fid, this, FIFF.FIFF_MNE_SOURCE_SPACE_NORMALS)
if tag is None:
- tag = tag = find_tag(fid, this, FIFF.FIFF_BEM_SURF_NORMALS)
+ tag = find_tag(fid, this, FIFF.FIFF_BEM_SURF_NORMALS)
if tag is None:
res['nn'] = list()
else:
@@ -1451,6 +1475,7 @@ def write_bem_solution(fname, bem):
--------
read_bem_solution
"""
+ _check_bem_size(bem['surfs'])
with start_file(fname) as fid:
start_block(fid, FIFF.FIFFB_BEM)
# Coordinate frame (mainly for backward compatibility)
@@ -1537,6 +1562,10 @@ def convert_flash_mris(subject, flash30=True, convert=True, unwarp=False,
appropriate series:
$ ln -s <FLASH 5 series dir> flash05
$ ln -s <FLASH 30 series dir> flash30
+ Some partition formats (e.g. FAT32) do not support symbolic links.
+ In this case, copy the file to the appropriate series:
+ $ cp <FLASH 5 series dir> flash05
+ $ cp <FLASH 30 series dir> flash30
4. cd to the directory where flash05 and flash30 links are
5. Set SUBJECTS_DIR and SUBJECT environment variables appropriately
6. Run this script
@@ -1644,7 +1673,7 @@ def convert_flash_mris(subject, flash30=True, convert=True, unwarp=False,
@verbose
def make_flash_bem(subject, overwrite=False, show=True, subjects_dir=None,
- verbose=None):
+ flash_path=None, verbose=None):
"""Create 3-Layer BEM model from prepared flash MRI images
Parameters
@@ -1657,30 +1686,40 @@ def make_flash_bem(subject, overwrite=False, show=True, subjects_dir=None,
Show surfaces to visually inspect all three BEM surfaces (recommended).
subjects_dir : string, or None
Path to SUBJECTS_DIR if it is not set in the environment.
+ flash_path : str | None
+ Path to the flash images. If None (default), mri/flash/parameter_maps
+ within the subject reconstruction is used.
+
+ .. versionadded:: 0.13.0
+
verbose : bool, str, int, or None
If not None, override default verbose level (see mne.verbose).
Notes
-----
- This program assumes that FreeSurfer and MNE are installed and
- sourced properly.
+ This program assumes that FreeSurfer is installed and sourced properly.
This function extracts the BEM surfaces (outer skull, inner skull, and
outer skin) from multiecho FLASH MRI data with spin angles of 5 and 30
degrees, in mgz format.
- This function assumes that the flash images are available in the
- folder mri/bem/flash within the freesurfer subject reconstruction.
-
See Also
--------
convert_flash_mris
"""
from .viz.misc import plot_bem
+ from .surface import write_surface, read_tri
+
+ is_test = os.environ.get('MNE_SKIP_FS_FLASH_CALL', False)
+
env, mri_dir, bem_dir = _prepare_env(subject, subjects_dir,
requires_freesurfer=True,
- requires_mne=True)
+ requires_mne=False)
+ if flash_path is None:
+ flash_path = op.join(mri_dir, 'flash', 'parameter_maps')
+ else:
+ flash_path = op.abspath(flash_path)
curdir = os.getcwd()
subjects_dir = env['SUBJECTS_DIR']
@@ -1692,13 +1731,15 @@ def make_flash_bem(subject, overwrite=False, show=True, subjects_dir=None,
op.join(bem_dir, 'flash')))
# Step 4 : Register with MPRAGE
logger.info("\n---- Registering flash 5 with MPRAGE ----")
- if not op.exists('flash5_reg.mgz'):
+ flash5 = op.join(flash_path, 'flash5.mgz')
+ flash5_reg = op.join(flash_path, 'flash5_reg.mgz')
+ if not op.exists(flash5_reg):
if op.exists(op.join(mri_dir, 'T1.mgz')):
ref_volume = op.join(mri_dir, 'T1.mgz')
else:
ref_volume = op.join(mri_dir, 'T1')
- cmd = ['fsl_rigid_register', '-r', ref_volume, '-i', 'flash5.mgz',
- '-o', 'flash5_reg.mgz']
+ cmd = ['fsl_rigid_register', '-r', ref_volume, '-i', flash5,
+ '-o', flash5_reg]
run_subprocess(cmd, env=env, stdout=sys.stdout, stderr=sys.stderr)
else:
logger.info("Registered flash 5 image is already there")
@@ -1706,8 +1747,9 @@ def make_flash_bem(subject, overwrite=False, show=True, subjects_dir=None,
logger.info("\n---- Converting flash5 volume into COR format ----")
shutil.rmtree(op.join(mri_dir, 'flash5'), ignore_errors=True)
os.makedirs(op.join(mri_dir, 'flash5'))
- cmd = ['mri_convert', 'flash5_reg.mgz', op.join(mri_dir, 'flash5')]
- run_subprocess(cmd, env=env, stdout=sys.stdout, stderr=sys.stderr)
+ if not is_test: # CIs don't have freesurfer, skipped when testing.
+ cmd = ['mri_convert', flash5_reg, op.join(mri_dir, 'flash5')]
+ run_subprocess(cmd, env=env, stdout=sys.stdout, stderr=sys.stderr)
# Step 5b and c : Convert the mgz volumes into COR
os.chdir(mri_dir)
convert_T1 = False
@@ -1735,9 +1777,11 @@ def make_flash_bem(subject, overwrite=False, show=True, subjects_dir=None,
else:
logger.info("Brain volume is already in COR format")
# Finally ready to go
- logger.info("\n---- Creating the BEM surfaces ----")
- cmd = ['mri_make_bem_surfaces', subject]
- run_subprocess(cmd, env=env, stdout=sys.stdout, stderr=sys.stderr)
+ if not is_test: # CIs don't have freesurfer, skipped when testing.
+ logger.info("\n---- Creating the BEM surfaces ----")
+ cmd = ['mri_make_bem_surfaces', subject]
+ run_subprocess(cmd, env=env, stdout=sys.stdout, stderr=sys.stderr)
+
logger.info("\n---- Converting the tri files into surf files ----")
os.chdir(bem_dir)
if not op.exists('flash'):
@@ -1746,11 +1790,16 @@ def make_flash_bem(subject, overwrite=False, show=True, subjects_dir=None,
surfs = ['inner_skull', 'outer_skull', 'outer_skin']
for surf in surfs:
shutil.move(op.join(bem_dir, surf + '.tri'), surf + '.tri')
- cmd = ['mne_convert_surface', '--tri', surf + '.tri', '--surfout',
- surf + '.surf', '--swap', '--mghmri',
- op.join(subjects_dir, subject, 'mri', 'flash', 'parameter_maps',
- 'flash5_reg.mgz')]
- run_subprocess(cmd, env=env, stdout=sys.stdout, stderr=sys.stderr)
+
+ nodes, tris = read_tri(surf + '.tri', swap=True)
+ vol_info = _extract_volume_info(flash5_reg)
+ if vol_info is None:
+ warn('nibabel is required to update the volume info. Volume info '
+ 'omitted from the written surface.')
+ else:
+ vol_info['head'] = np.array([20])
+ write_surface(surf + '.surf', nodes, tris, volume_info=vol_info)
+
# Cleanup section
logger.info("\n---- Cleaning up ----")
os.chdir(bem_dir)
@@ -1774,7 +1823,7 @@ def make_flash_bem(subject, overwrite=False, show=True, subjects_dir=None,
else:
if op.exists(surf):
os.remove(surf)
- os.symlink(op.join('flash', surf), op.join(surf))
+ _symlink(op.join('flash', surf), op.join(surf))
skip_symlink = False
if skip_symlink:
logger.info("Unable to create all symbolic links to .surf files "
@@ -1794,3 +1843,22 @@ def make_flash_bem(subject, overwrite=False, show=True, subjects_dir=None,
# Go back to initial directory
os.chdir(curdir)
+
+
+def _check_bem_size(surfs):
+ """Helper for checking bem surface sizes."""
+ if surfs[0]['np'] > 10000:
+ msg = ('The bem surface has %s data points. 5120 (ico grade=4) should '
+ 'be enough.' % surfs[0]['np'])
+ if len(surfs) == 3:
+ msg += ' Dense 3-layer bems may not save properly.'
+ warn(msg)
+
+
+def _symlink(src, dest):
+ try:
+ os.symlink(src, dest)
+ except OSError:
+ warn('Could not create symbolic link %s. Check that your partition '
+ 'handles symbolic links. The file will be copied instead.' % dest)
+ shutil.copy(src, dest)
diff --git a/mne/channels/channels.py b/mne/channels/channels.py
index 7970047..1d55dc6 100644
--- a/mne/channels/channels.py
+++ b/mne/channels/channels.py
@@ -14,10 +14,12 @@ from scipy import sparse
from ..externals.six import string_types
-from ..utils import verbose, logger, warn, _check_copy_dep
+from ..utils import verbose, logger, warn, copy_function_doc_to_method_doc
+from ..io.compensator import get_current_comp
+from ..io.constants import FIFF
+from ..io.meas_info import anonymize_info
from ..io.pick import (channel_type, pick_info, pick_types,
_check_excludes_includes, _PICK_TYPES_KEYS)
-from ..io.constants import FIFF
def _get_meg_system(info):
@@ -25,6 +27,7 @@ def _get_meg_system(info):
system = '306m'
for ch in info['chs']:
if ch['kind'] == FIFF.FIFFV_MEG_CH:
+ # Only take first 16 bits, as higher bits store CTF grad comp order
coil_type = ch['coil_type'] & 0xFFFF
if coil_type == FIFF.FIFFV_COIL_NM_122:
system = '122m'
@@ -70,8 +73,9 @@ def _contains_ch_type(info, ch_type):
'`str`'.format(actual_class=type(ch_type)))
meg_extras = ['mag', 'grad', 'planar1', 'planar2']
+ fnirs_extras = ['hbo', 'hbr']
valid_channel_types = sorted([key for key in _PICK_TYPES_KEYS
- if key != 'meg'] + meg_extras)
+ if key != 'meg'] + meg_extras + fnirs_extras)
if ch_type not in valid_channel_types:
raise ValueError('ch_type must be one of %s, not "%s"'
% (valid_channel_types, ch_type))
@@ -144,7 +148,28 @@ class ContainsMixin(object):
"""Mixin class for Raw, Evoked, Epochs
"""
def __contains__(self, ch_type):
- """Check channel type membership"""
+ """Check channel type membership
+
+ Parameters
+ ----------
+ ch_type : str
+ Channel type to check for. Can be e.g. 'meg', 'eeg', 'stim', etc.
+
+ Returns
+ -------
+ in : bool
+ Whether or not the instance contains the given channel type.
+
+ Examples
+ --------
+ Channel type membership can be tested as::
+
+ >>> 'meg' in inst # doctest: +SKIP
+ True
+ >>> 'seeg' in inst # doctest: +SKIP
+ False
+
+ """
if ch_type == 'meg':
has_ch_type = (_contains_ch_type(self.info, 'mag') or
_contains_ch_type(self.info, 'grad'))
@@ -152,6 +177,12 @@ class ContainsMixin(object):
has_ch_type = _contains_ch_type(self.info, ch_type)
return has_ch_type
+ @property
+ def compensation_grade(self):
+ """The current gradient compensation grade"""
+ return get_current_comp(self.info)
+
+
# XXX Eventually de-duplicate with _kind_dict of mne/io/meas_info.py
_human2fiff = {'ecg': FIFF.FIFFV_ECG_CH,
'eeg': FIFF.FIFFV_EEG_CH,
@@ -165,7 +196,9 @@ _human2fiff = {'ecg': FIFF.FIFFV_ECG_CH,
'stim': FIFF.FIFFV_STIM_CH,
'syst': FIFF.FIFFV_SYST_CH,
'bio': FIFF.FIFFV_BIO_CH,
- 'ecog': FIFF.FIFFV_ECOG_CH}
+ 'ecog': FIFF.FIFFV_ECOG_CH,
+ 'hbo': FIFF.FIFFV_FNIRS_CH,
+ 'hbr': FIFF.FIFFV_FNIRS_CH}
_human2unit = {'ecg': FIFF.FIFF_UNIT_V,
'eeg': FIFF.FIFF_UNIT_V,
'emg': FIFF.FIFF_UNIT_V,
@@ -178,10 +211,13 @@ _human2unit = {'ecg': FIFF.FIFF_UNIT_V,
'stim': FIFF.FIFF_UNIT_NONE,
'syst': FIFF.FIFF_UNIT_NONE,
'bio': FIFF.FIFF_UNIT_V,
- 'ecog': FIFF.FIFF_UNIT_V}
+ 'ecog': FIFF.FIFF_UNIT_V,
+ 'hbo': FIFF.FIFF_UNIT_MOL,
+ 'hbr': FIFF.FIFF_UNIT_MOL}
_unit2human = {FIFF.FIFF_UNIT_V: 'V',
FIFF.FIFF_UNIT_T: 'T',
FIFF.FIFF_UNIT_T_M: 'T/m',
+ FIFF.FIFF_UNIT_MOL: 'M',
FIFF.FIFF_UNIT_NONE: 'NA'}
@@ -198,8 +234,52 @@ def _check_set(ch, projs, ch_type):
class SetChannelsMixin(object):
- """Mixin class for Raw, Evoked, Epochs
- """
+ """Mixin class for Raw, Evoked, Epochs."""
+
+ def set_eeg_reference(self, ref_channels=None):
+ """Rereference EEG channels to new reference channel(s).
+
+ If multiple reference channels are specified, they will be averaged. If
+ no reference channels are specified, an average reference will be
+ applied.
+
+ Parameters
+ ----------
+ ref_channels : list of str | None
+ The names of the channels to use to construct the reference. If
+ None (default), an average reference will be added as an SSP
+ projector but not immediately applied to the data. If an empty list
+ is specified, the data is assumed to already have a proper
+ reference and MNE will not attempt any re-referencing of the data.
+ Defaults to an average reference (None).
+
+ Returns
+ -------
+ inst : instance of Raw | Epochs | Evoked
+ Data with EEG channels re-referenced. For ``ref_channels=None``,
+ an average projector will be added instead of directly subtarcting
+ data.
+
+ Notes
+ -----
+ 1. If a reference is requested that is not the average reference, this
+ function removes any pre-existing average reference projections.
+
+ 2. During source localization, the EEG signal should have an average
+ reference.
+
+ 3. In order to apply a reference other than an average reference, the
+ data must be preloaded.
+
+ .. versionadded:: 0.13.0
+
+ See Also
+ --------
+ mne.set_bipolar_reference
+ """
+ from ..io.reference import set_eeg_reference
+ return set_eeg_reference(self, ref_channels, copy=False)[0]
+
def _get_channel_positions(self, picks=None):
"""Gets channel locations from info
@@ -258,7 +338,8 @@ class SetChannelsMixin(object):
"""Define the sensor type of channels.
Note: The following sensor types are accepted:
- ecg, eeg, emg, eog, exci, ias, misc, resp, seeg, stim, syst, ecog
+ ecg, eeg, emg, eog, exci, ias, misc, resp, seeg, stim, syst, ecog,
+ hbo, hbr
Parameters
----------
@@ -298,9 +379,14 @@ class SetChannelsMixin(object):
% (ch_name, _unit2human[unit_old], _unit2human[unit_new]))
self.info['chs'][c_ind]['unit'] = _human2unit[ch_type]
if ch_type in ['eeg', 'seeg', 'ecog']:
- self.info['chs'][c_ind]['coil_type'] = FIFF.FIFFV_COIL_EEG
+ coil_type = FIFF.FIFFV_COIL_EEG
+ elif ch_type == 'hbo':
+ coil_type = FIFF.FIFFV_COIL_FNIRS_HBO
+ elif ch_type == 'hbr':
+ coil_type = FIFF.FIFFV_COIL_FNIRS_HBR
else:
- self.info['chs'][c_ind]['coil_type'] = FIFF.FIFFV_COIL_NONE
+ coil_type = FIFF.FIFFV_COIL_NONE
+ self.info['chs'][c_ind]['coil_type'] = coil_type
def rename_channels(self, mapping):
"""Rename channels.
@@ -339,23 +425,50 @@ class SetChannelsMixin(object):
_set_montage(self.info, montage)
def plot_sensors(self, kind='topomap', ch_type=None, title=None,
- show_names=False, show=True):
+ show_names=False, ch_groups=None, axes=None, block=False,
+ show=True):
"""
Plot sensors positions.
Parameters
----------
kind : str
- Whether to plot the sensors as 3d or as topomap. Available options
- 'topomap', '3d'. Defaults to 'topomap'.
- ch_type : 'mag' | 'grad' | 'eeg' | 'seeg' | 'ecog' | None
- The channel type to plot. If None, then channels are chosen in the
- order given above.
+ Whether to plot the sensors as 3d, topomap or as an interactive
+ sensor selection dialog. Available options 'topomap', '3d',
+ 'select'. If 'select', a set of channels can be selected
+ interactively by using lasso selector or clicking while holding
+ control key. The selected channels are returned along with the
+ figure instance. Defaults to 'topomap'.
+ ch_type : None | str
+ The channel type to plot. Available options 'mag', 'grad', 'eeg',
+ 'seeg', 'ecog', 'all'. If ``'all'``, all the available mag, grad,
+ eeg, seeg and ecog channels are plotted. If None (default), then
+ channels are chosen in the order given above.
title : str | None
- Title for the figure. If None (default), equals to
- ``'Sensor positions (%s)' % ch_type``.
+ Title for the figure. If None (default), equals to ``'Sensor
+ positions (%s)' % ch_type``.
show_names : bool
Whether to display all channel names. Defaults to False.
+ ch_groups : 'position' | array of shape (ch_groups, picks) | None
+ Channel groups for coloring the sensors. If None (default), default
+ coloring scheme is used. If 'position', the sensors are divided
+ into 8 regions. See ``order`` kwarg of :func:`mne.viz.plot_raw`. If
+ array, the channels are divided by picks given in the array.
+
+ .. versionadded:: 0.13.0
+
+ axes : instance of Axes | instance of Axes3D | None
+ Axes to draw the sensors to. If ``kind='3d'``, axes must be an
+ instance of Axes3D. If None (default), a new axes will be created.
+
+ .. versionadded:: 0.13.0
+
+ block : bool
+ Whether to halt program execution until the figure is closed.
+ Defaults to False.
+
+ .. versionadded:: 0.13.0
+
show : bool
Show figure if True. Defaults to True.
@@ -363,6 +476,8 @@ class SetChannelsMixin(object):
-------
fig : instance of matplotlib figure
Figure containing the sensor topography.
+ selection : list
+ A list of selected channels. Only returned if ``kind=='select'``.
See Also
--------
@@ -375,11 +490,19 @@ class SetChannelsMixin(object):
:func:`mne.viz.plot_trans`.
.. versionadded:: 0.12.0
-
"""
from ..viz.utils import plot_sensors
return plot_sensors(self.info, kind=kind, ch_type=ch_type, title=title,
- show_names=show_names, show=show)
+ show_names=show_names, ch_groups=ch_groups,
+ axes=axes, block=block, show=show)
+
+ @copy_function_doc_to_method_doc(anonymize_info)
+ def anonymize(self):
+ """
+ .. versionadded:: 0.13.0
+ """
+ anonymize_info(self.info)
+ return self
class UpdateChannelsMixin(object):
@@ -388,8 +511,8 @@ class UpdateChannelsMixin(object):
def pick_types(self, meg=True, eeg=False, stim=False, eog=False,
ecg=False, emg=False, ref_meg='auto', misc=False,
resp=False, chpi=False, exci=False, ias=False, syst=False,
- seeg=False, bio=False, ecog=False, include=[],
- exclude='bads', selection=None, copy=None):
+ seeg=False, dipole=False, gof=False, bio=False, ecog=False,
+ fnirs=False, include=[], exclude='bads', selection=None):
"""Pick some channels by type and names
Parameters
@@ -427,10 +550,19 @@ class UpdateChannelsMixin(object):
System status channel information (on Triux systems only).
seeg : bool
Stereotactic EEG channels.
+ dipole : bool
+ Dipole time course channels.
+ gof : bool
+ Dipole goodness of fit channels.
bio : bool
Bio channels.
ecog : bool
Electrocorticography channels.
+ fnirs : bool | str
+ Functional near-infrared spectroscopy channels. If True include all
+ fNIRS channels. If False (default) include none. If string it can
+ be 'hbo' (to include channels measuring oxyhemoglobin) or 'hbr' (to
+ include channels measuring deoxyhemoglobin).
include : list of string
List of additional channels to include. If empty do not include
any.
@@ -439,10 +571,6 @@ class UpdateChannelsMixin(object):
in ``info['bads']``.
selection : list of string
Restrict sensor channels (MEG, EEG) to this list of channel names.
- copy : bool
- This parameter has been deprecated and will be removed in 0.13.
- Use inst.copy() instead.
- Whether to return a new instance or modify in place.
Returns
-------
@@ -453,26 +581,22 @@ class UpdateChannelsMixin(object):
-----
.. versionadded:: 0.9.0
"""
- inst = _check_copy_dep(self, copy)
idx = pick_types(
self.info, meg=meg, eeg=eeg, stim=stim, eog=eog, ecg=ecg, emg=emg,
ref_meg=ref_meg, misc=misc, resp=resp, chpi=chpi, exci=exci,
- ias=ias, syst=syst, seeg=seeg, bio=bio, ecog=ecog, include=include,
- exclude=exclude, selection=selection)
- inst._pick_drop_channels(idx)
- return inst
+ ias=ias, syst=syst, seeg=seeg, dipole=dipole, gof=gof, bio=bio,
+ ecog=ecog, fnirs=fnirs, include=include, exclude=exclude,
+ selection=selection)
+ self._pick_drop_channels(idx)
+ return self
- def pick_channels(self, ch_names, copy=None):
+ def pick_channels(self, ch_names):
"""Pick some channels
Parameters
----------
ch_names : list
The list of channels to select.
- copy : bool
- This parameter has been deprecated and will be removed in 0.13.
- Use inst.copy() instead.
- Whether to return a new instance or modify in place.
Returns
-------
@@ -487,23 +611,18 @@ class UpdateChannelsMixin(object):
-----
.. versionadded:: 0.9.0
"""
- inst = _check_copy_dep(self, copy)
_check_excludes_includes(ch_names)
- idx = [inst.ch_names.index(c) for c in ch_names if c in inst.ch_names]
- inst._pick_drop_channels(idx)
- return inst
+ idx = [self.ch_names.index(c) for c in ch_names if c in self.ch_names]
+ self._pick_drop_channels(idx)
+ return self
- def drop_channels(self, ch_names, copy=None):
+ def drop_channels(self, ch_names):
"""Drop some channels
Parameters
----------
ch_names : list
- The list of channels to remove.
- copy : bool
- This parameter has been deprecated and will be removed in 0.13.
- Use inst.copy() instead.
- Whether to return a new instance or modify in place.
+ List of the names of the channels to remove.
Returns
-------
@@ -518,12 +637,27 @@ class UpdateChannelsMixin(object):
-----
.. versionadded:: 0.9.0
"""
- inst = _check_copy_dep(self, copy)
- bad_idx = [inst.ch_names.index(c) for c in ch_names
- if c in inst.ch_names]
- idx = np.setdiff1d(np.arange(len(inst.ch_names)), bad_idx)
- inst._pick_drop_channels(idx)
- return inst
+ msg = ("'ch_names' should be a list of strings (the name[s] of the "
+ "channel to be dropped), not a {0}.")
+ if isinstance(ch_names, string_types):
+ raise ValueError(msg.format("string"))
+ else:
+ if not all([isinstance(ch_name, string_types)
+ for ch_name in ch_names]):
+ raise ValueError(msg.format(type(ch_names[0])))
+
+ missing = [ch_name for ch_name in ch_names
+ if ch_name not in self.ch_names]
+ if len(missing) > 0:
+ msg = "Channel(s) {0} not found, nothing dropped."
+ raise ValueError(msg.format(", ".join(missing)))
+
+ bad_idx = [self.ch_names.index(ch_name) for ch_name in ch_names
+ if ch_name in self.ch_names]
+ idx = np.setdiff1d(np.arange(len(self.ch_names)), bad_idx)
+ self._pick_drop_channels(idx)
+
+ return self
def _pick_drop_channels(self, idx):
# avoid circular imports
@@ -560,7 +694,7 @@ class UpdateChannelsMixin(object):
elif isinstance(self, Evoked):
self.data = self.data.take(idx, axis=0)
- def add_channels(self, add_list, copy=None, force_update_info=False):
+ def add_channels(self, add_list, force_update_info=False):
"""Append new channels to the instance.
Parameters
@@ -568,10 +702,6 @@ class UpdateChannelsMixin(object):
add_list : list
A list of objects to append to self. Must contain all the same
type as the current object
- copy : bool
- This parameter has been deprecated and will be removed in 0.13.
- Use inst.copy() instead.
- Whether to return a new instance or modify in place.
force_update_info : bool
If True, force the info for objects to be appended to match the
values in `self`. This should generally only be used when adding
@@ -584,7 +714,6 @@ class UpdateChannelsMixin(object):
inst : instance of Raw, Epochs, or Evoked
The modified instance.
"""
- out = _check_copy_dep(self, copy)
# avoid circular imports
from ..io import _BaseRaw, _merge_info
from ..epochs import _BaseEpochs
@@ -623,12 +752,12 @@ class UpdateChannelsMixin(object):
new_info = _merge_info(infos, force_update_to_first=force_update_info)
# Now update the attributes
- setattr(out, data_name, data)
- out.info = new_info
+ setattr(self, data_name, data)
+ self.info = new_info
if isinstance(self, _BaseRaw):
- out._cals = np.concatenate([getattr(inst, '_cals')
- for inst in [self] + add_list])
- return out
+ self._cals = np.concatenate([getattr(inst, '_cals')
+ for inst in [self] + add_list])
+ return self
class InterpolationMixin(object):
diff --git a/mne/channels/data/layouts/KIT-UMD-3.lout b/mne/channels/data/layouts/KIT-UMD-3.lout
new file mode 100644
index 0000000..72cd69f
--- /dev/null
+++ b/mne/channels/data/layouts/KIT-UMD-3.lout
@@ -0,0 +1,158 @@
+ -25.00 28.00 -21.35 23.75
+000 -23.42 20.48 3.20 2.40 MEG 001
+001 -22.32 15.16 3.20 2.40 MEG 002
+002 -24.20 10.24 3.20 2.40 MEG 003
+003 -25.00 5.27 3.20 2.40 MEG 004
+004 -24.75 -0.21 3.20 2.40 MEG 005
+005 -23.41 -5.22 3.20 2.40 MEG 006
+006 -22.35 -11.37 3.20 2.40 MEG 007
+007 -14.06 -15.64 3.20 2.40 MEG 008
+008 -15.12 -18.15 3.20 2.40 MEG 009
+009 -11.26 -20.73 3.20 2.40 MEG 010
+010 -6.28 -20.94 3.20 2.40 MEG 011
+011 -2.04 -21.35 3.20 2.40 MEG 012
+012 2.04 -21.35 3.20 2.40 MEG 013
+013 6.28 -20.94 3.20 2.40 MEG 014
+014 11.26 -20.73 3.20 2.40 MEG 015
+015 15.12 -18.15 3.20 2.40 MEG 016
+016 19.41 -14.06 3.20 2.40 MEG 017
+017 22.35 -11.37 3.20 2.40 MEG 018
+018 24.06 -3.70 3.20 2.40 MEG 019
+019 24.23 1.80 3.20 2.40 MEG 020
+020 24.80 5.19 3.20 2.40 MEG 021
+021 22.03 13.42 3.20 2.40 MEG 022
+022 21.58 16.68 3.20 2.40 MEG 023
+023 23.42 20.48 3.20 2.40 MEG 024
+024 20.15 19.33 3.20 2.40 MEG 025
+025 7.46 -2.58 3.20 2.40 MEG 026
+026 22.86 7.70 3.20 2.40 MEG 027
+027 20.76 2.91 3.20 2.40 MEG 028
+028 19.70 -8.80 3.20 2.40 MEG 029
+029 3.41 -5.91 3.20 2.40 MEG 030
+030 14.06 -15.64 3.20 2.40 MEG 031
+031 0.12 -5.34 3.20 2.40 MEG 032
+032 1.80 -18.87 3.20 2.40 MEG 033
+033 -1.80 -18.87 3.20 2.40 MEG 034
+034 -10.12 -18.16 3.20 2.40 MEG 035
+035 -3.41 -5.91 3.20 2.40 MEG 036
+036 -18.35 -13.97 3.20 2.40 MEG 037
+037 -19.70 -8.80 3.20 2.40 MEG 038
+038 -20.76 2.91 3.20 2.40 MEG 039
+039 -22.86 7.70 3.20 2.40 MEG 040
+040 -7.46 -2.58 3.20 2.40 MEG 041
+041 -20.15 19.33 3.20 2.40 MEG 042
+042 -16.84 18.53 3.20 2.40 MEG 043
+043 -18.55 14.46 3.20 2.40 MEG 044
+044 -20.31 10.64 3.20 2.40 MEG 045
+045 -10.05 0.17 3.20 2.40 MEG 046
+046 -20.62 -2.66 3.20 2.40 MEG 047
+047 -17.20 -6.26 3.20 2.40 MEG 048
+048 -16.21 -11.50 3.20 2.40 MEG 049
+049 -8.92 -15.60 3.20 2.40 MEG 050
+050 -5.79 -18.42 3.20 2.40 MEG 051
+051 -1.62 -16.14 3.20 2.40 MEG 052
+052 -8.25 6.10 3.20 2.40 MEG 053
+053 5.79 -18.42 3.20 2.40 MEG 054
+054 8.92 -15.60 3.20 2.40 MEG 055
+055 16.21 -11.50 3.20 2.40 MEG 056
+056 17.20 -6.26 3.20 2.40 MEG 057
+057 20.62 -2.66 3.20 2.40 MEG 058
+058 -6.11 13.61 3.20 2.40 MEG 059
+059 20.31 10.64 3.20 2.40 MEG 060
+060 17.58 15.92 3.20 2.40 MEG 061
+061 16.84 18.53 3.20 2.40 MEG 062
+062 13.49 18.47 3.20 2.40 MEG 063
+063 15.28 13.32 3.20 2.40 MEG 064
+064 -4.11 11.13 3.20 2.40 MEG 065
+065 19.39 7.54 3.20 2.40 MEG 066
+066 17.50 3.47 3.20 2.40 MEG 067
+067 -6.54 8.57 3.20 2.40 MEG 068
+068 11.44 -8.04 3.20 2.40 MEG 069
+069 12.41 -13.14 3.20 2.40 MEG 070
+070 8.16 -13.13 3.20 2.40 MEG 071
+071 -7.60 2.77 3.20 2.40 MEG 072
+072 1.62 -16.14 3.20 2.40 MEG 073
+073 -6.80 0.14 3.20 2.40 MEG 074
+074 -5.40 -15.93 3.20 2.40 MEG 075
+075 -8.16 -13.13 3.20 2.40 MEG 076
+076 -12.41 -13.14 3.20 2.40 MEG 077
+077 -14.81 -8.97 3.20 2.40 MEG 078
+078 -3.23 -2.94 3.20 2.40 MEG 079
+079 -17.50 3.47 3.20 2.40 MEG 080
+080 -19.39 7.54 3.20 2.40 MEG 081
+081 4.03 -2.84 3.20 2.40 MEG 082
+082 -15.28 13.32 3.20 2.40 MEG 083
+083 -13.49 18.47 3.20 2.40 MEG 084
+084 -12.29 15.99 3.20 2.40 MEG 085
+085 -16.74 10.63 3.20 2.40 MEG 086
+086 6.80 0.14 3.20 2.40 MEG 087
+087 -17.30 -2.88 3.20 2.40 MEG 088
+088 -13.99 -4.86 3.20 2.40 MEG 089
+089 11.58 6.13 3.20 2.40 MEG 090
+090 -11.44 -8.04 3.20 2.40 MEG 091
+091 -3.30 -13.45 3.20 2.40 MEG 092
+092 6.54 8.57 3.20 2.40 MEG 093
+093 -9.52 -10.67 3.20 2.40 MEG 094
+094 9.52 -10.67 3.20 2.40 MEG 095
+095 4.11 11.13 3.20 2.40 MEG 096
+096 13.99 -4.86 3.20 2.40 MEG 097
+097 18.10 -0.17 3.20 2.40 MEG 098
+098 0.74 11.38 3.20 2.40 MEG 099
+099 16.74 10.63 3.20 2.40 MEG 100
+100 12.29 15.99 3.20 2.40 MEG 101
+101 10.11 18.86 3.20 2.40 MEG 102
+102 6.83 19.80 3.20 2.40 MEG 103
+103 3.48 21.35 3.20 2.40 MEG 104
+104 0.00 21.35 3.20 2.40 MEG 105
+105 -3.48 21.35 3.20 2.40 MEG 106
+106 -6.83 19.80 3.20 2.40 MEG 107
+107 -10.11 18.86 3.20 2.40 MEG 108
+108 -12.03 13.52 3.20 2.40 MEG 109
+109 -1.63 8.64 3.20 2.40 MEG 110
+110 -3.36 18.88 3.20 2.40 MEG 111
+111 -0.02 18.88 3.20 2.40 MEG 112
+112 3.36 18.88 3.20 2.40 MEG 113
+113 1.63 8.64 3.20 2.40 MEG 114
+114 9.01 16.34 3.20 2.40 MEG 115
+115 4.97 5.29 3.20 2.40 MEG 116
+116 13.28 10.76 3.20 2.40 MEG 117
+117 15.78 7.58 3.20 2.40 MEG 118
+118 14.24 3.60 3.20 2.40 MEG 119
+119 14.69 -0.31 3.20 2.40 MEG 120
+120 3.37 -0.21 3.20 2.40 MEG 121
+121 8.20 -8.14 3.20 2.40 MEG 122
+122 6.11 -10.67 3.20 2.40 MEG 123
+123 2.77 -10.98 3.20 2.40 MEG 124
+124 0.10 -13.43 3.20 2.40 MEG 125
+125 0.02 -0.57 3.20 2.40 MEG 126
+126 -2.77 -10.98 3.20 2.40 MEG 127
+127 -8.20 -8.14 3.20 2.40 MEG 128
+128 -3.37 -0.21 3.20 2.40 MEG 129
+129 -14.69 -0.31 3.20 2.40 MEG 130
+130 -14.24 3.60 3.20 2.40 MEG 131
+131 -15.78 7.58 3.20 2.40 MEG 132
+132 -13.28 10.76 3.20 2.40 MEG 133
+133 -4.97 5.29 3.20 2.40 MEG 134
+134 -9.46 11.02 3.20 2.40 MEG 135
+135 -12.21 7.84 3.20 2.40 MEG 136
+136 -10.93 3.58 3.20 2.40 MEG 137
+137 -10.71 -3.82 3.20 2.40 MEG 138
+138 -6.89 -5.51 3.20 2.40 MEG 139
+139 -1.66 5.24 3.20 2.40 MEG 140
+140 -2.40 -8.39 3.20 2.40 MEG 141
+141 2.40 -8.39 3.20 2.40 MEG 142
+142 -4.29 2.66 3.20 2.40 MEG 143
+143 6.89 -5.51 3.20 2.40 MEG 144
+144 10.71 -3.82 3.20 2.40 MEG 145
+145 10.93 3.58 3.20 2.40 MEG 146
+146 4.29 2.66 3.20 2.40 MEG 147
+147 9.46 11.02 3.20 2.40 MEG 148
+148 5.70 16.39 3.20 2.40 MEG 149
+149 1.66 5.24 3.20 2.40 MEG 150
+150 -2.37 16.38 3.20 2.40 MEG 151
+151 -5.70 16.39 3.20 2.40 MEG 152
+152 8.25 6.10 3.20 2.40 MEG 153
+153 -0.58 13.96 3.20 2.40 MEG 154
+154 2.81 13.89 3.20 2.40 MEG 155
+155 6.11 13.61 3.20 2.40 MEG 156
+156 2.37 16.38 3.20 2.40 MEG 157
diff --git a/mne/channels/data/neighbors/KIT-UMD-1_neighb.mat b/mne/channels/data/neighbors/KIT-UMD-1_neighb.mat
new file mode 100644
index 0000000..f860666
Binary files /dev/null and b/mne/channels/data/neighbors/KIT-UMD-1_neighb.mat differ
diff --git a/mne/channels/data/neighbors/KIT-UMD-2_neighb.mat b/mne/channels/data/neighbors/KIT-UMD-2_neighb.mat
new file mode 100644
index 0000000..19ad03c
Binary files /dev/null and b/mne/channels/data/neighbors/KIT-UMD-2_neighb.mat differ
diff --git a/mne/channels/data/neighbors/KIT-UMD-3_neighb.mat b/mne/channels/data/neighbors/KIT-UMD-3_neighb.mat
new file mode 100644
index 0000000..c7ded3d
Binary files /dev/null and b/mne/channels/data/neighbors/KIT-UMD-3_neighb.mat differ
diff --git a/mne/channels/data/neighbors/__init__.py b/mne/channels/data/neighbors/__init__.py
index 8fc6ea7..c741b62 100644
--- a/mne/channels/data/neighbors/__init__.py
+++ b/mne/channels/data/neighbors/__init__.py
@@ -4,3 +4,6 @@
# For additional information on how these definitions were computed, please
# consider the related fieldtrip documentation:
# http://fieldtrip.fcdonders.nl/template/neighbours.
+#
+# KIT neighbor files were computed with ft_prepare_neighbours using the
+# triangulation method.
diff --git a/mne/channels/interpolation.py b/mne/channels/interpolation.py
index bffe644..988f6cd 100644
--- a/mne/channels/interpolation.py
+++ b/mne/channels/interpolation.py
@@ -172,14 +172,15 @@ def _interpolate_bads_meg(inst, mode='accurate', verbose=None):
If not None, override default verbose level (see mne.verbose).
"""
picks_meg = pick_types(inst.info, meg=True, eeg=False, exclude=[])
- ch_names = [inst.info['ch_names'][p] for p in picks_meg]
picks_good = pick_types(inst.info, meg=True, eeg=False, exclude='bads')
+ meg_ch_names = [inst.info['ch_names'][p] for p in picks_meg]
+ bads_meg = [ch for ch in inst.info['bads'] if ch in meg_ch_names]
# select the bad meg channel to be interpolated
- if len(inst.info['bads']) == 0:
+ if len(bads_meg) == 0:
picks_bad = []
else:
- picks_bad = pick_channels(ch_names, inst.info['bads'],
+ picks_bad = pick_channels(inst.info['ch_names'], bads_meg,
exclude=[])
# return without doing anything if there are no meg channels
diff --git a/mne/channels/layout.py b/mne/channels/layout.py
index 911add1..4d19bf9 100644
--- a/mne/channels/layout.py
+++ b/mne/channels/layout.py
@@ -20,7 +20,7 @@ from ..bem import fit_sphere_to_headshape
from ..io.pick import pick_types
from ..io.constants import FIFF
from ..io.meas_info import Info
-from ..utils import _clean_names
+from ..utils import _clean_names, warn
from ..externals.six.moves import map
@@ -101,9 +101,7 @@ class Layout(object):
Notes
-----
-
.. versionadded:: 0.12.0
-
"""
from ..viz.topomap import plot_layout
return plot_layout(self, show=show)
@@ -387,7 +385,8 @@ def find_layout(info, ch_type=None, exclude='bads'):
'`ch_type` must be %s' % (ch_type, our_types))
chs = info['chs']
- coil_types = set([ch['coil_type'] for ch in chs])
+ # Only take first 16 bits, as higher bits store CTF comp order
+ coil_types = set([ch['coil_type'] & 0xFFFF for ch in chs])
channel_types = set([ch['kind'] for ch in chs])
has_vv_mag = any(k in coil_types for k in
@@ -409,7 +408,8 @@ def find_layout(info, ch_type=None, exclude='bads'):
(FIFF.FIFFV_MEG_CH in channel_types and
any(k in ctf_other_types for k in coil_types)))
# hack due to MNE-C bug in IO of CTF
- n_kit_grads = sum(ch['coil_type'] == FIFF.FIFFV_COIL_KIT_GRAD
+ # only take first 16 bits, as higher bits store CTF comp order
+ n_kit_grads = sum(ch['coil_type'] & 0xFFFF == FIFF.FIFFV_COIL_KIT_GRAD
for ch in chs)
has_any_meg = any([has_vv_mag, has_vv_grad, has_4D_mag, has_CTF_grad,
@@ -445,10 +445,8 @@ def find_layout(info, ch_type=None, exclude='bads'):
layout_name = 'magnesWH3600'
elif has_CTF_grad:
layout_name = 'CTF-275'
- elif n_kit_grads <= 157:
- layout_name = 'KIT-157'
- elif n_kit_grads > 157:
- layout_name = 'KIT-AD'
+ elif n_kit_grads > 0:
+ layout_name = _find_kit_layout(info, n_kit_grads)
else:
return None
@@ -461,6 +459,56 @@ def find_layout(info, ch_type=None, exclude='bads'):
return layout
+def _find_kit_layout(info, n_grads):
+ """Determine the KIT layout
+
+ Parameters
+ ----------
+ info : Info
+ Info object.
+ n_grads : int
+ Number of KIT-gradiometers in the info.
+
+ Returns
+ -------
+ kit_layout : str
+ One of 'KIT-AD', 'KIT-157' or 'KIT-UMD'.
+ """
+ if info['kit_system_id'] is not None:
+ # avoid circular import
+ from ..io.kit.constants import KIT_LAYOUT
+
+ if info['kit_system_id'] in KIT_LAYOUT:
+ kit_layout = KIT_LAYOUT[info['kit_system_id']]
+ if kit_layout is not None:
+ return kit_layout
+ raise NotImplementedError("The layout for the KIT system with ID %i "
+ "is missing. Please contact the developers "
+ "about adding it." % info['kit_system_id'])
+ elif n_grads > 157:
+ return 'KIT-AD'
+
+ # channels which are on the left hemisphere for NY and right for UMD
+ test_chs = ('MEG 13', 'MEG 14', 'MEG 15', 'MEG 16', 'MEG 25',
+ 'MEG 26', 'MEG 27', 'MEG 28', 'MEG 29', 'MEG 30',
+ 'MEG 31', 'MEG 32', 'MEG 57', 'MEG 60', 'MEG 61',
+ 'MEG 62', 'MEG 63', 'MEG 64', 'MEG 73', 'MEG 90',
+ 'MEG 93', 'MEG 95', 'MEG 96', 'MEG 105', 'MEG 112',
+ 'MEG 120', 'MEG 121', 'MEG 122', 'MEG 123', 'MEG 124',
+ 'MEG 125', 'MEG 126', 'MEG 142', 'MEG 144', 'MEG 153',
+ 'MEG 154', 'MEG 155', 'MEG 156')
+ x = [ch['loc'][0] < 0 for ch in info['chs'] if ch['ch_name'] in test_chs]
+ if np.all(x):
+ return 'KIT-157' # KIT-NY
+ elif np.all(np.invert(x)):
+ raise NotImplementedError("Guessing sensor layout for legacy UMD "
+ "files is not implemented. Please convert "
+ "your files using MNE-Python 0.13 or "
+ "higher.")
+ else:
+ raise RuntimeError("KIT system could not be determined for data")
+
+
def _box_size(points, width=None, height=None, padding=0.0):
""" Given a series of points, calculate an appropriate box size.
@@ -705,7 +753,8 @@ def _topo_to_sphere(pos, eegs):
return np.column_stack([xs, ys, zs])
-def _pair_grad_sensors(info, layout=None, topomap_coords=True, exclude='bads'):
+def _pair_grad_sensors(info, layout=None, topomap_coords=True, exclude='bads',
+ raise_error=True):
"""Find the picks for pairing grad channels
Parameters
@@ -720,6 +769,9 @@ def _pair_grad_sensors(info, layout=None, topomap_coords=True, exclude='bads'):
exclude : list of str | str
List of channels to exclude. If empty do not exclude any (default).
If 'bads', exclude channels in info['bads']. Defaults to 'bads'.
+ raise_error : bool
+ Whether to raise an error when no pairs are found. If False, raises a
+ warning.
Returns
-------
@@ -741,7 +793,11 @@ def _pair_grad_sensors(info, layout=None, topomap_coords=True, exclude='bads'):
pairs[key].append(ch)
pairs = [p for p in pairs.values() if len(p) == 2]
if len(pairs) == 0:
- raise ValueError("No 'grad' channel pairs found.")
+ if raise_error:
+ raise ValueError("No 'grad' channel pairs found.")
+ else:
+ warn("No 'grad' channel pairs found.")
+ return list()
# find the picks corresponding to the grad channels
grad_chs = sum(pairs, [])
diff --git a/mne/channels/montage.py b/mne/channels/montage.py
index 849b714..8b86d4a 100644
--- a/mne/channels/montage.py
+++ b/mne/channels/montage.py
@@ -9,6 +9,7 @@
#
# License: Simplified BSD
+from collections import Iterable
import os
import os.path as op
@@ -131,9 +132,10 @@ def read_montage(kind, ch_names=None, path=None, unit='m', transform=False):
Parameters
----------
kind : str
- The name of the montage file (e.g. kind='easycap-M10' for
- 'easycap-M10.txt'). Files with extensions '.elc', '.txt', '.csd',
- '.elp', '.hpts', '.sfp' or '.loc' ('.locs' and '.eloc') are supported.
+ The name of the montage file without the file extension (e.g.
+ kind='easycap-M10' for 'easycap-M10.txt'). Files with extensions
+ '.elc', '.txt', '.csd', '.elp', '.hpts', '.sfp' or '.loc' ('.locs' and
+ '.eloc') are supported.
ch_names : list of str | None
If not all electrodes defined in the montage are present in the EEG
data, use this parameter to select subset of electrode positions to
@@ -187,15 +189,32 @@ def read_montage(kind, ch_names=None, path=None, unit='m', transform=False):
if ext == '.sfp':
# EGI geodesic
- dtype = np.dtype('S4, f8, f8, f8')
- data = np.loadtxt(fname, dtype=dtype)
- pos = np.c_[data['f1'], data['f2'], data['f3']]
- ch_names_ = data['f0'].astype(np.str)
+ with open(fname, 'r') as f:
+ lines = f.read().replace('\t', ' ').splitlines()
+
+ ch_names_, pos = [], []
+ for ii, line in enumerate(lines):
+ line = line.strip().split()
+ if len(line) > 0: # skip empty lines
+ if len(line) != 4: # name, x, y, z
+ raise ValueError("Malformed .sfp file in line " + str(ii))
+ this_name, x, y, z = line
+ ch_names_.append(this_name)
+ pos.append([float(cord) for cord in (x, y, z)])
+ pos = np.asarray(pos)
elif ext == '.elc':
# 10-5 system
ch_names_ = []
pos = []
with open(fname) as fid:
+ # Default units are meters
+ for line in fid:
+ if 'UnitPosition' in line:
+ units = line.split()[1]
+ scale_factor = dict(m=1., mm=1e-3)[units]
+ break
+ else:
+ raise RuntimeError('Could not detect units in file %s' % fname)
for line in fid:
if 'Positions\n' in line:
break
@@ -208,7 +227,7 @@ def read_montage(kind, ch_names=None, path=None, unit='m', transform=False):
if not line or not set(line) - set([' ']):
break
ch_names_.append(line.strip(' ').strip('\n'))
- pos = np.array(pos)
+ pos = np.array(pos) * scale_factor
elif ext == '.txt':
# easycap
try: # newer version
@@ -421,7 +440,7 @@ def _check_frame(d, frame_str):
def read_dig_montage(hsp=None, hpi=None, elp=None, point_names=None,
- unit='mm', fif=None, transform=True, dev_head_t=False):
+ unit='auto', fif=None, transform=True, dev_head_t=False):
"""Read subject-specific digitization montage from a file
Parameters
@@ -430,8 +449,8 @@ def read_dig_montage(hsp=None, hpi=None, elp=None, point_names=None,
If str, this corresponds to the filename of the headshape points.
This is typically used with the Polhemus FastSCAN system.
If numpy.array, this corresponds to an array of positions of the
- headshape points in 3d. These points are in the native
- digitizer space.
+ headshape points in 3d. These points are assumed to be in the native
+ digitizer space and will be rescaled according to the unit parameter.
hpi : None | str | array, shape (n_hpi, 3)
If str, this corresponds to the filename of Head Position Indicator
(HPI) points. If numpy.array, this corresponds to an array
@@ -441,14 +460,17 @@ def read_dig_montage(hsp=None, hpi=None, elp=None, point_names=None,
points. This is typically used with the Polhemus FastSCAN system.
Fiducials should be listed first: nasion, left periauricular point,
right periauricular point, then the points corresponding to the HPI.
- These points are in the native digitizer space.
- If numpy.array, this corresponds to an array of fids + HPI points.
+ If numpy.array, this corresponds to an array of digitizer points in
+ the same order. These points are assumed to be in the native digitizer
+ space and will be rescaled according to the unit parameter.
point_names : None | list
If list, this corresponds to a list of point names. This must be
specified if elp is defined.
- unit : 'm' | 'cm' | 'mm'
- Unit of the input file. If not 'm', coordinates will be rescaled
- to 'm'. Default is 'mm'. This is applied only for hsp and elp files.
+ unit : 'auto' | 'm' | 'cm' | 'mm'
+ Unit of the digitizer files (hsp and elp). If not 'm', coordinates will
+ be rescaled to 'm'. Default is 'auto', which assumes 'm' for \*.hsp and
+ \*.elp files and 'mm' for \*.txt files, corresponding to the known
+ Polhemus export formats.
fif : str | None
FIF file from which to read digitization locations.
If str (filename), all other arguments are ignored.
@@ -483,11 +505,6 @@ def read_dig_montage(hsp=None, hpi=None, elp=None, point_names=None,
.. versionadded:: 0.9.0
"""
- if not isinstance(unit, string_types) or unit not in('m', 'mm', 'cm'):
- raise ValueError('unit must be "m", "mm", or "cm"')
- scale = dict(m=1., mm=1e-3, cm=1e-2)[unit]
- dig_ch_pos = None
- fids = None
if fif is not None:
# Use a different code path
if dev_head_t or not transform:
@@ -525,63 +542,84 @@ def read_dig_montage(hsp=None, hpi=None, elp=None, point_names=None,
dig_ch_pos['EEG%03d' % d['ident']] = d['r']
fids = np.array([fids[key] for key in ('nasion', 'lpa', 'rpa')])
hsp = np.array(hsp)
- hsp /= scale # will be multiplied later
elp = np.array(elp)
- elp /= scale # will be multiplied later
- transform = False
- if isinstance(hsp, string_types):
- hsp = _read_dig_points(hsp)
- if hsp is not None:
- hsp = hsp * scale
- if isinstance(hpi, string_types):
- ext = op.splitext(hpi)[-1]
- if ext == '.txt':
- hpi = _read_dig_points(hpi)
- elif ext in ('.sqd', '.mrk'):
- from ..io.kit import read_mrk
- hpi = read_mrk(hpi)
+ else:
+ dig_ch_pos = None
+ scale = {'mm': 1e-3, 'cm': 1e-2, 'auto': 1e-3, 'm': None}
+ if unit not in scale:
+ raise ValueError("Unit needs to be one of %s, not %r" %
+ (tuple(map(repr, scale)), unit))
+
+ # HSP
+ if isinstance(hsp, string_types):
+ hsp = _read_dig_points(hsp, unit=unit)
+ elif hsp is not None and scale[unit]:
+ hsp *= scale[unit]
+
+ # HPI
+ if isinstance(hpi, string_types):
+ ext = op.splitext(hpi)[-1]
+ if ext == '.txt':
+ hpi = _read_dig_points(hpi, unit='m')
+ elif ext in ('.sqd', '.mrk'):
+ from ..io.kit import read_mrk
+ hpi = read_mrk(hpi)
+ else:
+ raise ValueError('HPI file with extension *%s is not '
+ 'supported. Only *.txt, *.sqd and *.mrk are '
+ 'supported.' % ext)
+
+ # ELP
+ if isinstance(elp, string_types):
+ elp = _read_dig_points(elp, unit=unit)
+ elif elp is not None and scale[unit]:
+ elp *= scale[unit]
+
+ if elp is not None:
+ if not isinstance(point_names, Iterable):
+ raise TypeError("If elp is specified, point_names must "
+ "provide a list of str with one entry per ELP "
+ "point")
+ point_names = list(point_names)
+ if len(point_names) != len(elp):
+ raise ValueError("The elp file contains %i points, but %i "
+ "names were specified." %
+ (len(elp), len(point_names)))
+
+ # Transform digitizer coordinates to neuromag space
+ if transform:
+ if elp is None:
+ raise ValueError("ELP points are not specified. Points are "
+ "needed for transformation.")
+ names_lower = [name.lower() for name in point_names]
+
+ # check that all needed points are present
+ missing = [name for name in ('nasion', 'lpa', 'rpa')
+ if name not in names_lower]
+ if missing:
+ raise ValueError("The points %s are missing, but are needed "
+ "to transform the points to the MNE "
+ "coordinate system. Either add the points, "
+ "or read the montage with transform=False."
+ % str(missing))
+
+ nasion = elp[names_lower.index('nasion')]
+ lpa = elp[names_lower.index('lpa')]
+ rpa = elp[names_lower.index('rpa')]
+
+ # remove fiducials from elp
+ mask = np.ones(len(names_lower), dtype=bool)
+ for fid in ['nasion', 'lpa', 'rpa']:
+ mask[names_lower.index(fid)] = False
+ elp = elp[mask]
+
+ neuromag_trans = get_ras_to_neuromag_trans(nasion, lpa, rpa)
+ fids = apply_trans(neuromag_trans, [nasion, lpa, rpa])
+ elp = apply_trans(neuromag_trans, elp)
+ hsp = apply_trans(neuromag_trans, hsp)
else:
- raise TypeError('HPI file is not supported.')
- if isinstance(elp, string_types):
- elp = _read_dig_points(elp)
- if elp is not None:
- if len(elp) != len(point_names):
- raise ValueError("The elp file contains %i points, but %i names "
- "were specified." % (len(elp), len(point_names)))
- elp = elp * scale
- if transform:
- if elp is None:
- raise ValueError("ELP points are not specified. Points are needed "
- "for transformation.")
- names_lower = [name.lower() for name in point_names]
-
- # check that all needed points are present
- missing = tuple(name for name in ('nasion', 'lpa', 'rpa')
- if name not in names_lower)
- if missing:
- raise ValueError("The points %s are missing, but are needed "
- "to transform the points to the MNE coordinate "
- "system. Either add the points, or read the "
- "montage with transform=False." % str(missing))
-
- nasion = elp[names_lower.index('nasion')]
- lpa = elp[names_lower.index('lpa')]
- rpa = elp[names_lower.index('rpa')]
-
- # remove fiducials from elp
- mask = np.ones(len(names_lower), dtype=bool)
- for fid in ['nasion', 'lpa', 'rpa']:
- mask[names_lower.index(fid)] = False
- elp = elp[mask]
-
- neuromag_trans = get_ras_to_neuromag_trans(nasion, lpa, rpa)
+ fids = [None] * 3
- fids = np.array([nasion, lpa, rpa])
- fids = apply_trans(neuromag_trans, fids)
- elp = apply_trans(neuromag_trans, elp)
- hsp = apply_trans(neuromag_trans, hsp)
- elif fids is None:
- fids = [None] * 3
if dev_head_t:
from ..coreg import fit_matched_points
trans = fit_matched_points(tgt_pts=elp, src_pts=hpi, out='trans')
diff --git a/mne/channels/tests/test_channels.py b/mne/channels/tests/test_channels.py
index d7a28f9..0cda23d 100644
--- a/mne/channels/tests/test_channels.py
+++ b/mne/channels/tests/test_channels.py
@@ -6,17 +6,18 @@
import os.path as op
from copy import deepcopy
+from functools import partial
import warnings
import numpy as np
+from scipy.io import savemat
from numpy.testing import assert_array_equal
from nose.tools import assert_raises, assert_true, assert_equal
from mne.channels import rename_channels, read_ch_connectivity
from mne.channels.channels import _ch_neighbor_connectivity
-from mne.io import read_info, Raw
+from mne.io import read_info, read_raw_fif
from mne.io.constants import FIFF
-from mne.fixes import partial, savemat
from mne.utils import _TempDir, run_tests_if_main
from mne import pick_types, pick_channels
@@ -27,8 +28,7 @@ warnings.simplefilter('always')
def test_rename_channels():
- """Test rename channels
- """
+ """Test rename channels"""
info = read_info(raw_fname)
# Error Tests
# Test channel name exists in ch_names
@@ -64,9 +64,8 @@ def test_rename_channels():
def test_set_channel_types():
- """Test set_channel_types
- """
- raw = Raw(raw_fname)
+ """Test set_channel_types"""
+ raw = read_raw_fif(raw_fname, add_eeg_ref=False)
# Error Tests
# Test channel name exists in ch_names
mapping = {'EEG 160': 'EEG060'}
@@ -76,12 +75,14 @@ def test_set_channel_types():
assert_raises(ValueError, raw.set_channel_types, mapping)
# Test changing type if in proj (avg eeg ref here)
mapping = {'EEG 058': 'ecog', 'EEG 059': 'ecg', 'EEG 060': 'eog',
- 'EOG 061': 'seeg', 'MEG 2441': 'eeg', 'MEG 2443': 'eeg'}
+ 'EOG 061': 'seeg', 'MEG 2441': 'eeg', 'MEG 2443': 'eeg',
+ 'MEG 2442': 'hbo'}
assert_raises(RuntimeError, raw.set_channel_types, mapping)
# Test type change
- raw2 = Raw(raw_fname, add_eeg_ref=False)
+ raw2 = read_raw_fif(raw_fname, add_eeg_ref=False)
raw2.info['bads'] = ['EEG 059', 'EEG 060', 'EOG 061']
- assert_raises(RuntimeError, raw2.set_channel_types, mapping) # has proj
+ with warnings.catch_warnings(record=True): # MEG channel change
+ assert_raises(RuntimeError, raw2.set_channel_types, mapping) # has prj
raw2.add_proj([], remove_existing=True)
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter('always')
@@ -109,6 +110,10 @@ def test_set_channel_types():
assert_true(info['chs'][idx]['kind'] == FIFF.FIFFV_EEG_CH)
assert_true(info['chs'][idx]['unit'] == FIFF.FIFF_UNIT_V)
assert_true(info['chs'][idx]['coil_type'] == FIFF.FIFFV_COIL_EEG)
+ idx = pick_channels(raw.ch_names, ['MEG 2442'])[0]
+ assert_true(info['chs'][idx]['kind'] == FIFF.FIFFV_FNIRS_CH)
+ assert_true(info['chs'][idx]['unit'] == FIFF.FIFF_UNIT_MOL)
+ assert_true(info['chs'][idx]['coil_type'] == FIFF.FIFFV_COIL_FNIRS_HBO)
# Test meaningful error when setting channel type with unknown unit
raw.info['chs'][0]['unit'] = 0.
@@ -117,7 +122,7 @@ def test_set_channel_types():
def test_read_ch_connectivity():
- "Test reading channel connectivity templates"
+ """Test reading channel connectivity templates"""
tempdir = _TempDir()
a = partial(np.array, dtype='<U7')
# no pep8
@@ -158,9 +163,8 @@ def test_read_ch_connectivity():
def test_get_set_sensor_positions():
- """Test get/set functions for sensor positions
- """
- raw1 = Raw(raw_fname)
+ """Test get/set functions for sensor positions"""
+ raw1 = read_raw_fif(raw_fname, add_eeg_ref=False)
picks = pick_types(raw1.info, meg=False, eeg=True)
pos = np.array([ch['loc'][:3] for ch in raw1.info['chs']])[picks]
raw_pos = raw1._get_channel_positions(picks=picks)
@@ -168,7 +172,7 @@ def test_get_set_sensor_positions():
ch_name = raw1.info['ch_names'][13]
assert_raises(ValueError, raw1._set_channel_positions, [1, 2], ['name'])
- raw2 = Raw(raw_fname)
+ raw2 = read_raw_fif(raw_fname, add_eeg_ref=False)
raw2.info['chs'][13]['loc'][:3] = np.array([1, 2, 3])
raw1._set_channel_positions([[1, 2, 3]], [ch_name])
assert_array_equal(raw1.info['chs'][13]['loc'],
diff --git a/mne/channels/tests/test_interpolation.py b/mne/channels/tests/test_interpolation.py
index 6bae870..a326a80 100644
--- a/mne/channels/tests/test_interpolation.py
+++ b/mne/channels/tests/test_interpolation.py
@@ -12,7 +12,6 @@ from mne.utils import run_tests_if_main, slow_test
base_dir = op.join(op.dirname(__file__), '..', '..', 'io', 'tests', 'data')
raw_fname = op.join(base_dir, 'test_raw.fif')
event_name = op.join(base_dir, 'test-eve.fif')
-evoked_nf_name = op.join(base_dir, 'test-nf-ave.fif')
event_id, tmin, tmax = 1, -0.2, 0.5
event_id_2 = 2
@@ -93,6 +92,15 @@ def test_interpolation():
inst.info['bads'] = [inst.ch_names[1]]
assert_raises(ValueError, inst.interpolate_bads)
+ # check that interpolation works when non M/EEG channels are present
+ # before MEG channels
+ with warnings.catch_warnings(record=True): # change of units
+ raw.rename_channels({'MEG 0113': 'TRIGGER'})
+ raw.set_channel_types({'TRIGGER': 'stim'})
+ raw.info['bads'] = [raw.info['ch_names'][1]]
+ raw.load_data()
+ raw.interpolate_bads()
+
# check that interpolation works for MEG
epochs_meg.info['bads'] = ['MEG 0141']
evoked = epochs_meg.average()
diff --git a/mne/channels/tests/test_layout.py b/mne/channels/tests/test_layout.py
index 816e66e..7789b8f 100644
--- a/mne/channels/tests/test_layout.py
+++ b/mne/channels/tests/test_layout.py
@@ -15,14 +15,14 @@ import matplotlib
import numpy as np
from numpy.testing import (assert_array_almost_equal, assert_array_equal,
assert_allclose)
-from nose.tools import assert_true, assert_raises
+from nose.tools import assert_equal, assert_true, assert_raises
from mne.channels import (make_eeg_layout, make_grid_layout, read_layout,
find_layout)
from mne.channels.layout import (_box_size, _auto_topomap_coords,
generate_2d_layout)
from mne.utils import run_tests_if_main
from mne import pick_types, pick_info
-from mne.io import Raw, read_raw_kit, _empty_info
+from mne.io import read_raw_kit, _empty_info, read_info
from mne.io.constants import FIFF
from mne.bem import fit_sphere_to_headshape
from mne.utils import _TempDir
@@ -30,20 +30,13 @@ matplotlib.use('Agg') # for testing don't use X server
warnings.simplefilter('always')
-fif_fname = op.join(op.dirname(__file__), '..', '..', 'io',
- 'tests', 'data', 'test_raw.fif')
-
-lout_path = op.join(op.dirname(__file__), '..', '..', 'io',
- 'tests', 'data')
-
-bti_dir = op.join(op.dirname(__file__), '..', '..', 'io', 'bti',
- 'tests', 'data')
-
-fname_ctf_raw = op.join(op.dirname(__file__), '..', '..', 'io', 'tests',
- 'data', 'test_ctf_comp_raw.fif')
-
-fname_kit_157 = op.join(op.dirname(__file__), '..', '..', 'io', 'kit',
- 'tests', 'data', 'test.sqd')
+io_dir = op.join(op.dirname(__file__), '..', '..', 'io')
+fif_fname = op.join(io_dir, 'tests', 'data', 'test_raw.fif')
+lout_path = op.join(io_dir, 'tests', 'data')
+bti_dir = op.join(io_dir, 'bti', 'tests', 'data')
+fname_ctf_raw = op.join(io_dir, 'tests', 'data', 'test_ctf_comp_raw.fif')
+fname_kit_157 = op.join(io_dir, 'kit', 'tests', 'data', 'test.sqd')
+fname_kit_umd = op.join(io_dir, 'kit', 'tests', 'data', 'test_umd-raw.sqd')
def _get_test_info():
@@ -92,7 +85,7 @@ def test_io_layout_lay():
def test_auto_topomap_coords():
"""Test mapping of coordinates in 3D space to 2D"""
- info = Raw(fif_fname).info.copy()
+ info = read_info(fif_fname)
picks = pick_types(info, meg=False, eeg=True, eog=False, stim=False)
# Remove extra digitization point, so EEG digitization points match up
@@ -152,7 +145,7 @@ def test_make_eeg_layout():
tmp_name = 'foo'
lout_name = 'test_raw'
lout_orig = read_layout(kind=lout_name, path=lout_path)
- info = Raw(fif_fname).info
+ info = read_info(fif_fname)
info['bads'].append(info['ch_names'][360])
layout = make_eeg_layout(info, exclude=[])
assert_array_equal(len(layout.names), len([ch for ch in info['ch_names']
@@ -200,7 +193,7 @@ def test_find_layout():
import matplotlib.pyplot as plt
assert_raises(ValueError, find_layout, _get_test_info(), ch_type='meep')
- sample_info = Raw(fif_fname).info
+ sample_info = read_info(fif_fname)
grads = pick_types(sample_info, meg='grad')
sample_info2 = pick_info(sample_info, grads)
@@ -217,52 +210,62 @@ def test_find_layout():
sample_info5 = pick_info(sample_info, eegs)
lout = find_layout(sample_info, ch_type=None)
- assert_true(lout.kind == 'Vectorview-all')
+ assert_equal(lout.kind, 'Vectorview-all')
assert_true(all(' ' in k for k in lout.names))
lout = find_layout(sample_info2, ch_type='meg')
- assert_true(lout.kind == 'Vectorview-all')
+ assert_equal(lout.kind, 'Vectorview-all')
# test new vector-view
lout = find_layout(sample_info4, ch_type=None)
- assert_true(lout.kind == 'Vectorview-all')
+ assert_equal(lout.kind, 'Vectorview-all')
assert_true(all(' ' not in k for k in lout.names))
lout = find_layout(sample_info, ch_type='grad')
- assert_true(lout.kind == 'Vectorview-grad')
+ assert_equal(lout.kind, 'Vectorview-grad')
lout = find_layout(sample_info2)
- assert_true(lout.kind == 'Vectorview-grad')
+ assert_equal(lout.kind, 'Vectorview-grad')
lout = find_layout(sample_info2, ch_type='grad')
- assert_true(lout.kind == 'Vectorview-grad')
+ assert_equal(lout.kind, 'Vectorview-grad')
lout = find_layout(sample_info2, ch_type='meg')
- assert_true(lout.kind == 'Vectorview-all')
+ assert_equal(lout.kind, 'Vectorview-all')
lout = find_layout(sample_info, ch_type='mag')
- assert_true(lout.kind == 'Vectorview-mag')
+ assert_equal(lout.kind, 'Vectorview-mag')
lout = find_layout(sample_info3)
- assert_true(lout.kind == 'Vectorview-mag')
+ assert_equal(lout.kind, 'Vectorview-mag')
lout = find_layout(sample_info3, ch_type='mag')
- assert_true(lout.kind == 'Vectorview-mag')
+ assert_equal(lout.kind, 'Vectorview-mag')
lout = find_layout(sample_info3, ch_type='meg')
- assert_true(lout.kind == 'Vectorview-all')
+ assert_equal(lout.kind, 'Vectorview-all')
lout = find_layout(sample_info, ch_type='eeg')
- assert_true(lout.kind == 'EEG')
+ assert_equal(lout.kind, 'EEG')
lout = find_layout(sample_info5)
- assert_true(lout.kind == 'EEG')
+ assert_equal(lout.kind, 'EEG')
lout = find_layout(sample_info5, ch_type='eeg')
- assert_true(lout.kind == 'EEG')
+ assert_equal(lout.kind, 'EEG')
# no common layout, 'meg' option not supported
+ lout = find_layout(read_info(fname_ctf_raw))
+ assert_equal(lout.kind, 'CTF-275')
+
fname_bti_raw = op.join(bti_dir, 'exported4D_linux_raw.fif')
- lout = find_layout(Raw(fname_bti_raw).info)
- assert_true(lout.kind == 'magnesWH3600')
+ lout = find_layout(read_info(fname_bti_raw))
+ assert_equal(lout.kind, 'magnesWH3600')
+
+ raw_kit = read_raw_kit(fname_kit_157)
+ lout = find_layout(raw_kit.info)
+ assert_equal(lout.kind, 'KIT-157')
+
+ raw_kit.info['bads'] = ['MEG 13', 'MEG 14', 'MEG 15', 'MEG 16']
+ lout = find_layout(raw_kit.info)
+ assert_equal(lout.kind, 'KIT-157')
- lout = find_layout(Raw(fname_ctf_raw).info)
- assert_true(lout.kind == 'CTF-275')
+ raw_umd = read_raw_kit(fname_kit_umd)
+ lout = find_layout(raw_umd.info)
+ assert_equal(lout.kind, 'KIT-UMD-3')
- lout = find_layout(read_raw_kit(fname_kit_157).info)
- assert_true(lout.kind == 'KIT-157')
# Test plotting
lout.plot()
plt.close('all')
diff --git a/mne/channels/tests/test_montage.py b/mne/channels/tests/test_montage.py
index 138f07c..19ebd2d 100644
--- a/mne/channels/tests/test_montage.py
+++ b/mne/channels/tests/test_montage.py
@@ -5,11 +5,12 @@
import os.path as op
import warnings
-from nose.tools import assert_equal, assert_true
+from nose.tools import assert_equal, assert_true, assert_raises
import numpy as np
from numpy.testing import (assert_array_equal, assert_almost_equal,
- assert_allclose, assert_array_almost_equal)
+ assert_allclose, assert_array_almost_equal,
+ assert_array_less)
from mne.tests.common import assert_dig_allclose
from mne.channels.montage import read_montage, _set_montage, read_dig_montage
from mne.utils import _TempDir, run_tests_if_main
@@ -39,49 +40,39 @@ def test_montage():
"""Test making montages"""
tempdir = _TempDir()
# no pep8
- input_str = ["""FidNz 0.00000 10.56381 -2.05108
- FidT9 -7.82694 0.45386 -3.76056
- FidT10 7.82694 0.45386 -3.76056""",
- """// MatLab Sphere coordinates [degrees] Cartesian coordinates
- // Label Theta Phi Radius X Y Z off sphere surface
- E1 37.700 -14.000 1.000 0.7677 0.5934 -0.2419 -0.00000000000000011
- E2 44.600 -0.880 1.000 0.7119 0.7021 -0.0154 0.00000000000000000
- E3 51.700 11.000 1.000 0.6084 0.7704 0.1908 0.00000000000000000""", # noqa
- """# ASA electrode file
- ReferenceLabel avg
- UnitPosition mm
- NumberPositions= 68
- Positions
- -86.0761 -19.9897 -47.9860
- 85.7939 -20.0093 -48.0310
- 0.0083 86.8110 -39.9830
- Labels
- LPA
- RPA
- Nz
- """,
- """Site Theta Phi
- Fp1 -92 -72
- Fp2 92 72
- F3 -60 -51
- """,
- """346
- EEG F3 -62.027 -50.053 85
- EEG Fz 45.608 90 85
- EEG F4 62.01 50.103 85
- """,
- """
- eeg Fp1 -95.0 -31.0 -3.0
- eeg AF7 -81 -59 -3
- eeg AF3 -87 -41 28
- """]
- kinds = ['test.sfp', 'test.csd', 'test.elc', 'test.txt', 'test.elp',
- 'test.hpts']
+ input_str = [
+ 'FidNz 0.00000 10.56381 -2.05108\nFidT9 -7.82694 0.45386 -3.76056\n'
+ 'very_very_very_long_name 7.82694 0.45386 -3.76056',
+ '// MatLab Sphere coordinates [degrees] Cartesian coordinates\n' # noqa
+ '// Label Theta Phi Radius X Y Z off sphere surface\n' # noqa
+ 'E1 37.700 -14.000 1.000 0.7677 0.5934 -0.2419 -0.00000000000000011\n' # noqa
+ 'E2 44.600 -0.880 1.000 0.7119 0.7021 -0.0154 0.00000000000000000\n' # noqa
+ 'E3 51.700 11.000 1.000 0.6084 0.7704 0.1908 0.00000000000000000', # noqa
+ '# ASA electrode file\nReferenceLabel avg\nUnitPosition mm\n'
+ 'NumberPositions= 68\nPositions\n-86.0761 -19.9897 -47.9860\n'
+ '85.7939 -20.0093 -48.0310\n0.0083 86.8110 -39.9830\n'
+ 'Labels\nLPA\nRPA\nNz\n',
+ '# ASA electrode file\nReferenceLabel avg\nUnitPosition m\n'
+ 'NumberPositions= 68\nPositions\n-.0860761 -.0199897 -.0479860\n'
+ '.0857939 -.0200093 -.0480310\n.0000083 .00868110 -.0399830\n'
+ 'Labels\nLPA\nRPA\nNz\n',
+ 'Site Theta Phi\nFp1 -92 -72\nFp2 92 72\n'
+ 'very_very_very_long_name -60 -51\n',
+ '346\n'
+ 'EEG\t F3\t -62.027\t -50.053\t 85\n'
+ 'EEG\t Fz\t 45.608\t 90\t 85\n'
+ 'EEG\t F4\t 62.01\t 50.103\t 85\n',
+ 'eeg Fp1 -95.0 -31.0 -3.0\neeg AF7 -81 -59 -3\neeg AF3 -87 -41 28\n'
+ ]
+ kinds = ['test.sfp', 'test.csd', 'test_mm.elc', 'test_m.elc', 'test.txt',
+ 'test.elp', 'test.hpts']
for kind, text in zip(kinds, input_str):
fname = op.join(tempdir, kind)
with open(fname, 'w') as fid:
fid.write(text)
montage = read_montage(fname)
+ if ".sfp" in kind or ".txt" in kind:
+ assert_true('very_very_very_long_name' in montage.ch_names)
assert_equal(len(montage.ch_names), 3)
assert_equal(len(montage.ch_names), len(montage.pos))
assert_equal(montage.pos.shape, (3, 3))
@@ -96,6 +87,15 @@ def test_montage():
table = np.loadtxt(fname, skiprows=2, dtype=dtype)
pos2 = np.c_[table['x'], table['y'], table['z']]
assert_array_almost_equal(pos2, montage.pos, 4)
+ if kind.endswith('elc'):
+ # Make sure points are reasonable distance from geometric centroid
+ centroid = np.sum(montage.pos, axis=0) / montage.pos.shape[0]
+ distance_from_centroid = np.apply_along_axis(
+ np.linalg.norm, 1,
+ montage.pos - centroid)
+ assert_array_less(distance_from_centroid, 0.2)
+ assert_array_less(0.01, distance_from_centroid)
+
# test transform
input_str = """
eeg Fp1 -95.0 -31.0 -3.0
@@ -158,7 +158,7 @@ def test_montage():
def test_read_dig_montage():
"""Test read_dig_montage"""
names = ['nasion', 'lpa', 'rpa', '1', '2', '3', '4', '5']
- montage = read_dig_montage(hsp, hpi, elp, names, unit='m', transform=False)
+ montage = read_dig_montage(hsp, hpi, elp, names, transform=False)
elp_points = _read_dig_points(elp)
hsp_points = _read_dig_points(hsp)
hpi_points = read_mrk(hpi)
@@ -181,6 +181,21 @@ def test_read_dig_montage():
src_pts=montage.hpi, out='trans')
assert_array_equal(montage.dev_head_t, dev_head_t)
+ # Digitizer as array
+ m2 = read_dig_montage(hsp_points, hpi_points, elp_points, names, unit='m')
+ assert_array_equal(m2.hsp, montage.hsp)
+ m3 = read_dig_montage(hsp_points * 1000, hpi_points, elp_points * 1000,
+ names)
+ assert_allclose(m3.hsp, montage.hsp)
+
+ # test unit parameter
+ montage_cm = read_dig_montage(hsp, hpi, elp, names, unit='cm')
+ assert_allclose(montage_cm.hsp, montage.hsp * 10.)
+ assert_allclose(montage_cm.elp, montage.elp * 10.)
+ assert_array_equal(montage_cm.hpi, montage.hpi)
+ assert_raises(ValueError, read_dig_montage, hsp, hpi, elp, names,
+ unit='km')
+
def test_set_dig_montage():
"""Test applying DigMontage to inst
@@ -197,7 +212,7 @@ def test_set_dig_montage():
nasion_point, lpa_point, rpa_point = elp_points[:3]
hsp_points = apply_trans(nm_trans, hsp_points)
- montage = read_dig_montage(hsp, hpi, elp, names, unit='m', transform=True)
+ montage = read_dig_montage(hsp, hpi, elp, names, transform=True)
info = create_info(['Test Ch'], 1e3, ['eeg'])
_set_montage(info, montage)
hs = np.array([p['r'] for i, p in enumerate(info['dig'])
@@ -227,7 +242,9 @@ def test_fif_dig_montage():
dig_montage = read_dig_montage(fif=fif_dig_montage_fname)
# Make a BrainVision file like the one the user would have had
- raw_bv = read_raw_brainvision(bv_fname, preload=True)
+ with warnings.catch_warnings(record=True) as w:
+ raw_bv = read_raw_brainvision(bv_fname, preload=True)
+ assert_true(any('will be dropped' in str(ww.message) for ww in w))
raw_bv_2 = raw_bv.copy()
mapping = dict()
for ii, ch_name in enumerate(raw_bv.ch_names[:-1]):
diff --git a/mne/chpi.py b/mne/chpi.py
index c3586b5..bc477e1 100644
--- a/mne/chpi.py
+++ b/mne/chpi.py
@@ -2,105 +2,30 @@
#
# License: BSD (3-clause)
+from functools import partial
+
import numpy as np
from scipy import linalg, fftpack
from .io.pick import pick_types, pick_channels
-from .io.base import _BaseRaw
from .io.constants import FIFF
from .forward import (_magnetic_dipole_field_vec, _create_meg_coils,
_concatenate_coils, _read_coil_defs)
from .cov import make_ad_hoc_cov, _get_whitener_data
from .transforms import (apply_trans, invert_transform, _angle_between_quats,
quat_to_rot, rot_to_quat)
-from .utils import (verbose, logger, check_version, use_log_level, deprecated,
+from .utils import (verbose, logger, check_version, use_log_level,
_check_fname, warn)
-from .fixes import partial
-from .externals.six import string_types
# Eventually we should add:
# hpicons
# high-passing of data during fits
+# parsing cHPI coil information from acq pars, then to PSD if necessary
# ############################################################################
# Reading from text or FIF file
- at deprecated('get_chpi_positions will be removed in v0.13, use '
- 'read_head_pos(fname) or raw[pick_types(meg=False, chpi=True), :] '
- 'instead')
- at verbose
-def get_chpi_positions(raw, t_step=None, return_quat=False, verbose=None):
- """Extract head positions
-
- Note that the raw instance must have CHPI channels recorded.
-
- Parameters
- ----------
- raw : instance of Raw | str
- Raw instance to extract the head positions from. Can also be a
- path to a Maxfilter head position estimation log file (str).
- t_step : float | None
- Sampling interval to use when converting data. If None, it will
- be automatically determined. By default, a sampling interval of
- 1 second is used if processing a raw data. If processing a
- Maxfilter log file, this must be None because the log file
- itself will determine the sampling interval.
- return_quat : bool
- If True, also return the quaternions.
-
- .. versionadded:: 0.11
-
- verbose : bool, str, int, or None
- If not None, override default verbose level (see mne.verbose).
-
- Returns
- -------
- translation : ndarray, shape (N, 3)
- Translations at each time point.
- rotation : ndarray, shape (N, 3, 3)
- Rotations at each time point.
- t : ndarray, shape (N,)
- The time points.
- quat : ndarray, shape (N, 3)
- The quaternions. Only returned if ``return_quat`` is True.
-
- Notes
- -----
- The digitized HPI head frame y is related to the frame position X as:
-
- Y = np.dot(rotation, X) + translation
-
- Note that if a Maxfilter log file is being processed, the start time
- may not use the same reference point as the rest of mne-python (i.e.,
- it could be referenced relative to raw.first_samp or something else).
- """
- if isinstance(raw, _BaseRaw):
- # for simplicity, we'll sample at 1 sec intervals like maxfilter
- if t_step is None:
- t_step = 1.0
- t_step = float(t_step)
- picks = pick_types(raw.info, meg=False, ref_meg=False,
- chpi=True, exclude=[])
- if len(picks) == 0:
- raise RuntimeError('raw file has no CHPI channels')
- time_idx = raw.time_as_index(np.arange(0, raw.times[-1], t_step))
- data = [raw[picks, ti] for ti in time_idx]
- t = np.array([d[1] for d in data])
- data = np.array([d[0][:, 0] for d in data])
- data = np.c_[t, data]
- else:
- if not isinstance(raw, string_types):
- raise TypeError('raw must be an instance of Raw or string')
- if t_step is not None:
- raise ValueError('t_step must be None if processing a log')
- data = read_head_pos(raw)
- out = head_pos_to_trans_rot_t(data)
- if return_quat:
- out = out + (data[:, 1:4],)
- return out
-
-
def read_head_pos(fname):
"""Read MaxFilter-formatted head position parameters
@@ -127,6 +52,9 @@ def read_head_pos(fname):
_check_fname(fname, must_exist=True, overwrite=True)
data = np.loadtxt(fname, skiprows=1) # first line is header, skip it
data.shape = (-1, 10) # ensure it's the right size even if empty
+ if np.isnan(data).any(): # make sure we didn't do something dumb
+ raise RuntimeError('positions could not be read properly from %s'
+ % fname)
return data
@@ -193,7 +121,8 @@ def head_pos_to_trans_rot_t(quats):
# ############################################################################
# Estimate positions from data
-def _get_hpi_info(info):
+ at verbose
+def _get_hpi_info(info, adjust=False, verbose=None):
"""Helper to get HPI information from raw"""
if len(info['hpi_meas']) == 0 or \
('coil_freq' not in info['hpi_meas'][0]['hpi_coils'][0]):
@@ -206,7 +135,6 @@ def _get_hpi_info(info):
if d['kind'] == FIFF.FIFFV_POINT_HPI],
key=lambda x: x['ident']) # ascending (dig) order
pos_order = hpi_result['order'] - 1 # zero-based indexing, dig->info
- # hpi_result['dig_points'] are in FIFFV_COORD_UNKNOWN coords...?
# this shouldn't happen, eventually we could add the transforms
# necessary to put it in head coords
@@ -219,10 +147,47 @@ def _get_hpi_info(info):
% (len(hpi_result['used']),
' '.join(str(h) for h in hpi_result['used'])))
hpi_rrs = np.array([d['r'] for d in hpi_dig])[pos_order]
- # errors = 1000 * np.sqrt((hpi_rrs - hpi_rrs_fit) ** 2).sum(axis=1)
- # logger.debug('HPIFIT errors: %s'
- # % ', '.join('%0.1f' % e for e in errors))
- hpi_freqs = np.array([float(x['coil_freq']) for x in hpi_coils])
+
+ # Fitting errors
+ hpi_rrs_fit = sorted([d for d in info['hpi_results'][-1]['dig_points']],
+ key=lambda x: x['ident'])
+ hpi_rrs_fit = np.array([d['r'] for d in hpi_rrs_fit])
+ # hpi_result['dig_points'] are in FIFFV_COORD_UNKNOWN coords, but this
+ # is probably a misnomer because it should be FIFFV_COORD_DEVICE for this
+ # to work
+ assert hpi_result['coord_trans']['to'] == FIFF.FIFFV_COORD_HEAD
+ hpi_rrs_fit = apply_trans(hpi_result['coord_trans']['trans'], hpi_rrs_fit)
+ if 'moments' in hpi_result:
+ logger.debug('Hpi coil moments (%d %d):'
+ % hpi_result['moments'].shape[::-1])
+ for moment in hpi_result['moments']:
+ logger.debug("%g %g %g" % tuple(moment))
+ errors = np.sqrt(((hpi_rrs - hpi_rrs_fit) ** 2).sum(axis=1))
+ logger.debug('HPIFIT errors: %s mm.'
+ % ', '.join('%0.1f' % (1000. * e) for e in errors))
+ if errors.sum() < len(errors) * hpi_result['dist_limit']:
+ logger.info('HPI consistency of isotrak and hpifit is OK.')
+ elif not adjust and (len(hpi_result['used']) == len(hpi_coils)):
+ warn('HPI consistency of isotrak and hpifit is poor.')
+ else:
+ # adjust HPI coil locations using the hpifit transformation
+ for hi, (r_dig, r_fit) in enumerate(zip(hpi_rrs, hpi_rrs_fit)):
+ # transform to head frame
+ d = 1000 * np.sqrt(((r_dig - r_fit) ** 2).sum())
+ if not adjust:
+ warn('Discrepancy of HPI coil %d isotrak and hpifit is %.1f '
+ 'mm!' % (hi + 1, d))
+ elif hi + 1 not in hpi_result['used']:
+ if hpi_result['goodness'][hi] >= hpi_result['good_limit']:
+ logger.info('Note: HPI coil %d isotrak is adjusted by '
+ '%.1f mm!' % (hi + 1, d))
+ hpi_rrs[hi] = r_fit
+ else:
+ warn('Discrepancy of HPI coil %d isotrak and hpifit of '
+ '%.1f mm was not adjusted!' % (hi + 1, d))
+ logger.debug('HP fitting limits: err = %.1f mm, gval = %.3f.'
+ % (1000 * hpi_result['dist_limit'], hpi_result['good_limit']))
+
# how cHPI active is indicated in the FIF file
hpi_sub = info['hpi_subsystem']
if 'event_channel' in hpi_sub:
@@ -234,6 +199,9 @@ def _get_hpi_info(info):
# not all HPI coils will actually be used
hpi_on = np.array([hpi_on[hc['number'] - 1] for hc in hpi_coils])
assert len(hpi_coils) == len(hpi_on)
+
+ # get frequencies
+ hpi_freqs = np.array([float(x['coil_freq']) for x in hpi_coils])
logger.info('Using %s HPI coils: %s Hz'
% (len(hpi_freqs), ' '.join(str(int(s)) for s in hpi_freqs)))
return hpi_freqs, hpi_rrs, hpi_pick, hpi_on, pos_order
@@ -544,10 +512,9 @@ def _calculate_chpi_positions(raw, t_step_min=0.1, t_step_max=10.,
% ((use_mask.sum(), n_freqs) + vs))
# resulting errors in head coil positions
est_coil_head_rrs = apply_trans(this_dev_head_t, this_coil_dev_rrs)
- errs = 1000. * np.sqrt(np.sum((hpi['coil_head_rrs'] -
- est_coil_head_rrs) ** 2,
- axis=1))
- e = 0. # XXX eventually calculate this -- cumulative error of fit?
+ errs = 1000. * np.sqrt(((hpi['coil_head_rrs'] -
+ est_coil_head_rrs) ** 2).sum(axis=-1))
+ e = errs.mean() / 1000. # mm -> m
d = 100 * np.sqrt(np.sum(last['quat'][3:] - this_quat[3:]) ** 2) # cm
r = _angle_between_quats(last['quat'][:3], this_quat[:3]) / dt
v = d / dt # cm/sec
@@ -625,13 +592,13 @@ def filter_chpi(raw, include_line=True, verbose=None):
t_window = 0.2
t_step = 0.01
n_step = int(np.ceil(t_step * raw.info['sfreq']))
- hpi = _setup_chpi_fits(raw.info, t_window, t_window, exclude=(),
+ hpi = _setup_chpi_fits(raw.info, t_window, t_window, exclude='bads',
add_hpi_stim_pick=False, remove_aliased=True,
verbose=False)[0]
fit_idxs = np.arange(0, len(raw.times) + hpi['n_window'] // 2, n_step)
n_freqs = len(hpi['freqs'])
n_remove = 2 * n_freqs
- meg_picks = hpi['picks']
+ meg_picks = pick_types(raw.info, meg=True, exclude=()) # filter all chs
n_times = len(raw.times)
msg = 'Removing %s cHPI' % n_freqs
diff --git a/mne/commands/mne_browse_raw.py b/mne/commands/mne_browse_raw.py
index ecf59df..9ff3b2b 100755
--- a/mne/commands/mne_browse_raw.py
+++ b/mne/commands/mne_browse_raw.py
@@ -62,6 +62,8 @@ def run():
parser.add_option("--clipping", dest="clipping",
help="Enable trace clipping mode, either 'clip' or "
"'transparent'", default=None)
+ parser.add_option("--filterchpi", dest="filterchpi",
+ help="Enable filtering cHPI signals.", default=None)
options, args = parser.parse_args()
@@ -79,6 +81,7 @@ def run():
lowpass = options.lowpass
filtorder = options.filtorder
clipping = options.clipping
+ filterchpi = options.filterchpi
if raw_in is None:
parser.print_help()
@@ -93,6 +96,13 @@ def run():
events = mne.read_events(eve_in)
else:
events = None
+
+ if filterchpi:
+ if not preload:
+ raise RuntimeError(
+ 'Raw data must be preloaded for chpi, use --preload')
+ raw = mne.chpi.filter_chpi(raw)
+
highpass = None if highpass < 0 or filtorder <= 0 else highpass
lowpass = None if lowpass < 0 or filtorder <= 0 else lowpass
filtorder = 4 if filtorder <= 0 else filtorder
diff --git a/mne/commands/mne_coreg.py b/mne/commands/mne_coreg.py
index 1ead4e9..fb37d16 100644
--- a/mne/commands/mne_coreg.py
+++ b/mne/commands/mne_coreg.py
@@ -7,20 +7,34 @@ example usage: $ mne coreg
"""
-import os
import sys
import mne
+from mne.utils import ETSContext
def run():
from mne.commands.utils import get_optparser
parser = get_optparser(__file__)
+
+ parser.add_option("-d", "--subjects-dir", dest="subjects_dir",
+ default=None, help="Subjects directory")
+ parser.add_option("-s", "--subject", dest="subject", default=None,
+ help="Subject name")
+ parser.add_option("-f", "--fiff", dest="inst", default=None,
+ help="FIFF file with digitizer data for coregistration")
+ parser.add_option("-t", "--tabbed", dest="tabbed", action="store_true",
+ default=False, help="Option for small screens: Combine "
+ "the data source panel and the coregistration panel "
+ "into a single panel with tabs.")
+
options, args = parser.parse_args()
- os.environ['ETS_TOOLKIT'] = 'qt4'
- mne.gui.coregistration()
+ with ETSContext():
+ mne.gui.coregistration(options.tabbed, inst=options.inst,
+ subject=options.subject,
+ subjects_dir=options.subjects_dir)
if is_main:
sys.exit(0)
diff --git a/mne/commands/mne_flash_bem.py b/mne/commands/mne_flash_bem.py
index f46f0a2..cc52010 100644
--- a/mne/commands/mne_flash_bem.py
+++ b/mne/commands/mne_flash_bem.py
@@ -23,6 +23,10 @@ Before running this script do the following:
appropriate series:
$ ln -s <FLASH 5 series dir> flash05
$ ln -s <FLASH 30 series dir> flash30
+ Some partition formats (e.g. FAT32) do not support symbolic links.
+ In this case, copy the file to the appropriate series:
+ $ cp <FLASH 5 series dir> flash05
+ $ cp <FLASH 30 series dir> flash30
4. cd to the directory where flash05 and flash30 links are
5. Set SUBJECTS_DIR and SUBJECT environment variables appropriately
6. Run this script
@@ -83,7 +87,7 @@ def run():
convert_flash_mris(subject=subject, subjects_dir=subjects_dir,
flash30=flash30, convert=convert, unwarp=unwarp)
make_flash_bem(subject=subject, subjects_dir=subjects_dir,
- overwrite=overwrite, show=show)
+ overwrite=overwrite, show=show, flash_path='.')
is_main = (__name__ == '__main__')
if is_main:
diff --git a/mne/commands/mne_kit2fiff.py b/mne/commands/mne_kit2fiff.py
index 2fcf086..bc337d5 100755
--- a/mne/commands/mne_kit2fiff.py
+++ b/mne/commands/mne_kit2fiff.py
@@ -8,11 +8,11 @@ Use without arguments to invoke GUI: $ mne kt2fiff
"""
-import os
import sys
import mne
from mne.io import read_raw_kit
+from mne.utils import ETSContext
def run():
@@ -44,8 +44,8 @@ def run():
input_fname = options.input_fname
if input_fname is None:
- os.environ['ETS_TOOLKIT'] = 'qt4'
- mne.gui.kit2fiff()
+ with ETSContext():
+ mne.gui.kit2fiff()
sys.exit(0)
hsp_fname = options.hsp_fname
diff --git a/mne/commands/mne_make_scalp_surfaces.py b/mne/commands/mne_make_scalp_surfaces.py
index 9c9318e..08c60b1 100755
--- a/mne/commands/mne_make_scalp_surfaces.py
+++ b/mne/commands/mne_make_scalp_surfaces.py
@@ -17,7 +17,7 @@ import copy
import os.path as op
import sys
import mne
-from mne.utils import run_subprocess, _TempDir, verbose, logger
+from mne.utils import run_subprocess, _TempDir, verbose, logger, ETSContext
def _check_file(fname, overwrite):
@@ -44,7 +44,9 @@ def run():
help='Print the debug messages.')
parser.add_option("-d", "--subjects-dir", dest="subjects_dir",
help="Subjects directory", default=subjects_dir)
-
+ parser.add_option("-n", "--no-decimate", dest="no_decimate",
+ help="Disable medium and sparse decimations "
+ "(dense only)", action='store_true')
options, args = parser.parse_args()
subject = vars(options).get('subject', os.getenv('SUBJECT'))
@@ -52,12 +54,13 @@ def run():
if subject is None or subjects_dir is None:
parser.print_help()
sys.exit(1)
+ print(options.no_decimate)
_run(subjects_dir, subject, options.force, options.overwrite,
- options.verbose)
+ options.no_decimate, options.verbose)
@verbose
-def _run(subjects_dir, subject, force, overwrite, verbose=None):
+def _run(subjects_dir, subject, force, overwrite, no_decimate, verbose=None):
this_env = copy.copy(os.environ)
this_env['SUBJECTS_DIR'] = subjects_dir
this_env['SUBJECT'] = subject
@@ -115,15 +118,16 @@ def _run(subjects_dir, subject, force, overwrite, verbose=None):
'--fif', dense_fname], env=this_env)
levels = 'medium', 'sparse'
my_surf = mne.read_bem_surfaces(dense_fname)[0]
- tris = [30000, 2500]
+ tris = [] if no_decimate else [30000, 2500]
if os.getenv('_MNE_TESTING_SCALP', 'false') == 'true':
tris = [len(my_surf['tris'])] # don't actually decimate
for ii, (n_tri, level) in enumerate(zip(tris, levels), 3):
logger.info('%i. Creating %s tessellation...' % (ii, level))
logger.info('%i.1 Decimating the dense tessellation...' % ii)
- points, tris = mne.decimate_surface(points=my_surf['rr'],
- triangles=my_surf['tris'],
- n_triangles=n_tri)
+ with ETSContext():
+ points, tris = mne.decimate_surface(points=my_surf['rr'],
+ triangles=my_surf['tris'],
+ n_triangles=n_tri)
other_fname = dense_fname.replace('dense', level)
logger.info('%i.2 Creating %s' % (ii, other_fname))
_check_file(other_fname, overwrite)
diff --git a/mne/commands/mne_show_fiff.py b/mne/commands/mne_show_fiff.py
index cb4fb4c..3076fe7 100644
--- a/mne/commands/mne_show_fiff.py
+++ b/mne/commands/mne_show_fiff.py
@@ -8,6 +8,7 @@ $ mne show_fiff test_raw.fif
# Authors : Eric Larson, PhD
+import codecs
import sys
import mne
@@ -19,6 +20,11 @@ def run():
if len(args) != 1:
parser.print_help()
sys.exit(1)
+ # This works around an annoying bug on Windows for show_fiff, see:
+ # https://pythonhosted.org/kitchen/unicode-frustrations.html
+ if int(sys.version[0]) < 3:
+ UTF8Writer = codecs.getwriter('utf8')
+ sys.stdout = UTF8Writer(sys.stdout)
print(mne.io.show_fiff(args[0]))
diff --git a/mne/commands/tests/test_commands.py b/mne/commands/tests/test_commands.py
index 55de6ca..51520c1 100644
--- a/mne/commands/tests/test_commands.py
+++ b/mne/commands/tests/test_commands.py
@@ -6,6 +6,7 @@ import glob
import warnings
from nose.tools import assert_true, assert_raises
+from mne import concatenate_raws
from mne.commands import (mne_browse_raw, mne_bti2fiff, mne_clean_eog_ecg,
mne_compute_proj_ecg, mne_compute_proj_eog,
mne_coreg, mne_kit2fiff,
@@ -13,11 +14,11 @@ from mne.commands import (mne_browse_raw, mne_bti2fiff, mne_clean_eog_ecg,
mne_report, mne_surf2bem, mne_watershed_bem,
mne_compare_fiff, mne_flash_bem, mne_show_fiff,
mne_show_info)
+from mne.datasets import testing, sample
+from mne.io import read_raw_fif
from mne.utils import (run_tests_if_main, _TempDir, requires_mne, requires_PIL,
requires_mayavi, requires_tvtk, requires_freesurfer,
ArgvSetter, slow_test, ultra_slow_test)
-from mne.io import Raw
-from mne.datasets import testing, sample
base_dir = op.join(op.dirname(__file__), '..', '..', 'io', 'tests', 'data')
@@ -41,22 +42,22 @@ def check_usage(module, force_help=False):
@slow_test
def test_browse_raw():
- """Test mne browse_raw"""
+ """Test mne browse_raw."""
check_usage(mne_browse_raw)
def test_bti2fiff():
- """Test mne bti2fiff"""
+ """Test mne bti2fiff."""
check_usage(mne_bti2fiff)
def test_compare_fiff():
- """Test mne compare_fiff"""
+ """Test mne compare_fiff."""
check_usage(mne_compare_fiff)
def test_show_fiff():
- """Test mne compare_fiff"""
+ """Test mne compare_fiff."""
check_usage(mne_show_fiff)
with ArgvSetter((raw_fname,)):
mne_show_fiff.run()
@@ -64,10 +65,11 @@ def test_show_fiff():
@requires_mne
def test_clean_eog_ecg():
- """Test mne clean_eog_ecg"""
+ """Test mne clean_eog_ecg."""
check_usage(mne_clean_eog_ecg)
tempdir = _TempDir()
- raw = Raw([raw_fname, raw_fname, raw_fname])
+ raw = concatenate_raws([read_raw_fif(f, add_eeg_ref=False)
+ for f in [raw_fname, raw_fname, raw_fname]])
raw.info['bads'] = ['MEG 2443']
use_fname = op.join(tempdir, op.basename(raw_fname))
raw.save(use_fname)
@@ -81,7 +83,7 @@ def test_clean_eog_ecg():
@slow_test
def test_compute_proj_ecg_eog():
- """Test mne compute_proj_ecg/eog"""
+ """Test mne compute_proj_ecg/eog."""
for fun in (mne_compute_proj_ecg, mne_compute_proj_eog):
check_usage(fun)
tempdir = _TempDir()
@@ -100,12 +102,12 @@ def test_compute_proj_ecg_eog():
def test_coreg():
- """Test mne coreg"""
+ """Test mne coreg."""
assert_true(hasattr(mne_coreg, 'run'))
def test_kit2fiff():
- """Test mne kit2fiff"""
+ """Test mne kit2fiff."""
# Can't check
check_usage(mne_kit2fiff, force_help=True)
@@ -114,7 +116,7 @@ def test_kit2fiff():
@requires_mne
@testing.requires_testing_data
def test_make_scalp_surfaces():
- """Test mne make_scalp_surfaces"""
+ """Test mne make_scalp_surfaces."""
check_usage(mne_make_scalp_surfaces)
# Copy necessary files to avoid FreeSurfer call
tempdir = _TempDir()
@@ -122,7 +124,8 @@ def test_make_scalp_surfaces():
surf_path_new = op.join(tempdir, 'sample', 'surf')
os.mkdir(op.join(tempdir, 'sample'))
os.mkdir(surf_path_new)
- os.mkdir(op.join(tempdir, 'sample', 'bem'))
+ subj_dir = op.join(tempdir, 'sample', 'bem')
+ os.mkdir(subj_dir)
shutil.copy(op.join(surf_path, 'lh.seghead'), surf_path_new)
orig_fs = os.getenv('FREESURFER_HOME', None)
@@ -139,6 +142,8 @@ def test_make_scalp_surfaces():
assert_raises(RuntimeError, mne_make_scalp_surfaces.run)
os.environ['MNE_ROOT'] = orig_mne
mne_make_scalp_surfaces.run()
+ assert_true(op.isfile(op.join(subj_dir, 'sample-head-dense.fif')))
+ assert_true(op.isfile(op.join(subj_dir, 'sample-head-medium.fif')))
assert_raises(IOError, mne_make_scalp_surfaces.run) # no overwrite
finally:
if orig_fs is not None:
@@ -148,7 +153,7 @@ def test_make_scalp_surfaces():
def test_maxfilter():
- """Test mne maxfilter"""
+ """Test mne maxfilter."""
check_usage(mne_maxfilter)
with ArgvSetter(('-i', raw_fname, '--st', '--movecomp', '--linefreq', '60',
'--trans', raw_fname)) as out:
@@ -169,7 +174,7 @@ def test_maxfilter():
@requires_PIL
@testing.requires_testing_data
def test_report():
- """Test mne report"""
+ """Test mne report."""
check_usage(mne_report)
tempdir = _TempDir()
use_fname = op.join(tempdir, op.basename(raw_fname))
@@ -182,7 +187,7 @@ def test_report():
def test_surf2bem():
- """Test mne surf2bem"""
+ """Test mne surf2bem."""
check_usage(mne_surf2bem)
@@ -190,7 +195,7 @@ def test_surf2bem():
@requires_freesurfer
@testing.requires_testing_data
def test_watershed_bem():
- """Test mne watershed bem"""
+ """Test mne watershed bem."""
check_usage(mne_watershed_bem)
# Copy necessary files to tempdir
tempdir = _TempDir()
@@ -211,11 +216,10 @@ def test_watershed_bem():
@ultra_slow_test
- at requires_mne
@requires_freesurfer
@sample.requires_sample_data
def test_flash_bem():
- """Test mne flash_bem"""
+ """Test mne flash_bem."""
check_usage(mne_flash_bem, force_help=True)
# Using the sample dataset
subjects_dir = op.join(sample.data_path(download=False), 'subjects')
@@ -244,7 +248,7 @@ def test_flash_bem():
def test_show_info():
- """Test mne show_info"""
+ """Test mne show_info."""
check_usage(mne_show_info)
with ArgvSetter((raw_fname,)):
mne_show_info.run()
diff --git a/mne/connectivity/spectral.py b/mne/connectivity/spectral.py
index 273e8b4..5f83225 100644
--- a/mne/connectivity/spectral.py
+++ b/mne/connectivity/spectral.py
@@ -2,14 +2,14 @@
#
# License: BSD (3-clause)
-from ..externals.six import string_types
+from functools import partial
from inspect import getmembers
import numpy as np
from scipy.fftpack import fftfreq
from .utils import check_indices
-from ..fixes import tril_indices, partial, _get_args
+from ..fixes import _get_args
from ..parallel import parallel_func
from ..source_estimate import _BaseSourceEstimate
from ..epochs import _BaseEpochs
@@ -18,6 +18,7 @@ from ..time_frequency.multitaper import (dpss_windows, _mt_spectra,
_psd_from_mt_adaptive)
from ..time_frequency.tfr import morlet, cwt
from ..utils import logger, verbose, _time_mask, warn
+from ..externals.six import string_types
########################################################################
# Various connectivity estimators
@@ -704,7 +705,7 @@ def spectral_connectivity(data, method='coh', indices=None, sfreq=2 * np.pi,
The number of DPSS tapers used. Only defined in 'multitaper' mode.
Otherwise None is returned.
"""
- if n_jobs > 1:
+ if n_jobs != 1:
parallel, my_epoch_spectral_connectivity, _ = \
parallel_func(_epoch_spectral_connectivity, n_jobs,
verbose=verbose)
@@ -789,7 +790,7 @@ def spectral_connectivity(data, method='coh', indices=None, sfreq=2 * np.pi,
if indices is None:
# only compute r for lower-triangular region
- indices_use = tril_indices(n_signals, -1)
+ indices_use = np.tril_indices(n_signals, -1)
else:
indices_use = check_indices(indices)
@@ -924,8 +925,8 @@ def spectral_connectivity(data, method='coh', indices=None, sfreq=2 * np.pi,
cwt_n_cycles = cwt_n_cycles[freq_mask]
# get the Morlet wavelets
- wavelets = morlet(sfreq, freqs,
- n_cycles=cwt_n_cycles, zero_mean=True)
+ wavelets = morlet(sfreq, freqs, n_cycles=cwt_n_cycles,
+ zero_mean=True)
eigvals = None
n_tapers = None
window_fun = None
diff --git a/mne/connectivity/tests/test_spectral.py b/mne/connectivity/tests/test_spectral.py
index 8678f5b..9341939 100644
--- a/mne/connectivity/tests/test_spectral.py
+++ b/mne/connectivity/tests/test_spectral.py
@@ -1,11 +1,9 @@
-import os
+import warnings
+
import numpy as np
from numpy.testing import assert_array_almost_equal
from nose.tools import assert_true, assert_raises
-from nose.plugins.skip import SkipTest
-import warnings
-from mne.fixes import tril_indices
from mne.connectivity import spectral_connectivity
from mne.connectivity.spectral import _CohEst
@@ -13,6 +11,10 @@ from mne import SourceEstimate
from mne.utils import run_tests_if_main, slow_test
from mne.filter import band_pass_filter
+trans_bandwidth = 2.5
+filt_kwargs = dict(filter_length='auto', fir_window='hamming', phase='zero',
+ l_trans_bandwidth=trans_bandwidth,
+ h_trans_bandwidth=trans_bandwidth)
warnings.simplefilter('always')
@@ -35,10 +37,6 @@ def _stc_gen(data, sfreq, tmin, combo=False):
@slow_test
def test_spectral_connectivity():
"""Test frequency-domain connectivity methods"""
- # XXX For some reason on 14 Oct 2015 Travis started timing out on this
- # test, so for a quick workaround we will skip it:
- if os.getenv('TRAVIS', 'false') == 'true':
- raise SkipTest('Travis is broken')
# Use a case known to have no spurious correlations (it would bad if
# nosetests could randomly fail):
np.random.seed(0)
@@ -55,10 +53,9 @@ def test_spectral_connectivity():
# simulate connectivity from 5Hz..15Hz
fstart, fend = 5.0, 15.0
for i in range(n_epochs):
- with warnings.catch_warnings(record=True):
- warnings.simplefilter('always')
- data[i, 1, :] = band_pass_filter(data[i, 0, :],
- sfreq, fstart, fend)
+ data[i, 1, :] = band_pass_filter(data[i, 0, :],
+ sfreq, fstart, fend,
+ **filt_kwargs)
# add some noise, so the spectrum is not exactly zero
data[i, 1, :] += 1e-2 * np.random.randn(n_times)
@@ -117,31 +114,40 @@ def test_spectral_connectivity():
if mode == 'multitaper':
upper_t = 0.95
lower_t = 0.5
- else:
+ elif mode == 'fourier':
# other estimates have higher variance
upper_t = 0.8
lower_t = 0.75
+ else: # cwt_morlet
+ upper_t = 0.64
+ lower_t = 0.63
# test the simulated signal
if method == 'coh':
- idx = np.searchsorted(freqs, (fstart + 1, fend - 1))
+ idx = np.searchsorted(freqs, (fstart + trans_bandwidth,
+ fend - trans_bandwidth))
# we see something for zero-lag
- assert_true(np.all(con[1, 0, idx[0]:idx[1]] > upper_t))
+ assert_true(np.all(con[1, 0, idx[0]:idx[1]] > upper_t),
+ con[1, 0, idx[0]:idx[1]].min())
if mode != 'cwt_morlet':
- idx = np.searchsorted(freqs, (fstart - 1, fend + 1))
+ idx = np.searchsorted(freqs,
+ (fstart - trans_bandwidth * 2,
+ fend + trans_bandwidth * 2))
assert_true(np.all(con[1, 0, :idx[0]] < lower_t))
- assert_true(np.all(con[1, 0, idx[1]:] < lower_t))
+ assert_true(np.all(con[1, 0, idx[1]:] < lower_t),
+ con[1, 0, idx[1:]].max())
elif method == 'cohy':
idx = np.searchsorted(freqs, (fstart + 1, fend - 1))
# imaginary coh will be zero
- assert_true(np.all(np.imag(con[1, 0, idx[0]:idx[1]]) <
- lower_t))
+ check = np.imag(con[1, 0, idx[0]:idx[1]])
+ assert_true(np.all(check < lower_t), check.max())
# we see something for zero-lag
assert_true(np.all(np.abs(con[1, 0, idx[0]:idx[1]]) >
upper_t))
- idx = np.searchsorted(freqs, (fstart - 1, fend + 1))
+ idx = np.searchsorted(freqs, (fstart - trans_bandwidth * 2,
+ fend + trans_bandwidth * 2))
if mode != 'cwt_morlet':
assert_true(np.all(np.abs(con[1, 0, :idx[0]]) <
lower_t))
@@ -153,10 +159,11 @@ def test_spectral_connectivity():
assert_true(np.all(con[1, 0, idx[0]:idx[1]] < lower_t))
idx = np.searchsorted(freqs, (fstart - 1, fend + 1))
assert_true(np.all(con[1, 0, :idx[0]] < lower_t))
- assert_true(np.all(con[1, 0, idx[1]:] < lower_t))
+ assert_true(np.all(con[1, 0, idx[1]:] < lower_t),
+ con[1, 0, idx[1]:].max())
# compute same connections using indices and 2 jobs
- indices = tril_indices(n_signals, -1)
+ indices = np.tril_indices(n_signals, -1)
if not isinstance(method, list):
test_methods = (method, _CohEst)
diff --git a/mne/coreg.py b/mne/coreg.py
index b97cfac..9fb5fd6 100644
--- a/mne/coreg.py
+++ b/mne/coreg.py
@@ -39,6 +39,9 @@ head_bem_fname = pformat(bem_fname, name='head')
fid_fname = pformat(bem_fname, name='fiducials')
fid_fname_general = os.path.join(bem_dirname, "{head}-fiducials.fif")
src_fname = os.path.join(bem_dirname, '{subject}-{spacing}-src.fif')
+_high_res_head_fnames = (os.path.join(bem_dirname, '{subject}-head-dense.fif'),
+ os.path.join(surf_dirname, 'lh.seghead'),
+ os.path.join(surf_dirname, 'lh.smseghead'))
def _make_writable(fname):
@@ -54,6 +57,13 @@ def _make_writable_recursive(path):
_make_writable(os.path.join(root, f))
+def _find_high_res_head(subject, subjects_dir):
+ for fname in _high_res_head_fnames:
+ path = fname.format(subjects_dir=subjects_dir, subject=subject)
+ if os.path.exists(path):
+ return path
+
+
def create_default_subject(mne_root=None, fs_home=None, update=False,
subjects_dir=None):
"""Create an average brain subject for subjects without structural MRI
@@ -587,13 +597,16 @@ def _find_label_paths(subject='fsaverage', pattern=None, subjects_dir=None):
return paths
-def _find_mri_paths(subject='fsaverage', subjects_dir=None):
+def _find_mri_paths(subject, skip_fiducials, subjects_dir):
"""Find all files of an mri relevant for source transformation
Parameters
----------
subject : str
Name of the mri subject.
+ skip_fiducials : bool
+ Do not scale the MRI fiducials. If False, an IOError will be raised
+ if no fiducials file can be found.
subjects_dir : None | str
Override the SUBJECTS_DIR environment variable
(sys.environ['SUBJECTS_DIR'])
@@ -613,15 +626,18 @@ def _find_mri_paths(subject='fsaverage', subjects_dir=None):
# surf/ files
paths['surf'] = surf = []
surf_fname = os.path.join(surf_dirname, '{name}')
- surf_names = ('inflated', 'sphere', 'sphere.reg', 'white')
- if os.getenv('_MNE_FEW_SURFACES', '') != 'true': # for testing
- surf_names = surf_names + (
- 'orig', 'orig_avg', 'inflated_avg', 'inflated_pre', 'pial',
- 'pial_avg', 'smoothwm', 'white_avg', 'sphere.reg.avg')
- for name in surf_names:
+ surf_names = ('inflated', 'sphere', 'sphere.reg', 'white', 'orig',
+ 'orig_avg', 'inflated_avg', 'inflated_pre', 'pial',
+ 'pial_avg', 'smoothwm', 'white_avg', 'sphere.reg.avg')
+ if os.getenv('_MNE_FEW_SURFACES', '') == 'true': # for testing
+ surf_names = surf_names[:4]
+ for surf_name in surf_names:
for hemi in ('lh.', 'rh.'):
- fname = pformat(surf_fname, name=hemi + name)
- surf.append(fname)
+ name = hemi + surf_name
+ path = surf_fname.format(subjects_dir=subjects_dir,
+ subject=subject, name=name)
+ if os.path.exists(path):
+ surf.append(pformat(surf_fname, name=name))
# BEM files
paths['bem'] = bem = []
@@ -638,7 +654,17 @@ def _find_mri_paths(subject='fsaverage', subjects_dir=None):
bem.append(name)
# fiducials
- paths['fid'] = [fid_fname]
+ if skip_fiducials:
+ paths['fid'] = []
+ else:
+ paths['fid'] = _find_fiducials_files(subject, subjects_dir)
+ # check that we found at least one
+ if len(paths['fid']) == 0:
+ raise IOError("No fiducials file found for %s. The fiducials "
+ "file should be named "
+ "{subject}/bem/{subject}-fiducials.fif. In "
+ "order to scale an MRI without fiducials set "
+ "skip_fiducials=True." % subject)
# duplicate curvature files
paths['duplicate'] = dup = []
@@ -648,7 +674,7 @@ def _find_mri_paths(subject='fsaverage', subjects_dir=None):
dup.append(fname)
# check presence of required files
- for ftype in ['surf', 'fid', 'duplicate']:
+ for ftype in ['surf', 'duplicate']:
for fname in paths[ftype]:
path = fname.format(subjects_dir=subjects_dir, subject=subject)
path = os.path.realpath(path)
@@ -669,6 +695,24 @@ def _find_mri_paths(subject='fsaverage', subjects_dir=None):
return paths
+def _find_fiducials_files(subject, subjects_dir):
+ fid = []
+ # standard fiducials
+ if os.path.exists(fid_fname.format(subjects_dir=subjects_dir,
+ subject=subject)):
+ fid.append(fid_fname)
+ # fiducials with subject name
+ pattern = pformat(fid_fname_general, subjects_dir=subjects_dir,
+ subject=subject, head='*')
+ regex = pformat(fid_fname_general, subjects_dir=subjects_dir,
+ subject=subject, head='(.+)')
+ for path in iglob(pattern):
+ match = re.match(regex, path)
+ head = match.group(1).replace(subject, '{subject}')
+ fid.append(pformat(fid_fname_general, head=head))
+ return fid
+
+
def _is_mri_subject(subject, subjects_dir=None):
"""Check whether a directory in subjects_dir is an mri subject directory
@@ -685,12 +729,30 @@ def _is_mri_subject(subject, subjects_dir=None):
Whether ``subject`` is an mri subject.
"""
subjects_dir = get_subjects_dir(subjects_dir, raise_error=True)
-
fname = head_bem_fname.format(subjects_dir=subjects_dir, subject=subject)
- if not os.path.exists(fname):
- return False
+ return os.path.exists(fname)
+
+
+def _is_scaled_mri_subject(subject, subjects_dir=None):
+ """Check whether a directory in subjects_dir is a scaled mri subject
- return True
+ Parameters
+ ----------
+ subject : str
+ Name of the potential subject/directory.
+ subjects_dir : None | str
+ Override the SUBJECTS_DIR environment variable.
+
+ Returns
+ -------
+ is_scaled_mri_subject : bool
+ Whether ``subject`` is a scaled mri subject.
+ """
+ subjects_dir = get_subjects_dir(subjects_dir, raise_error=True)
+ if not _is_mri_subject(subject, subjects_dir):
+ return False
+ fname = os.path.join(subjects_dir, subject, 'MRI scaling parameters.cfg')
+ return os.path.exists(fname)
def _mri_subject_has_bem(subject, subjects_dir=None):
@@ -789,6 +851,22 @@ def _write_mri_config(fname, subject_from, subject_to, scale):
def _scale_params(subject_to, subject_from, scale, subjects_dir):
+ """Assemble parameters for scaling
+
+ Returns
+ -------
+ subjects_dir : str
+ Subjects directory.
+ subject_from : str
+ Name of the source subject.
+ scale : array
+ Scaling factor, either shape=() for uniform scaling or shape=(3,) for
+ non-uniform scaling.
+ nn_scale : None | array
+ Scaling factor for surface normal. If scaling is uniform, normals are
+ unchanged and nn_scale is None. If scaling is non-uniform nn_scale is
+ an array of shape (3,).
+ """
subjects_dir = get_subjects_dir(subjects_dir, raise_error=True)
if (subject_from is None) != (scale is None):
raise TypeError("Need to provide either both subject_from and scale "
@@ -809,7 +887,15 @@ def _scale_params(subject_to, subject_from, scale, subjects_dir):
raise ValueError("Invalid shape for scale parameer. Need scalar "
"or array of length 3. Got %s." % str(scale))
- return subjects_dir, subject_from, n_params, scale
+ # prepare scaling parameter for normals
+ if n_params == 1:
+ nn_scale = None
+ elif n_params == 3:
+ nn_scale = 1. / scale
+ else:
+ raise RuntimeError("Invalid n_params value: %s" % repr(n_params))
+
+ return subjects_dir, subject_from, scale, nn_scale
def scale_bem(subject_to, bem_name, subject_from=None, scale=None,
@@ -833,9 +919,8 @@ def scale_bem(subject_to, bem_name, subject_from=None, scale=None,
subjects_dir : None | str
Override the SUBJECTS_DIR environment variable.
"""
- subjects_dir, subject_from, _, scale = _scale_params(subject_to,
- subject_from, scale,
- subjects_dir)
+ subjects_dir, subject_from, scale, nn_scale = \
+ _scale_params(subject_to, subject_from, scale, subjects_dir)
src = bem_fname.format(subjects_dir=subjects_dir, subject=subject_from,
name=bem_name)
@@ -846,12 +931,12 @@ def scale_bem(subject_to, bem_name, subject_from=None, scale=None,
raise IOError("File alredy exists: %s" % dst)
surfs = read_bem_surfaces(src)
- if len(surfs) != 1:
- raise NotImplementedError("BEM file with more than one surface: %r"
- % src)
- surf0 = surfs[0]
- surf0['rr'] = surf0['rr'] * scale
- write_bem_surfaces(dst, surf0)
+ for surf in surfs:
+ surf['rr'] *= scale
+ if nn_scale is not None:
+ surf['nn'] *= nn_scale
+ surf['nn'] /= np.sqrt(np.sum(surf['nn'] ** 2, 1))[:, np.newaxis]
+ write_bem_surfaces(dst, surfs)
def scale_labels(subject_to, pattern=None, overwrite=False, subject_from=None,
@@ -915,7 +1000,7 @@ def scale_labels(subject_to, pattern=None, overwrite=False, subject_from=None,
def scale_mri(subject_from, subject_to, scale, overwrite=False,
- subjects_dir=None):
+ subjects_dir=None, skip_fiducials=False):
"""Create a scaled copy of an MRI subject
Parameters
@@ -930,6 +1015,9 @@ def scale_mri(subject_from, subject_to, scale, overwrite=False,
If an MRI already exists for subject_to, overwrite it.
subjects_dir : None | str
Override the SUBJECTS_DIR environment variable.
+ skip_fiducials : bool
+ Do not scale the MRI fiducials. If False (default), an IOError will be
+ raised if no fiducials file can be found.
See Also
--------
@@ -937,7 +1025,7 @@ def scale_mri(subject_from, subject_to, scale, overwrite=False,
scale_source_space : add a source space to a scaled MRI
"""
subjects_dir = get_subjects_dir(subjects_dir, raise_error=True)
- paths = _find_mri_paths(subject_from, subjects_dir=subjects_dir)
+ paths = _find_mri_paths(subject_from, skip_fiducials, subjects_dir)
scale = np.asarray(scale)
# make sure we have an empty target directory
@@ -950,6 +1038,7 @@ def scale_mri(subject_from, subject_to, scale, overwrite=False,
raise IOError("Subject directory for %s already exists: %r"
% (subject_to, dest))
+ # create empty directory structure
for dirname in paths['dirs']:
dir_ = dirname.format(subject=subject_to, subjects_dir=subjects_dir)
os.makedirs(dir_)
@@ -1024,10 +1113,8 @@ def scale_source_space(subject_to, src_name, subject_from=None, scale=None,
applies if scale is an array of length 3, and will not use more cores
than there are source spaces).
"""
- subjects_dir, subject_from, n_params, scale = _scale_params(subject_to,
- subject_from,
- scale,
- subjects_dir)
+ subjects_dir, subject_from, scale, nn_scale = \
+ _scale_params(subject_to, subject_from, scale, subjects_dir)
# find the source space file names
if src_name.isdigit():
@@ -1047,15 +1134,6 @@ def scale_source_space(subject_to, src_name, subject_from=None, scale=None,
dst = src_pattern.format(subjects_dir=subjects_dir, subject=subject_to,
spacing=spacing)
- # prepare scaling parameters
- if n_params == 1:
- norm_scale = None
- elif n_params == 3:
- norm_scale = 1. / scale
- else:
- raise RuntimeError("Invalid n_params entry in MRI cfg file: %s"
- % str(n_params))
-
# read and scale the source space [in m]
sss = read_source_spaces(src)
logger.info("scaling source space %s: %s -> %s", spacing, subject_from,
@@ -1067,16 +1145,14 @@ def scale_source_space(subject_to, src_name, subject_from=None, scale=None,
ss['rr'] *= scale
# distances and patch info
- if norm_scale is None:
+ if nn_scale is None: # i.e. uniform scaling
if ss['dist'] is not None:
ss['dist'] *= scale
ss['nearest_dist'] *= scale
ss['dist_limit'] *= scale
- else:
- nn = ss['nn']
- nn *= norm_scale
- norm = np.sqrt(np.sum(nn ** 2, 1))
- nn /= norm[:, np.newaxis]
+ else: # non-uniform scaling
+ ss['nn'] *= nn_scale
+ ss['nn'] /= np.sqrt(np.sum(ss['nn'] ** 2, 1))[:, np.newaxis]
if ss['dist'] is not None:
add_dist = True
diff --git a/mne/cov.py b/mne/cov.py
index 3cbc947..292ed25 100644
--- a/mne/cov.py
+++ b/mne/cov.py
@@ -32,8 +32,8 @@ from .epochs import Epochs
from .event import make_fixed_length_events
from .utils import (check_fname, logger, verbose, estimate_rank,
_compute_row_norms, check_version, _time_mask, warn,
- _check_copy_dep)
-from .fixes import in1d
+ copy_function_doc_to_method_doc)
+from . import viz
from .externals.six.moves import zip
from .externals.six import string_types
@@ -163,16 +163,9 @@ class Covariance(dict):
"""
return cp.deepcopy(self)
- def as_diag(self, copy=None):
+ def as_diag(self):
"""Set covariance to be processed as being diagonal.
- Parameters
- ----------
- copy : bool
- This parameter has been deprecated and will be removed in 0.13.
- Use inst.copy() instead.
- Whether to return a new instance or modify in place.
-
Returns
-------
cov : dict
@@ -183,14 +176,13 @@ class Covariance(dict):
This function allows creation of inverse operators
equivalent to using the old "--diagnoise" mne option.
"""
- cov = _check_copy_dep(self, copy, default=True)
- if cov['diag']:
- return cov
- cov['diag'] = True
- cov['data'] = np.diag(cov['data'])
- cov['eig'] = None
- cov['eigvec'] = None
- return cov
+ if self['diag']:
+ return self
+ self['diag'] = True
+ self['data'] = np.diag(self['data'])
+ self['eig'] = None
+ self['eigvec'] = None
+ return self
def __repr__(self):
if self.data.ndim == 2:
@@ -227,38 +219,11 @@ class Covariance(dict):
return self
@verbose
+ @copy_function_doc_to_method_doc(viz.misc.plot_cov)
def plot(self, info, exclude=[], colorbar=True, proj=False, show_svd=True,
show=True, verbose=None):
- """Plot Covariance data.
-
- Parameters
- ----------
- info: dict
- Measurement info.
- exclude : list of string | str
- List of channels to exclude. If empty do not exclude any channel.
- If 'bads', exclude info['bads'].
- colorbar : bool
- Show colorbar or not.
- proj : bool
- Apply projections or not.
- show_svd : bool
- Plot also singular values of the noise covariance for each sensor
- type. We show square roots ie. standard deviations.
- show : bool
- Call pyplot.show() as the end or not.
- verbose : bool, str, int, or None
- If not None, override default verbose level (see mne.verbose).
-
- Returns
- -------
- fig_cov : instance of matplotlib.pyplot.Figure
- The covariance plot.
- fig_svd : instance of matplotlib.pyplot.Figure | None
- The SVD spectra plot of the covariance.
- """
- from .viz.misc import plot_cov
- return plot_cov(self, info, exclude, colorbar, proj, show_svd, show)
+ return viz.misc.plot_cov(self, info, exclude, colorbar, proj, show_svd,
+ show, verbose)
###############################################################################
@@ -466,13 +431,15 @@ def compute_raw_covariance(raw, tmin=0, tmax=None, tstep=0.2, reject=None,
if picks is None:
# Need to include all channels e.g. if eog rejection is to be used
picks = np.arange(raw.info['nchan'])
- pick_mask = in1d(
+ pick_mask = np.in1d(
picks, _pick_data_channels(raw.info, with_ref_meg=False))
else:
pick_mask = slice(None)
epochs = Epochs(raw, events, 1, 0, tstep_m1, baseline=None,
picks=picks, reject=reject, flat=flat, verbose=False,
- preload=False, proj=False)
+ preload=False, proj=False, add_eeg_ref=False)
+ if method is None:
+ method = 'empirical'
if isinstance(method, string_types) and method == 'empirical':
# potentially *much* more memory efficient to do it the iterative way
picks = picks[pick_mask]
@@ -486,8 +453,7 @@ def compute_raw_covariance(raw, tmin=0, tmax=None, tstep=0.2, reject=None,
data += np.dot(raw_segment, raw_segment.T)
n_samples += raw_segment.shape[1]
_check_n_samples(n_samples, len(picks))
- mu /= n_samples
- data -= n_samples * mu[:, None] * mu[None, :]
+ data -= mu[:, None] * (mu[None, :] / n_samples)
data /= (n_samples - 1.0)
logger.info("Number of samples used : %d" % n_samples)
logger.info('[done]')
@@ -514,7 +480,7 @@ def compute_raw_covariance(raw, tmin=0, tmax=None, tstep=0.2, reject=None,
def compute_covariance(epochs, keep_sample_mean=True, tmin=None, tmax=None,
projs=None, method='empirical', method_params=None,
cv=3, scalings=None, n_jobs=1, return_estimators=False,
- verbose=None):
+ on_mismatch='raise', verbose=None):
"""Estimate noise covariance matrix from epochs.
The noise covariance is typically estimated on pre-stim periods
@@ -622,6 +588,14 @@ def compute_covariance(epochs, keep_sample_mean=True, tmin=None, tmax=None,
return_estimators : bool (default False)
Whether to return all estimators or the best. Only considered if
method equals 'auto' or is a list of str. Defaults to False
+ on_mismatch : str
+ What to do when the MEG<->Head transformations do not match between
+ epochs. If "raise" (default) an error is raised, if "warn" then a
+ warning is emitted, if "ignore" then nothing is printed. Having
+ mismatched transforms can in some cases lead to unexpected or
+ unstable results in covariance calculation, e.g. when data
+ have been processed with Maxwell filtering but not transformed
+ to the same head position.
verbose : bool | str | int | or None (default None)
If not None, override default verbose level (see mne.verbose).
@@ -654,7 +628,7 @@ def compute_covariance(epochs, keep_sample_mean=True, tmin=None, tmax=None,
accepted_methods = ('auto', 'empirical', 'diagonal_fixed', 'ledoit_wolf',
'shrunk', 'pca', 'factor_analysis',)
msg = ('Invalid method ({method}). Accepted values (individually or '
- 'in a list) are "%s"' % '" or "'.join(accepted_methods + ('None',)))
+ 'in a list) are "%s" or None.' % '" or "'.join(accepted_methods))
# scale to natural unit for best stability with MEG/EEG
if isinstance(scalings, dict):
@@ -698,14 +672,30 @@ def compute_covariance(epochs, keep_sample_mean=True, tmin=None, tmax=None,
epochs = sum([_unpack_epochs(epoch) for epoch in epochs], [])
# check for baseline correction
- for epochs_t in epochs:
- if epochs_t.baseline is None and epochs_t.info['highpass'] < 0.5 and \
- keep_sample_mean:
- warn('Epochs are not baseline corrected, covariance '
- 'matrix may be inaccurate')
-
- for epoch in epochs:
+ if any(epochs_t.baseline is None and epochs_t.info['highpass'] < 0.5 and
+ keep_sample_mean for epochs_t in epochs):
+ warn('Epochs are not baseline corrected, covariance '
+ 'matrix may be inaccurate')
+
+ orig = epochs[0].info['dev_head_t']
+ if not isinstance(on_mismatch, string_types) or \
+ on_mismatch not in ['raise', 'warn', 'ignore']:
+ raise ValueError('on_mismatch must be "raise", "warn", or "ignore", '
+ 'got %s' % on_mismatch)
+ for ei, epoch in enumerate(epochs):
epoch.info._check_consistency()
+ if (orig is None) != (epoch.info['dev_head_t'] is None) or \
+ (orig is not None and not
+ np.allclose(orig['trans'],
+ epoch.info['dev_head_t']['trans'])):
+ msg = ('MEG<->Head transform mismatch between epochs[0]:\n%s\n\n'
+ 'and epochs[%s]:\n%s'
+ % (orig, ei, epoch.info['dev_head_t']))
+ if on_mismatch == 'raise':
+ raise ValueError(msg)
+ elif on_mismatch == 'warn':
+ warn(msg)
+
bads = epochs[0].info['bads']
if projs is None:
projs = cp.deepcopy(epochs[0].info['projs'])
@@ -732,7 +722,9 @@ def compute_covariance(epochs, keep_sample_mean=True, tmin=None, tmax=None,
ch_names = [epochs[0].ch_names[k] for k in picks_meeg]
info = epochs[0].info # we will overwrite 'epochs'
- if method == 'auto':
+ if method is None:
+ method = ['empirical']
+ elif method == 'auto':
method = ['shrunk', 'diagonal_fixed', 'empirical', 'factor_analysis']
if not isinstance(method, (list, tuple)):
@@ -1225,7 +1217,7 @@ def prepare_noise_cov(noise_cov, info, ch_names, rank=None,
Parameters
----------
- noise_cov : Covariance
+ noise_cov : instance of Covariance
The noise covariance to process.
info : dict
The measurement info (used to get channel types and bad channels).
@@ -1238,18 +1230,25 @@ def prepare_noise_cov(noise_cov, info, ch_names, rank=None,
to specify the rank for each modality.
scalings : dict | None
Data will be rescaled before rank estimation to improve accuracy.
- If dict, it will override the following dict (default if None):
+ If dict, it will override the following dict (default if None)::
dict(mag=1e12, grad=1e11, eeg=1e5)
verbose : bool, str, int, or None
If not None, override default verbose level (see mne.verbose).
+
+ Returns
+ -------
+ cov : instance of Covariance
+ A copy of the covariance with the good channels subselected
+ and parameters updated.
"""
- C_ch_idx = [noise_cov.ch_names.index(c) for c in ch_names]
- if noise_cov['diag'] is False:
- C = noise_cov.data[np.ix_(C_ch_idx, C_ch_idx)]
+ noise_cov_idx = [noise_cov.ch_names.index(c) for c in ch_names]
+ n_chan = len(ch_names)
+ if not noise_cov['diag']:
+ C = noise_cov.data[np.ix_(noise_cov_idx, noise_cov_idx)]
else:
- C = np.diag(noise_cov.data[C_ch_idx])
+ C = np.diag(noise_cov.data[noise_cov_idx])
scalings = _handle_default('scalings_cov_rank', scalings)
@@ -1260,17 +1259,33 @@ def prepare_noise_cov(noise_cov, info, ch_names, rank=None,
% ncomp)
C = np.dot(proj, np.dot(C, proj.T))
- pick_meg = pick_types(info, meg=True, eeg=False, ref_meg=False,
- exclude='bads')
- pick_eeg = pick_types(info, meg=False, eeg=True, ref_meg=False,
- exclude='bads')
- meg_names = [info['chs'][k]['ch_name'] for k in pick_meg]
- C_meg_idx = [k for k in range(len(C)) if ch_names[k] in meg_names]
- eeg_names = [info['chs'][k]['ch_name'] for k in pick_eeg]
- C_eeg_idx = [k for k in range(len(C)) if ch_names[k] in eeg_names]
-
- has_meg = len(C_meg_idx) > 0
- has_eeg = len(C_eeg_idx) > 0
+ info_pick_meg = pick_types(info, meg=True, eeg=False, ref_meg=False,
+ exclude='bads')
+ info_pick_eeg = pick_types(info, meg=False, eeg=True, ref_meg=False,
+ exclude='bads')
+ info_meg_names = [info['chs'][k]['ch_name'] for k in info_pick_meg]
+ out_meg_idx = [k for k in range(len(C)) if ch_names[k] in info_meg_names]
+ info_eeg_names = [info['chs'][k]['ch_name'] for k in info_pick_eeg]
+ out_eeg_idx = [k for k in range(len(C)) if ch_names[k] in info_eeg_names]
+ # re-index based on ch_names order
+ del info_pick_meg, info_pick_eeg
+ meg_names = [ch_names[k] for k in out_meg_idx]
+ eeg_names = [ch_names[k] for k in out_eeg_idx]
+ if len(meg_names) > 0:
+ info_pick_meg = pick_channels(info['ch_names'], meg_names)
+ else:
+ info_pick_meg = []
+ if len(eeg_names) > 0:
+ info_pick_eeg = pick_channels(info['ch_names'], eeg_names)
+ else:
+ info_pick_eeg = []
+ assert len(info_pick_meg) == len(meg_names) == len(out_meg_idx)
+ assert len(info_pick_eeg) == len(eeg_names) == len(out_eeg_idx)
+ assert(len(out_meg_idx) + len(out_eeg_idx) == n_chan)
+ eigvec = np.zeros((n_chan, n_chan))
+ eig = np.zeros(n_chan)
+ has_meg = len(out_meg_idx) > 0
+ has_eeg = len(out_eeg_idx) > 0
# Get the specified noise covariance rank
if rank is not None:
@@ -1284,45 +1299,26 @@ def prepare_noise_cov(noise_cov, info, ch_names, rank=None,
rank_meg, rank_eeg = None, None
if has_meg:
- C_meg = C[np.ix_(C_meg_idx, C_meg_idx)]
- this_info = pick_info(info, pick_meg)
+ C_meg = C[np.ix_(out_meg_idx, out_meg_idx)]
+ this_info = pick_info(info, info_pick_meg)
if rank_meg is None:
- if len(C_meg_idx) < len(pick_meg):
- this_info = pick_info(info, C_meg_idx)
rank_meg = _estimate_rank_meeg_cov(C_meg, this_info, scalings)
- C_meg_eig, C_meg_eigvec = _get_ch_whitener(C_meg, False, 'MEG',
- rank_meg)
+ eig[out_meg_idx], eigvec[np.ix_(out_meg_idx, out_meg_idx)] = \
+ _get_ch_whitener(C_meg, False, 'MEG', rank_meg)
if has_eeg:
- C_eeg = C[np.ix_(C_eeg_idx, C_eeg_idx)]
- this_info = pick_info(info, pick_eeg)
+ C_eeg = C[np.ix_(out_eeg_idx, out_eeg_idx)]
+ this_info = pick_info(info, info_pick_eeg)
if rank_eeg is None:
- if len(C_meg_idx) < len(pick_meg):
- this_info = pick_info(info, C_eeg_idx)
rank_eeg = _estimate_rank_meeg_cov(C_eeg, this_info, scalings)
- C_eeg_eig, C_eeg_eigvec = _get_ch_whitener(C_eeg, False, 'EEG',
- rank_eeg)
+ eig[out_eeg_idx], eigvec[np.ix_(out_eeg_idx, out_eeg_idx)], = \
+ _get_ch_whitener(C_eeg, False, 'EEG', rank_eeg)
if _needs_eeg_average_ref_proj(info):
warn('No average EEG reference present in info["projs"], covariance '
'may be adversely affected. Consider recomputing covariance using'
' a raw file with an average eeg reference projector added.')
-
- n_chan = len(ch_names)
- eigvec = np.zeros((n_chan, n_chan), dtype=np.float)
- eig = np.zeros(n_chan, dtype=np.float)
-
- if has_meg:
- eigvec[np.ix_(C_meg_idx, C_meg_idx)] = C_meg_eigvec
- eig[C_meg_idx] = C_meg_eig
- if has_eeg:
- eigvec[np.ix_(C_eeg_idx, C_eeg_idx)] = C_eeg_eigvec
- eig[C_eeg_idx] = C_eeg_eig
-
- assert(len(C_meg_idx) + len(C_eeg_idx) == n_chan)
-
noise_cov = cp.deepcopy(noise_cov)
noise_cov.update(data=C, eig=eig, eigvec=eigvec, dim=len(ch_names),
diag=False, names=ch_names)
-
return noise_cov
@@ -1630,13 +1626,17 @@ def _get_whitener_data(info, noise_cov, picks, diag=False, rank=None,
"""Get whitening matrix for a set of data."""
ch_names = [info['ch_names'][k] for k in picks]
noise_cov = pick_channels_cov(noise_cov, include=ch_names, exclude=[])
- info = pick_info(info, picks)
+ if len(noise_cov['data']) != len(ch_names):
+ missing = list(set(ch_names) - set(noise_cov['names']))
+ raise RuntimeError('Not all channels present in noise covariance:\n%s'
+ % missing)
if diag:
noise_cov = cp.deepcopy(noise_cov)
noise_cov['data'] = np.diag(np.diag(noise_cov['data']))
scalings = _handle_default('scalings_cov_rank', scalings)
- W = compute_whitener(noise_cov, info, rank=rank, scalings=scalings)[0]
+ W = compute_whitener(noise_cov, info, picks=picks, rank=rank,
+ scalings=scalings)[0]
return W
@@ -1907,9 +1907,6 @@ def _estimate_rank_meeg_signals(data, info, scalings, tol='auto',
return_singular : bool
If True, also return the singular values that were used
to determine the rank.
- copy : bool
- If False, values in data will be modified in-place during
- rank estimation (saves memory).
Returns
-------
@@ -1935,7 +1932,7 @@ def _estimate_rank_meeg_signals(data, info, scalings, tol='auto',
def _estimate_rank_meeg_cov(data, info, scalings, tol='auto',
return_singular=False):
- """Estimate rank for M/EEG data.
+ """Estimate rank of M/EEG covariance data, given the covariance
Parameters
----------
diff --git a/mne/cuda.py b/mne/cuda.py
index 02ae626..f8d729e 100644
--- a/mne/cuda.py
+++ b/mne/cuda.py
@@ -5,7 +5,7 @@
import numpy as np
from scipy.fftpack import fft, ifft, rfft, irfft
-from .utils import sizeof_fmt, logger, get_config, warn
+from .utils import sizeof_fmt, logger, get_config, warn, _explain_exception
# Support CUDA for FFTs; requires scikits.cuda and pycuda
@@ -74,7 +74,7 @@ def init_cuda(ignore_config=False):
import pycuda.autoinit # noqa
except ImportError:
warn('pycuda.autoinit could not be imported, likely a hardware error, '
- 'CUDA not enabled')
+ 'CUDA not enabled%s' % _explain_exception())
return
# Make sure scikit-cuda is installed
cudafft = _get_cudafft()
@@ -94,8 +94,9 @@ def init_cuda(ignore_config=False):
# Make sure we can use 64-bit FFTs
try:
cudafft.Plan(16, np.float64, np.complex128) # will get auto-GC'ed
- except:
- warn('Device does not support 64-bit FFTs, CUDA not enabled')
+ except Exception:
+ warn('Device does not appear to support 64-bit FFTs, CUDA not '
+ 'enabled%s' % _explain_exception())
return
_cuda_capable = True
# Figure out limit for CUDA FFT calculations
diff --git a/mne/data/coil_def.dat b/mne/data/coil_def.dat
index 13bd7b4..3be7fff 100644
--- a/mne/data/coil_def.dat
+++ b/mne/data/coil_def.dat
@@ -33,7 +33,7 @@
#
# Produced with:
#
-# mne_list_coil_def version 1.12 compiled at Jan 13 2015 18:20:15
+# mne_list_coil_def version 1.12 compiled at Jul 12 2016 18:39:09
#
3 2 0 2 2.789e-02 1.620e-02 "Neuromag-122 planar gradiometer size = 27.89 mm base = 16.20 mm"
61.7284 8.100e-03 0.000e+00 0.000e+00 0.000 0.000 1.000
@@ -366,6 +366,21 @@
-0.1250 3.164e-03 -5.480e-03 5.000e-02 0.000 0.000 1.000
-0.1250 -3.164e-03 5.480e-03 5.000e-02 0.000 0.000 1.000
-0.1250 -3.164e-03 -5.480e-03 5.000e-02 0.000 0.000 1.000
+1 6002 0 1 1.550e-02 0.000e+00 "MIT KIT system reference magnetometer size = 15.50 mm"
+ 1.0000 0.000e+00 0.000e+00 0.000e+00 0.000 0.000 1.000
+1 6002 1 4 1.550e-02 0.000e+00 "MIT KIT system reference magnetometer size = 15.50 mm"
+ 0.2500 3.875e-03 3.875e-03 0.000e+00 0.000 0.000 1.000
+ 0.2500 -3.875e-03 3.875e-03 0.000e+00 0.000 0.000 1.000
+ 0.2500 -3.875e-03 -3.875e-03 0.000e+00 0.000 0.000 1.000
+ 0.2500 3.875e-03 -3.875e-03 0.000e+00 0.000 0.000 1.000
+1 6002 2 7 1.550e-02 0.000e+00 "MIT KIT system reference magnetometer size = 15.50 mm"
+ 0.2500 0.000e+00 0.000e+00 0.000e+00 0.000 0.000 1.000
+ 0.1250 6.328e-03 0.000e+00 0.000e+00 0.000 0.000 1.000
+ 0.1250 -6.328e-03 0.000e+00 0.000e+00 0.000 0.000 1.000
+ 0.1250 3.164e-03 5.480e-03 0.000e+00 0.000 0.000 1.000
+ 0.1250 3.164e-03 -5.480e-03 0.000e+00 0.000 0.000 1.000
+ 0.1250 -3.164e-03 5.480e-03 0.000e+00 0.000 0.000 1.000
+ 0.1250 -3.164e-03 -5.480e-03 0.000e+00 0.000 0.000 1.000
2 7001 0 2 6.000e-03 5.000e-02 "BabySQUID system gradiometer size = 6.00 mm base = 50.00 mm"
1.0000 0.000e+00 0.000e+00 0.000e+00 0.000 0.000 1.000
-1.0000 0.000e+00 0.000e+00 5.000e-02 0.000 0.000 1.000
diff --git a/mne/data/mne_analyze.sel b/mne/data/mne_analyze.sel
index b0e9034..ae4bf34 100644
--- a/mne/data/mne_analyze.sel
+++ b/mne/data/mne_analyze.sel
@@ -10,4 +10,10 @@ Left-occipital:MEG 2042|MEG 2043|MEG 1913|MEG 1912|MEG 2113|MEG 2112|MEG 1922|ME
Right-occipital:MEG 2032|MEG 2033|MEG 2313|MEG 2312|MEG 2342|MEG 2343|MEG 2322|MEG 2323|MEG 2433|MEG 2432|MEG 2122|MEG 2123|MEG 2333|MEG 2332|MEG 2513|MEG 2512|MEG 2523|MEG 2522|MEG 2133|MEG 2132|MEG 2542|MEG 2543|MEG 2532|MEG 2533|MEG 2031|MEG 2311|MEG 2341|MEG 2321|MEG 2431|MEG 2121|MEG 2331|MEG 2511|MEG 2521|MEG 2131|MEG 2541|MEG 2531
Left-frontal:MEG 0522|MEG 0523|MEG 0512|MEG 0513|MEG 0312|MEG 0313|MEG 0342|MEG 0343|MEG 0122|MEG 0123|MEG 0822|MEG 0823|MEG 0533|MEG 0532|MEG 0543|MEG 0542|MEG 0322|MEG 0323|MEG 0612|MEG 0613|MEG 0333|MEG 0332|MEG 0622|MEG 0623|MEG 0643|MEG 0642|MEG 0521|MEG 0511|MEG 0311|MEG 0341|MEG 0121|MEG 0821|MEG 0531|MEG 0541|MEG 0321|MEG 0611|MEG 0331|MEG 0621|MEG 0641
Right-frontal:MEG 0813|MEG 0812|MEG 0912|MEG 0913|MEG 0922|MEG 0923|MEG 1212|MEG 1213|MEG 1223|MEG 1222|MEG 1412|MEG 1413|MEG 0943|MEG 0942|MEG 0933|MEG 0932|MEG 1232|MEG 1233|MEG 1012|MEG 1013|MEG 1022|MEG 1023|MEG 1243|MEG 1242|MEG 1033|MEG 1032|MEG 0811|MEG 0911|MEG 0921|MEG 1211|MEG 1221|MEG 1411|MEG 0941|MEG 0931|MEG 1231|MEG 1011|MEG 1021|MEG 1241|MEG 1031
-
+#
+# EEG in groups of 32 channels
+#
+EEG 1-32:EEG 001|EEG 002|EEG 003|EEG 004|EEG 005|EEG 006|EEG 007|EEG 008|EEG 009|EEG 010|EEG 011|EEG 012|EEG 013|EEG 014|EEG 015|EEG 016|EEG 017|EEG 018|EEG 019|EEG 020|EEG 021|EEG 022|EEG 023|EEG 024|EEG 025|EEG 026|EEG 027|EEG 028|EEG 029|EEG 030|EEG 031|EEG 032
+EEG 33-64:EEG 033|EEG 034|EEG 035|EEG 036|EEG 037|EEG 038|EEG 039|EEG 040|EEG 041|EEG 042|EEG 043|EEG 044|EEG 045|EEG 046|EEG 047|EEG 048|EEG 049|EEG 050|EEG 051|EEG 052|EEG 053|EEG 054|EEG 055|EEG 056|EEG 057|EEG 058|EEG 059|EEG 060|EEG 061|EEG 062|EEG 063|EEG 064
+EEG 65-96:EEG 065|EEG 066|EEG 067|EEG 068|EEG 069|EEG 070|EEG 071|EEG 072|EEG 073|EEG 074|EEG 075|EEG 076|EEG 077|EEG 078|EEG 079|EEG 080|EEG 081|EEG 082|EEG 083|EEG 084|EEG 085|EEG 086|EEG 087|EEG 088|EEG 089|EEG 090|EEG 091|EEG 092|EEG 093|EEG 094|EEG 095|EEG 096
+EEG 97-128:EEG 097|EEG 098|EEG 099|EEG 100|EEG 101|EEG 102|EEG 103|EEG 104|EEG 105|EEG 106|EEG 107|EEG 108|EEG 109|EEG 110|EEG 111|EEG 112|EEG 113|EEG 114|EEG 115|EEG 116|EEG 117|EEG 118|EEG 119|EEG 120|EEG 121|EEG 122|EEG 123|EEG 124|EEG 125|EEG 126|EEG 127|EEG 128
\ No newline at end of file
diff --git a/mne/datasets/__init__.py b/mne/datasets/__init__.py
index bc86467..b97da34 100644
--- a/mne/datasets/__init__.py
+++ b/mne/datasets/__init__.py
@@ -7,6 +7,7 @@ from . import megsim
from . import misc
from . import sample
from . import somato
+from . import multimodal
from . import spm_face
from . import testing
from . import _fake
diff --git a/mne/datasets/brainstorm/__init__.py b/mne/datasets/brainstorm/__init__.py
index eb985dc..9e129ad 100644
--- a/mne/datasets/brainstorm/__init__.py
+++ b/mne/datasets/brainstorm/__init__.py
@@ -1,4 +1,5 @@
"""Brainstorm Dataset
"""
-from . import bst_raw, bst_resting, bst_auditory
+from . import (bst_raw, bst_resting, bst_auditory, bst_phantom_ctf,
+ bst_phantom_elekta)
diff --git a/mne/datasets/brainstorm/bst_auditory.py b/mne/datasets/brainstorm/bst_auditory.py
index af8bcc9..22af474 100644
--- a/mne/datasets/brainstorm/bst_auditory.py
+++ b/mne/datasets/brainstorm/bst_auditory.py
@@ -2,9 +2,10 @@
#
# License: BSD (3-clause)
+from functools import partial
import os.path as op
+
from ...utils import verbose
-from ...fixes import partial
from ..utils import (has_dataset, _data_path, _get_version, _version_doc,
_data_path_doc)
diff --git a/mne/datasets/brainstorm/bst_resting.py b/mne/datasets/brainstorm/bst_phantom_ctf.py
similarity index 71%
copy from mne/datasets/brainstorm/bst_resting.py
copy to mne/datasets/brainstorm/bst_phantom_ctf.py
index 8e999e0..ab10b63 100644
--- a/mne/datasets/brainstorm/bst_resting.py
+++ b/mne/datasets/brainstorm/bst_phantom_ctf.py
@@ -1,32 +1,31 @@
-# Authors: Mainak Jas <mainak.jas at telecom-paristech.fr>
+# Authors: Eric Larson <larson.eric.d at gmail.com>
#
# License: BSD (3-clause)
+from functools import partial
import os.path as op
+
from ...utils import verbose
-from ...fixes import partial
from ..utils import (has_dataset, _data_path, _get_version, _version_doc,
_data_path_doc)
has_brainstorm_data = partial(has_dataset, name='brainstorm')
+
_description = u"""
-URL: http://neuroimage.usc.edu/brainstorm/DatasetResting
- - One subject
- - Two runs of 10 min of resting state recordings
- - Eyes open
+URL: http://neuroimage.usc.edu/brainstorm/Tutorials/PhantomCtf
"""
@verbose
def data_path(path=None, force_update=False, update_path=True, download=True,
verbose=None):
- archive_name = dict(brainstorm='bst_resting.tar.gz')
+ archive_name = dict(brainstorm='bst_phantom_ctf.tar.gz')
data_path = _data_path(path=path, force_update=force_update,
update_path=update_path, name='brainstorm',
download=download, archive_name=archive_name)
if data_path != '':
- return op.join(data_path, 'bst_resting')
+ return op.join(data_path, 'bst_phantom_ctf')
else:
return data_path
@@ -34,7 +33,7 @@ _data_path_doc = _data_path_doc.format(name='brainstorm',
conf='MNE_DATASETS_BRAINSTORM_DATA'
'_PATH')
_data_path_doc = _data_path_doc.replace('brainstorm dataset',
- 'brainstorm (bst_resting) dataset')
+ 'brainstorm (bst_phantom_ctf) dataset')
data_path.__doc__ = _data_path_doc
@@ -45,7 +44,7 @@ get_version.__doc__ = _version_doc.format(name='brainstorm')
def description():
- """Get description of brainstorm (bst_resting) dataset
+ """Get description of brainstorm (bst_phantom_ctf) dataset
"""
for desc in _description.splitlines():
print(desc)
diff --git a/mne/datasets/brainstorm/bst_resting.py b/mne/datasets/brainstorm/bst_phantom_elekta.py
similarity index 71%
copy from mne/datasets/brainstorm/bst_resting.py
copy to mne/datasets/brainstorm/bst_phantom_elekta.py
index 8e999e0..3c66b8b 100644
--- a/mne/datasets/brainstorm/bst_resting.py
+++ b/mne/datasets/brainstorm/bst_phantom_elekta.py
@@ -1,32 +1,31 @@
-# Authors: Mainak Jas <mainak.jas at telecom-paristech.fr>
+# Authors: Eric Larson <larson.eric.d at gmail.com>
#
# License: BSD (3-clause)
+from functools import partial
import os.path as op
+
from ...utils import verbose
-from ...fixes import partial
from ..utils import (has_dataset, _data_path, _get_version, _version_doc,
_data_path_doc)
has_brainstorm_data = partial(has_dataset, name='brainstorm')
+
_description = u"""
-URL: http://neuroimage.usc.edu/brainstorm/DatasetResting
- - One subject
- - Two runs of 10 min of resting state recordings
- - Eyes open
+URL: http://neuroimage.usc.edu/brainstorm/Tutorials/PhantomElekta
"""
@verbose
def data_path(path=None, force_update=False, update_path=True, download=True,
verbose=None):
- archive_name = dict(brainstorm='bst_resting.tar.gz')
+ archive_name = dict(brainstorm='bst_phantom_elekta.tar.gz')
data_path = _data_path(path=path, force_update=force_update,
update_path=update_path, name='brainstorm',
download=download, archive_name=archive_name)
if data_path != '':
- return op.join(data_path, 'bst_resting')
+ return op.join(data_path, 'bst_phantom_elekta')
else:
return data_path
@@ -34,7 +33,8 @@ _data_path_doc = _data_path_doc.format(name='brainstorm',
conf='MNE_DATASETS_BRAINSTORM_DATA'
'_PATH')
_data_path_doc = _data_path_doc.replace('brainstorm dataset',
- 'brainstorm (bst_resting) dataset')
+ 'brainstorm (bst_phantom_elekta) '
+ 'dataset')
data_path.__doc__ = _data_path_doc
@@ -45,7 +45,7 @@ get_version.__doc__ = _version_doc.format(name='brainstorm')
def description():
- """Get description of brainstorm (bst_resting) dataset
+ """Get description of brainstorm (bst_phantom_elekta) dataset
"""
for desc in _description.splitlines():
print(desc)
diff --git a/mne/datasets/brainstorm/bst_raw.py b/mne/datasets/brainstorm/bst_raw.py
index dc3a187..6d94854 100644
--- a/mne/datasets/brainstorm/bst_raw.py
+++ b/mne/datasets/brainstorm/bst_raw.py
@@ -2,9 +2,10 @@
#
# License: BSD (3-clause)
+from functools import partial
import os.path as op
+
from ...utils import verbose
-from ...fixes import partial
from ..utils import (has_dataset, _data_path, _get_version, _version_doc,
_data_path_doc)
diff --git a/mne/datasets/brainstorm/bst_resting.py b/mne/datasets/brainstorm/bst_resting.py
index 8e999e0..2077fd3 100644
--- a/mne/datasets/brainstorm/bst_resting.py
+++ b/mne/datasets/brainstorm/bst_resting.py
@@ -2,9 +2,10 @@
#
# License: BSD (3-clause)
+from functools import partial
import os.path as op
+
from ...utils import verbose
-from ...fixes import partial
from ..utils import (has_dataset, _data_path, _get_version, _version_doc,
_data_path_doc)
diff --git a/mne/datasets/megsim/urls.py b/mne/datasets/megsim/urls.py
index c073b78..0c316ec 100644
--- a/mne/datasets/megsim/urls.py
+++ b/mne/datasets/megsim/urls.py
@@ -31,7 +31,7 @@ urls = ['/empdata/neuromag/visual/subject1_day1_vis_raw.fif',
'/simdata_singleTrials/4545_sim_oscOnly_v1_IPS_ILOG_30hzAdded.fif',
'/index.html',
-]
+ ]
data_formats = ['raw',
'raw',
diff --git a/mne/datasets/multimodal/__init__.py b/mne/datasets/multimodal/__init__.py
new file mode 100644
index 0000000..947071e
--- /dev/null
+++ b/mne/datasets/multimodal/__init__.py
@@ -0,0 +1,4 @@
+"""Multimodal dataset
+"""
+
+from .multimodal import data_path, has_multimodal_data, get_version
diff --git a/mne/datasets/somato/somato.py b/mne/datasets/multimodal/multimodal.py
similarity index 59%
copy from mne/datasets/somato/somato.py
copy to mne/datasets/multimodal/multimodal.py
index d0daf98..d8eb31a 100644
--- a/mne/datasets/somato/somato.py
+++ b/mne/datasets/multimodal/multimodal.py
@@ -3,27 +3,28 @@
# Eric Larson <larson.eric.d at gmail.com>
# License: BSD Style.
+from functools import partial
+
from ...utils import verbose
-from ...fixes import partial
from ..utils import (has_dataset, _data_path, _data_path_doc,
_get_version, _version_doc)
-has_somato_data = partial(has_dataset, name='somato')
+has_multimodal_data = partial(has_dataset, name='multimodal')
@verbose
def data_path(path=None, force_update=False, update_path=True, download=True,
verbose=None):
return _data_path(path=path, force_update=force_update,
- update_path=update_path, name='somato',
+ update_path=update_path, name='multimodal',
download=download)
-data_path.__doc__ = _data_path_doc.format(name='somato',
- conf='MNE_DATASETS_SOMATO_PATH')
+data_path.__doc__ = _data_path_doc.format(name='multimodal',
+ conf='MNE_DATASETS_MULTIMODAL_PATH')
def get_version():
- return _get_version('somato')
+ return _get_version('multimodal')
-get_version.__doc__ = _version_doc.format(name='somato')
+get_version.__doc__ = _version_doc.format(name='multimodal')
diff --git a/mne/datasets/sample/sample.py b/mne/datasets/sample/sample.py
index 46f40d9..68b977d 100644
--- a/mne/datasets/sample/sample.py
+++ b/mne/datasets/sample/sample.py
@@ -3,10 +3,11 @@
# Eric Larson <larson.eric.d at gmail.com>
# License: BSD Style.
+from functools import partial
+
import numpy as np
from ...utils import verbose, get_config
-from ...fixes import partial
from ..utils import (has_dataset, _data_path, _data_path_doc,
_get_version, _version_doc)
diff --git a/mne/datasets/somato/somato.py b/mne/datasets/somato/somato.py
index d0daf98..fd11302 100644
--- a/mne/datasets/somato/somato.py
+++ b/mne/datasets/somato/somato.py
@@ -3,8 +3,9 @@
# Eric Larson <larson.eric.d at gmail.com>
# License: BSD Style.
+from functools import partial
+
from ...utils import verbose
-from ...fixes import partial
from ..utils import (has_dataset, _data_path, _data_path_doc,
_get_version, _version_doc)
diff --git a/mne/datasets/spm_face/spm_data.py b/mne/datasets/spm_face/spm_data.py
index 8fea978..c476d2a 100644
--- a/mne/datasets/spm_face/spm_data.py
+++ b/mne/datasets/spm_face/spm_data.py
@@ -2,10 +2,11 @@
#
# License: BSD Style.
+from functools import partial
+
import numpy as np
from ...utils import verbose, get_config
-from ...fixes import partial
from ..utils import (has_dataset, _data_path, _data_path_doc,
_get_version, _version_doc)
diff --git a/mne/datasets/testing/_testing.py b/mne/datasets/testing/_testing.py
index 932bd2e..658a963 100644
--- a/mne/datasets/testing/_testing.py
+++ b/mne/datasets/testing/_testing.py
@@ -3,10 +3,11 @@
# Eric Larson <larson.eric.d at gmail.com>
# License: BSD Style.
+from functools import partial
+
import numpy as np
from ...utils import verbose, get_config
-from ...fixes import partial
from ..utils import (has_dataset, _data_path, _data_path_doc,
_get_version, _version_doc)
diff --git a/mne/datasets/tests/test_datasets.py b/mne/datasets/tests/test_datasets.py
index 34614ca..cdfdef1 100644
--- a/mne/datasets/tests/test_datasets.py
+++ b/mne/datasets/tests/test_datasets.py
@@ -1,3 +1,4 @@
+import os
from os import path as op
from nose.tools import assert_true, assert_equal
@@ -19,6 +20,15 @@ def test_datasets():
assert_true(isinstance(dataset.get_version(), string_types))
else:
assert_true(dataset.get_version() is None)
+ tempdir = _TempDir()
+ # don't let it read from the config file to get the directory,
+ # force it to look for the default
+ os.environ['_MNE_FAKE_HOME_DIR'] = tempdir
+ try:
+ assert_equal(datasets.utils._get_path(None, 'foo', 'bar'),
+ op.join(tempdir, 'mne_data'))
+ finally:
+ del os.environ['_MNE_FAKE_HOME_DIR']
@requires_good_network
diff --git a/mne/datasets/utils.py b/mne/datasets/utils.py
index 41047a3..3ea20c4 100644
--- a/mne/datasets/utils.py
+++ b/mne/datasets/utils.py
@@ -69,10 +69,11 @@ tutorial, e.g. for research purposes, is prohibited without written consent
from the MEG Lab.
If you reference this dataset in your publications, please:
-1) aknowledge its authors: Elizabeth Bock, Esther Florin, Francois Tadel and
-Sylvain Baillet
-2) cite Brainstorm as indicated on the website:
-http://neuroimage.usc.edu/brainstorm
+
+ 1) acknowledge its authors: Elizabeth Bock, Esther Florin, Francois Tadel
+ and Sylvain Baillet, and
+ 2) cite Brainstorm as indicated on the website:
+ http://neuroimage.usc.edu/brainstorm
For questions, please contact Francois Tadel (francois.tadel at mcgill.ca).
"""
@@ -94,35 +95,31 @@ def _dataset_version(path, name):
def _get_path(path, key, name):
"""Helper to get a dataset path"""
- if path is None:
- # use an intelligent guess if it's not defined
- def_path = op.realpath(op.join(op.dirname(__file__), '..', '..',
- 'examples'))
- if get_config(key) is None:
- key = 'MNE_DATA'
- path = get_config(key, def_path)
-
- # use the same for all datasets
- if not op.exists(path) or not os.access(path, os.W_OK):
- try:
- os.mkdir(path)
- except OSError:
- try:
- logger.info('Checking for %s data in '
- '"~/mne_data"...' % name)
- path = op.join(op.expanduser("~"), "mne_data")
- if not op.exists(path):
- logger.info("Trying to create "
- "'~/mne_data' in home directory")
- os.mkdir(path)
- except OSError:
- raise OSError("User does not have write permissions "
- "at '%s', try giving the path as an "
- "argument to data_path() where user has "
- "write permissions, for ex:data_path"
- "('/home/xyz/me2/')" % (path))
- if not isinstance(path, string_types):
- raise ValueError('path must be a string or None')
+ # 1. Input
+ if path is not None:
+ if not isinstance(path, string_types):
+ raise ValueError('path must be a string or None')
+ return path
+ # 2. get_config(key)
+ # 3. get_config('MNE_DATA')
+ path = get_config(key, get_config('MNE_DATA'))
+ if path is not None:
+ return path
+ # 4. ~/mne_data (but use a fake home during testing so we don't
+ # unnecessarily create ~/mne_data)
+ logger.info('Using default location ~/mne_data for %s...' % name)
+ path = op.join(os.getenv('_MNE_FAKE_HOME_DIR',
+ op.expanduser("~")), 'mne_data')
+ if not op.exists(path):
+ logger.info('Creating ~/mne_data')
+ try:
+ os.mkdir(path)
+ except OSError:
+ raise OSError("User does not have write permissions "
+ "at '%s', try giving the path as an "
+ "argument to data_path() where user has "
+ "write permissions, for ex:data_path"
+ "('/home/xyz/me2/')" % (path))
return path
@@ -145,7 +142,7 @@ def _do_path_update(path, update_path, key, name):
update_path = False
if update_path is True:
- set_config(key, path)
+ set_config(key, path, set_env=False)
return path
@@ -162,12 +159,13 @@ def _data_path(path=None, force_update=False, update_path=True, download=True,
'somato': 'MNE_DATASETS_SOMATO_PATH',
'brainstorm': 'MNE_DATASETS_BRAINSTORM_PATH',
'testing': 'MNE_DATASETS_TESTING_PATH',
+ 'multimodal': 'MNE_DATASETS_MULTIMODAL_PATH',
}[name]
path = _get_path(path, key, name)
# To update the testing or misc dataset, push commits, then make a new
# release on GitHub. Then update the "releases" variable:
- releases = dict(testing='0.19', misc='0.1')
+ releases = dict(testing='0.25', misc='0.1')
# And also update the "hashes['testing']" variable below.
# To update any other dataset, update the data archive itself (upload
@@ -178,6 +176,7 @@ def _data_path(path=None, force_update=False, update_path=True, download=True,
somato='MNE-somato-data.tar.gz',
spm='MNE-spm-face.tar.gz',
testing='mne-testing-data-%s.tar.gz' % releases['testing'],
+ multimodal='MNE-multimodal-data.tar.gz',
fake='foo.tgz',
)
if archive_name is not None:
@@ -188,6 +187,7 @@ def _data_path(path=None, force_update=False, update_path=True, download=True,
misc='MNE-misc-data',
sample='MNE-sample-data',
somato='MNE-somato-data',
+ multimodal='MNE-multimodal-data',
spm='MNE-spm-face',
testing='MNE-testing-data',
)
@@ -203,6 +203,7 @@ def _data_path(path=None, force_update=False, update_path=True, download=True,
spm='https://mne-tools.s3.amazonaws.com/datasets/%s',
testing='https://codeload.github.com/mne-tools/mne-testing-data/'
'tar.gz/%s' % releases['testing'],
+ multimodal='https://ndownloader.figshare.com/files/5999598',
)
hashes = dict(
brainstorm=None,
@@ -211,7 +212,8 @@ def _data_path(path=None, force_update=False, update_path=True, download=True,
sample='1d5da3a809fded1ef5734444ab5bf857',
somato='f3e3a8441477bb5bacae1d0c6e0964fb',
spm='f61041e3f3f2ba0def8a2ca71592cc41',
- testing='77b2a435d80adb23cbe7e19144e7bc47'
+ testing='217aed43e361c86b622dc0363ae3cef4',
+ multimodal='26ec847ae9ab80f58f204d09e2c08367',
)
folder_origs = dict( # not listed means None
misc='mne-misc-data-%s' % releases['misc'],
@@ -358,29 +360,29 @@ def _download_all_example_data(verbose=True):
"""Helper to download all datasets used in examples and tutorials"""
# This function is designed primarily to be used by CircleCI. It has
# verbose=True by default so we get nice status messages
+ # Consider adding datasets from here to CircleCI for PR-auto-build
from . import (sample, testing, misc, spm_face, somato, brainstorm, megsim,
- eegbci)
+ eegbci, multimodal)
sample.data_path()
testing.data_path()
misc.data_path()
spm_face.data_path()
somato.data_path()
+ multimodal.data_path()
sys.argv += ['--accept-brainstorm-license']
try:
brainstorm.bst_raw.data_path()
brainstorm.bst_auditory.data_path()
+ brainstorm.bst_phantom_elekta.data_path()
+ brainstorm.bst_phantom_ctf.data_path()
finally:
sys.argv.pop(-1)
- sys.argv += ['--update-dataset-path']
- try:
- megsim.load_data(condition='visual', data_format='single-trial',
- data_type='simulation')
- megsim.load_data(condition='visual', data_format='raw',
- data_type='experimental')
- megsim.load_data(condition='visual', data_format='evoked',
- data_type='simulation')
- finally:
- sys.argv.pop(-1)
+ megsim.load_data(condition='visual', data_format='single-trial',
+ data_type='simulation', update_path=True)
+ megsim.load_data(condition='visual', data_format='raw',
+ data_type='experimental', update_path=True)
+ megsim.load_data(condition='visual', data_format='evoked',
+ data_type='simulation', update_path=True)
url_root = 'http://www.physionet.org/physiobank/database/eegmmidb/'
eegbci.data_path(url_root + 'S001/S001R06.edf', update_path=True)
eegbci.data_path(url_root + 'S001/S001R10.edf', update_path=True)
diff --git a/mne/decoding/__init__.py b/mne/decoding/__init__.py
index 9a431a4..9764863 100644
--- a/mne/decoding/__init__.py
+++ b/mne/decoding/__init__.py
@@ -1,7 +1,9 @@
from .transformer import Scaler, FilterEstimator
-from .transformer import PSDEstimator, EpochsVectorizer
+from .transformer import (PSDEstimator, EpochsVectorizer, Vectorizer,
+ UnsupervisedSpatialFilter, TemporalFilter)
from .mixin import TransformerMixin
from .base import BaseEstimator, LinearModel
from .csp import CSP
-from .ems import compute_ems
+from .ems import compute_ems, EMS
from .time_gen import GeneralizationAcrossTime, TimeDecoding
+from .time_frequency import TimeFrequency
diff --git a/mne/decoding/base.py b/mne/decoding/base.py
index 8e65dcb..5f845c2 100644
--- a/mne/decoding/base.py
+++ b/mne/decoding/base.py
@@ -10,6 +10,7 @@ import numpy as np
from ..externals.six import iteritems
from ..fixes import _get_args
+from ..utils import check_version
class BaseEstimator(object):
@@ -368,9 +369,20 @@ class LinearModel(BaseEstimator):
If None, the maximum absolute value is used. If vmin is None,
but vmax is not, defaults to np.min(data).
If callable, the output equals vmax(data).
- cmap : matplotlib colormap
- Colormap. For magnetometers and eeg defaults to 'RdBu_r', else
- 'Reds'.
+ cmap : matplotlib colormap | (colormap, bool) | 'interactive' | None
+ Colormap to use. If tuple, the first value indicates the colormap
+ to use and the second value is a boolean defining interactivity. In
+ interactive mode the colors are adjustable by clicking and dragging
+ the colorbar with left and right mouse button. Left mouse button
+ moves the scale up and down and right mouse button adjusts the
+ range. Hitting space bar resets the range. Up and down arrows can
+ be used to change the colormap. If None, 'Reds' is used for all
+ positive data, otherwise defaults to 'RdBu_r'. If 'interactive',
+ translates to (None, True). Defaults to 'RdBu_r'.
+
+ .. warning:: Interactive mode works smoothly only for a small
+ amount of topomaps.
+
sensors : bool | str
Add markers for sensor locations to the plot. Accepts matplotlib
plot format string (e.g., 'r+' for red plusses). If True,
@@ -474,7 +486,7 @@ class LinearModel(BaseEstimator):
mask=mask, outlines=outlines,
contours=contours, title=title,
image_interp=image_interp, show=show,
- head_pos=head_pos)
+ head_pos=head_pos, average=average)
def plot_filters(self, info, times=None, ch_type=None, layout=None,
vmin=None, vmax=None, cmap='RdBu_r', sensors=True,
@@ -518,9 +530,20 @@ class LinearModel(BaseEstimator):
If None, the maximum absolute value is used. If vmin is None,
but vmax is not, defaults to np.min(data).
If callable, the output equals vmax(data).
- cmap : matplotlib colormap
- Colormap. For magnetometers and eeg defaults to 'RdBu_r', else
- 'Reds'.
+ cmap : matplotlib colormap | (colormap, bool) | 'interactive' | None
+ Colormap to use. If tuple, the first value indicates the colormap
+ to use and the second value is a boolean defining interactivity. In
+ interactive mode the colors are adjustable by clicking and dragging
+ the colorbar with left and right mouse button. Left mouse button
+ moves the scale up and down and right mouse button adjusts the
+ range. Hitting space bar resets the range. Up and down arrows can
+ be used to change the colormap. If None, 'Reds' is used for all
+ positive data, otherwise defaults to 'RdBu_r'. If 'interactive',
+ translates to (None, True). Defaults to 'RdBu_r'.
+
+ .. warning:: Interactive mode works smoothly only for a small
+ amount of topomaps.
+
sensors : bool | str
Add markers for sensor locations to the plot. Accepts matplotlib
plot format string (e.g., 'r+' for red plusses). If True,
@@ -624,4 +647,79 @@ class LinearModel(BaseEstimator):
mask=mask, outlines=outlines,
contours=contours, title=title,
image_interp=image_interp, show=show,
- head_pos=head_pos)
+ head_pos=head_pos, average=average)
+
+
+def _set_cv(cv, estimator=None, X=None, y=None):
+ """ Set the default cross-validation depending on whether clf is classifier
+ or regressor. """
+
+ from sklearn.base import is_classifier
+
+ # Detect whether classification or regression
+ if estimator in ['classifier', 'regressor']:
+ est_is_classifier = estimator == 'classifier'
+ else:
+ est_is_classifier = is_classifier(estimator)
+ # Setup CV
+ if check_version('sklearn', '0.18'):
+ from sklearn import model_selection as models
+ from sklearn.model_selection import (check_cv, StratifiedKFold, KFold)
+ if isinstance(cv, (int, np.int)):
+ XFold = StratifiedKFold if est_is_classifier else KFold
+ cv = XFold(n_splits=cv)
+ elif isinstance(cv, str):
+ if not hasattr(models, cv):
+ raise ValueError('Unknown cross-validation')
+ cv = getattr(models, cv)
+ cv = cv()
+ cv = check_cv(cv=cv, y=y, classifier=est_is_classifier)
+ else:
+ from sklearn import cross_validation as models
+ from sklearn.cross_validation import (check_cv, StratifiedKFold, KFold)
+ if isinstance(cv, (int, np.int)):
+ if est_is_classifier:
+ cv = StratifiedKFold(y=y, n_folds=cv)
+ else:
+ cv = KFold(n=len(y), n_folds=cv)
+ elif isinstance(cv, str):
+ if not hasattr(models, cv):
+ raise ValueError('Unknown cross-validation')
+ cv = getattr(models, cv)
+ if cv.__name__ not in ['KFold', 'LeaveOneOut']:
+ raise NotImplementedError('CV cannot be defined with str for'
+ ' sklearn < .017.')
+ cv = cv(len(y))
+ cv = check_cv(cv=cv, X=X, y=y, classifier=est_is_classifier)
+
+ # Extract train and test set to retrieve them at predict time
+ if hasattr(cv, 'split'):
+ cv_splits = [(train, test) for train, test in
+ cv.split(X=np.zeros_like(y), y=y)]
+ else:
+ # XXX support sklearn.cross_validation cv
+ cv_splits = [(train, test) for train, test in cv]
+
+ if not np.all([len(train) for train, _ in cv_splits]):
+ raise ValueError('Some folds do not have any train epochs.')
+
+ return cv, cv_splits
+
+
+def _check_estimator(estimator, get_params=True):
+ """Check whether an object has the fit, transform, fit_transform and
+ get_params methods required by scikit-learn"""
+ valid_methods = ('predict', 'transform', 'predict_proba',
+ 'decision_function')
+ if (
+ (not hasattr(estimator, 'fit')) or
+ (not any(hasattr(estimator, method) for method in valid_methods))
+ ):
+ raise ValueError('estimator must be a scikit-learn transformer or '
+ 'an estimator with the fit and a predict-like (e.g. '
+ 'predict_proba) or a transform method.')
+
+ if get_params and not hasattr(estimator, 'get_params'):
+ raise ValueError('estimator must be a scikit-learn transformer or an '
+ 'estimator with the get_params method that allows '
+ 'cloning.')
diff --git a/mne/decoding/csp.py b/mne/decoding/csp.py
index ade2177..2d34303 100644
--- a/mne/decoding/csp.py
+++ b/mne/decoding/csp.py
@@ -3,6 +3,7 @@
# Alexandre Gramfort <alexandre.gramfort at telecom-paristech.fr>
# Alexandre Barachant <alexandre.barachant at gmail.com>
# Clemens Brunner <clemens.brunner at gmail.com>
+# Jean-Remi King <jeanremi.king at gmail.com>
#
# License: BSD (3-clause)
@@ -11,46 +12,54 @@ import copy as cp
import numpy as np
from scipy import linalg
-from .mixin import TransformerMixin, EstimatorMixin
+from .mixin import TransformerMixin
+from .base import BaseEstimator
from ..cov import _regularized_covariance
+from ..utils import warn
-class CSP(TransformerMixin, EstimatorMixin):
+class CSP(TransformerMixin, BaseEstimator):
"""M/EEG signal decomposition using the Common Spatial Patterns (CSP).
This object can be used as a supervised decomposition to estimate
spatial filters for feature extraction in a 2 class decoding problem.
CSP in the context of EEG was first described in [1]; a comprehensive
- tutorial on CSP can be found in [2].
+ tutorial on CSP can be found in [2]. Multiclass solving is implemented
+ from [3].
Parameters
----------
- n_components : int (default 4)
+ n_components : int, defaults to 4
The number of components to decompose M/EEG signals.
This number should be set by cross-validation.
- reg : float | str | None (default None)
+ reg : float | str | None, defaults to None
if not None, allow regularization for covariance estimation
if float, shrinkage covariance is used (0 <= shrinkage <= 1).
if str, optimal shrinkage using Ledoit-Wolf Shrinkage ('ledoit_wolf')
or Oracle Approximating Shrinkage ('oas').
- log : bool (default True)
- If true, apply log to standardize the features.
- If false, features are just z-scored.
- cov_est : str (default 'concat')
+ log : None | bool, defaults to None
+ If transform_into == 'average_power' and log is None or True, then
+ applies a log transform to standardize the features, else the features
+ are z-scored. If transform_into == 'csp_space', then log must be None.
+ cov_est : 'concat' | 'epoch', defaults to 'concat'
If 'concat', covariance matrices are estimated on concatenated epochs
for each class.
If 'epoch', covariance matrices are estimated on each epoch separately
and then averaged over each class.
+ transform_into : {'average_power', 'csp_space'}
+ If 'average_power' then self.transform will return the average power of
+ each spatial filter. If 'csp_space' self.transform will return the data
+ in CSP space. Defaults to 'average_power'.
Attributes
----------
- filters_ : ndarray, shape (n_channels, n_channels)
+ ``filters_`` : ndarray, shape (n_channels, n_channels)
If fit, the CSP components used to decompose the data, else None.
- patterns_ : ndarray, shape (n_channels, n_channels)
+ ``patterns_`` : ndarray, shape (n_channels, n_channels)
If fit, the CSP patterns used to restore M/EEG signals, else None.
- mean_ : ndarray, shape (n_channels,)
+ ``mean_`` : ndarray, shape (n_components,)
If fit, the mean squared power for each component.
- std_ : ndarray, shape (n_channels,)
+ ``std_`` : ndarray, shape (n_components,)
If fit, the std squared power for each component.
References
@@ -62,40 +71,67 @@ class CSP(TransformerMixin, EstimatorMixin):
Klaus-Robert Müller. Optimizing Spatial Filters for Robust EEG
Single-Trial Analysis. IEEE Signal Processing Magazine 25(1), 41-56,
2008.
+ [3] Grosse-Wentrup, Moritz, and Martin Buss. Multiclass common spatial
+ patterns and information theoretic feature extraction. IEEE
+ Transactions on Biomedical Engineering, Vol 55, no. 8, 2008.
"""
- def __init__(self, n_components=4, reg=None, log=True, cov_est="concat"):
+ def __init__(self, n_components=4, reg=None, log=None, cov_est="concat",
+ transform_into='average_power'):
"""Init of CSP."""
+ # Init default CSP
+ if not isinstance(n_components, int):
+ raise ValueError('n_components must be an integer.')
self.n_components = n_components
+
+ # Init default regularization
+ if (
+ (reg is not None) and
+ (reg not in ['oas', 'ledoit_wolf']) and
+ ((not isinstance(reg, (float, int))) or
+ (not ((reg <= 1.) and (reg >= 0.))))
+ ):
+ raise ValueError('reg must be None, "oas", "ledoit_wolf" or a '
+ 'float in between 0. and 1.')
self.reg = reg
- self.log = log
+
+ # Init default cov_est
+ if not (cov_est == "concat" or cov_est == "epoch"):
+ raise ValueError("unknown covariance estimation method")
self.cov_est = cov_est
- self.filters_ = None
- self.patterns_ = None
- self.mean_ = None
- self.std_ = None
- def get_params(self, deep=True):
- """Return all parameters (mimics sklearn API).
+ # Init default transform_into
+ if transform_into not in ('average_power', 'csp_space'):
+ raise ValueError('transform_into must be "average_power" or '
+ '"csp_space".')
+ self.transform_into = transform_into
+
+ # Init default log
+ if transform_into == 'average_power':
+ if log is not None and not isinstance(log, bool):
+ raise ValueError('log must be a boolean if transform_into == '
+ '"average_power".')
+ else:
+ if log is not None:
+ raise ValueError('log must be a None if transform_into == '
+ '"csp_space".')
+ self.log = log
- Parameters
- ----------
- deep: boolean, optional
- If True, will return the parameters for this estimator and
- contained subobjects that are estimators.
- """
- params = {"n_components": self.n_components,
- "reg": self.reg,
- "log": self.log}
- return params
+ def _check_Xy(self, X, y=None):
+ """Aux. function to check input data."""
+ if y is not None:
+ if len(X) != len(y) or len(y) < 1:
+ raise ValueError('X and y must have the same length.')
+ if X.ndim < 3:
+ raise ValueError('X must have at least 3 dimensions.')
- def fit(self, epochs_data, y):
+ def fit(self, X, y, epochs_data=None):
"""Estimate the CSP decomposition on epochs.
Parameters
----------
- epochs_data : ndarray, shape (n_epochs, n_channels, n_times)
- The data to estimate the CSP on.
+ X : ndarray, shape (n_epochs, n_channels, n_times)
+ The data on which to estimate the CSP.
y : array, shape (n_epochs,)
The class for each epoch.
@@ -104,55 +140,81 @@ class CSP(TransformerMixin, EstimatorMixin):
self : instance of CSP
Returns the modified instance.
"""
- if not isinstance(epochs_data, np.ndarray):
- raise ValueError("epochs_data should be of type ndarray (got %s)."
- % type(epochs_data))
- epochs_data = np.atleast_3d(epochs_data)
- e, c, t = epochs_data.shape
- # check number of epochs
- if e != len(y):
- raise ValueError("n_epochs must be the same for epochs_data and y")
- classes = np.unique(y)
- if len(classes) != 2:
- raise ValueError("More than two different classes in the data.")
- if not (self.cov_est == "concat" or self.cov_est == "epoch"):
- raise ValueError("unknown covariance estimation method")
-
- if self.cov_est == "concat": # concatenate epochs
- class_1 = np.transpose(epochs_data[y == classes[0]],
- [1, 0, 2]).reshape(c, -1)
- class_2 = np.transpose(epochs_data[y == classes[1]],
- [1, 0, 2]).reshape(c, -1)
- cov_1 = _regularized_covariance(class_1, reg=self.reg)
- cov_2 = _regularized_covariance(class_2, reg=self.reg)
- elif self.cov_est == "epoch":
- class_1 = epochs_data[y == classes[0]]
- class_2 = epochs_data[y == classes[1]]
- cov_1 = np.zeros((c, c))
- for t in class_1:
- cov_1 += _regularized_covariance(t, reg=self.reg)
- cov_1 /= class_1.shape[0]
- cov_2 = np.zeros((c, c))
- for t in class_2:
- cov_2 += _regularized_covariance(t, reg=self.reg)
- cov_2 /= class_2.shape[0]
-
- # normalize by trace
- cov_1 /= np.trace(cov_1)
- cov_2 /= np.trace(cov_2)
-
- e, w = linalg.eigh(cov_1, cov_1 + cov_2)
- n_vals = len(e)
- # Rearrange vectors
- ind = np.empty(n_vals, dtype=int)
- ind[::2] = np.arange(n_vals - 1, n_vals // 2 - 1, -1)
- ind[1::2] = np.arange(0, n_vals // 2)
- w = w[:, ind] # first, last, second, second last, third, ...
- self.filters_ = w.T
- self.patterns_ = linalg.pinv(w)
+ X = _check_deprecate(epochs_data, X)
+ if not isinstance(X, np.ndarray):
+ raise ValueError("X should be of type ndarray (got %s)."
+ % type(X))
+ self._check_Xy(X, y)
+ n_channels = X.shape[1]
+
+ self._classes = np.unique(y)
+ n_classes = len(self._classes)
+ if n_classes < 2:
+ raise ValueError("n_classes must be >= 2.")
+
+ covs = np.zeros((n_classes, n_channels, n_channels))
+ sample_weights = list()
+ for class_idx, this_class in enumerate(self._classes):
+ if self.cov_est == "concat": # concatenate epochs
+ class_ = np.transpose(X[y == this_class], [1, 0, 2])
+ class_ = class_.reshape(n_channels, -1)
+ cov = _regularized_covariance(class_, reg=self.reg)
+ weight = sum(y == this_class)
+ elif self.cov_est == "epoch":
+ class_ = X[y == this_class]
+ cov = np.zeros((n_channels, n_channels))
+ for this_X in class_:
+ cov += _regularized_covariance(this_X, reg=self.reg)
+ cov /= len(class_)
+ weight = len(class_)
+
+ # normalize by trace and stack
+ covs[class_idx] = cov / np.trace(cov)
+ sample_weights.append(weight)
+
+ if n_classes == 2:
+ eigen_values, eigen_vectors = linalg.eigh(covs[0], covs.sum(0))
+ # sort eigenvectors
+ ix = np.argsort(np.abs(eigen_values - 0.5))[::-1]
+ else:
+ # The multiclass case is adapted from
+ # http://github.com/alexandrebarachant/pyRiemann
+ eigen_vectors, D = _ajd_pham(covs)
+
+ # Here we apply an euclidean mean. See pyRiemann for other metrics
+ mean_cov = np.average(covs, axis=0, weights=sample_weights)
+ eigen_vectors = eigen_vectors.T
+
+ # normalize
+ for ii in range(eigen_vectors.shape[1]):
+ tmp = np.dot(np.dot(eigen_vectors[:, ii].T, mean_cov),
+ eigen_vectors[:, ii])
+ eigen_vectors[:, ii] /= np.sqrt(tmp)
+
+ # class probability
+ class_probas = [np.mean(y == _class) for _class in self._classes]
+
+ # mutual information
+ mutual_info = []
+ for jj in range(eigen_vectors.shape[1]):
+ aa, bb = 0, 0
+ for (cov, prob) in zip(covs, class_probas):
+ tmp = np.dot(np.dot(eigen_vectors[:, jj].T, cov),
+ eigen_vectors[:, jj])
+ aa += prob * np.log(np.sqrt(tmp))
+ bb += prob * (tmp ** 2 - 1)
+ mi = - (aa + (3.0 / 16) * (bb ** 2))
+ mutual_info.append(mi)
+ ix = np.argsort(mutual_info)[::-1]
+
+ # sort eigenvectors
+ eigen_vectors = eigen_vectors[:, ix]
+
+ self.filters_ = eigen_vectors.T
+ self.patterns_ = linalg.pinv(eigen_vectors)
pick_filters = self.filters_[:self.n_components]
- X = np.asarray([np.dot(pick_filters, epoch) for epoch in epochs_data])
+ X = np.asarray([np.dot(pick_filters, epoch) for epoch in X])
# compute features (mean band power)
X = (X ** 2).mean(axis=-1)
@@ -163,38 +225,41 @@ class CSP(TransformerMixin, EstimatorMixin):
return self
- def transform(self, epochs_data, y=None):
+ def transform(self, X, epochs_data=None):
"""Estimate epochs sources given the CSP filters.
Parameters
----------
- epochs_data : array, shape (n_epochs, n_channels, n_times)
+ X : array, shape (n_epochs, n_channels, n_times)
The data.
- y : None
- Not used.
Returns
-------
- X : ndarray of shape (n_epochs, n_sources)
- The CSP features averaged over time.
+ X : ndarray
+ If self.transform_into == 'average_power' then returns the power of
+ CSP features averaged over time and shape (n_epochs, n_sources)
+ If self.transform_into == 'csp_space' then returns the data in CSP
+ space and shape is (n_epochs, n_sources, n_times)
"""
- if not isinstance(epochs_data, np.ndarray):
- raise ValueError("epochs_data should be of type ndarray (got %s)."
- % type(epochs_data))
+ X = _check_deprecate(epochs_data, X)
+ if not isinstance(X, np.ndarray):
+ raise ValueError("X should be of type ndarray (got %s)." % type(X))
if self.filters_ is None:
raise RuntimeError('No filters available. Please first fit CSP '
'decomposition.')
pick_filters = self.filters_[:self.n_components]
- X = np.asarray([np.dot(pick_filters, epoch) for epoch in epochs_data])
+ X = np.asarray([np.dot(pick_filters, epoch) for epoch in X])
# compute features (mean band power)
- X = (X ** 2).mean(axis=-1)
- if self.log:
- X = np.log(X)
- else:
- X -= self.mean_
- X /= self.std_
+ if self.transform_into == 'average_power':
+ X = (X ** 2).mean(axis=-1)
+ log = True if self.log is None else self.log
+ if log:
+ X = np.log(X)
+ else:
+ X -= self.mean_
+ X /= self.std_
return X
def plot_patterns(self, info, components=None, ch_type=None, layout=None,
@@ -236,9 +301,20 @@ class CSP(TransformerMixin, EstimatorMixin):
If None, the maximum absolute value is used. If vmin is None,
but vmax is not, defaults to np.min(data).
If callable, the output equals vmax(data).
- cmap : matplotlib colormap
- Colormap. For magnetometers and eeg defaults to 'RdBu_r', else
- 'Reds'.
+ cmap : matplotlib colormap | (colormap, bool) | 'interactive' | None
+ Colormap to use. If tuple, the first value indicates the colormap
+ to use and the second value is a boolean defining interactivity. In
+ interactive mode the colors are adjustable by clicking and dragging
+ the colorbar with left and right mouse button. Left mouse button
+ moves the scale up and down and right mouse button adjusts the
+ range. Hitting space bar resets the range. Up and down arrows can
+ be used to change the colormap. If None, 'Reds' is used for all
+ positive data, otherwise defaults to 'RdBu_r'. If 'interactive',
+ translates to (None, True). Defaults to 'RdBu_r'.
+
+ .. warning:: Interactive mode works smoothly only for a small
+ amount of topomaps.
+
sensors : bool | str
Add markers for sensor locations to the plot. Accepts matplotlib
plot format string (e.g., 'r+' for red plusses). If True,
@@ -381,9 +457,20 @@ class CSP(TransformerMixin, EstimatorMixin):
If None, the maximum absolute value is used. If vmin is None,
but vmax is not, defaults to np.min(data).
If callable, the output equals vmax(data).
- cmap : matplotlib colormap
- Colormap. For magnetometers and eeg defaults to 'RdBu_r', else
- 'Reds'.
+ cmap : matplotlib colormap | (colormap, bool) | 'interactive' | None
+ Colormap to use. If tuple, the first value indicates the colormap
+ to use and the second value is a boolean defining interactivity. In
+ interactive mode the colors are adjustable by clicking and dragging
+ the colorbar with left and right mouse button. Left mouse button
+ moves the scale up and down and right mouse button adjusts the
+ range. Hitting space bar resets the range. Up and down arrows can
+ be used to change the colormap. If None, 'Reds' is used for all
+ positive data, otherwise defaults to 'RdBu_r'. If 'interactive',
+ translates to (None, True). Defaults to 'RdBu_r'.
+
+ .. warning:: Interactive mode works smoothly only for a small
+ amount of topomaps.
+
sensors : bool | str
Add markers for sensor locations to the plot. Accepts matplotlib
plot format string (e.g., 'r+' for red plusses). If True,
@@ -486,3 +573,95 @@ class CSP(TransformerMixin, EstimatorMixin):
contours=contours,
image_interp=image_interp, show=show,
head_pos=head_pos)
+
+
+def _ajd_pham(X, eps=1e-6, max_iter=15):
+ """Approximate joint diagonalization based on Pham's algorithm.
+
+ This is a direct implementation of the PHAM's AJD algorithm [1].
+
+ Parameters
+ ----------
+ X : ndarray, shape (n_epochs, n_channels, n_channels)
+ A set of covariance matrices to diagonalize.
+ eps : float, defaults to 1e-6
+ The tolerance for stoping criterion.
+ max_iter : int, defaults to 1000
+ The maximum number of iteration to reach convergence.
+
+ Returns
+ -------
+ V : ndarray, shape (n_channels, n_channels)
+ The diagonalizer.
+ D : ndarray, shape (n_epochs, n_channels, n_channels)
+ The set of quasi diagonal matrices.
+
+ References
+ ----------
+ [1] Pham, Dinh Tuan. "Joint approximate diagonalization of positive
+ definite Hermitian matrices." SIAM Journal on Matrix Analysis and
+ Applications 22, no. 4 (2001): 1136-1152.
+
+ """
+ # Adapted from http://github.com/alexandrebarachant/pyRiemann
+ n_epochs = X.shape[0]
+
+ # Reshape input matrix
+ A = np.concatenate(X, axis=0).T
+
+ # Init variables
+ n_times, n_m = A.shape
+ V = np.eye(n_times)
+ epsilon = n_times * (n_times - 1) * eps
+
+ for it in range(max_iter):
+ decr = 0
+ for ii in range(1, n_times):
+ for jj in range(ii):
+ Ii = np.arange(ii, n_m, n_times)
+ Ij = np.arange(jj, n_m, n_times)
+
+ c1 = A[ii, Ii]
+ c2 = A[jj, Ij]
+
+ g12 = np.mean(A[ii, Ij] / c1)
+ g21 = np.mean(A[ii, Ij] / c2)
+
+ omega21 = np.mean(c1 / c2)
+ omega12 = np.mean(c2 / c1)
+ omega = np.sqrt(omega12 * omega21)
+
+ tmp = np.sqrt(omega21 / omega12)
+ tmp1 = (tmp * g12 + g21) / (omega + 1)
+ tmp2 = (tmp * g12 - g21) / max(omega - 1, 1e-9)
+
+ h12 = tmp1 + tmp2
+ h21 = np.conj((tmp1 - tmp2) / tmp)
+
+ decr += n_epochs * (g12 * np.conj(h12) + g21 * h21) / 2.0
+
+ tmp = 1 + 1.j * 0.5 * np.imag(h12 * h21)
+ tmp = np.real(tmp + np.sqrt(tmp ** 2 - h12 * h21))
+ tau = np.array([[1, -h12 / tmp], [-h21 / tmp, 1]])
+
+ A[[ii, jj], :] = np.dot(tau, A[[ii, jj], :])
+ tmp = np.c_[A[:, Ii], A[:, Ij]]
+ tmp = np.reshape(tmp, (n_times * n_epochs, 2), order='F')
+ tmp = np.dot(tmp, tau.T)
+
+ tmp = np.reshape(tmp, (n_times, n_epochs * 2), order='F')
+ A[:, Ii] = tmp[:, :n_epochs]
+ A[:, Ij] = tmp[:, n_epochs:]
+ V[[ii, jj], :] = np.dot(tau, V[[ii, jj], :])
+ if decr < epsilon:
+ break
+ D = np.reshape(A, (n_times, n_m / n_times, n_times)).transpose(1, 0, 2)
+ return V, D
+
+
+def _check_deprecate(epochs_data, X):
+ """Aux. function to CSP to deal with the change param name."""
+ if epochs_data is not None:
+ X = epochs_data
+ warn('epochs_data will be deprecated in mne 0.14. Use X instead')
+ return X
diff --git a/mne/decoding/ems.py b/mne/decoding/ems.py
index 6064086..34b6c2b 100644
--- a/mne/decoding/ems.py
+++ b/mne/decoding/ems.py
@@ -1,19 +1,98 @@
# Author: Denis Engemann <denis.engemann at gmail.com>
# Alexandre Gramfort <alexandre.gramfort at telecom-paristech.fr>
+# Jean-Remi King <jeanremi.king at gmail.com>
#
# License: BSD (3-clause)
+from collections import Counter
+
import numpy as np
+from .mixin import TransformerMixin, EstimatorMixin
+from .base import _set_cv
from ..utils import logger, verbose
-from ..fixes import Counter
from ..parallel import parallel_func
from .. import pick_types, pick_info
+class EMS(TransformerMixin, EstimatorMixin):
+ """Transformer to compute event-matched spatial filters.
+
+ This version operates on the entire time course. The result is a spatial
+ filter at each time point and a corresponding time course. Intuitively,
+ the result gives the similarity between the filter at each time point and
+ the data vector (sensors) at that time point.
+
+ .. note : EMS only works for binary classification.
+
+ References
+ ----------
+ [1] Aaron Schurger, Sebastien Marti, and Stanislas Dehaene, "Reducing
+ multi-sensor data to a single time course that reveals experimental
+ effects", BMC Neuroscience 2013, 14:122
+
+ Attributes
+ ----------
+ filters_ : ndarray, shape (n_channels, n_times)
+ The set of spatial filters.
+ classes_ : ndarray, shape (n_classes,)
+ The target classes.
+ """
+
+ def __repr__(self):
+ if hasattr(self, 'filters_'):
+ return '<EMS: fitted with %i filters on %i classes.>' % (
+ len(self.filters_), len(self.classes_))
+ else:
+ return '<EMS: not fitted.>'
+
+ def fit(self, X, y):
+ """Fit the spatial filters.
+
+ .. note : EMS is fitted on data normalized by channel type before the
+ fitting of the spatial filters.
+
+ Parameters
+ ----------
+ X : array, shape (n_epochs, n_channels, n_times)
+ The training data.
+ y : array of int, shape (n_epochs)
+ The target classes.
+
+ Returns
+ -------
+ self : returns and instance of self.
+ """
+ classes = np.unique(y)
+ if len(classes) != 2:
+ raise ValueError('EMS only works for binary classification.')
+ self.classes_ = classes
+ filters = X[y == classes[0]].mean(0) - X[y == classes[1]].mean(0)
+ filters /= np.linalg.norm(filters, axis=0)[None, :]
+ self.filters_ = filters
+ return self
+
+ def transform(self, X):
+ """Transform the data by the spatial filters.
+
+ Parameters
+ ----------
+ X : array, shape (n_epochs, n_channels, n_times)
+ The input data.
+
+ Returns
+ -------
+ X : array, shape (n_epochs, n_times)
+ The input data transformed by the spatial filters.
+ """
+ Xt = np.sum(X * self.filters_, axis=1)
+ return Xt
+
+
@verbose
-def compute_ems(epochs, conditions=None, picks=None, n_jobs=1, verbose=None):
- """Compute event-matched spatial filter on epochs
+def compute_ems(epochs, conditions=None, picks=None, n_jobs=1, verbose=None,
+ cv=None):
+ """Compute event-matched spatial filter on epochs.
This version operates on the entire time course. No time window needs to
be specified. The result is a spatial filter at each time point and a
@@ -21,6 +100,15 @@ def compute_ems(epochs, conditions=None, picks=None, n_jobs=1, verbose=None):
between the filter at each time point and the data vector (sensors) at
that time point.
+ .. note : EMS only works for binary classification.
+ .. note : The present function applies a leave-one-out cross-validation,
+ following Schurger et al's paper. However, we recommend using
+ a stratified k-fold cross-validation. Indeed, leave-one-out tends
+ to overfit and cannot be used to estimate the variance of the
+ prediction within a given fold.
+ .. note : Because of the leave-one-out, thise function needs an equal
+ number of epochs in each of the two conditions.
+
References
----------
[1] Aaron Schurger, Sebastien Marti, and Stanislas Dehaene, "Reducing
@@ -31,29 +119,33 @@ def compute_ems(epochs, conditions=None, picks=None, n_jobs=1, verbose=None):
----------
epochs : instance of mne.Epochs
The epochs.
- conditions : list of str | None
- If a list of strings, strings must match the
- epochs.event_id's key as well as the number of conditions supported
- by the objective_function. If None keys in epochs.event_id are used.
- picks : array-like of int | None
+ conditions : list of str | None, defaults to None
+ If a list of strings, strings must match the epochs.event_id's key as
+ well as the number of conditions supported by the objective_function.
+ If None keys in epochs.event_id are used.
+ picks : array-like of int | None, defaults to None
Channels to be included. If None only good data channels are used.
- Defaults to None
- n_jobs : int
+ n_jobs : int, defaults to 1
Number of jobs to run in parallel.
- verbose : bool, str, int, or None
+ verbose : bool, str, int, or None, defaults to self.verbose
If not None, override default verbose level (see mne.verbose).
- Defaults to self.verbose.
+ cv : cross-validation object | str | None, defaults to LeaveOneOut
+ The cross-validation scheme.
Returns
-------
- surrogate_trials : ndarray, shape (trials, n_trials, n_time_points)
+ surrogate_trials : ndarray, shape (n_trials // 2, n_times)
The trial surrogates.
mean_spatial_filter : ndarray, shape (n_channels, n_times)
The set of spatial filters.
- conditions : ndarray, shape (n_epochs,)
+ conditions : ndarray, shape (n_classes,)
The conditions used. Values correspond to original event ids.
"""
logger.info('...computing surrogate time series. This can take some time')
+
+ # Default to leave-one-out cv
+ cv = 'LeaveOneOut' if cv is None else cv
+
if picks is None:
picks = pick_types(epochs.info, meg=True, eeg=True)
@@ -76,7 +168,7 @@ def compute_ems(epochs, conditions=None, picks=None, n_jobs=1, verbose=None):
len(conditions))
ev = epochs.events[:, 2]
- # special care to avoid path dependent mappings and orders
+ # Special care to avoid path dependent mappings and orders
conditions = list(sorted(conditions))
cond_idx = [np.where(ev == epochs.event_id[k])[0] for k in conditions]
@@ -84,30 +176,28 @@ def compute_ems(epochs, conditions=None, picks=None, n_jobs=1, verbose=None):
data = epochs.get_data()[:, picks]
# Scale (z-score) the data by channel type
+ # XXX the z-scoring is applied outside the CV, which is not standard.
for ch_type in ['mag', 'grad', 'eeg']:
if ch_type in epochs:
+ # FIXME should be applied to all sort of data channels
if ch_type == 'eeg':
this_picks = pick_types(info, meg=False, eeg=True)
else:
this_picks = pick_types(info, meg=ch_type, eeg=False)
data[:, this_picks] /= np.std(data[:, this_picks])
- try:
- from sklearn.model_selection import LeaveOneOut
- except: # XXX support sklearn < 0.18
- from sklearn.cross_validation import LeaveOneOut
-
- def _iter_cv(n): # XXX support sklearn < 0.18
- if hasattr(LeaveOneOut, 'split'):
- cv = LeaveOneOut()
- return cv.split(np.zeros((n, 1)))
- else:
- cv = LeaveOneOut(len(data))
- return cv
+ # Setup cross-validation. Need to use _set_cv to deal with sklearn
+ # deprecation of cv objects.
+ y = epochs.events[:, 2]
+ _, cv_splits = _set_cv(cv, 'classifier', X=y, y=y)
parallel, p_func, _ = parallel_func(_run_ems, n_jobs=n_jobs)
+ # FIXME this parallization should be removed.
+ # 1) it's numpy computation so it's already efficient,
+ # 2) it duplicates the data in RAM,
+ # 3) the computation is already super fast.
out = parallel(p_func(_ems_diff, data, cond_idx, train, test)
- for train, test in _iter_cv(len(data)))
+ for train, test in cv_splits)
surrogate_trials, spatial_filter = zip(*out)
surrogate_trials = np.array(surrogate_trials)
@@ -117,7 +207,8 @@ def compute_ems(epochs, conditions=None, picks=None, n_jobs=1, verbose=None):
def _ems_diff(data0, data1):
- """default diff objective function"""
+ """Aux. function to compute_ems that computes the default diff
+ objective function."""
return np.mean(data0, axis=0) - np.mean(data1, axis=0)
diff --git a/mne/decoding/search_light.py b/mne/decoding/search_light.py
new file mode 100644
index 0000000..6c9e301
--- /dev/null
+++ b/mne/decoding/search_light.py
@@ -0,0 +1,629 @@
+# Author: Jean-Remi King <jeanremi.king at gmail.com>
+#
+# License: BSD (3-clause)
+
+import numpy as np
+
+from .mixin import TransformerMixin
+from .base import BaseEstimator, _check_estimator
+from ..parallel import parallel_func
+
+
+class _SearchLight(BaseEstimator, TransformerMixin):
+ """Search Light.
+
+ Fit, predict and score a series of models to each subset of the dataset
+ along the last dimension.
+
+ Parameters
+ ----------
+ base_estimator : object
+ The base estimator to iteratively fit on a subset of the dataset.
+ scoring : callable, string, defaults to None
+ Score function (or loss function) with signature
+ score_func(y, y_pred, **kwargs).
+ n_jobs : int, optional (default=1)
+ The number of jobs to run in parallel for both `fit` and `predict`.
+ If -1, then the number of jobs is set to the number of cores.
+ """
+ def __repr__(self):
+ repr_str = '<' + super(_SearchLight, self).__repr__()
+ if hasattr(self, 'estimators_'):
+ repr_str = repr_str[:-1]
+ repr_str += ', fitted with %i estimators' % len(self.estimators_)
+ return repr_str + '>'
+
+ def __init__(self, base_estimator, scoring=None, n_jobs=1):
+ _check_estimator(base_estimator)
+ self.base_estimator = base_estimator
+ self.n_jobs = n_jobs
+ self.scoring = scoring
+
+ if not isinstance(self.n_jobs, int):
+ raise ValueError('n_jobs must be int, got %s' % n_jobs)
+
+ def fit_transform(self, X, y):
+ """
+ Fit and transform a series of independent estimators to the dataset.
+
+ Parameters
+ ----------
+ X : array, shape (n_samples, nd_features, n_estimators)
+ The training input samples. For each data slice, a clone estimator
+ is fitted independently. The feature dimension can be
+ multidimensional e.g.
+ X.shape = (n_samples, n_features_1, n_features_2, n_estimators)
+ y : array, shape (n_samples,) | (n_samples, n_targets)
+ The target values.
+
+ Returns
+ -------
+ y_pred : array, shape (n_samples, n_estimators) | (n_samples, n_estimators, n_targets) # noqa
+ The predicted values for each estimator.
+ """
+ return self.fit(X, y).transform(X)
+
+ def fit(self, X, y):
+ """Fit a series of independent estimators to the dataset.
+
+ Parameters
+ ----------
+ X : array, shape (n_samples, nd_features, n_estimators)
+ The training input samples. For each data slice, a clone estimator
+ is fitted independently. The feature dimension can be
+ multidimensional e.g.
+ X.shape = (n_samples, n_features_1, n_features_2, n_estimators)
+ y : array, shape (n_samples,) | (n_samples, n_targets)
+ The target values.
+
+ Returns
+ -------
+ self : object
+ Return self.
+ """
+ self._check_Xy(X, y)
+ self.estimators_ = list()
+ # For fitting, the parallelization is across estimators.
+ parallel, p_func, n_jobs = parallel_func(_sl_fit, self.n_jobs)
+ n_jobs = min(n_jobs, X.shape[-1])
+ estimators = parallel(
+ p_func(self.base_estimator, split, y)
+ for split in np.array_split(X, n_jobs, axis=-1))
+ self.estimators_ = np.concatenate(estimators, 0)
+ return self
+
+ def _transform(self, X, method):
+ """Aux. function to make parallel predictions/transformation."""
+ self._check_Xy(X)
+ method = _check_method(self.base_estimator, method)
+ if X.shape[-1] != len(self.estimators_):
+ raise ValueError('The number of estimators does not match '
+ 'X.shape[-1]')
+ # For predictions/transforms the parallelization is across the data and
+ # not across the estimators to avoid memory load.
+ parallel, p_func, n_jobs = parallel_func(_sl_transform, self.n_jobs)
+ n_jobs = min(n_jobs, X.shape[-1])
+ X_splits = np.array_split(X, n_jobs, axis=-1)
+ est_splits = np.array_split(self.estimators_, n_jobs)
+ y_pred = parallel(p_func(est, x, method)
+ for (est, x) in zip(est_splits, X_splits))
+
+ if n_jobs > 1:
+ y_pred = np.concatenate(y_pred, axis=1)
+ else:
+ y_pred = y_pred[0]
+ return y_pred
+
+ def transform(self, X):
+ """Transform each data slice with a series of independent estimators.
+
+ Parameters
+ ----------
+ X : array, shape (n_samples, nd_features, n_estimators)
+ The input samples. For each data slice, the corresponding estimator
+ makes a transformation of the data:
+ e.g. [estimators[ii].transform(X[..., ii])
+ for ii in range(n_estimators)]
+ The feature dimension can be multidimensional e.g.
+ X.shape = (n_samples, n_features_1, n_features_2, n_estimators)
+
+ Returns
+ -------
+ Xt : array, shape (n_samples, n_estimators)
+ The transformed values generated by each estimator.
+ """
+ return self._transform(X, 'transform')
+
+ def predict(self, X):
+ """Predict each data slice with a series of independent estimators.
+
+ Parameters
+ ----------
+ X : array, shape (n_samples, nd_features, n_estimators)
+ The input samples. For each data slice, the corresponding estimator
+ makes the sample predictions:
+ e.g. [estimators[ii].predict(X[..., ii])
+ for ii in range(n_estimators)]
+ The feature dimension can be multidimensional e.g.
+ X.shape = (n_samples, n_features_1, n_features_2, n_estimators)
+
+ Returns
+ -------
+ y_pred : array, shape (n_samples, n_estimators) | (n_samples, n_estimators, n_targets) # noqa
+ Predicted values for each estimator/data slice.
+ """
+ return self._transform(X, 'predict')
+
+ def predict_proba(self, X):
+ """Predict each data slice with a series of independent estimators.
+
+ Parameters
+ ----------
+ X : array, shape (n_samples, nd_features, n_estimators)
+ The input samples. For each data slice, the corresponding estimator
+ makes the sample probabilistic predictions:
+ e.g. [estimators[ii].predict_proba(X[..., ii])
+ for ii in range(n_estimators)]
+ The feature dimension can be multidimensional e.g.
+ X.shape = (n_samples, n_features_1, n_features_2, n_estimators)
+
+
+ Returns
+ -------
+ y_pred : array, shape (n_samples, n_estimators, n_classes)
+ Predicted probabilities for each estimator/data slice.
+ """
+ return self._transform(X, 'predict_proba')
+
+ def decision_function(self, X):
+ """Estimate distances of each data slice to the hyperplanes.
+
+ Parameters
+ ----------
+ X : array, shape (n_samples, nd_features, n_estimators)
+ The input samples. For each data slice, the corresponding estimator
+ outputs the distance to the hyperplane:
+ e.g. [estimators[ii].decision_function(X[..., ii])
+ for ii in range(n_estimators)]
+ The feature dimension can be multidimensional e.g.
+ X.shape = (n_samples, n_features_1, n_features_2, n_estimators)
+
+ Returns
+ -------
+ y_pred : array, shape (n_samples, n_estimators, n_classes * (n_classes-1) / 2) # noqa
+ Predicted distances for each estimator/data slice.
+
+ Notes
+ -----
+ This requires base_estimator to have a `decision_function` method.
+ """
+ return self._transform(X, 'decision_function')
+
+ def _check_Xy(self, X, y=None):
+ """Aux. function to check input data."""
+ if y is not None:
+ if len(X) != len(y) or len(y) < 1:
+ raise ValueError('X and y must have the same length.')
+ if X.ndim < 3:
+ raise ValueError('X must have at least 3 dimensions.')
+
+ def score(self, X, y):
+ """Returns the score obtained for each estimators/data slice couple.
+
+ Parameters
+ ----------
+ X : array, shape (n_samples, nd_features, n_estimators)
+ The input samples. For each data slice, the corresponding estimator
+ scores the prediction: e.g. [estimators[ii].score(X[..., ii], y)
+ for ii in range(n_estimators)]
+ The feature dimension can be multidimensional e.g.
+ X.shape = (n_samples, n_features_1, n_features_2, n_estimators)
+
+ y : array, shape (n_samples,) | (n_samples, n_targets)
+ The target values.
+
+ Returns
+ -------
+ score : array, shape (n_samples, n_estimators)
+ Score for each estimator / data slice couple.
+ """
+ from sklearn.metrics import make_scorer, get_scorer
+ self._check_Xy(X)
+ if X.shape[-1] != len(self.estimators_):
+ raise ValueError('The number of estimators does not match '
+ 'X.shape[-1]')
+
+ # If scoring is None (default), the predictions are internally
+ # generated by estimator.score(). Else, we must first get the
+ # predictions based on the scorer.
+ if not isinstance(self.scoring, str):
+ scoring_ = (make_scorer(self.scoring) if self.scoring is
+ not None else self.scoring)
+
+ elif self.scoring is not None:
+ scoring_ = get_scorer(self.scoring)
+
+ # For predictions/transforms the parallelization is across the data and
+ # not across the estimators to avoid memory load.
+ parallel, p_func, n_jobs = parallel_func(_sl_score, self.n_jobs)
+ n_jobs = min(n_jobs, X.shape[-1])
+ X_splits = np.array_split(X, n_jobs, axis=-1)
+ est_splits = np.array_split(self.estimators_, n_jobs)
+ score = parallel(p_func(est, scoring_, X, y)
+ for (est, x) in zip(est_splits, X_splits))
+
+ if n_jobs > 1:
+ score = np.concatenate(score, axis=0)
+ else:
+ score = score[0]
+ return score
+
+
+def _sl_fit(estimator, X, y):
+ """Aux. function to fit _SearchLight in parallel.
+
+ Fit a clone estimator to each slice of data.
+
+ Parameters
+ ----------
+ base_estimator : object
+ The base estimator to iteratively fit on a subset of the dataset.
+ X : array, shape (n_samples, nd_features, n_estimators)
+ The target data. The feature dimension can be multidimensional e.g.
+ X.shape = (n_samples, n_features_1, n_features_2, n_estimators)
+ y : array, shape (n_sample, )
+ The target values.
+
+ Returns
+ -------
+ estimators_ : list of estimators
+ The fitted estimators.
+ """
+ from sklearn.base import clone
+ estimators_ = list()
+ for ii in range(X.shape[-1]):
+ est = clone(estimator)
+ est.fit(X[..., ii], y)
+ estimators_.append(est)
+ return estimators_
+
+
+def _sl_transform(estimators, X, method):
+ """Aux. function to transform _SearchLight in parallel.
+
+ Applies transform/predict/decision_function etc for each slice of data.
+
+ Parameters
+ ----------
+ estimators : list of estimators
+ The fitted estimators.
+ X : array, shape (n_samples, nd_features, n_estimators)
+ The target data. The feature dimension can be multidimensional e.g.
+ X.shape = (n_samples, n_features_1, n_features_2, n_estimators)
+ method : str
+ The estimator method to use (e.g. 'predict', 'transform').
+
+ Returns
+ -------
+ y_pred : array, shape (n_samples, n_estimators, n_classes * (n_classes-1) / 2) # noqa
+ The transformations for each slice of data.
+ """
+ for ii, est in enumerate(estimators):
+ transform = getattr(est, method)
+ _y_pred = transform(X[..., ii])
+ # Initialize array of predictions on the first transform iteration
+ if ii == 0:
+ y_pred = _sl_init_pred(_y_pred, X)
+ y_pred[:, ii, ...] = _y_pred
+ return y_pred
+
+
+def _sl_init_pred(y_pred, X):
+ """Aux. function to _SearchLight to initialize y_pred."""
+ n_sample, n_iter = X.shape[0], X.shape[-1]
+ if y_pred.ndim > 1:
+ # for estimator that generate multidimensional y_pred,
+ # e.g. clf.predict_proba()
+ y_pred = np.zeros(np.r_[n_sample, n_iter, y_pred.shape[1:]],
+ y_pred.dtype)
+ else:
+ # for estimator that generate unidimensional y_pred,
+ # e.g. clf.predict()
+ y_pred = np.zeros((n_sample, n_iter), y_pred.dtype)
+ return y_pred
+
+
+def _sl_score(estimators, scoring, X, y):
+ """Aux. function to score _SearchLight in parallel.
+
+ Predict and score each slice of data.
+
+ Parameters
+ ----------
+ estimators : list of estimators
+ The fitted estimators.
+ X : array, shape (n_samples, nd_features, n_estimators)
+ The target data. The feature dimension can be multidimensional e.g.
+ X.shape = (n_samples, n_features_1, n_features_2, n_estimators)
+ scoring : callable, string or None
+ If scoring is None (default), the predictions are internally
+ generated by estimator.score(). Else, we must first get the
+ predictions to pass them to ad-hoc scorer.
+ y : array, shape (n_samples,) | (n_samples, n_targets)
+ The target values.
+
+ Returns
+ -------
+ score : array, shape (n_estimators,)
+ The score for each slice of data.
+ """
+ n_iter = X.shape[-1]
+ for ii, est in enumerate(estimators):
+ if scoring is not None:
+ _score = scoring(est, X[..., ii], y)
+ else:
+ _score = est.score(X[..., ii], y)
+ # Initialize array of scores on the first score iteration
+ if ii == 0:
+ if isinstance(_score, np.ndarray):
+ dtype = _score.dtype
+ shape = _score.shape
+ np.r_[n_iter, _score.shape]
+ else:
+ dtype = type(_score)
+ shape = n_iter
+ score = np.zeros(shape, dtype)
+ score[ii] = _score
+ return score
+
+
+def _check_method(estimator, method):
+ """Checks that an estimator has the method attribute.
+ If method == 'transform' and estimator does not have 'transform', use
+ 'predict' instead.
+ """
+ if method == 'transform' and not hasattr(estimator, 'transform'):
+ method = 'predict'
+ if not hasattr(estimator, method):
+ ValueError('base_estimator does not have `%s` method.' % method)
+ return method
+
+
+class _GeneralizationLight(_SearchLight):
+ """Generalization Light
+
+ Fit a search-light along the last dimension and use them to apply a
+ systematic cross-feature generalization.
+
+ Parameters
+ ----------
+ base_estimator : object
+ The base estimator to iteratively fit on a subset of the dataset.
+ scoring : callable | string | None
+ Score function (or loss function) with signature
+ score_func(y, y_pred, **kwargs).
+ n_jobs : int, optional (default=1)
+ The number of jobs to run in parallel for both `fit` and `predict`.
+ If -1, then the number of jobs is set to the number of cores.
+ """
+ def __repr__(self):
+ repr_str = super(_GeneralizationLight, self).__repr__()
+ if hasattr(self, 'estimators_'):
+ repr_str = repr_str[:-1]
+ repr_str += ', fitted with %i estimators>' % len(self.estimators_)
+ return repr_str
+
+ def _transform(self, X, method):
+ """Aux. function to make parallel predictions/transformation"""
+ self._check_Xy(X)
+ method = _check_method(self.base_estimator, method)
+ parallel, p_func, n_jobs = parallel_func(_gl_transform, self.n_jobs)
+ n_jobs = min(n_jobs, X.shape[-1])
+ y_pred = parallel(
+ p_func(self.estimators_, x_split, method)
+ for x_split in np.array_split(X, n_jobs, axis=-1))
+
+ y_pred = np.concatenate(y_pred, axis=2)
+ return y_pred
+
+ def transform(self, X):
+ """Transform each data slice with all possible estimators.
+
+ Parameters
+ ----------
+ X : array, shape (n_samples, nd_features, n_slices)
+ The input samples. For estimator the corresponding data slice is
+ used to make a transformation. The feature dimension can be
+ multidimensional e.g.
+ X.shape = (n_samples, n_features_1, n_features_2, n_estimators)
+
+ Returns
+ -------
+ Xt : array, shape (n_samples, n_estimators, n_slices)
+ The transformed values generated by each estimator.
+ """
+ return self._transform(X, 'transform')
+
+ def predict(self, X):
+ """Predict each data slice with all possible estimators.
+
+ Parameters
+ ----------
+ X : array, shape (n_samples, nd_features, n_slices)
+ The training input samples. For each data slice, a fitted estimator
+ predicts each slice of the data independently. The feature
+ dimension can be multidimensional e.g.
+ X.shape = (n_samples, n_features_1, n_features_2, n_estimators)
+
+ Returns
+ -------
+ y_pred : array, shape (n_samples, n_estimators, n_slices) | (n_samples, n_estimators, n_slices, n_targets) # noqa
+ The predicted values for each estimator.
+ """
+ return self._transform(X, 'predict')
+
+ def predict_proba(self, X):
+ """Estimate probabilistic estimates of each data slice with all
+ possible estimators.
+
+ Parameters
+ ----------
+ X : array, shape (n_samples, nd_features, n_slices)
+ The training input samples. For each data slice, a fitted estimator
+ predicts a slice of the data. The feature dimension can be
+ multidimensional e.g.
+ X.shape = (n_samples, n_features_1, n_features_2, n_estimators)
+
+ Returns
+ -------
+ y_pred : array, shape (n_samples, n_estimators, n_slices, n_classes)
+ The predicted values for each estimator.
+
+ Notes
+ -----
+ This requires base_estimator to have a `predict_proba` method.
+ """
+ return self._transform(X, 'predict_proba')
+
+ def decision_function(self, X):
+ """Estimate distances of each data slice to all hyperplanes.
+
+ Parameters
+ ----------
+ X : array, shape (n_samples, nd_features, n_slices)
+ The training input samples. Each estimator outputs the distance to
+ its hyperplane: e.g. [estimators[ii].decision_function(X[..., ii])
+ for ii in range(n_estimators)]
+ The feature dimension can be multidimensional e.g.
+ X.shape = (n_samples, n_features_1, n_features_2, n_estimators)
+
+ Returns
+ -------
+ y_pred : array, shape (n_samples, n_estimators, n_slices, n_classes * (n_classes-1) / 2) # noqa
+ The predicted values for each estimator.
+
+ Notes
+ -----
+ This requires base_estimator to have a `decision_function` method.
+ """
+ return self._transform(X, 'decision_function')
+
+ def score(self, X, y):
+ """Score each of the estimators on the tested dimensions.
+
+ Parameters
+ ----------
+ X : array, shape (n_samples, nd_features, n_slices)
+ The input samples. For each data slice, the corresponding estimator
+ scores the prediction: e.g. [estimators[ii].score(X[..., ii], y)
+ for ii in range(n_slices)]
+ The feature dimension can be multidimensional e.g.
+ X.shape = (n_samples, n_features_1, n_features_2, n_estimators)
+ y : array, shape (n_samples,) | (n_samples, n_targets)
+ The target values.
+
+ Returns
+ -------
+ score : array, shape (n_samples, n_estimators, n_slices)
+ Score for each estimator / data slice couple.
+ """
+ self._check_Xy(X)
+ # For predictions/transforms the parallelization is across the data and
+ # not across the estimators to avoid memory load.
+ parallel, p_func, n_jobs = parallel_func(_gl_score, self.n_jobs)
+ n_jobs = min(n_jobs, X.shape[-1])
+ X_splits = np.array_split(X, n_jobs, axis=-1)
+ score = parallel(p_func(self.estimators_, x, y) for x in X_splits)
+
+ if n_jobs > 1:
+ score = np.concatenate(score, axis=1)
+ else:
+ score = score[0]
+ return score
+
+
+def _gl_transform(estimators, X, method):
+ """Transform the dataset by applying each estimator to all slices of
+ the data.
+
+ Parameters
+ ----------
+ X : array, shape (n_samples, nd_features, n_slices)
+ The training input samples. For each data slice, a clone estimator
+ is fitted independently. The feature dimension can be multidimensional
+ e.g. X.shape = (n_samples, n_features_1, n_features_2, n_estimators)
+
+ Returns
+ -------
+ Xt : array, shape (n_samples, n_slices)
+ The transformed values generated by each estimator.
+ """
+ n_sample, n_iter = X.shape[0], X.shape[-1]
+ for ii, est in enumerate(estimators):
+ # stack generalized data for faster prediction
+ X_stack = X.transpose(np.r_[0, X.ndim - 1, range(1, X.ndim - 1)])
+ X_stack = X_stack.reshape(np.r_[n_sample * n_iter, X_stack.shape[2:]])
+ transform = getattr(est, method)
+ _y_pred = transform(X_stack)
+ # unstack generalizations
+ if _y_pred.ndim == 2:
+ _y_pred = np.reshape(_y_pred, [n_sample, n_iter, _y_pred.shape[1]])
+ else:
+ shape = np.r_[n_sample, n_iter, _y_pred.shape[1:]].astype(int)
+ _y_pred = np.reshape(_y_pred, shape)
+ # Initialize array of predictions on the first transform iteration
+ if ii == 0:
+ y_pred = _gl_init_pred(_y_pred, X, len(estimators))
+ y_pred[:, ii, ...] = _y_pred
+ return y_pred
+
+
+def _gl_init_pred(y_pred, X, n_train):
+ """Aux. function to _GeneralizationLight to initialize y_pred"""
+ n_sample, n_iter = X.shape[0], X.shape[-1]
+ if y_pred.ndim == 3:
+ y_pred = np.zeros((n_sample, n_train, n_iter, y_pred.shape[-1]),
+ y_pred.dtype)
+ else:
+ y_pred = np.zeros((n_sample, n_train, n_iter), y_pred.dtype)
+ return y_pred
+
+
+def _gl_score(estimators, X, y):
+ """Aux. function to score _GeneralizationLight in parallel.
+ Predict and score each slice of data.
+
+ Parameters
+ ----------
+ estimators : list of estimators
+ The fitted estimators.
+ X : array, shape (n_samples, nd_features, n_slices)
+ The target data. The feature dimension can be multidimensional e.g.
+ X.shape = (n_samples, n_features_1, n_features_2, n_estimators)
+ y : array, shape (n_samples,) | (n_samples, n_targets)
+ The target values.
+
+ Returns
+ -------
+ score : array, shape (n_estimators, n_slices)
+ The score for each slice of data.
+ """
+ # FIXME: The level parallization may be a bit high, and might be memory
+ # consuming. Perhaps need to lower it down to the loop across X slices.
+ n_iter = X.shape[-1]
+ n_est = len(estimators)
+ for ii, est in enumerate(estimators):
+ for jj in range(X.shape[-1]):
+ _score = est.score(X[..., jj], y)
+
+ # Initialize array of predictions on the first score iteration
+ if (ii == 0) & (jj == 0):
+ if isinstance(_score, np.ndarray):
+ dtype = _score.dtype
+ shape = np.r_[n_est, n_iter, _score.shape]
+ else:
+ dtype = type(_score)
+ shape = [n_est, n_iter]
+ score = np.zeros(shape, dtype)
+ score[ii, jj, ...] = _score
+ return score
diff --git a/mne/decoding/tests/test_csp.py b/mne/decoding/tests/test_csp.py
index 09f80a5..d0ca998 100644
--- a/mne/decoding/tests/test_csp.py
+++ b/mne/decoding/tests/test_csp.py
@@ -5,12 +5,12 @@
import os.path as op
-from nose.tools import assert_true, assert_raises
+from nose.tools import assert_true, assert_raises, assert_equal
import numpy as np
-from numpy.testing import assert_array_almost_equal
+from numpy.testing import assert_array_almost_equal, assert_array_equal
from mne import io, Epochs, read_events, pick_types
-from mne.decoding.csp import CSP
+from mne.decoding.csp import CSP, _ajd_pham
from mne.utils import requires_sklearn, slow_test
data_dir = op.join(op.dirname(__file__), '..', '..', 'io', 'tests', 'data')
@@ -27,74 +27,125 @@ start, stop = 0, 8
def test_csp():
"""Test Common Spatial Patterns algorithm on epochs
"""
- raw = io.read_raw_fif(raw_fname, preload=False)
+ raw = io.read_raw_fif(raw_fname, preload=False, add_eeg_ref=False)
events = read_events(event_name)
picks = pick_types(raw.info, meg=True, stim=False, ecg=False,
eog=False, exclude='bads')
- picks = picks[2:9:3] # subselect channels -> disable proj!
+ picks = picks[2:12:3] # subselect channels -> disable proj!
raw.add_proj([], remove_existing=True)
epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks,
- baseline=(None, 0), preload=True, proj=False)
+ baseline=(None, 0), preload=True, proj=False,
+ add_eeg_ref=False)
epochs_data = epochs.get_data()
n_channels = epochs_data.shape[1]
+ y = epochs.events[:, -1]
+
+ # Init
+ assert_raises(ValueError, CSP, n_components='foo')
+ for reg in ['foo', -0.1, 1.1]:
+ assert_raises(ValueError, CSP, reg=reg)
+ for reg in ['oas', 'ledoit_wolf', 0, 0.5, 1.]:
+ CSP(reg=reg)
+ for cov_est in ['foo', None]:
+ assert_raises(ValueError, CSP, cov_est=cov_est)
+ for cov_est in ['concat', 'epoch']:
+ CSP(cov_est=cov_est)
n_components = 3
csp = CSP(n_components=n_components)
+ # Fit
csp.fit(epochs_data, epochs.events[:, -1])
+ assert_equal(len(csp.mean_), n_components)
+ assert_equal(len(csp.std_), n_components)
- y = epochs.events[:, -1]
+ # Transform
X = csp.fit_transform(epochs_data, y)
+ sources = csp.transform(epochs_data)
+ assert_true(sources.shape[1] == n_components)
assert_true(csp.filters_.shape == (n_channels, n_channels))
assert_true(csp.patterns_.shape == (n_channels, n_channels))
- assert_array_almost_equal(csp.fit(epochs_data, y).transform(epochs_data),
- X)
+ assert_array_almost_equal(sources, X)
- # test init exception
+ # Test data exception
assert_raises(ValueError, csp.fit, epochs_data,
np.zeros_like(epochs.events))
assert_raises(ValueError, csp.fit, epochs, y)
- assert_raises(ValueError, csp.transform, epochs, y)
-
- csp.n_components = n_components
- sources = csp.transform(epochs_data)
- assert_true(sources.shape[1] == n_components)
+ assert_raises(ValueError, csp.transform, epochs)
+ # Test plots
epochs.pick_types(meg='mag')
-
- # test plot patterns
+ cmap = ('RdBu', True)
components = np.arange(n_components)
- csp.plot_patterns(epochs.info, components=components, res=12,
- show=False)
+ for plot in (csp.plot_patterns, csp.plot_filters):
+ plot(epochs.info, components=components, res=12, show=False, cmap=cmap)
- # test plot filters
- csp.plot_filters(epochs.info, components=components, res=12,
- show=False)
-
- # test covariance estimation methods (results should be roughly equal)
+ # Test covariance estimation methods (results should be roughly equal)
+ np.random.seed(0)
csp_epochs = CSP(cov_est="epoch")
csp_epochs.fit(epochs_data, y)
for attr in ('filters_', 'patterns_'):
corr = np.corrcoef(getattr(csp, attr).ravel(),
getattr(csp_epochs, attr).ravel())[0, 1]
- assert_true(corr >= 0.95, msg='%s < 0.95' % corr)
+ assert_true(corr >= 0.94)
+
+ # Test with more than 2 classes
+ epochs = Epochs(raw, events, tmin=tmin, tmax=tmax, picks=picks,
+ event_id=dict(aud_l=1, aud_r=2, vis_l=3, vis_r=4),
+ baseline=(None, 0), proj=False, preload=True,
+ add_eeg_ref=False)
+ epochs_data = epochs.get_data()
+ n_channels = epochs_data.shape[1]
- # make sure error is raised for undefined estimation method
- csp_fail = CSP(cov_est="undefined")
- assert_raises(ValueError, csp_fail.fit, epochs_data, y)
+ n_channels = epochs_data.shape[1]
+ for cov_est in ['concat', 'epoch']:
+ csp = CSP(n_components=n_components, cov_est=cov_est)
+ csp.fit(epochs_data, epochs.events[:, 2]).transform(epochs_data)
+ assert_equal(len(csp._classes), 4)
+ assert_array_equal(csp.filters_.shape, [n_channels, n_channels])
+ assert_array_equal(csp.patterns_.shape, [n_channels, n_channels])
+
+ # Test average power transform
+ n_components = 2
+ assert_true(csp.transform_into == 'average_power')
+ feature_shape = [len(epochs_data), n_components]
+ X_trans = dict()
+ for log in (None, True, False):
+ csp = CSP(n_components=n_components, log=log)
+ assert_true(csp.log is log)
+ Xt = csp.fit_transform(epochs_data, epochs.events[:, 2])
+ assert_array_equal(Xt.shape, feature_shape)
+ X_trans[str(log)] = Xt
+ # log=None => log=True
+ assert_array_almost_equal(X_trans['None'], X_trans['True'])
+ # Different normalization return different transform
+ assert_true(np.sum((X_trans['True'] - X_trans['False']) ** 2) > 1.)
+ # Check wrong inputs
+ assert_raises(ValueError, CSP, transform_into='average_power', log='foo')
+
+ # Test csp space transform
+ csp = CSP(transform_into='csp_space')
+ assert_true(csp.transform_into == 'csp_space')
+ for log in ('foo', True, False):
+ assert_raises(ValueError, CSP, transform_into='csp_space', log=log)
+ n_components = 2
+ csp = CSP(n_components=n_components, transform_into='csp_space')
+ Xt = csp.fit(epochs_data, epochs.events[:, 2]).transform(epochs_data)
+ feature_shape = [len(epochs_data), n_components, epochs_data.shape[2]]
+ assert_array_equal(Xt.shape, feature_shape)
@requires_sklearn
def test_regularized_csp():
"""Test Common Spatial Patterns algorithm using regularized covariance
"""
- raw = io.read_raw_fif(raw_fname, preload=False)
+ raw = io.read_raw_fif(raw_fname, preload=False, add_eeg_ref=False)
events = read_events(event_name)
picks = pick_types(raw.info, meg=True, stim=False, ecg=False,
eog=False, exclude='bads')
picks = picks[1:13:3]
epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks,
- baseline=(None, 0), preload=True)
+ baseline=(None, 0), preload=True, add_eeg_ref=False)
epochs_data = epochs.get_data()
n_channels = epochs_data.shape[1]
@@ -114,7 +165,7 @@ def test_regularized_csp():
assert_raises(ValueError, csp.fit, epochs_data,
np.zeros_like(epochs.events))
assert_raises(ValueError, csp.fit, epochs, y)
- assert_raises(ValueError, csp.transform, epochs, y)
+ assert_raises(ValueError, csp.transform, epochs)
csp.n_components = n_components
sources = csp.transform(epochs_data)
@@ -132,3 +183,24 @@ def test_csp_pipeline():
pipe = Pipeline([("CSP", csp), ("SVC", svc)])
pipe.set_params(CSP__reg=0.2)
assert_true(pipe.get_params()["CSP__reg"] == 0.2)
+
+
+def test_ajd():
+ """Test if Approximate joint diagonalization implementation obtains same
+ results as the Matlab implementation by Pham Dinh-Tuan.
+ """
+ # Generate a set of cavariances matrices for test purpose
+ n_times, n_channels = 10, 3
+ seed = np.random.RandomState(0)
+ diags = 2.0 + 0.1 * seed.randn(n_times, n_channels)
+ A = 2 * seed.rand(n_channels, n_channels) - 1
+ A /= np.atleast_2d(np.sqrt(np.sum(A ** 2, 1))).T
+ covmats = np.empty((n_times, n_channels, n_channels))
+ for i in range(n_times):
+ covmats[i] = np.dot(np.dot(A, np.diag(diags[i])), A.T)
+ V, D = _ajd_pham(covmats)
+ # Results obtained with original matlab implementation
+ V_matlab = [[-3.507280775058041, -5.498189967306344, 7.720624541198574],
+ [0.694689013234610, 0.775690358505945, -1.162043086446043],
+ [-0.592603135588066, -0.598996925696260, 1.009550086271192]]
+ assert_array_almost_equal(V, V_matlab)
diff --git a/mne/decoding/tests/test_ems.py b/mne/decoding/tests/test_ems.py
index 6a05618..fa75aba 100644
--- a/mne/decoding/tests/test_ems.py
+++ b/mne/decoding/tests/test_ems.py
@@ -3,12 +3,13 @@
# License: BSD (3-clause)
import os.path as op
-
+import numpy as np
+from numpy.testing import assert_array_almost_equal
from nose.tools import assert_equal, assert_raises
from mne import io, Epochs, read_events, pick_types
-from mne.utils import requires_sklearn
-from mne.decoding import compute_ems
+from mne.utils import requires_sklearn_0_15, check_version
+from mne.decoding import compute_ems, EMS
data_dir = op.join(op.dirname(__file__), '..', '..', 'io', 'tests', 'data')
curdir = op.join(op.dirname(__file__))
@@ -20,10 +21,10 @@ tmin, tmax = -0.2, 0.5
event_id = dict(aud_l=1, vis_l=3)
- at requires_sklearn
+ at requires_sklearn_0_15
def test_ems():
"""Test event-matched spatial filters"""
- raw = io.read_raw_fif(raw_fname, preload=False)
+ raw = io.read_raw_fif(raw_fname, preload=False, add_eeg_ref=False)
# create unequal number of events
events = read_events(event_name)
@@ -32,7 +33,7 @@ def test_ems():
eog=False, exclude='bads')
picks = picks[1:13:3]
epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks,
- baseline=(None, 0), preload=True)
+ baseline=(None, 0), preload=True, add_eeg_ref=False)
assert_raises(ValueError, compute_ems, epochs, ['aud_l', 'vis_l'])
epochs = epochs.equalize_event_counts(epochs.event_id, copy=False)[0]
@@ -43,7 +44,7 @@ def test_ems():
events = read_events(event_name)
event_id2 = dict(aud_l=1, aud_r=2, vis_l=3)
epochs = Epochs(raw, events, event_id2, tmin, tmax, picks=picks,
- baseline=(None, 0), preload=True)
+ baseline=(None, 0), preload=True, add_eeg_ref=False)
epochs = epochs.equalize_event_counts(epochs.event_id, copy=False)[0]
n_expected = sum([len(epochs[k]) for k in ['aud_l', 'vis_l']])
@@ -53,4 +54,35 @@ def test_ems():
assert_equal(n_expected, len(surrogates))
assert_equal(n_expected, len(conditions))
assert_equal(list(set(conditions)), [2, 3])
+
+ # test compute_ems cv
+ epochs = epochs['aud_r', 'vis_l']
+ epochs.equalize_event_counts(epochs.event_id)
+ if check_version('sklearn', '0.18'):
+ from sklearn.model_selection import StratifiedKFold
+ cv = StratifiedKFold()
+ else:
+ from sklearn.cross_validation import StratifiedKFold
+ cv = StratifiedKFold(epochs.events[:, 2])
+ compute_ems(epochs, cv=cv)
+ compute_ems(epochs, cv=2)
+ assert_raises(ValueError, compute_ems, epochs, cv='foo')
+ assert_raises(ValueError, compute_ems, epochs, cv=len(epochs) + 1)
raw.close()
+
+ # EMS transformer, check that identical to compute_ems
+ X = epochs.get_data()
+ y = epochs.events[:, 2]
+ X = X / np.std(X) # X scaled outside cv in compute_ems
+ Xt, coefs = list(), list()
+ ems = EMS()
+ assert_equal(ems.__repr__(), '<EMS: not fitted.>')
+ # manual leave-one-out to avoid sklearn version problem
+ for test in range(len(y)):
+ train = np.setdiff1d(range(len(y)), test)
+ ems.fit(X[train], y[train])
+ coefs.append(ems.filters_)
+ Xt.append(ems.transform(X[[test]]))
+ assert_equal(ems.__repr__(), '<EMS: fitted with 4 filters on 2 classes.>')
+ assert_array_almost_equal(filters, np.mean(coefs, axis=0))
+ assert_array_almost_equal(surrogates, np.vstack(Xt))
diff --git a/mne/decoding/tests/test_search_light.py b/mne/decoding/tests/test_search_light.py
new file mode 100644
index 0000000..4905c7f
--- /dev/null
+++ b/mne/decoding/tests/test_search_light.py
@@ -0,0 +1,170 @@
+# Author: Jean-Remi King, <jeanremi.king at gmail.com>
+#
+# License: BSD (3-clause)
+
+
+import numpy as np
+from numpy.testing import assert_array_equal
+from nose.tools import assert_raises, assert_true, assert_equal
+from ...utils import requires_sklearn_0_15
+from ..search_light import _SearchLight, _GeneralizationLight
+from .. import Vectorizer
+
+
+def make_data():
+ n_epochs, n_chan, n_time = 50, 32, 10
+ X = np.random.rand(n_epochs, n_chan, n_time)
+ y = np.arange(n_epochs) % 2
+ for ii in range(n_time):
+ coef = np.random.randn(n_chan)
+ X[y == 0, :, ii] += coef
+ X[y == 1, :, ii] -= coef
+ return X, y
+
+
+ at requires_sklearn_0_15
+def test_SearchLight():
+ """Test _SearchLight"""
+ from sklearn.linear_model import Ridge, LogisticRegression
+ from sklearn.pipeline import make_pipeline
+ from sklearn.metrics import roc_auc_score
+
+ X, y = make_data()
+ n_epochs, _, n_time = X.shape
+ # init
+ assert_raises(ValueError, _SearchLight, 'foo')
+ sl = _SearchLight(Ridge())
+ sl = _SearchLight(LogisticRegression())
+ # fit
+ assert_equal(sl.__repr__()[:14], '<_SearchLight(')
+ sl.fit(X, y)
+ assert_equal(sl.__repr__()[-28:], ', fitted with 10 estimators>')
+ assert_raises(ValueError, sl.fit, X[1:], y)
+ assert_raises(ValueError, sl.fit, X[:, :, 0], y)
+
+ # transforms
+ assert_raises(ValueError, sl.predict, X[:, :, :2])
+ y_pred = sl.predict(X)
+ assert_true(y_pred.dtype == int)
+ assert_array_equal(y_pred.shape, [n_epochs, n_time])
+ y_proba = sl.predict_proba(X)
+ assert_true(y_proba.dtype == float)
+ assert_array_equal(y_proba.shape, [n_epochs, n_time, 2])
+
+ # score
+ score = sl.score(X, y)
+ assert_array_equal(score.shape, [n_time])
+ assert_true(np.sum(np.abs(score)) != 0)
+ assert_true(score.dtype == float)
+
+ # change score method
+ sl1 = _SearchLight(LogisticRegression(), scoring=roc_auc_score)
+ sl1.fit(X, y)
+ score1 = sl1.score(X, y)
+ assert_array_equal(score1.shape, [n_time])
+ assert_true(score1.dtype == float)
+
+ X_2d = X.reshape(X.shape[0], X.shape[1] * X.shape[2])
+ lg_score = LogisticRegression().fit(X_2d, y).predict_proba(X_2d)[:, 1]
+ assert_equal(score1[0], roc_auc_score(y, lg_score))
+
+ sl2 = _SearchLight(LogisticRegression(), scoring='roc_auc')
+ sl2.fit(X, y)
+ assert_array_equal(score1, sl2.score(X, y))
+
+ sl = _SearchLight(LogisticRegression(), scoring='foo')
+ sl.fit(X, y)
+ assert_raises(ValueError, sl.score, X, y)
+
+ sl = _SearchLight(LogisticRegression())
+ assert_equal(sl.scoring, None)
+
+ # n_jobs
+ sl = _SearchLight(LogisticRegression(), n_jobs=2)
+ sl.fit(X, y)
+ sl.predict(X)
+ sl.score(X, y)
+
+ # n_jobs > n_estimators
+ sl.fit(X[..., [0]], y)
+ sl.predict(X[..., [0]])
+
+ # pipeline
+
+ class _LogRegTransformer(LogisticRegression):
+ # XXX needs transformer in pipeline to get first proba only
+ def transform(self, X):
+ return super(_LogRegTransformer, self).predict_proba(X)[..., 1]
+
+ pipe = make_pipeline(_SearchLight(_LogRegTransformer()),
+ LogisticRegression())
+ pipe.fit(X, y)
+ pipe.predict(X)
+
+ # n-dimensional feature space
+ X = np.random.rand(10, 3, 4, 2)
+ y = np.arange(10) % 2
+ y_preds = list()
+ for n_jobs in [1, 2]:
+ pipe = _SearchLight(make_pipeline(Vectorizer(), LogisticRegression()),
+ n_jobs=n_jobs)
+ y_preds.append(pipe.fit(X, y).predict(X))
+ features_shape = pipe.estimators_[0].steps[0][1].features_shape_
+ assert_array_equal(features_shape, [3, 4])
+ assert_array_equal(y_preds[0], y_preds[1])
+
+
+ at requires_sklearn_0_15
+def test_GeneralizationLight():
+ """Test _GeneralizationLight"""
+ from sklearn.pipeline import make_pipeline
+ from sklearn.linear_model import LogisticRegression
+ X, y = make_data()
+ n_epochs, _, n_time = X.shape
+ # fit
+ gl = _GeneralizationLight(LogisticRegression())
+ assert_equal(gl.__repr__()[:22], '<_GeneralizationLight(')
+ gl.fit(X, y)
+
+ assert_equal(gl.__repr__()[-28:], ', fitted with 10 estimators>')
+ # transforms
+ y_pred = gl.predict(X)
+ assert_array_equal(y_pred.shape, [n_epochs, n_time, n_time])
+ assert_true(y_pred.dtype == int)
+ y_proba = gl.predict_proba(X)
+ assert_true(y_proba.dtype == float)
+ assert_array_equal(y_proba.shape, [n_epochs, n_time, n_time, 2])
+
+ # transform to different datasize
+ y_pred = gl.predict(X[:, :, :2])
+ assert_array_equal(y_pred.shape, [n_epochs, n_time, 2])
+
+ # score
+ score = gl.score(X[:, :, :3], y)
+ assert_array_equal(score.shape, [n_time, 3])
+ assert_true(np.sum(np.abs(score)) != 0)
+ assert_true(score.dtype == float)
+
+ # n_jobs
+ gl = _GeneralizationLight(LogisticRegression(), n_jobs=2)
+ gl.fit(X, y)
+ y_pred = gl.predict(X)
+ assert_array_equal(y_pred.shape, [n_epochs, n_time, n_time])
+ score = gl.score(X, y)
+ assert_array_equal(score.shape, [n_time, n_time])
+
+ # n_jobs > n_estimators
+ gl.fit(X[..., [0]], y)
+ gl.predict(X[..., [0]])
+
+ # n-dimensional feature space
+ X = np.random.rand(10, 3, 4, 2)
+ y = np.arange(10) % 2
+ y_preds = list()
+ for n_jobs in [1, 2]:
+ pipe = _GeneralizationLight(
+ make_pipeline(Vectorizer(), LogisticRegression()), n_jobs=n_jobs)
+ y_preds.append(pipe.fit(X, y).predict(X))
+ features_shape = pipe.estimators_[0].steps[0][1].features_shape_
+ assert_array_equal(features_shape, [3, 4])
+ assert_array_equal(y_preds[0], y_preds[1])
diff --git a/mne/decoding/tests/test_time_frequency.py b/mne/decoding/tests/test_time_frequency.py
new file mode 100644
index 0000000..1c19b53
--- /dev/null
+++ b/mne/decoding/tests/test_time_frequency.py
@@ -0,0 +1,41 @@
+# Author: Jean-Remi King, <jeanremi.king at gmail.com>
+#
+# License: BSD (3-clause)
+
+
+import numpy as np
+from numpy.testing import assert_array_equal
+from nose.tools import assert_raises
+from mne.utils import requires_sklearn
+from mne.decoding.time_frequency import TimeFrequency
+
+
+ at requires_sklearn
+def test_timefrequency():
+ from sklearn.base import clone
+ # Init
+ n_freqs = 3
+ frequencies = np.linspace(20, 30, n_freqs)
+ tf = TimeFrequency(frequencies, sfreq=100)
+ for output in ['avg_power', 'foo', None]:
+ assert_raises(ValueError, TimeFrequency, frequencies, output=output)
+ tf = clone(tf)
+
+ # Fit
+ n_epochs, n_chans, n_times = 10, 2, 100
+ X = np.random.rand(n_epochs, n_chans, n_times)
+ tf.fit(X, None)
+
+ # Transform
+ tf = TimeFrequency(frequencies, sfreq=100)
+ tf.fit_transform(X, None)
+ # 3-D X
+ Xt = tf.transform(X)
+ assert_array_equal(Xt.shape, [n_epochs, n_chans, n_freqs, n_times])
+ # 2-D X
+ Xt = tf.transform(X[:, 0, :])
+ assert_array_equal(Xt.shape, [n_epochs, n_freqs, n_times])
+ # 3-D with decim
+ tf = TimeFrequency(frequencies, sfreq=100, decim=2)
+ Xt = tf.transform(X)
+ assert_array_equal(Xt.shape, [n_epochs, n_chans, n_freqs, n_times // 2])
diff --git a/mne/decoding/tests/test_time_gen.py b/mne/decoding/tests/test_time_gen.py
index fff4059..0a5bd13 100644
--- a/mne/decoding/tests/test_time_gen.py
+++ b/mne/decoding/tests/test_time_gen.py
@@ -12,7 +12,7 @@ from numpy.testing import assert_array_equal
from mne import io, Epochs, read_events, pick_types
from mne.utils import (requires_sklearn, requires_sklearn_0_15, slow_test,
- run_tests_if_main, check_version)
+ run_tests_if_main, check_version, use_log_level)
from mne.decoding import GeneralizationAcrossTime, TimeDecoding
@@ -28,7 +28,7 @@ warnings.simplefilter('always')
def make_epochs():
- raw = io.read_raw_fif(raw_fname, preload=False)
+ raw = io.read_raw_fif(raw_fname, preload=False, add_eeg_ref=False)
events = read_events(event_name)
picks = pick_types(raw.info, meg='mag', stim=False, ecg=False,
eog=False, exclude='bads')
@@ -38,7 +38,8 @@ def make_epochs():
# Test on time generalization within one condition
with warnings.catch_warnings(record=True):
epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks,
- baseline=(None, 0), preload=True, decim=decim)
+ baseline=(None, 0), preload=True, decim=decim,
+ add_eeg_ref=False)
return epochs
@@ -152,7 +153,8 @@ def test_generalization_across_time():
gat.predict_mode = 'mean-prediction'
epochs2.events[:, 2] += 10
gat_ = copy.deepcopy(gat)
- assert_raises(ValueError, gat_.score, epochs2)
+ with use_log_level('error'):
+ assert_raises(ValueError, gat_.score, epochs2)
gat.predict_mode = 'cross-validation'
# Test basics
@@ -234,8 +236,6 @@ def test_generalization_across_time():
assert_equal(np.shape(gat.scores_), (15, 1))
assert_array_equal([tim for ttime in gat.test_times_['times']
for tim in ttime], gat.train_times_['times'])
- from mne.utils import set_log_level
- set_log_level('error')
# Test generalization across conditions
gat = GeneralizationAcrossTime(predict_mode='mean-prediction', cv=2)
with warnings.catch_warnings(record=True):
@@ -249,7 +249,8 @@ def test_generalization_across_time():
gat_ = copy.deepcopy(gat)
# --- start stop outside time range
gat_.train_times = dict(start=-999.)
- assert_raises(ValueError, gat_.fit, epochs)
+ with use_log_level('error'):
+ assert_raises(ValueError, gat_.fit, epochs)
gat_.train_times = dict(start=999.)
assert_raises(ValueError, gat_.fit, epochs)
# --- impossible slices
@@ -316,8 +317,9 @@ def test_generalization_across_time():
# sklearn needs it: c.f.
# https://github.com/scikit-learn/scikit-learn/issues/2723
# and http://bit.ly/1u7t8UT
- assert_raises(ValueError, gat.score, epochs2)
- gat.score(epochs)
+ with use_log_level('error'):
+ assert_raises(ValueError, gat.score, epochs2)
+ gat.score(epochs)
assert_true(0.0 <= np.min(scores) <= 1.0)
assert_true(0.0 <= np.max(scores) <= 1.0)
@@ -409,7 +411,10 @@ def test_decoding_time():
"""Test TimeDecoding
"""
from sklearn.svm import SVR
- from sklearn.cross_validation import KFold
+ if check_version('sklearn', '0.18'):
+ from sklearn.model_selection import KFold
+ else:
+ from sklearn.cross_validation import KFold
epochs = make_epochs()
tg = TimeDecoding()
assert_equal("<TimeDecoding | no fit, no prediction, no score>", '%s' % tg)
diff --git a/mne/decoding/tests/test_transformer.py b/mne/decoding/tests/test_transformer.py
index 791810e..443d7cf 100644
--- a/mne/decoding/tests/test_transformer.py
+++ b/mne/decoding/tests/test_transformer.py
@@ -8,11 +8,14 @@ import os.path as op
import numpy as np
from nose.tools import assert_true, assert_raises
-from numpy.testing import assert_array_equal
+from numpy.testing import (assert_array_equal, assert_equal,
+ assert_array_almost_equal)
from mne import io, read_events, Epochs, pick_types
from mne.decoding import Scaler, FilterEstimator
-from mne.decoding import PSDEstimator, EpochsVectorizer
+from mne.decoding import (PSDEstimator, EpochsVectorizer, Vectorizer,
+ UnsupervisedSpatialFilter, TemporalFilter)
+from mne.utils import requires_sklearn_0_15
warnings.simplefilter('always') # enable b/c these tests throw warnings
@@ -26,16 +29,15 @@ event_name = op.join(data_dir, 'test-eve.fif')
def test_scaler():
- """Test methods of Scaler
- """
- raw = io.read_raw_fif(raw_fname, preload=False)
+ """Test methods of Scaler."""
+ raw = io.read_raw_fif(raw_fname, preload=False, add_eeg_ref=False)
events = read_events(event_name)
picks = pick_types(raw.info, meg=True, stim=False, ecg=False,
eog=False, exclude='bads')
picks = picks[1:13:3]
epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks,
- baseline=(None, 0), preload=True)
+ baseline=(None, 0), preload=True, add_eeg_ref=False)
epochs_data = epochs.get_data()
scaler = Scaler(epochs.info)
y = epochs.events[:, -1]
@@ -59,26 +61,29 @@ def test_scaler():
def test_filterestimator():
- """Test methods of FilterEstimator
- """
- raw = io.read_raw_fif(raw_fname, preload=False)
+ """Test methods of FilterEstimator."""
+ raw = io.read_raw_fif(raw_fname, preload=False, add_eeg_ref=False)
events = read_events(event_name)
picks = pick_types(raw.info, meg=True, stim=False, ecg=False,
eog=False, exclude='bads')
picks = picks[1:13:3]
epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks,
- baseline=(None, 0), preload=True)
+ baseline=(None, 0), preload=True, add_eeg_ref=False)
epochs_data = epochs.get_data()
# Add tests for different combinations of l_freq and h_freq
- filt = FilterEstimator(epochs.info, l_freq=1, h_freq=40)
+ filt = FilterEstimator(epochs.info, l_freq=40, h_freq=80,
+ filter_length='auto',
+ l_trans_bandwidth='auto', h_trans_bandwidth='auto')
y = epochs.events[:, -1]
with warnings.catch_warnings(record=True): # stop freq attenuation warning
X = filt.fit_transform(epochs_data, y)
assert_true(X.shape == epochs_data.shape)
assert_array_equal(filt.fit(epochs_data, y).transform(epochs_data), X)
- filt = FilterEstimator(epochs.info, l_freq=0, h_freq=40)
+ filt = FilterEstimator(epochs.info, l_freq=None, h_freq=40,
+ filter_length='auto',
+ l_trans_bandwidth='auto', h_trans_bandwidth='auto')
y = epochs.events[:, -1]
with warnings.catch_warnings(record=True): # stop freq attenuation warning
X = filt.fit_transform(epochs_data, y)
@@ -88,7 +93,9 @@ def test_filterestimator():
with warnings.catch_warnings(record=True): # stop freq attenuation warning
assert_raises(ValueError, filt.fit_transform, epochs_data, y)
- filt = FilterEstimator(epochs.info, l_freq=1, h_freq=None)
+ filt = FilterEstimator(epochs.info, l_freq=40, h_freq=None,
+ filter_length='auto',
+ l_trans_bandwidth='auto', h_trans_bandwidth='auto')
with warnings.catch_warnings(record=True): # stop freq attenuation warning
X = filt.fit_transform(epochs_data, y)
@@ -98,15 +105,14 @@ def test_filterestimator():
def test_psdestimator():
- """Test methods of PSDEstimator
- """
- raw = io.read_raw_fif(raw_fname, preload=False)
+ """Test methods of PSDEstimator."""
+ raw = io.read_raw_fif(raw_fname, preload=False, add_eeg_ref=False)
events = read_events(event_name)
picks = pick_types(raw.info, meg=True, stim=False, ecg=False,
eog=False, exclude='bads')
picks = picks[1:13:3]
epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks,
- baseline=(None, 0), preload=True)
+ baseline=(None, 0), preload=True, add_eeg_ref=False)
epochs_data = epochs.get_data()
psd = PSDEstimator(2 * np.pi, 0, np.inf)
y = epochs.events[:, -1]
@@ -121,18 +127,18 @@ def test_psdestimator():
def test_epochs_vectorizer():
- """Test methods of EpochsVectorizer
- """
- raw = io.read_raw_fif(raw_fname, preload=False)
+ """Test methods of EpochsVectorizer."""
+ raw = io.read_raw_fif(raw_fname, preload=False, add_eeg_ref=False)
events = read_events(event_name)
picks = pick_types(raw.info, meg=True, stim=False, ecg=False,
eog=False, exclude='bads')
picks = picks[1:13:3]
with warnings.catch_warnings(record=True):
epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks,
- baseline=(None, 0), preload=True)
+ baseline=(None, 0), preload=True, add_eeg_ref=False)
epochs_data = epochs.get_data()
- vector = EpochsVectorizer(epochs.info)
+ with warnings.catch_warnings(record=True): # deprecation
+ vector = EpochsVectorizer(epochs.info)
y = epochs.events[:, -1]
X = vector.fit_transform(epochs_data, y)
@@ -160,3 +166,94 @@ def test_epochs_vectorizer():
# Test init exception
assert_raises(ValueError, vector.fit, epochs, y)
assert_raises(ValueError, vector.transform, epochs, y)
+
+
+def test_vectorizer():
+ """Test Vectorizer."""
+ data = np.random.rand(150, 18, 6)
+ vect = Vectorizer()
+ result = vect.fit_transform(data)
+ assert_equal(result.ndim, 2)
+
+ # check inverse_trasnform
+ orig_data = vect.inverse_transform(result)
+ assert_equal(orig_data.ndim, 3)
+ assert_array_equal(orig_data, data)
+ assert_array_equal(vect.inverse_transform(result[1:]), data[1:])
+
+ # check with different shape
+ assert_equal(vect.fit_transform(np.random.rand(150, 18, 6, 3)).shape,
+ (150, 324))
+ assert_equal(vect.fit_transform(data[1:]).shape, (149, 108))
+
+ # check if raised errors are working correctly
+ vect.fit(np.random.rand(105, 12, 3))
+ assert_raises(ValueError, vect.transform, np.random.rand(105, 12, 3, 1))
+ assert_raises(ValueError, vect.inverse_transform,
+ np.random.rand(102, 12, 12))
+
+
+ at requires_sklearn_0_15
+def test_unsupervised_spatial_filter():
+ """Test unsupervised spatial filter."""
+ from sklearn.decomposition import PCA
+ from sklearn.kernel_ridge import KernelRidge
+ raw = io.read_raw_fif(raw_fname, preload=False, add_eeg_ref=False)
+ events = read_events(event_name)
+ picks = pick_types(raw.info, meg=True, stim=False, ecg=False,
+ eog=False, exclude='bads')
+ picks = picks[1:13:3]
+ epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks,
+ preload=True, baseline=None, verbose=False,
+ add_eeg_ref=False)
+
+ # Test estimator
+ assert_raises(ValueError, UnsupervisedSpatialFilter, KernelRidge(2))
+
+ # Test fit
+ X = epochs.get_data()
+ n_components = 4
+ usf = UnsupervisedSpatialFilter(PCA(n_components))
+ usf.fit(X)
+ usf1 = UnsupervisedSpatialFilter(PCA(n_components))
+
+ # test transform
+ assert_equal(usf.transform(X).ndim, 3)
+ # test fit_transform
+ assert_array_almost_equal(usf.transform(X), usf1.fit_transform(X))
+ # assert shape
+ assert_equal(usf.transform(X).shape[1], n_components)
+
+ # Test with average param
+ usf = UnsupervisedSpatialFilter(PCA(4), average=True)
+ usf.fit_transform(X)
+ assert_raises(ValueError, UnsupervisedSpatialFilter, PCA(4), 2)
+
+
+def test_temporal_filter():
+ """Test methods of TemporalFilter."""
+ X = np.random.rand(5, 5, 1200)
+
+ # Test init test
+ values = (('10hz', None, 100., 'auto'), (5., '10hz', 100., 'auto'),
+ (10., 20., 5., 'auto'), (None, None, 100., '5hz'))
+ for low, high, sf, ltrans in values:
+ filt = TemporalFilter(low, high, sf, ltrans)
+ assert_raises(ValueError, filt.fit_transform, X)
+
+ # Add tests for different combinations of l_freq and h_freq
+ for low, high in ((5., 15.), (None, 15.), (5., None)):
+ filt = TemporalFilter(low, high, sfreq=100.)
+ Xt = filt.fit_transform(X)
+ assert_array_equal(filt.fit_transform(X), Xt)
+ assert_true(X.shape == Xt.shape)
+
+ # Test fit and transform numpy type check
+ with warnings.catch_warnings(record=True):
+ assert_raises(TypeError, filt.transform, [1, 2])
+
+ # Test with 2 dimensional data array
+ X = np.random.rand(101, 500)
+ filt = TemporalFilter(l_freq=25., h_freq=50., sfreq=1000.,
+ filter_length=150)
+ assert_equal(filt.fit_transform(X).shape, X.shape)
diff --git a/mne/decoding/time_frequency.py b/mne/decoding/time_frequency.py
new file mode 100644
index 0000000..c913130
--- /dev/null
+++ b/mne/decoding/time_frequency.py
@@ -0,0 +1,152 @@
+# Author: Jean-Remi King <jeanremi.king at gmail.com>
+#
+# License: BSD (3-clause)
+
+import numpy as np
+from .mixin import TransformerMixin
+from .base import BaseEstimator
+from ..time_frequency.tfr import _compute_tfr, _check_tfr_param
+
+
+class TimeFrequency(TransformerMixin, BaseEstimator):
+ """Time frequency transformer.
+
+ Time-frequency transform of times series along the last axis.
+
+ Parameters
+ ----------
+ frequencies : array-like of floats, shape (n_freqs,)
+ The frequencies.
+ sfreq : float | int, defaults to 1.0
+ Sampling frequency of the data.
+ method : 'multitaper' | 'morlet', defaults to 'morlet'
+ The time-frequency method. 'morlet' convolves a Morlet wavelet.
+ 'multitaper' uses Morlet wavelets windowed with multiple DPSS
+ multitapers.
+ n_cycles : float | array of float, defaults to 7.0
+ Number of cycles in the Morlet wavelet. Fixed number
+ or one per frequency.
+ time_bandwidth : float, defaults to None
+ If None and method=multitaper, will be set to 4.0 (3 tapers).
+ Time x (Full) Bandwidth product. Only applies if
+ method == 'multitaper'. The number of good tapers (low-bias) is
+ chosen automatically based on this to equal floor(time_bandwidth - 1).
+ use_fft : bool, defaults to True
+ Use the FFT for convolutions or not.
+ decim : int | slice, defaults to 1
+ To reduce memory usage, decimation factor after time-frequency
+ decomposition.
+ If `int`, returns tfr[..., ::decim].
+ If `slice`, returns tfr[..., decim].
+ .. note:
+ Decimation may create aliasing artifacts, yet decimation
+ is done after the convolutions.
+ output : str, defaults to 'complex'
+
+ * 'complex' : single trial complex.
+ * 'power' : single trial power.
+ * 'phase' : single trial phase.
+
+ n_jobs : int, defaults to 1
+ The number of epochs to process at the same time. The parallelization
+ is implemented across channels.
+ verbose : bool, str, int, or None, defaults to None
+ If not None, override default verbose level (see mne.verbose).
+
+ See Also
+ --------
+ mne.time_frequency.tfr_morlet
+ mne.time_frequency.tfr_multitaper
+ """
+
+ def __init__(self, frequencies, sfreq=1.0, method='morlet', n_cycles=7.0,
+ time_bandwidth=None, use_fft=True, decim=1, output='complex',
+ n_jobs=1, verbose=None):
+ """Init TimeFrequency transformer."""
+ frequencies, sfreq, _, n_cycles, time_bandwidth, decim = \
+ _check_tfr_param(frequencies, sfreq, method, True, n_cycles,
+ time_bandwidth, use_fft, decim, output)
+ self.frequencies = frequencies
+ self.sfreq = sfreq
+ self.method = method
+ self.n_cycles = n_cycles
+ self.time_bandwidth = time_bandwidth
+ self.use_fft = use_fft
+ self.decim = decim
+ # Check that output is not an average metric (e.g. ITC)
+ if output not in ['complex', 'power', 'phase']:
+ raise ValueError("output must be 'complex', 'power', 'phase'. "
+ "Got %s instead." % output)
+ self.output = output
+ self.n_jobs = n_jobs
+ self.verbose = verbose
+
+ def fit_transform(self, X, y=None):
+ """
+ Time-frequency transform of times series along the last axis.
+
+ Parameters
+ ----------
+ X : array, shape (n_samples, n_channels, n_times)
+ The training data samples. The channel dimension can be zero- or
+ 1-dimensional.
+ y : None
+ For scikit-learn compatibility purposes.
+
+ Returns
+ -------
+ Xt : array, shape (n_samples, n_channels, n_frequencies, n_times)
+ The time-frequency transform of the data, where n_channels can be
+ zero- or 1-dimensional.
+ """
+ return self.fit(X, y).transform(X)
+
+ def fit(self, X, y=None):
+ """ Does nothing, for scikit-learn compatibility purposes.
+
+ Parameters
+ ----------
+ X : array, shape (n_samples, n_channels, n_times)
+ The training data.
+ y : array | None
+ The target values.
+
+ Returns
+ -------
+ self : object
+ Return self.
+ """
+ return self
+
+ def transform(self, X):
+ """Time-frequency transform of times series along the last axis.
+
+ Parameters
+ ----------
+ X : array, shape (n_samples, n_channels, n_times)
+ The training data samples. The channel dimension can be zero- or
+ 1-dimensional.
+
+ Returns
+ -------
+ Xt : array, shape (n_samples, n_channels, n_frequencies, n_times)
+ The time-frequency transform of the data, where n_channels can be
+ zero- or 1-dimensional.
+
+ """
+ # Ensure 3-dimensional X
+ shape = X.shape[1:-1]
+ if not shape:
+ X = X[:, np.newaxis, :]
+
+ # Compute time-frequency
+ Xt = _compute_tfr(X, self.frequencies, self.sfreq, self.method,
+ self.n_cycles, True, self.time_bandwidth,
+ self.use_fft, self.decim, self.output, self.n_jobs,
+ self.verbose)
+
+ # Back to original shape
+ if not shape:
+ Xt = Xt[:, 0, :]
+
+ return Xt
diff --git a/mne/decoding/time_gen.py b/mne/decoding/time_gen.py
index af073df..ce9c8ad 100644
--- a/mne/decoding/time_gen.py
+++ b/mne/decoding/time_gen.py
@@ -8,6 +8,7 @@
import numpy as np
import copy
+from .base import _set_cv
from ..io.pick import _pick_data_channels
from ..viz.decoding import plot_gat_matrix, plot_gat_times
from ..parallel import parallel_func, check_n_jobs
@@ -141,7 +142,7 @@ class _GeneralizationAcrossTime(object):
self.ch_names = [epochs.ch_names[p] for p in self.picks_]
# Prepare cross-validation
- self.cv_, self._cv_splits = _set_cv(self.cv, clf=self.clf, X=X, y=y)
+ self.cv_, self._cv_splits = _set_cv(self.cv, self.clf, X=X, y=y)
self.y_train_ = y
@@ -870,7 +871,7 @@ class GeneralizationAcrossTime(_GeneralizationAcrossTime):
Creates an estimator object used to 1) fit a series of classifiers on
multidimensional time-resolved data, and 2) test the ability of each
- classifier to generalize across other time samples.
+ classifier to generalize across other time samples, as in [1]_.
Parameters
----------
@@ -880,7 +881,7 @@ class GeneralizationAcrossTime(_GeneralizationAcrossTime):
cv : int | object
If an integer is passed, it is the number of folds.
Specific cross-validation objects can be passed, see
- scikit-learn.model_selection module for the list of possible objects.
+ scikit-learn.cross_validation module for the list of possible objects.
If clf is a classifier, defaults to StratifiedKFold(n_folds=5), else
defaults to KFold(n_folds=5).
clf : object | None
@@ -959,13 +960,13 @@ class GeneralizationAcrossTime(_GeneralizationAcrossTime):
Attributes
----------
- picks_ : array-like of int | None
+ ``picks_`` : array-like of int | None
The channels indices to include.
ch_names : list, array-like, shape (n_channels,)
Names of the channels used for training.
- y_train_ : list | ndarray, shape (n_samples,)
+ ``y_train_`` : list | ndarray, shape (n_samples,)
The categories used for training.
- train_times_ : dict
+ ``train_times_`` : dict
A dictionary that configures the training times:
* ``slices`` : ndarray, shape (n_clfs,)
@@ -975,7 +976,7 @@ class GeneralizationAcrossTime(_GeneralizationAcrossTime):
* ``times`` : ndarray, shape (n_clfs,)
The training times (in seconds).
- test_times_ : dict
+ ``test_times_`` : dict
A dictionary that configures the testing times for each training time:
``slices`` : ndarray, shape (n_clfs, n_testing_times)
@@ -983,20 +984,20 @@ class GeneralizationAcrossTime(_GeneralizationAcrossTime):
``times`` : ndarray, shape (n_clfs, n_testing_times)
The testing times (in seconds) for each training time.
- cv_ : CrossValidation object
+ ``cv_`` : CrossValidation object
The actual CrossValidation input depending on y.
- estimators_ : list of list of scikit-learn.base.BaseEstimator subclasses.
+ ``estimators_`` : list of list of scikit-learn.base.BaseEstimator subclasses.
The estimators for each time point and each fold.
- y_pred_ : list of lists of arrays of floats, shape (n_train_times, n_test_times, n_epochs, n_prediction_dims)
+ ``y_pred_`` : list of lists of arrays of floats, shape (n_train_times, n_test_times, n_epochs, n_prediction_dims)
The single-trial predictions estimated by self.predict() at each
training time and each testing time. Note that the number of testing
times per training time need not be regular, else
``np.shape(y_pred_) = (n_train_time, n_test_time, n_epochs).``
- y_true_ : list | ndarray, shape (n_samples,)
+ ``y_true_`` : list | ndarray, shape (n_samples,)
The categories used for scoring ``y_pred_``.
- scorer_ : object
+ ``scorer_`` : object
scikit-learn Scorer instance.
- scores_ : list of lists of float
+ ``scores_`` : list of lists of float
The scores estimated by ``self.scorer_`` at each training time and each
testing time (e.g. mean accuracy of self.predict(X)). Note that the
number of testing times per training time need not be regular;
@@ -1006,14 +1007,12 @@ class GeneralizationAcrossTime(_GeneralizationAcrossTime):
--------
TimeDecoding
- Notes
- -----
- The function implements the method used in:
-
- Jean-Remi King, Alexandre Gramfort, Aaron Schurger, Lionel Naccache
- and Stanislas Dehaene, "Two distinct dynamic modes subtend the
- detection of unexpected sounds", PLoS ONE, 2014
- DOI: 10.1371/journal.pone.0085791
+ References
+ ----------
+ .. [1] Jean-Remi King, Alexandre Gramfort, Aaron Schurger, Lionel Naccache
+ and Stanislas Dehaene, "Two distinct dynamic modes subtend the
+ detection of unexpected sounds", PLoS ONE, 2014
+ DOI: 10.1371/journal.pone.0085791
.. versionadded:: 0.9.0
""" # noqa
@@ -1209,7 +1208,7 @@ class TimeDecoding(_GeneralizationAcrossTime):
cv : int | object
If an integer is passed, it is the number of folds.
Specific cross-validation objects can be passed, see
- scikit-learn.model_selection module for the list of possible objects.
+ scikit-learn.cross_validation module for the list of possible objects.
If clf is a classifier, defaults to StratifiedKFold(n_folds=5), else
defaults to KFold(n_folds=5).
clf : object | None
@@ -1279,13 +1278,13 @@ class TimeDecoding(_GeneralizationAcrossTime):
Attributes
----------
- picks_ : array-like of int | None
+ ``picks_`` : array-like of int | None
The channels indices to include.
ch_names : list, array-like, shape (n_channels,)
Names of the channels used for training.
- y_train_ : ndarray, shape (n_samples,)
+ ``y_train_`` : ndarray, shape (n_samples,)
The categories used for training.
- times_ : dict
+ ``times_`` : dict
A dictionary that configures the training times:
* ``slices`` : ndarray, shape (n_clfs,)
@@ -1295,17 +1294,17 @@ class TimeDecoding(_GeneralizationAcrossTime):
* ``times`` : ndarray, shape (n_clfs,)
The training times (in seconds).
- cv_ : CrossValidation object
+ ``cv_`` : CrossValidation object
The actual CrossValidation input depending on y.
- estimators_ : list of list of scikit-learn.base.BaseEstimator subclasses.
+ ``estimators_`` : list of list of scikit-learn.base.BaseEstimator subclasses.
The estimators for each time point and each fold.
- y_pred_ : ndarray, shape (n_times, n_epochs, n_prediction_dims)
+ ``y_pred_`` : ndarray, shape (n_times, n_epochs, n_prediction_dims)
Class labels for samples in X.
- y_true_ : list | ndarray, shape (n_samples,)
+ ``y_true_`` : list | ndarray, shape (n_samples,)
The categories used for scoring ``y_pred_``.
- scorer_ : object
+ ``scorer_`` : object
scikit-learn Scorer instance.
- scores_ : list of float, shape (n_times,)
+ ``scores_`` : list of float, shape (n_times,)
The scores (mean accuracy of self.predict(X) wrt. y.).
See Also
@@ -1317,7 +1316,7 @@ class TimeDecoding(_GeneralizationAcrossTime):
The function is equivalent to the diagonal of GeneralizationAcrossTime()
.. versionadded:: 0.10
- """
+ """ # noqa
def __init__(self, picks=None, cv=5, clf=None, times=None,
predict_method='predict', predict_mode='cross-validation',
@@ -1560,37 +1559,3 @@ def _chunk_data(X, slices):
slices_chunk = [sl - start for sl in slices]
X_chunk = X[:, :, start:stop]
return X_chunk, slices_chunk
-
-
-def _set_cv(cv, clf=None, X=None, y=None):
- from sklearn.base import is_classifier
-
- # Set the default cross-validation depending on whether clf is classifier
- # or regressor.
- if check_version('sklearn', '0.18'):
- from sklearn.model_selection import (check_cv, StratifiedKFold, KFold)
- if isinstance(cv, (int, np.int)):
- XFold = StratifiedKFold if is_classifier(clf) else KFold
- cv = XFold(n_folds=cv)
- cv = check_cv(cv=cv, y=y, classifier=is_classifier(clf))
- else:
- from sklearn.cross_validation import (check_cv, StratifiedKFold, KFold)
- if isinstance(cv, (int, np.int)):
- if is_classifier(clf):
- cv = StratifiedKFold(y=y, n_folds=cv)
- else:
- cv = KFold(n=len(y), n_folds=cv)
- cv = check_cv(cv=cv, X=X, y=y, classifier=is_classifier(clf))
-
- # Extract train and test set to retrieve them at predict time
- if hasattr(cv, 'split'):
- cv_splits = [(train, test) for train, test in
- cv.split(X=np.zeros_like(y), y=y)]
- else:
- # XXX support sklearn.cross_validation cv
- cv_splits = [(train, test) for train, test in cv]
-
- if not np.all([len(train) for train, _ in cv_splits]):
- raise ValueError('Some folds do not have any train epochs.')
-
- return cv, cv_splits
diff --git a/mne/decoding/transformer.py b/mne/decoding/transformer.py
index 6f1d772..220a12a 100644
--- a/mne/decoding/transformer.py
+++ b/mne/decoding/transformer.py
@@ -7,13 +7,14 @@
import numpy as np
from .mixin import TransformerMixin
+from .base import BaseEstimator
from .. import pick_types
from ..filter import (low_pass_filter, high_pass_filter, band_pass_filter,
- band_stop_filter)
+ band_stop_filter, filter_data, _triage_filter_params)
from ..time_frequency.psd import _psd_multitaper
from ..externals import six
-from ..utils import _check_type_picks
+from ..utils import _check_type_picks, deprecated
class Scaler(TransformerMixin):
@@ -33,9 +34,9 @@ class Scaler(TransformerMixin):
----------
info : instance of Info
The measurement info
- ch_mean_ : dict
+ ``ch_mean_`` : dict
The mean value for each channel type
- std_ : dict
+ ``std_`` : dict
The standard deviation for each channel type
"""
def __init__(self, info, with_mean=True, with_std=True):
@@ -147,6 +148,8 @@ class Scaler(TransformerMixin):
return X
+ at deprecated("EpochsVectorizer will be deprecated in version 0.14; "
+ "use Vectorizer instead")
class EpochsVectorizer(TransformerMixin):
"""EpochsVectorizer transforms epoch data to fit into a scikit-learn pipeline.
@@ -245,6 +248,108 @@ class EpochsVectorizer(TransformerMixin):
return X.reshape(-1, self.n_channels, self.n_times)
+class Vectorizer(TransformerMixin):
+ """Transforms n-dimensional array into 2D array of n_samples by n_features.
+
+ This class reshapes an n-dimensional array into an n_samples * n_features
+ array, usable by the estimators and transformers of scikit-learn.
+
+ Examples
+ --------
+ clf = make_pipeline(SpatialFilter(), _XdawnTransformer(), Vectorizer(),
+ LogisticRegression())
+
+ Attributes
+ ----------
+ ``features_shape_`` : tuple
+ Stores the original shape of data.
+ """
+
+ def fit(self, X, y=None):
+ """Stores the shape of the features of X.
+
+ Parameters
+ ----------
+ X : array-like
+ The data to fit. Can be, for example a list, or an array of at
+ least 2d. The first dimension must be of length n_samples, where
+ samples are the independent samples used by the estimator
+ (e.g. n_epochs for epoched data).
+ y : None | array, shape (n_samples,)
+ Used for scikit-learn compatibility.
+
+ Returns
+ -------
+ self : Instance of Vectorizer
+ Return the modified instance.
+ """
+ X = np.asarray(X)
+ self.features_shape_ = X.shape[1:]
+ return self
+
+ def transform(self, X):
+ """Convert given array into two dimensions.
+
+ Parameters
+ ----------
+ X : array-like
+ The data to fit. Can be, for example a list, or an array of at
+ least 2d. The first dimension must be of length n_samples, where
+ samples are the independent samples used by the estimator
+ (e.g. n_epochs for epoched data).
+
+ Returns
+ -------
+ X : array, shape (n_samples, n_features)
+ The transformed data.
+ """
+ X = np.asarray(X)
+ if X.shape[1:] != self.features_shape_:
+ raise ValueError("Shape of X used in fit and transform must be "
+ "same")
+ return X.reshape(len(X), -1)
+
+ def fit_transform(self, X, y=None):
+ """Fit the data, then transform in one step.
+
+ Parameters
+ ----------
+ X : array-like
+ The data to fit. Can be, for example a list, or an array of at
+ least 2d. The first dimension must be of length n_samples, where
+ samples are the independent samples used by the estimator
+ (e.g. n_epochs for epoched data).
+ y : None | array, shape (n_samples,)
+ Used for scikit-learn compatibility.
+
+ Returns
+ -------
+ X : array, shape (n_samples, -1)
+ The transformed data.
+ """
+ return self.fit(X).transform(X)
+
+ def inverse_transform(self, X):
+ """Transform 2D data back to its original feature shape.
+
+ Parameters
+ ----------
+ X : array-like, shape (n_samples, n_features)
+ Data to be transformed back to original shape.
+
+ Returns
+ -------
+ X : array
+ The data transformed into shape as used in fit. The first
+ dimension is of length n_samples.
+ """
+ X = np.asarray(X)
+ if X.ndim != 2:
+ raise ValueError("X should be of 2 dimensions but given has %s "
+ "dimension(s)" % X.ndim)
+ return X.reshape((len(X),) + self.features_shape_)
+
+
class PSDEstimator(TransformerMixin):
"""Compute power spectrum density (PSD) using a multi-taper method
@@ -392,9 +497,13 @@ class FilterEstimator(TransformerMixin):
verbose : bool, str, int, or None
If not None, override default verbose level (see mne.verbose).
Defaults to self.verbose.
+
+ See Also
+ --------
+ TemporalFilter
"""
- def __init__(self, info, l_freq, h_freq, picks=None, filter_length='10s',
- l_trans_bandwidth=0.5, h_trans_bandwidth=0.5, n_jobs=1,
+ def __init__(self, info, l_freq, h_freq, picks=None, filter_length='',
+ l_trans_bandwidth=None, h_trans_bandwidth=None, n_jobs=1,
method='fft', iir_params=None, verbose=None):
self.info = info
self.l_freq = l_freq
@@ -520,3 +629,245 @@ class FilterEstimator(TransformerMixin):
picks=self.picks, n_jobs=self.n_jobs,
copy=False, verbose=False)
return epochs_data
+
+
+class UnsupervisedSpatialFilter(TransformerMixin, BaseEstimator):
+ """Fit and transform with an unsupervised spatial filtering across time
+ and samples.
+
+ Parameters
+ ----------
+ estimator : scikit-learn estimator
+ Estimator using some decomposition algorithm.
+ average : bool, defaults to False
+ If True, the estimator is fitted on the average across samples
+ (e.g. epochs).
+ """
+ def __init__(self, estimator, average=False):
+ # XXX: Use _check_estimator #3381
+ for attr in ('fit', 'transform', 'fit_transform'):
+ if not hasattr(estimator, attr):
+ raise ValueError('estimator must be a scikit-learn '
+ 'transformer, missing %s method' % attr)
+
+ if not isinstance(average, bool):
+ raise ValueError("average parameter must be of bool type, got "
+ "%s instead" % type(bool))
+
+ self.estimator = estimator
+ self.average = average
+
+ def fit(self, X, y=None):
+ """Fit the spatial filters.
+
+ Parameters
+ ----------
+ X : array, shape (n_epochs, n_channels, n_times)
+ The data to be filtered.
+ y : None | array, shape (n_samples,)
+ Used for scikit-learn compatibility.
+
+ Returns
+ -------
+ self : Instance of UnsupervisedSpatialFilter
+ Return the modified instance.
+ """
+ if self.average:
+ X = np.mean(X, axis=0).T
+ else:
+ n_epochs, n_channels, n_times = X.shape
+ # trial as time samples
+ X = np.transpose(X, (1, 0, 2)).reshape((n_channels, n_epochs *
+ n_times)).T
+ self.estimator.fit(X)
+ return self
+
+ def fit_transform(self, X, y=None):
+ """Transform the data to its filtered components after fitting.
+
+ Parameters
+ ----------
+ X : array, shape (n_epochs, n_channels, n_times)
+ The data to be filtered.
+ y : None | array, shape (n_samples,)
+ Used for scikit-learn compatibility.
+
+ Returns
+ -------
+ X : array, shape (n_trials, n_channels, n_times)
+ The transformed data.
+ """
+ return self.fit(X).transform(X)
+
+ def transform(self, X):
+ """Transform the data to its spatial filters.
+
+ Parameters
+ ----------
+ X : array, shape (n_epochs, n_channels, n_times)
+ The data to be filtered.
+
+ Returns
+ -------
+ X : array, shape (n_trials, n_channels, n_times)
+ The transformed data.
+ """
+ n_epochs, n_channels, n_times = X.shape
+ # trial as time samples
+ X = np.transpose(X, [1, 0, 2]).reshape([n_channels, n_epochs *
+ n_times]).T
+ X = self.estimator.transform(X)
+ X = np.reshape(X.T, [-1, n_epochs, n_times]).transpose([1, 0, 2])
+ return X
+
+
+class TemporalFilter(TransformerMixin):
+ """Estimator to filter data array along the last dimension.
+
+ Applies a zero-phase low-pass, high-pass, band-pass, or band-stop
+ filter to the channels.
+
+ l_freq and h_freq are the frequencies below which and above which,
+ respectively, to filter out of the data. Thus the uses are:
+
+ - l_freq < h_freq: band-pass filter
+ - l_freq > h_freq: band-stop filter
+ - l_freq is not None, h_freq is None: low-pass filter
+ - l_freq is None, h_freq is not None: high-pass filter
+
+ See ``mne.filter.filter_data``.
+
+ Parameters
+ ----------
+ l_freq : float | None
+ Low cut-off frequency in Hz. If None the data are only low-passed.
+ h_freq : float | None
+ High cut-off frequency in Hz. If None the data are only
+ high-passed.
+ sfreq : float, defaults to 1.0
+ Sampling frequency in Hz.
+ filter_length : str | int, defaults to 'auto'
+ Length of the FIR filter to use (if applicable):
+ * int: specified length in samples.
+ * 'auto' (default in 0.14): the filter length is chosen based
+ on the size of the transition regions (7 times the reciprocal
+ of the shortest transition band).
+ * str: (default in 0.13 is "10s") a human-readable time in
+ units of "s" or "ms" (e.g., "10s" or "5500ms") will be
+ converted to that number of samples if ``phase="zero"``, or
+ the shortest power-of-two length at least that duration for
+ ``phase="zero-double"``.
+ l_trans_bandwidth : float | str, defaults to 'auto'
+ Width of the transition band at the low cut-off frequency in Hz
+ (high pass or cutoff 1 in bandpass). Can be "auto"
+ (default in 0.14) to use a multiple of ``l_freq``::
+ min(max(l_freq * 0.25, 2), l_freq)
+ Only used for ``method='fir'``.
+ h_trans_bandwidth : float | str, defaults to 'auto'
+ Width of the transition band at the high cut-off frequency in Hz
+ (low pass or cutoff 2 in bandpass). Can be "auto"
+ (default in 0.14) to use a multiple of ``h_freq``::
+ min(max(h_freq * 0.25, 2.), info['sfreq'] / 2. - h_freq)
+ Only used for ``method='fir'``.
+ n_jobs : int | str, defaults to 1
+ Number of jobs to run in parallel. Can be 'cuda' if scikits.cuda
+ is installed properly, CUDA is initialized, and method='fft'.
+ method : str, defaults to 'fir'
+ 'fir' will use overlap-add FIR filtering, 'iir' will use IIR
+ forward-backward filtering (via filtfilt).
+ iir_params : dict | None, defaults to None
+ Dictionary of parameters to use for IIR filtering.
+ See mne.filter.construct_iir_filter for details. If iir_params
+ is None and method="iir", 4th order Butterworth will be used.
+ fir_window : str, defaults to 'hamming'
+ The window to use in FIR design, can be "hamming", "hann",
+ or "blackman".
+ verbose : bool, str, int, or None, defaults to None
+ If not None, override default verbose level (see mne.verbose).
+ Defaults to self.verbose.
+
+ See Also
+ --------
+ FilterEstimator
+ Vectorizer
+ mne.filter.band_pass_filter
+ mne.filter.band_stop_filter
+ mne.filter.low_pass_filter
+ mne.filter.high_pass_filter
+ """
+ def __init__(self, l_freq=None, h_freq=None, sfreq=1.0,
+ filter_length='auto', l_trans_bandwidth='auto',
+ h_trans_bandwidth='auto', n_jobs=1, method='fir',
+ iir_params=None, fir_window='hamming', verbose=None):
+ self.l_freq = l_freq
+ self.h_freq = h_freq
+ self.sfreq = sfreq
+ self.filter_length = filter_length
+ self.l_trans_bandwidth = l_trans_bandwidth
+ self.h_trans_bandwidth = h_trans_bandwidth
+ self.n_jobs = n_jobs
+ self.method = method
+ self.iir_params = iir_params
+ self.fir_window = fir_window
+ self.verbose = verbose
+
+ if not isinstance(self.n_jobs, int) and self.n_jobs == 'cuda':
+ raise ValueError('n_jobs must be int or "cuda", got %s instead.'
+ % type(self.n_jobs))
+
+ def fit(self, X, y=None):
+ """Does nothing. For scikit-learn compatibility purposes.
+
+ Parameters
+ ----------
+ X : array, shape (n_epochs, n_channels, n_times) or or shape (n_channels, n_times) # noqa
+ The data to be filtered over the last dimension. The channels
+ dimension can be zero when passing a 2D array.
+ y : None
+ Not used, for scikit-learn compatibility issues.
+
+ Returns
+ -------
+ self : instance of Filterer
+ Returns the modified instance.
+ """
+ return self
+
+ def transform(self, X):
+ """Filters data along the last dimension.
+
+ Parameters
+ ----------
+ X : array, shape (n_epochs, n_channels, n_times) or shape (n_channels, n_times) # noqa
+ The data to be filtered over the last dimension. The channels
+ dimension can be zero when passing a 2D array.
+
+ Returns
+ -------
+ X : array, shape is same as used in input.
+ The data after filtering.
+ """
+ X = np.atleast_2d(X)
+
+ if X.ndim > 3:
+ raise ValueError("Array must be of at max 3 dimensions instead "
+ "got %s dimensional matrix" % (X.ndim))
+
+ shape = X.shape
+ X = X.reshape(-1, shape[-1])
+ (X, self.sfreq, self.l_freq, self.h_freq, self.l_trans_bandwidth,
+ self.h_trans_bandwidth, self.filter_length, _, self.fir_window) = \
+ _triage_filter_params(X, self.sfreq, self.l_freq, self.h_freq,
+ self.l_trans_bandwidth,
+ self.h_trans_bandwidth, self.filter_length,
+ self.method, phase='zero',
+ fir_window=self.fir_window)
+ X = filter_data(X, self.sfreq, self.l_freq, self.h_freq,
+ filter_length=self.filter_length,
+ l_trans_bandwidth=self.l_trans_bandwidth,
+ h_trans_bandwidth=self.h_trans_bandwidth,
+ n_jobs=self.n_jobs, method=self.method,
+ iir_params=self.iir_params, copy=False,
+ fir_window=self.fir_window,
+ verbose=self.verbose)
+ return X.reshape(shape)
diff --git a/mne/defaults.py b/mne/defaults.py
index 4f8faa2..fd657aa 100644
--- a/mne/defaults.py
+++ b/mne/defaults.py
@@ -7,29 +7,30 @@
from copy import deepcopy
DEFAULTS = dict(
- color=dict(mag='darkblue', grad='b', eeg='k', eog='k', ecg='m',
- emg='k', ref_meg='steelblue', misc='k', stim='k',
- resp='k', chpi='k', exci='k', ias='k', syst='k',
- seeg='k', dipole='k', gof='k', bio='k', ecog='k'),
- config_opts=dict(),
+ color=dict(mag='darkblue', grad='b', eeg='k', eog='k', ecg='m', emg='k',
+ ref_meg='steelblue', misc='k', stim='k', resp='k', chpi='k',
+ exci='k', ias='k', syst='k', seeg='k', dipole='k', gof='k',
+ bio='k', ecog='k', hbo='darkblue', hbr='b'),
units=dict(eeg='uV', grad='fT/cm', mag='fT', eog='uV', misc='AU',
seeg='uV', dipole='nAm', gof='GOF', emg='uV', ecg='uV',
- bio='uV', ecog='uV'),
+ bio='uV', ecog='uV', hbo='uM', hbr='uM'),
scalings=dict(mag=1e15, grad=1e13, eeg=1e6, eog=1e6, emg=1e6, ecg=1e6,
- misc=1.0, seeg=1e4, dipole=1e9, gof=1.0, bio=1e6, ecog=1e6),
- scalings_plot_raw=dict(mag=1e-12, grad=4e-11, eeg=20e-6,
- eog=150e-6, ecg=5e-4, emg=1e-3,
- ref_meg=1e-12, misc=1e-3,
- stim=1, resp=1, chpi=1e-4, exci=1,
- ias=1, syst=1, seeg=1e-5, bio=1e-6, ecog=1e-4),
+ misc=1.0, seeg=1e4, dipole=1e9, gof=1.0, bio=1e6, ecog=1e6,
+ hbo=1e6, hbr=1e6),
+ scalings_plot_raw=dict(mag=1e-12, grad=4e-11, eeg=20e-6, eog=150e-6,
+ ecg=5e-4, emg=1e-3, ref_meg=1e-12, misc=1e-3,
+ stim=1, resp=1, chpi=1e-4, exci=1, ias=1, syst=1,
+ seeg=1e-5, bio=1e-6, ecog=1e-4, hbo=10e-6,
+ hbr=10e-6),
scalings_cov_rank=dict(mag=1e12, grad=1e11, eeg=1e5),
ylim=dict(mag=(-600., 600.), grad=(-200., 200.), eeg=(-200., 200.),
misc=(-5., 5.), seeg=(-200., 200.), dipole=(-100., 100.),
- gof=(0., 1.), bio=(-500., 500.), ecog=(-200., 200.)),
+ gof=(0., 1.), bio=(-500., 500.), ecog=(-200., 200.), hbo=(0, 20),
+ hbr=(0, 20)),
titles=dict(eeg='EEG', grad='Gradiometers', mag='Magnetometers',
misc='misc', seeg='sEEG', dipole='Dipole', eog='EOG',
gof='Goodness of fit', ecg='ECG', emg='EMG', bio='BIO',
- ecog='ECoG'),
+ ecog='ECoG', hbo='Oxyhemoglobin', hbr='Deoxyhemoglobin'),
mask_params=dict(marker='o',
markerfacecolor='w',
markeredgecolor='k',
diff --git a/mne/dipole.py b/mne/dipole.py
index b5d8fde..7abe161 100644
--- a/mne/dipole.py
+++ b/mne/dipole.py
@@ -4,6 +4,7 @@
# License: Simplified BSD
from copy import deepcopy
+from functools import partial
import re
import numpy as np
@@ -31,7 +32,6 @@ from .bem import _bem_find_surface, _bem_explain_surface
from .source_space import (_make_volume_source_space, SourceSpaces,
_points_outside_surface)
from .parallel import parallel_func
-from .fixes import partial
from .utils import logger, verbose, _time_mask, warn, _check_fname, check_fname
@@ -210,25 +210,48 @@ class Dipole(object):
from .viz import plot_dipole_amplitudes
return plot_dipole_amplitudes([self], [color], show)
- def __getitem__(self, idx_slice):
- """Handle indexing"""
- if isinstance(idx_slice, int): # make sure attributes stay 2d
- idx_slice = [idx_slice]
+ def __getitem__(self, item):
+ """Get a time slice
- selected_times = self.times[idx_slice].copy()
- selected_pos = self.pos[idx_slice, :].copy()
- selected_amplitude = self.amplitude[idx_slice].copy()
- selected_ori = self.ori[idx_slice, :].copy()
- selected_gof = self.gof[idx_slice].copy()
- selected_name = self.name
+ Parameters
+ ----------
+ item : array-like or slice
+ The slice of time points to use.
- new_dipole = Dipole(selected_times, selected_pos,
- selected_amplitude, selected_ori,
- selected_gof, selected_name)
- return new_dipole
+ Returns
+ -------
+ dip : instance of Dipole
+ The sliced dipole.
+ """
+ if isinstance(item, int): # make sure attributes stay 2d
+ item = [item]
+
+ selected_times = self.times[item].copy()
+ selected_pos = self.pos[item, :].copy()
+ selected_amplitude = self.amplitude[item].copy()
+ selected_ori = self.ori[item, :].copy()
+ selected_gof = self.gof[item].copy()
+ selected_name = self.name
+ return Dipole(
+ selected_times, selected_pos, selected_amplitude, selected_ori,
+ selected_gof, selected_name)
def __len__(self):
- """Handle len function"""
+ """The number of dipoles
+
+ Returns
+ -------
+ len : int
+ The number of dipoles.
+
+ Examples
+ --------
+ This can be used as::
+
+ >>> len(dipoles) # doctest: +SKIP
+ 10
+
+ """
return self.pos.shape[0]
@@ -363,29 +386,80 @@ def read_dipole(fname, verbose=None):
_check_fname(fname, overwrite=True, must_exist=True)
if fname.endswith('.fif') or fname.endswith('.fif.gz'):
return _read_dipole_fixed(fname)
- try:
- data = np.loadtxt(fname, comments='%')
- except:
- data = np.loadtxt(fname, comments='#') # handle 2 types of comments...
- name = None
+ else:
+ return _read_dipole_text(fname)
+
+
+def _read_dipole_text(fname):
+ """Read a dipole text file."""
+ # Figure out the special fields
+ need_header = True
+ def_line = name = None
+ # There is a bug in older np.loadtxt regarding skipping fields,
+ # so just read the data ourselves (need to get name and header anyway)
+ data = list()
with open(fname, 'r') as fid:
- for line in fid.readlines():
- if line.startswith('##') or line.startswith('%%'):
- m = re.search('Name "(.*) dipoles"', line)
- if m:
- name = m.group(1)
- break
- if data.ndim == 1:
- data = data[None, :]
+ for line in fid:
+ if not (line.startswith('%') or line.startswith('#')):
+ need_header = False
+ data.append(line.strip().split())
+ else:
+ if need_header:
+ def_line = line
+ if line.startswith('##') or line.startswith('%%'):
+ m = re.search('Name "(.*) dipoles"', line)
+ if m:
+ name = m.group(1)
+ del line
+ data = np.atleast_2d(np.array(data, float))
+ if def_line is None:
+ raise IOError('Dipole text file is missing field definition '
+ 'comment, cannot parse %s' % (fname,))
+ # actually parse the fields
+ def_line = def_line.lstrip('%').lstrip('#').strip()
+ # MNE writes it out differently than Elekta, let's standardize them...
+ fields = re.sub('([X|Y|Z] )\(mm\)', # "X (mm)", etc.
+ lambda match: match.group(1).strip() + '/mm', def_line)
+ fields = re.sub('\((.*?)\)', # "Q(nAm)", etc.
+ lambda match: '/' + match.group(1), fields)
+ fields = re.sub('(begin|end) ', # "begin" and "end" with no units
+ lambda match: match.group(1) + '/ms', fields)
+ fields = fields.lower().split()
+ used_fields = ('begin/ms',
+ 'x/mm', 'y/mm', 'z/mm',
+ 'q/nam',
+ 'qx/nam', 'qy/nam', 'qz/nam',
+ 'g/%')
+ missing_fields = sorted(set(used_fields) - set(fields))
+ if len(missing_fields) > 0:
+ raise RuntimeError('Could not find necessary fields in header: %s'
+ % (missing_fields,))
+ ignored_fields = sorted(set(fields) - set(used_fields) - set(['end/ms']))
+ if len(ignored_fields) > 0:
+ warn('Ignoring extra fields in dipole file: %s' % (ignored_fields,))
+ if len(fields) != data.shape[1]:
+ raise IOError('More data fields (%s) found than data columns (%s): %s'
+ % (len(fields), data.shape[1], fields))
+
logger.info("%d dipole(s) found" % len(data))
- times = data[:, 0] / 1000.
- pos = 1e-3 * data[:, 2:5] # put data in meters
- amplitude = data[:, 5]
+
+ if 'end/ms' in fields:
+ if np.diff(data[:, [fields.index('begin/ms'),
+ fields.index('end/ms')]], 1, -1).any():
+ warn('begin and end fields differed, but only begin will be used '
+ 'to store time values')
+
+ # Find the correct column in our data array, then scale to proper units
+ idx = [fields.index(field) for field in used_fields]
+ assert len(idx) == 9
+ times = data[:, idx[0]] / 1000.
+ pos = 1e-3 * data[:, idx[1:4]] # put data in meters
+ amplitude = data[:, idx[4]]
norm = amplitude.copy()
amplitude /= 1e9
norm[norm == 0] = 1
- ori = data[:, 6:9] / norm[:, np.newaxis]
- gof = data[:, 9]
+ ori = data[:, idx[5:8]] / norm[:, np.newaxis]
+ gof = data[:, idx[8]]
return Dipole(times, pos, amplitude, ori, gof, name)
@@ -650,7 +724,7 @@ def _fit_dipole(min_dist_to_inner_skull, B_orig, t, guess_rrs,
B2 = np.dot(B, B)
if B2 == 0:
warn('Zero field found for time %s' % t)
- return np.zeros(3), 0, np.zeros(3), 0
+ return np.zeros(3), 0, np.zeros(3), 0, B
idx = np.argmin([_fit_eval(guess_rrs[[fi], :], B, B2, fwd_svd)
for fi, fwd_svd in enumerate(guess_data['fwd_svd'])])
@@ -720,8 +794,8 @@ def fit_dipole(evoked, cov, bem, trans=None, min_dist=5., n_jobs=1,
The dataset to fit.
cov : str | instance of Covariance
The noise covariance.
- bem : str | dict
- The BEM filename (str) or a loaded sphere model (dict).
+ bem : str | instance of ConductorModel
+ The BEM filename (str) or conductor model.
trans : str | None
The head<->MRI transform filename. Must be provided unless BEM
is a sphere model.
@@ -795,15 +869,19 @@ def fit_dipole(evoked, cov, bem, trans=None, min_dist=5., n_jobs=1,
del min_dist
# Figure out our inputs
- neeg = len(pick_types(info, meg=False, eeg=True, exclude=[]))
+ neeg = len(pick_types(info, meg=False, eeg=True, ref_meg=False,
+ exclude=[]))
if isinstance(bem, string_types):
- logger.info('BEM : %s' % bem)
+ bem_extra = bem
+ else:
+ bem_extra = repr(bem)
+ logger.info('BEM : %s' % bem_extra)
if trans is not None:
logger.info('MRI transform : %s' % trans)
mri_head_t, trans = _get_trans(trans)
else:
mri_head_t = Transform('head', 'mri', np.eye(4))
- bem = _setup_bem(bem, bem, neeg, mri_head_t, verbose=False)
+ bem = _setup_bem(bem, bem_extra, neeg, mri_head_t, verbose=False)
if not bem['is_sphere']:
if trans is None:
raise ValueError('mri must not be None if BEM is provided')
@@ -909,7 +987,7 @@ def fit_dipole(evoked, cov, bem, trans=None, min_dist=5., n_jobs=1,
# Whitener for the data
logger.info('Decomposing the sensor noise covariance matrix...')
- picks = pick_types(info, meg=True, eeg=True)
+ picks = pick_types(info, meg=True, eeg=True, ref_meg=False)
# In case we want to more closely match MNE-C for debugging:
# from .io.pick import pick_info
@@ -1009,3 +1087,63 @@ def fit_dipole(evoked, cov, bem, trans=None, min_dist=5., n_jobs=1,
residual = out[4]
logger.info('%d time points fitted' % len(dipoles.times))
return dipoles, residual
+
+
+def get_phantom_dipoles(kind='vectorview'):
+ """Get standard phantom dipole locations and orientations
+
+ Parameters
+ ----------
+ kind : str
+ Get the information for the given system.
+
+ ``vectorview`` (default)
+ The Neuromag VectorView phantom.
+
+ ``122``
+ The Neuromag-122 phantom. This has the same dipoles
+ as the VectorView phantom, but in a different order.
+
+ Returns
+ -------
+ pos : ndarray, shape (n_dipoles, 3)
+ The dipole positions.
+ ori : ndarray, shape (n_dipoles, 3)
+ The dipole orientations.
+ """
+ _valid_types = ('122', 'vectorview')
+ if not isinstance(kind, string_types) or kind not in _valid_types:
+ raise ValueError('kind must be one of %s, got %s'
+ % (_valid_types, kind,))
+ if kind in ('122', 'vectorview'):
+ a = np.array([59.7, 48.6, 35.8, 24.8, 37.2, 27.5, 15.8, 7.9])
+ b = np.array([46.1, 41.9, 38.3, 31.5, 13.9, 16.2, 20, 19.3])
+ x = np.concatenate((a, [0] * 8, -b, [0] * 8))
+ y = np.concatenate(([0] * 8, -a, [0] * 8, b))
+ c = [22.9, 23.5, 25.5, 23.1, 52, 46.4, 41, 33]
+ d = [44.4, 34, 21.6, 12.7, 62.4, 51.5, 39.1, 27.9]
+ z = np.concatenate((c, c, d, d))
+ pos = np.vstack((x, y, z)).T / 1000.
+ if kind == 122:
+ reorder = (list(range(8, 16)) + list(range(0, 8)) +
+ list(range(24, 32) + list(range(16, 24))))
+ pos = pos[reorder]
+ # Locs are always in XZ or YZ, and so are the oris. The oris are
+ # also in the same plane and tangential, so it's easy to determine
+ # the orientation.
+ ori = list()
+ for this_pos in pos:
+ this_ori = np.zeros(3)
+ idx = np.where(this_pos == 0)[0]
+ # assert len(idx) == 1
+ idx = np.setdiff1d(np.arange(3), idx[0])
+ this_ori[idx] = (this_pos[idx][::-1] /
+ np.linalg.norm(this_pos[idx])) * [1, -1]
+ # Now we have this quality, which we could uncomment to
+ # double-check:
+ # np.testing.assert_allclose(np.dot(this_ori, this_pos) /
+ # np.linalg.norm(this_pos), 0,
+ # atol=1e-15)
+ ori.append(this_ori)
+ ori = np.array(ori)
+ return pos, ori
diff --git a/mne/epochs.py b/mne/epochs.py
index 8dbc202..7ffebdd 100644
--- a/mne/epochs.py
+++ b/mne/epochs.py
@@ -32,19 +32,18 @@ from .io.pick import (pick_types, channel_indices_by_type, channel_type,
from .io.proj import setup_proj, ProjMixin, _proj_equal
from .io.base import _BaseRaw, ToDataFrameMixin, TimeMixin
from .bem import _check_origin
-from .evoked import EvokedArray
+from .evoked import EvokedArray, _check_decim
from .baseline import rescale, _log_rescale
from .channels.channels import (ContainsMixin, UpdateChannelsMixin,
SetChannelsMixin, InterpolationMixin)
from .filter import resample, detrend, FilterMixin
-from .event import _read_events_fif
-from .fixes import in1d, _get_args
+from .event import _read_events_fif, make_fixed_length_events
+from .fixes import _get_args
from .viz import (plot_epochs, plot_epochs_psd, plot_epochs_psd_topomap,
- plot_epochs_image, plot_topo_image_epochs)
+ plot_epochs_image, plot_topo_image_epochs, plot_drop_log)
from .utils import (check_fname, logger, verbose, _check_type_picks,
- _time_mask, check_random_state, object_hash, warn,
- _check_copy_dep)
-from .utils import deprecated
+ _time_mask, check_random_state, warn, _check_copy_dep,
+ sizeof_fmt, SizeMixin, copy_function_doc_to_method_doc)
from .externals.six import iteritems, string_types
from .externals.six.moves import zip
@@ -142,7 +141,7 @@ def _save_split(epochs, fname, part_idx, n_parts):
class _BaseEpochs(ProjMixin, ContainsMixin, UpdateChannelsMixin,
SetChannelsMixin, InterpolationMixin, FilterMixin,
- ToDataFrameMixin, TimeMixin):
+ ToDataFrameMixin, TimeMixin, SizeMixin):
"""Abstract base class for Epochs-type classes
This class provides basic functionality and should never be instantiated
@@ -152,7 +151,7 @@ class _BaseEpochs(ProjMixin, ContainsMixin, UpdateChannelsMixin,
baseline=(None, 0), raw=None,
picks=None, name='Unknown', reject=None, flat=None,
decim=1, reject_tmin=None, reject_tmax=None, detrend=None,
- add_eeg_ref=True, proj=True, on_missing='error',
+ add_eeg_ref=False, proj=True, on_missing='error',
preload_at_end=False, selection=None, drop_log=None,
verbose=None):
@@ -188,7 +187,8 @@ class _BaseEpochs(ProjMixin, ContainsMixin, UpdateChannelsMixin,
raise ValueError('events must be an array of type int')
if events.ndim != 2 or events.shape[1] != 3:
raise ValueError('events must be 2D with 3 columns')
-
+ if len(np.unique(events[:, 0])) != len(events):
+ raise RuntimeError('Event time samples were not unique')
for key, val in self.event_id.items():
if val not in events[:, 2]:
msg = ('No matching events found for %s '
@@ -201,7 +201,7 @@ class _BaseEpochs(ProjMixin, ContainsMixin, UpdateChannelsMixin,
pass
values = list(self.event_id.values())
- selected = in1d(events[:, 2], values)
+ selected = np.in1d(events[:, 2], values)
if selection is None:
self.selection = np.where(selected)[0]
else:
@@ -237,25 +237,13 @@ class _BaseEpochs(ProjMixin, ContainsMixin, UpdateChannelsMixin,
if (reject_tmin is not None) and (reject_tmax is not None):
if reject_tmin >= reject_tmax:
raise ValueError('reject_tmin needs to be < reject_tmax')
- if detrend not in [None, 0, 1]:
+ if (detrend not in [None, 0, 1]) or isinstance(detrend, bool):
raise ValueError('detrend must be None, 0, or 1')
# check that baseline is in available data
- if baseline is not None:
- baseline_tmin, baseline_tmax = baseline
- tstep = 1. / info['sfreq']
- if baseline_tmin is not None:
- if baseline_tmin < tmin - tstep:
- err = ("Baseline interval (tmin = %s) is outside of epoch "
- "data (tmin = %s)" % (baseline_tmin, tmin))
- raise ValueError(err)
- if baseline_tmax is not None:
- if baseline_tmax > tmax + tstep:
- err = ("Baseline interval (tmax = %s) is outside of epoch "
- "data (tmax = %s)" % (baseline_tmax, tmax))
- raise ValueError(err)
if tmin > tmax:
raise ValueError('tmin has to be less than or equal to tmax')
+ _check_baseline(baseline, tmin, tmax, info['sfreq'])
_log_rescale(baseline)
self.baseline = baseline
self.reject_tmin = reject_tmin
@@ -309,11 +297,10 @@ class _BaseEpochs(ProjMixin, ContainsMixin, UpdateChannelsMixin,
logger.info('Entering delayed SSP mode.')
else:
self._do_delayed_proj = False
-
+ add_eeg_ref = _dep_eeg_ref(add_eeg_ref) if 'eeg' in self else False
activate = False if self._do_delayed_proj else proj
self._projector, self.info = setup_proj(self.info, add_eeg_ref,
activate=activate)
-
if preload_at_end:
assert self._data is None
assert self.preload is False
@@ -341,7 +328,7 @@ class _BaseEpochs(ProjMixin, ContainsMixin, UpdateChannelsMixin,
.. versionadded:: 0.10.0
"""
if self.preload:
- return
+ return self
self._data = self._get_data()
self.preload = True
self._decim_slice = slice(None, None, None)
@@ -350,17 +337,16 @@ class _BaseEpochs(ProjMixin, ContainsMixin, UpdateChannelsMixin,
assert self._data.shape[-1] == len(self.times)
return self
- def decimate(self, decim, copy=None, offset=0):
+ def decimate(self, decim, offset=0):
"""Decimate the epochs
+ .. note:: No filtering is performed. To avoid aliasing, ensure
+ your data are properly lowpassed.
+
Parameters
----------
decim : int
The amount to decimate data.
- copy : bool
- This parameter has been deprecated and will be removed in 0.13.
- Use inst.copy() instead.
- Whether to return a new instance or modify in place.
offset : int
Apply an offset to where the decimation starts relative to the
sample corresponding to t=0. The offset is in samples at the
@@ -373,6 +359,12 @@ class _BaseEpochs(ProjMixin, ContainsMixin, UpdateChannelsMixin,
epochs : instance of Epochs
The decimated Epochs object.
+ See Also
+ --------
+ Evoked.decimate
+ Epochs.resample
+ Raw.resample
+
Notes
-----
Decimation can be done multiple times. For example,
@@ -381,62 +373,39 @@ class _BaseEpochs(ProjMixin, ContainsMixin, UpdateChannelsMixin,
.. versionadded:: 0.10.0
"""
- if decim < 1 or decim != int(decim):
- raise ValueError('decim must be an integer > 0')
- decim = int(decim)
- epochs = _check_copy_dep(self, copy)
- del self
-
- new_sfreq = epochs.info['sfreq'] / float(decim)
- lowpass = epochs.info['lowpass']
- if decim > 1 and lowpass is None:
- warn('The measurement information indicates data is not low-pass '
- 'filtered. The decim=%i parameter will result in a sampling '
- 'frequency of %g Hz, which can cause aliasing artifacts.'
- % (decim, new_sfreq))
- elif decim > 1 and new_sfreq < 2.5 * lowpass:
- warn('The measurement information indicates a low-pass frequency '
- 'of %g Hz. The decim=%i parameter will result in a sampling '
- 'frequency of %g Hz, which can cause aliasing artifacts.'
- % (lowpass, decim, new_sfreq)) # > 50% nyquist lim
- offset = int(offset)
- if not 0 <= offset < decim:
- raise ValueError('decim must be at least 0 and less than %s, got '
- '%s' % (decim, offset))
- epochs._decim *= decim
- start_idx = int(round(epochs._raw_times[0] * (epochs.info['sfreq'] *
- epochs._decim)))
- i_start = start_idx % epochs._decim
- decim_slice = slice(i_start + offset, len(epochs._raw_times),
- epochs._decim)
- epochs.info['sfreq'] = new_sfreq
- if epochs.preload:
- epochs._data = epochs._data[:, :, decim_slice].copy()
- epochs._raw_times = epochs._raw_times[decim_slice].copy()
- epochs._decim_slice = slice(None, None, None)
- epochs._decim = 1
- epochs.times = epochs._raw_times
+ decim, offset, new_sfreq = _check_decim(self.info, decim, offset)
+ start_idx = int(round(-self._raw_times[0] * (self.info['sfreq'] *
+ self._decim)))
+ self._decim *= decim
+ i_start = start_idx % self._decim + offset
+ decim_slice = slice(i_start, None, self._decim)
+ self.info['sfreq'] = new_sfreq
+ if self.preload:
+ self._data = self._data[:, :, decim_slice].copy()
+ self._raw_times = self._raw_times[decim_slice].copy()
+ self._decim_slice = slice(None)
+ self._decim = 1
+ self.times = self._raw_times
else:
- epochs._decim_slice = decim_slice
- epochs.times = epochs._raw_times[epochs._decim_slice]
- return epochs
+ self._decim_slice = decim_slice
+ self.times = self._raw_times[self._decim_slice]
+ return self
@verbose
- def apply_baseline(self, baseline, copy=None, verbose=None):
+ def apply_baseline(self, baseline=(None, 0), verbose=None):
"""Baseline correct epochs
Parameters
----------
baseline : tuple of length 2
- The time interval to apply baseline correction. (a, b) is the
- interval is between "a (s)" and "b (s)". If a is None the beginning
- of the data is used and if b is None then b is set to the end of
- the interval. If baseline is equal to (None, None) all the time
- interval is used.
- copy : bool
- This parameter has been deprecated and will be removed in 0.13.
- Use inst.copy() instead.
- Whether to return a new instance or modify in place.
+ The time interval to apply baseline correction. If None do not
+ apply it. If baseline is (a, b) the interval is between "a (s)" and
+ "b (s)". If a is None the beginning of the data is used and if b is
+ None then b is set to the end of the interval. If baseline is equal
+ to (None, None) all the time interval is used. Correction is
+ applied by computing mean of the baseline period and subtracting it
+ from the data. The baseline (a, b) includes both endpoints, i.e.
+ all timepoints t such that a <= t <= b.
verbose : bool, str, int, or None
If not None, override default verbose level (see mne.verbose).
@@ -451,21 +420,23 @@ class _BaseEpochs(ProjMixin, ContainsMixin, UpdateChannelsMixin,
.. versionadded:: 0.10.0
"""
- if not isinstance(baseline, tuple) or len(baseline) != 2:
- raise ValueError('`baseline=%s` is an invalid argument.'
- % str(baseline))
-
- epochs = _check_copy_dep(self, copy)
- picks = _pick_data_channels(epochs.info, exclude=[], with_ref_meg=True)
- picks_aux = _pick_aux_channels(epochs.info, exclude=[])
+ if not self.preload:
+ # Eventually we can relax this restriction, but it will require
+ # more careful checking of baseline (e.g., refactor with the
+ # _BaseEpochs.__init__ checks)
+ raise RuntimeError('Data must be loaded to apply a new baseline')
+ _check_baseline(baseline, self.tmin, self.tmax, self.info['sfreq'])
+
+ picks = _pick_data_channels(self.info, exclude=[], with_ref_meg=True)
+ picks_aux = _pick_aux_channels(self.info, exclude=[])
picks = np.sort(np.concatenate((picks, picks_aux)))
- data = epochs._data
+ data = self._data
data[:, picks, :] = rescale(data[:, picks, :], self.times, baseline,
copy=False)
- epochs.baseline = baseline
+ self.baseline = baseline
- return epochs
+ return self
def _reject_setup(self, reject, flat):
"""Sets self._reject_time and self._channel_type_idx"""
@@ -577,7 +548,8 @@ class _BaseEpochs(ProjMixin, ContainsMixin, UpdateChannelsMixin,
# Baseline correct
picks = pick_types(self.info, meg=True, eeg=True, stim=False,
ref_meg=True, eog=True, ecg=True, seeg=True,
- emg=True, bio=True, ecog=True, exclude=[])
+ emg=True, bio=True, ecog=True, fnirs=True,
+ exclude=[])
epoch[picks] = rescale(epoch[picks], self._raw_times, self.baseline,
copy=False, verbose=False)
@@ -683,18 +655,13 @@ class _BaseEpochs(ProjMixin, ContainsMixin, UpdateChannelsMixin,
"""Wrapper for Py3k"""
return self.next(*args, **kwargs)
- def __hash__(self):
- if not self.preload:
- raise RuntimeError('Cannot hash epochs unless preloaded')
- return object_hash(dict(info=self.info, data=self._data))
-
def average(self, picks=None):
"""Compute average of epochs
Parameters
----------
picks : array-like of int | None
- If None only MEG, EEG, SEEG, and ECoG channels are kept
+ If None only MEG, EEG, SEEG, ECoG, and fNIRS channels are kept
otherwise the channels indices in picks are kept.
Returns
@@ -716,7 +683,7 @@ class _BaseEpochs(ProjMixin, ContainsMixin, UpdateChannelsMixin,
Parameters
----------
picks : array-like of int | None
- If None only MEG, EEG, SEEG, and ECoG channels are kept
+ If None only MEG, EEG, SEEG, ECoG, and fNIRS channels are kept
otherwise the channels indices in picks are kept.
Returns
@@ -797,295 +764,54 @@ class _BaseEpochs(ProjMixin, ContainsMixin, UpdateChannelsMixin,
"""Channel names"""
return self.info['ch_names']
- def plot(self, picks=None, scalings=None, show=True,
- block=False, n_epochs=20,
- n_channels=20, title=None):
- """Visualize epochs.
-
- Bad epochs can be marked with a left click on top of the epoch. Bad
- channels can be selected by clicking the channel name on the left side
- of the main axes. Calling this function drops all the selected bad
- epochs as well as bad epochs marked beforehand with rejection
- parameters.
-
- Parameters
- ----------
- picks : array-like of int | None
- Channels to be included. If None only good data channels are used.
- Defaults to None
- scalings : dict | None
- Scaling factors for the traces. If any fields in scalings are
- 'auto', the scaling factor is set to match the 99.5th percentile of
- a subset of the corresponding data. If scalings == 'auto', all
- scalings fields are set to 'auto'. If any fields are 'auto' and
- data is not preloaded, a subset of epochs up to 100mb will be
- loaded. If None, defaults to::
-
- dict(mag=1e-12, grad=4e-11, eeg=20e-6, eog=150e-6, ecg=5e-4,
- emg=1e-3, ref_meg=1e-12, misc=1e-3, stim=1, resp=1,
- chpi=1e-4)
-
- show : bool
- Whether to show the figure or not.
- block : bool
- Whether to halt program execution until the figure is closed.
- Useful for rejecting bad trials on the fly by clicking on a
- sub plot.
- n_epochs : int
- The number of epochs per view.
- n_channels : int
- The number of channels per view on mne_browse_epochs. If trellis is
- True, this parameter has no effect. Defaults to 20.
- title : str | None
- The title of the window. If None, epochs name will be displayed.
- If trellis is True, this parameter has no effect.
- Defaults to None.
-
- Returns
- -------
- fig : Instance of matplotlib.figure.Figure
- The figure.
-
- Notes
- -----
- The arrow keys (up/down/left/right) can
- be used to navigate between channels and epochs and the scaling can be
- adjusted with - and + (or =) keys, but this depends on the backend
- matplotlib is configured to use (e.g., mpl.use(``TkAgg``) should work).
- Full screen mode can be toggled with f11 key. The amount of epochs
- and channels per view can be adjusted with home/end and
- page down/page up keys. Butterfly plot can be toggled with ``b`` key.
- Right mouse click adds a vertical line to the plot.
-
- .. versionadded:: 0.10.0
- """
+ @copy_function_doc_to_method_doc(plot_epochs)
+ def plot(self, picks=None, scalings=None, n_epochs=20, n_channels=20,
+ title=None, show=True, block=False):
return plot_epochs(self, picks=picks, scalings=scalings,
n_epochs=n_epochs, n_channels=n_channels,
title=title, show=show, block=block)
- def plot_psd(self, fmin=0, fmax=np.inf, proj=False, bandwidth=None,
- adaptive=False, low_bias=True, normalization='length',
- picks=None, ax=None, color='black', area_mode='std',
- area_alpha=0.33, dB=True, n_jobs=1, verbose=None, show=True):
- """Plot the power spectral density across epochs
-
- Parameters
- ----------
- fmin : float
- Start frequency to consider.
- fmax : float
- End frequency to consider.
- proj : bool
- Apply projection.
- bandwidth : float
- The bandwidth of the multi taper windowing function in Hz.
- The default value is a window half-bandwidth of 4.
- adaptive : bool
- Use adaptive weights to combine the tapered spectra into PSD
- (slow, use n_jobs >> 1 to speed up computation).
- low_bias : bool
- Only use tapers with more than 90% spectral concentration within
- bandwidth.
- normalization : str
- Either "full" or "length" (default). If "full", the PSD will
- be normalized by the sampling rate as well as the length of
- the signal (as in nitime).
- picks : array-like of int | None
- List of channels to use.
- ax : instance of matplotlib Axes | None
- Axes to plot into. If None, axes will be created.
- color : str | tuple
- A matplotlib-compatible color to use.
- area_mode : str | None
- Mode for plotting area. If 'std', the mean +/- 1 STD (across
- channels) will be plotted. If 'range', the min and max (across
- channels) will be plotted. Bad channels will be excluded from
- these calculations. If None, no area will be plotted.
- area_alpha : float
- Alpha for the area.
- dB : bool
- If True, transform data to decibels.
- n_jobs : int
- Number of jobs to run in parallel.
- verbose : bool, str, int, or None
- If not None, override default verbose level (see mne.verbose).
- show : bool
- Show figure if True.
-
- Returns
- -------
- fig : instance of matplotlib figure
- Figure distributing one image per channel across sensor topography.
- """
- return plot_epochs_psd(self, fmin=fmin, fmax=fmax, proj=proj,
- bandwidth=bandwidth, adaptive=adaptive,
- low_bias=low_bias, normalization=normalization,
- picks=picks, ax=ax, color=color,
- area_mode=area_mode, area_alpha=area_alpha,
- dB=dB, n_jobs=n_jobs, verbose=None, show=show)
-
- def plot_psd_topomap(self, bands=None, vmin=None, vmax=None, proj=False,
- bandwidth=None, adaptive=False, low_bias=True,
- normalization='length', ch_type=None,
+ @copy_function_doc_to_method_doc(plot_epochs_psd)
+ def plot_psd(self, fmin=0, fmax=np.inf, tmin=None, tmax=None, proj=False,
+ bandwidth=None, adaptive=False, low_bias=True,
+ normalization='length', picks=None, ax=None, color='black',
+ area_mode='std', area_alpha=0.33, dB=True, n_jobs=1,
+ show=True, verbose=None):
+ return plot_epochs_psd(self, fmin=fmin, fmax=fmax, tmin=tmin,
+ tmax=tmax, proj=proj, bandwidth=bandwidth,
+ adaptive=adaptive, low_bias=low_bias,
+ normalization=normalization, picks=picks, ax=ax,
+ color=color, area_mode=area_mode,
+ area_alpha=area_alpha, dB=dB, n_jobs=n_jobs,
+ show=show, verbose=verbose)
+
+ @copy_function_doc_to_method_doc(plot_epochs_psd_topomap)
+ def plot_psd_topomap(self, bands=None, vmin=None, vmax=None, tmin=None,
+ tmax=None, proj=False, bandwidth=None, adaptive=False,
+ low_bias=True, normalization='length', ch_type=None,
layout=None, cmap='RdBu_r', agg_fun=None, dB=True,
n_jobs=1, normalize=False, cbar_fmt='%0.3f',
- outlines='head', show=True, verbose=None):
- """Plot the topomap of the power spectral density across epochs
-
- Parameters
- ----------
- bands : list of tuple | None
- The lower and upper frequency and the name for that band. If None,
- (default) expands to:
-
- bands = [(0, 4, 'Delta'), (4, 8, 'Theta'), (8, 12, 'Alpha'),
- (12, 30, 'Beta'), (30, 45, 'Gamma')]
-
- vmin : float | callable | None
- The value specifying the lower bound of the color range.
- If None, and vmax is None, -vmax is used. Else np.min(data).
- If callable, the output equals vmin(data).
- vmax : float | callable | None
- The value specifying the upper bound of the color range.
- If None, the maximum absolute value is used. If callable, the
- output equals vmax(data). Defaults to None.
- proj : bool
- Apply projection.
- bandwidth : float
- The bandwidth of the multi taper windowing function in Hz.
- The default value is a window half-bandwidth of 4 Hz.
- adaptive : bool
- Use adaptive weights to combine the tapered spectra into PSD
- (slow, use n_jobs >> 1 to speed up computation).
- low_bias : bool
- Only use tapers with more than 90% spectral concentration within
- bandwidth.
- normalization : str
- Either "full" or "length" (default). If "full", the PSD will
- be normalized by the sampling rate as well as the length of
- the signal (as in nitime).
- ch_type : {None, 'mag', 'grad', 'planar1', 'planar2', 'eeg'}
- The channel type to plot. For 'grad', the gradiometers are
- collected in
- pairs and the RMS for each pair is plotted. If None, defaults to
- 'mag' if MEG data are present and to 'eeg' if only EEG data are
- present.
- layout : None | Layout
- Layout instance specifying sensor positions (does not need to
- be specified for Neuromag data). If possible, the correct layout
- file is inferred from the data; if no appropriate layout file was
- found, the layout is automatically generated from the sensor
- locations.
- cmap : matplotlib colormap
- Colormap. For magnetometers and eeg defaults to 'RdBu_r', else
- 'Reds'.
- agg_fun : callable
- The function used to aggregate over frequencies.
- Defaults to np.sum. if normalize is True, else np.mean.
- dB : bool
- If True, transform data to decibels (with ``10 * np.log10(data)``)
- following the application of `agg_fun`. Only valid if normalize
- is False.
- n_jobs : int
- Number of jobs to run in parallel.
- normalize : bool
- If True, each band will be divided by the total power. Defaults to
- False.
- cbar_fmt : str
- The colorbar format. Defaults to '%0.3f'.
- outlines : 'head' | 'skirt' | dict | None
- The outlines to be drawn. If 'head', the default head scheme will
- be drawn. If 'skirt' the head scheme will be drawn, but sensors are
- allowed to be plotted outside of the head circle. If dict, each key
- refers to a tuple of x and y positions, the values in 'mask_pos'
- will serve as image mask, and the 'autoshrink' (bool) field will
- trigger automated shrinking of the positions due to points outside
- the outline. Alternatively, a matplotlib patch object can be passed
- for advanced masking options, either directly or as a function that
- returns patches (required for multi-axis plots). If None, nothing
- will be drawn. Defaults to 'head'.
- show : bool
- Show figure if True.
- verbose : bool, str, int, or None
- If not None, override default verbose level (see mne.verbose).
-
- Returns
- -------
- fig : instance of matplotlib figure
- Figure distributing one image per channel across sensor topography.
- """
+ outlines='head', axes=None, show=True, verbose=None):
return plot_epochs_psd_topomap(
- self, bands=bands, vmin=vmin, vmax=vmax, proj=proj,
- bandwidth=bandwidth, adaptive=adaptive,
- low_bias=low_bias, normalization=normalization,
- ch_type=ch_type, layout=layout, cmap=cmap,
- agg_fun=agg_fun, dB=dB, n_jobs=n_jobs, normalize=normalize,
- cbar_fmt=cbar_fmt, outlines=outlines, show=show, verbose=None)
-
+ self, bands=bands, vmin=vmin, vmax=vmax, tmin=tmin, tmax=tmax,
+ proj=proj, bandwidth=bandwidth, adaptive=adaptive,
+ low_bias=low_bias, normalization=normalization, ch_type=ch_type,
+ layout=layout, cmap=cmap, agg_fun=agg_fun, dB=dB, n_jobs=n_jobs,
+ normalize=normalize, cbar_fmt=cbar_fmt, outlines=outlines,
+ axes=axes, show=show, verbose=verbose)
+
+ @copy_function_doc_to_method_doc(plot_topo_image_epochs)
def plot_topo_image(self, layout=None, sigma=0., vmin=None, vmax=None,
colorbar=True, order=None, cmap='RdBu_r',
layout_scale=.95, title=None, scalings=None,
- border='none', fig_facecolor='k', font_color='w',
- show=True):
- """Plot Event Related Potential / Fields image on topographies
-
- Parameters
- ----------
- layout: instance of Layout
- System specific sensor positions.
- sigma : float
- The standard deviation of the Gaussian smoothing to apply along the
- epoch axis to apply in the image. If 0., no smoothing is applied.
- vmin : float
- The min value in the image. The unit is uV for EEG channels,
- fT for magnetometers and fT/cm for gradiometers.
- vmax : float
- The max value in the image. The unit is uV for EEG channels,
- fT for magnetometers and fT/cm for gradiometers.
- colorbar : bool
- Display or not a colorbar.
- order : None | array of int | callable
- If not None, order is used to reorder the epochs on the y-axis
- of the image. If it's an array of int it should be of length
- the number of good epochs. If it's a callable the arguments
- passed are the times vector and the data as 2d array
- (data.shape[1] == len(times)).
- cmap : instance of matplotlib.pyplot.colormap
- Colors to be mapped to the values.
- layout_scale: float
- scaling factor for adjusting the relative size of the layout
- on the canvas.
- title : str
- Title of the figure.
- scalings : dict | None
- The scalings of the channel types to be applied for plotting. If
- None, defaults to `dict(eeg=1e6, grad=1e13, mag=1e15)`.
- border : str
- matplotlib borders style to be used for each sensor plot.
- fig_facecolor : str | obj
- The figure face color. Defaults to black.
- font_color : str | obj
- The color of tick labels in the colorbar. Defaults to white.
- show : bool
- Show figure if True.
-
- Returns
- -------
- fig : instance of matplotlib figure
- Figure distributing one image per channel across sensor topography.
- """
+ border='none', fig_facecolor='k', fig_background=None,
+ font_color='w', show=True):
return plot_topo_image_epochs(
self, layout=layout, sigma=sigma, vmin=vmin, vmax=vmax,
colorbar=colorbar, order=order, cmap=cmap,
layout_scale=layout_scale, title=title, scalings=scalings,
- border=border, fig_facecolor=fig_facecolor, font_color=font_color,
- show=show)
-
- @deprecated('drop_bad_epochs method has been renamed drop_bad. '
- 'drop_bad_epochs method will be removed in 0.13')
- def drop_bad_epochs(self, reject='existing', flat='existing'):
- """Drop bad epochs without retaining the epochs data"""
- return self.drop_bad(reject, flat)
+ border=border, fig_facecolor=fig_facecolor,
+ fig_background=fig_background, font_color=font_color, show=show)
@verbose
def drop_bad(self, reject='existing', flat='existing', verbose=None):
@@ -1158,114 +884,29 @@ class _BaseEpochs(ProjMixin, ContainsMixin, UpdateChannelsMixin,
"""
return _drop_log_stats(self.drop_log, ignore)
+ @copy_function_doc_to_method_doc(plot_drop_log)
def plot_drop_log(self, threshold=0, n_max_plot=20, subject='Unknown',
color=(0.9, 0.9, 0.9), width=0.8, ignore=('IGNORED',),
show=True):
- """Show the channel stats based on a drop_log from Epochs
-
- Parameters
- ----------
- threshold : float
- The percentage threshold to use to decide whether or not to
- plot. Default is zero (always plot).
- n_max_plot : int
- Maximum number of channels to show stats for.
- subject : str
- The subject name to use in the title of the plot.
- color : tuple | str
- Color to use for the bars.
- width : float
- Width of the bars.
- ignore : list
- The drop reasons to ignore.
- show : bool
- Show figure if True.
-
- Returns
- -------
- perc : float
- Total percentage of epochs dropped.
- fig : Instance of matplotlib.figure.Figure
- The figure.
- """
if not self._bad_dropped:
raise ValueError("You cannot use plot_drop_log since bad "
"epochs have not yet been dropped. "
"Use epochs.drop_bad().")
-
- from .viz import plot_drop_log
return plot_drop_log(self.drop_log, threshold, n_max_plot, subject,
color=color, width=width, ignore=ignore,
show=show)
+ @copy_function_doc_to_method_doc(plot_epochs_image)
def plot_image(self, picks=None, sigma=0., vmin=None,
vmax=None, colorbar=True, order=None, show=True,
units=None, scalings=None, cmap='RdBu_r',
- fig=None, overlay_times=None):
- """Plot Event Related Potential / Fields image
-
- Parameters
- ----------
- picks : int | array-like of int | None
- The indices of the channels to consider. If None, the first
- five good channels are plotted.
- sigma : float
- The standard deviation of the Gaussian smoothing to apply along
- the epoch axis to apply in the image. If 0., no smoothing is
- applied.
- vmin : float
- The min value in the image. The unit is uV for EEG channels,
- fT for magnetometers and fT/cm for gradiometers.
- vmax : float
- The max value in the image. The unit is uV for EEG channels,
- fT for magnetometers and fT/cm for gradiometers.
- colorbar : bool
- Display or not a colorbar.
- order : None | array of int | callable
- If not None, order is used to reorder the epochs on the y-axis
- of the image. If it's an array of int it should be of length
- the number of good epochs. If it's a callable the arguments
- passed are the times vector and the data as 2d array
- (data.shape[1] == len(times).
- show : bool
- Show figure if True.
- units : dict | None
- The units of the channel types used for axes lables. If None,
- defaults to `units=dict(eeg='uV', grad='fT/cm', mag='fT')`.
- scalings : dict | None
- The scalings of the channel types to be applied for plotting.
- If None, defaults to `scalings=dict(eeg=1e6, grad=1e13, mag=1e15,
- eog=1e6)`.
- cmap : matplotlib colormap
- Colormap.
- fig : matplotlib figure | None
- Figure instance to draw the image to. Figure must contain two
- axes for drawing the single trials and evoked responses. If
- None a new figure is created. Defaults to None.
- overlay_times : array-like, shape (n_epochs,) | None
- If not None the parameter is interpreted as time instants in
- seconds and is added to the image. It is typically useful to
- display reaction times. Note that it is defined with respect
- to the order of epochs such that overlay_times[0] corresponds
- to epochs[0].
-
- Returns
- -------
- figs : list of matplotlib figures
- One figure per channel displayed.
- """
+ fig=None, axes=None, overlay_times=None):
return plot_epochs_image(self, picks=picks, sigma=sigma, vmin=vmin,
vmax=vmax, colorbar=colorbar, order=order,
show=show, units=units, scalings=scalings,
- cmap=cmap, fig=fig,
+ cmap=cmap, fig=fig, axes=axes,
overlay_times=overlay_times)
- @deprecated('drop_epochs method has been renamed drop. '
- 'drop_epochs method will be removed in 0.13')
- def drop_epochs(self, indices, reason='USER', verbose=None):
- """Drop epochs based on indices or boolean mask"""
- return self.drop(indices, reason, verbose)
-
@verbose
def drop(self, indices, reason='USER', verbose=None):
"""Drop epochs based on indices or boolean mask
@@ -1443,7 +1084,27 @@ class _BaseEpochs(ProjMixin, ContainsMixin, UpdateChannelsMixin,
return self._get_data()
def __len__(self):
- """Number of epochs.
+ """The number of epochs
+
+ Returns
+ -------
+ n_epochs : int
+ The number of remaining epochs.
+
+ Notes
+ -----
+ This function only works if bad epochs have been dropped.
+
+ Examples
+ --------
+ This can be used as::
+
+ >>> epochs.drop_bad() # doctest: +SKIP
+ >>> len(epochs) # doctest: +SKIP
+ 43
+ >>> len(epochs.events) # doctest: +SKIP
+ 43
+
"""
if not self._bad_dropped:
raise RuntimeError('Since bad epochs have not been dropped, the '
@@ -1455,7 +1116,17 @@ class _BaseEpochs(ProjMixin, ContainsMixin, UpdateChannelsMixin,
return len(self.events)
def __iter__(self):
- """To make iteration over epochs easy.
+ """Function to make iteration over epochs easy
+
+ Notes
+ -----
+ This enables the use of this Python pattern::
+
+ >>> for epoch in epochs: # doctest: +SKIP
+ >>> print(epoch) # doctest: +SKIP
+
+ Where ``epoch`` is given by successive outputs of
+ :func:`mne.Epochs.next`.
"""
self._current = 0
while True:
@@ -1514,30 +1185,68 @@ class _BaseEpochs(ProjMixin, ContainsMixin, UpdateChannelsMixin,
return self.times[-1]
def __repr__(self):
- """ Build string representation
- """
+ """ Build string representation"""
s = 'n_events : %s ' % len(self.events)
s += '(all good)' if self._bad_dropped else '(good & bad)'
s += ', tmin : %s (s)' % self.tmin
s += ', tmax : %s (s)' % self.tmax
s += ', baseline : %s' % str(self.baseline)
+ s += ', ~%s' % (sizeof_fmt(self._size),)
+ s += ', data%s loaded' % ('' if self.preload else ' not')
if len(self.event_id) > 1:
counts = ['%r: %i' % (k, sum(self.events[:, 2] == v))
for k, v in sorted(self.event_id.items())]
s += ',\n %s' % ', '.join(counts)
class_name = self.__class__.__name__
- if class_name == '_BaseEpochs':
- class_name = 'Epochs'
+ class_name = 'Epochs' if class_name == '_BaseEpochs' else class_name
return '<%s | %s>' % (class_name, s)
- def _key_match(self, key):
- """Helper function for event dict use"""
- if key not in self.event_id:
- raise KeyError('Event "%s" is not in Epochs.' % key)
- return self.events[:, 2] == self.event_id[key]
+ def _keys_to_idx(self, keys):
+ """Find entries in event dict."""
+ return np.array([self.events[:, 2] == self.event_id[k]
+ for k in _hid_match(self.event_id, keys)]).any(axis=0)
+
+ def __getitem__(self, item):
+ """Return an Epochs object with a copied subset of epochs
+
+ Parameters
+ ----------
+ item : slice, array-like, str, or list
+ See below for use cases.
+
+ Returns
+ -------
+ epochs : instance of Epochs
+ See below for use cases.
+
+ Notes
+ -----
+ Epochs can be accessed as ``epochs[...]`` in several ways:
+
+ 1. ``epochs[idx]``: Return ``Epochs`` object with a subset of
+ epochs (supports single index and python-style slicing).
+
+ 2. ``epochs['name']``: Return ``Epochs`` object with a copy of the
+ subset of epochs corresponding to an experimental condition as
+ specified by 'name'.
+
+ If conditions are tagged by names separated by '/' (e.g.
+ 'audio/left', 'audio/right'), and 'name' is not in itself an
+ event key, this selects every event whose condition contains
+ the 'name' tag (e.g., 'left' matches 'audio/left' and
+ 'visual/left'; but not 'audio_left'). Note that tags like
+ 'auditory/left' and 'left/auditory' will be treated the
+ same way when accessed using tags.
+
+ 3. ``epochs[['name_1', 'name_2', ... ]]``: Return ``Epochs`` object
+ with a copy of the subset of epochs corresponding to multiple
+ experimental conditions as specified by
+ ``'name_1', 'name_2', ...`` .
+
+ If conditions are separated by '/', selects every item
+ containing every list tag (e.g. ['audio', 'left'] selects
+ 'audio/left' and 'audio/center/left', but not 'audio/right').
- def __getitem__(self, key):
- """Return an Epochs object with a subset of epochs
"""
data = self._data
del self._data
@@ -1545,26 +1254,15 @@ class _BaseEpochs(ProjMixin, ContainsMixin, UpdateChannelsMixin,
self._data, epochs._data = data, data
del self
- if isinstance(key, string_types):
- key = [key]
-
- if isinstance(key, (list, tuple)) and isinstance(key[0], string_types):
- if any('/' in k_i for k_i in epochs.event_id.keys()):
- if any(k_e not in epochs.event_id for k_e in key):
- # Select a given key if the requested set of
- # '/'-separated types are a subset of the types in that key
- key = [k for k in epochs.event_id.keys()
- if all(set(k_i.split('/')).issubset(k.split('/'))
- for k_i in key)]
- if len(key) == 0:
- raise KeyError('Attempting selection of events via '
- 'multiple/partial matching, but no '
- 'event matches all criteria.')
- select = np.any(np.atleast_2d([epochs._key_match(k)
- for k in key]), axis=0)
- epochs.name = '+'.join(key)
+ if isinstance(item, string_types):
+ item = [item]
+
+ if isinstance(item, (list, tuple)) and \
+ isinstance(item[0], string_types):
+ select = epochs._keys_to_idx(item)
+ epochs.name = '+'.join(item)
else:
- select = key if isinstance(key, slice) else np.atleast_1d(key)
+ select = item if isinstance(item, slice) else np.atleast_1d(item)
key_selection = epochs.selection[select]
for k in np.setdiff1d(epochs.selection, key_selection):
@@ -1580,7 +1278,7 @@ class _BaseEpochs(ProjMixin, ContainsMixin, UpdateChannelsMixin,
if v in epochs.events[:, 2])
return epochs
- def crop(self, tmin=None, tmax=None, copy=None):
+ def crop(self, tmin=None, tmax=None):
"""Crops a time interval from epochs object.
Parameters
@@ -1589,10 +1287,6 @@ class _BaseEpochs(ProjMixin, ContainsMixin, UpdateChannelsMixin,
Start time of selection in seconds.
tmax : float | None
End time of selection in seconds.
- copy : bool
- This parameter has been deprecated and will be removed in 0.13.
- Use inst.copy() instead.
- Whether to return a new instance or modify in place.
Returns
-------
@@ -1625,15 +1319,14 @@ class _BaseEpochs(ProjMixin, ContainsMixin, UpdateChannelsMixin,
tmax = self.tmax
tmask = _time_mask(self.times, tmin, tmax, sfreq=self.info['sfreq'])
- this_epochs = _check_copy_dep(self, copy)
- this_epochs.times = this_epochs.times[tmask]
- this_epochs._raw_times = this_epochs._raw_times[tmask]
- this_epochs._data = this_epochs._data[:, :, tmask]
- return this_epochs
+ self.times = self.times[tmask]
+ self._raw_times = self._raw_times[tmask]
+ self._data = self._data[:, :, tmask]
+ return self
@verbose
- def resample(self, sfreq, npad=None, window='boxcar', n_jobs=1,
- copy=None, verbose=None):
+ def resample(self, sfreq, npad='auto', window='boxcar', n_jobs=1,
+ verbose=None):
"""Resample preloaded data
Parameters
@@ -1648,10 +1341,6 @@ class _BaseEpochs(ProjMixin, ContainsMixin, UpdateChannelsMixin,
Window to use in resampling. See scipy.signal.resample.
n_jobs : int
Number of jobs to run in parallel.
- copy : bool
- This parameter has been deprecated and will be removed in 0.13.
- Use inst.copy() instead.
- Whether to return a new instance or modify in place.
verbose : bool, str, int, or None
If not None, override default verbose level (see mne.verbose).
Defaults to self.verbose.
@@ -1674,20 +1363,15 @@ class _BaseEpochs(ProjMixin, ContainsMixin, UpdateChannelsMixin,
# XXX this could operate on non-preloaded data, too
if not self.preload:
raise RuntimeError('Can only resample preloaded data')
- if npad is None:
- npad = 100
- warn('npad is currently taken to be 100, but will be changed to '
- '"auto" in 0.13. Please set the value explicitly.',
- DeprecationWarning)
- inst = _check_copy_dep(self, copy)
- o_sfreq = inst.info['sfreq']
- inst._data = resample(inst._data, sfreq, o_sfreq, npad, window=window,
+ o_sfreq = self.info['sfreq']
+ self._data = resample(self._data, sfreq, o_sfreq, npad, window=window,
n_jobs=n_jobs)
# adjust indirectly affected variables
- inst.info['sfreq'] = float(sfreq)
- inst.times = (np.arange(inst._data.shape[2], dtype=np.float) /
- sfreq + inst.times[0])
- return inst
+ self.info['sfreq'] = float(sfreq)
+ self.times = (np.arange(self._data.shape[2], dtype=np.float) /
+ sfreq + self.times[0])
+ self._raw_times = self.times
+ return self
def copy(self):
"""Return copy of Epochs instance"""
@@ -1766,7 +1450,7 @@ class _BaseEpochs(ProjMixin, ContainsMixin, UpdateChannelsMixin,
list. If 'mintime', timing differences between each event list
will be minimized.
copy : bool
- This parameter has been deprecated and will be removed in 0.13.
+ This parameter has been deprecated and will be removed in 0.14.
Use inst.copy() instead.
Whether to return a new instance or modify in place.
@@ -1792,7 +1476,7 @@ class _BaseEpochs(ProjMixin, ContainsMixin, UpdateChannelsMixin,
conditions will contribute evenly. E.g., it is possible to end up
with 70 'Nonspatial' trials, 69 'Left' and 1 'Right'.
"""
- epochs = _check_copy_dep(self, copy, default=True)
+ epochs = _check_copy_dep(self, copy)
if len(event_ids) == 0:
raise ValueError('event_ids must have at least one element')
if not epochs._bad_dropped:
@@ -1839,12 +1523,7 @@ class _BaseEpochs(ProjMixin, ContainsMixin, UpdateChannelsMixin,
"orthogonal selection.")
for eq in event_ids:
- eq = np.atleast_1d(eq)
- # eq is now a list of types
- key_match = np.zeros(epochs.events.shape[0])
- for key in eq:
- key_match = np.logical_or(key_match, epochs._key_match(key))
- eq_inds.append(np.where(key_match)[0])
+ eq_inds.append(np.where(epochs._keys_to_idx(eq))[0])
event_times = [epochs.events[e, 0] for e in eq_inds]
indices = _get_drop_indices(event_times, method)
@@ -1855,6 +1534,65 @@ class _BaseEpochs(ProjMixin, ContainsMixin, UpdateChannelsMixin,
return epochs, indices
+def _hid_match(event_id, keys):
+ """Match event IDs using HID selection.
+
+ Parameters
+ ----------
+ event_id : dict
+ The event ID dictionary.
+ keys : list | str
+ The event ID or subset (for HID), or list of such items.
+
+ Returns
+ -------
+ use_keys : list
+ The full keys that fit the selection criteria.
+ """
+ # form the hierarchical event ID mapping
+ keys = [keys] if not isinstance(keys, (list, tuple)) else keys
+ use_keys = []
+ for key in keys:
+ if not isinstance(key, string_types):
+ raise KeyError('keys must be strings, got %s (%s)'
+ % (type(key), key))
+ use_keys.extend(k for k in event_id.keys()
+ if set(key.split('/')).issubset(k.split('/')))
+ if len(use_keys) == 0:
+ raise KeyError('Event "%s" is not in Epochs.' % key)
+ use_keys = list(set(use_keys)) # deduplicate if necessary
+ return use_keys
+
+
+def _check_baseline(baseline, tmin, tmax, sfreq):
+ """Helper to check for a valid baseline"""
+ if baseline is not None:
+ if not isinstance(baseline, tuple) or len(baseline) != 2:
+ raise ValueError('`baseline=%s` is an invalid argument.'
+ % str(baseline))
+ baseline_tmin, baseline_tmax = baseline
+ tstep = 1. / float(sfreq)
+ if baseline_tmin is None:
+ baseline_tmin = tmin
+ baseline_tmin = float(baseline_tmin)
+ if baseline_tmax is None:
+ baseline_tmax = tmax
+ baseline_tmax = float(baseline_tmax)
+ if baseline_tmin < tmin - tstep:
+ raise ValueError(
+ "Baseline interval (tmin = %s) is outside of epoch "
+ "data (tmin = %s)" % (baseline_tmin, tmin))
+ if baseline_tmax > tmax + tstep:
+ raise ValueError(
+ "Baseline interval (tmax = %s) is outside of epoch "
+ "data (tmax = %s)" % (baseline_tmax, tmax))
+ if baseline_tmin > baseline_tmax:
+ raise ValueError(
+ "Baseline min (%s) must be less than baseline max (%s)"
+ % (baseline_tmin, baseline_tmax))
+ del baseline_tmin, baseline_tmax
+
+
def _drop_log_stats(drop_log, ignore=('IGNORED',)):
"""
Parameters
@@ -1876,6 +1614,24 @@ def _drop_log_stats(drop_log, ignore=('IGNORED',)):
return perc
+def _dep_eeg_ref(add_eeg_ref, current_default=True):
+ """Helper for deprecation add_eeg_ref -> False"""
+ if current_default is True:
+ if add_eeg_ref is None:
+ add_eeg_ref = True
+ warn('add_eeg_ref defaults to True in 0.13, will default to '
+ 'False in 0.14, and will be removed in 0.15. We recommend '
+ 'to use add_eeg_ref=False and set_eeg_reference() instead.',
+ DeprecationWarning)
+ # current_default is False
+ elif add_eeg_ref is None:
+ add_eeg_ref = False
+ else:
+ warn('add_eeg_ref will be removed in 0.14, use set_eeg_reference()'
+ ' instead', DeprecationWarning)
+ return add_eeg_ref
+
+
class Epochs(_BaseEpochs):
"""Epochs extracted from a Raw instance
@@ -1900,15 +1656,14 @@ class Epochs(_BaseEpochs):
tmax : float
End time after event. If nothing is provided, defaults to 0.5
baseline : None or tuple of length 2 (default (None, 0))
- The time interval to apply baseline correction.
- If None do not apply it. If baseline is (a, b)
- the interval is between "a (s)" and "b (s)".
- If a is None the beginning of the data is used
- and if b is None then b is set to the end of the interval.
- If baseline is equal to (None, None) all the time
- interval is used.
- The baseline (a, b) includes both endpoints, i.e. all
- timepoints t such that a <= t <= b.
+ The time interval to apply baseline correction. If None do not apply
+ it. If baseline is (a, b) the interval is between "a (s)" and "b (s)".
+ If a is None the beginning of the data is used and if b is None then b
+ is set to the end of the interval. If baseline is equal to (None, None)
+ all the time interval is used. Correction is applied by computing mean
+ of the baseline period and subtracting it from the data. The baseline
+ (a, b) includes both endpoints, i.e. all timepoints t such that
+ a <= t <= b.
picks : array-like of int | None (default)
Indices of channels to include (if None, all channels are used).
name : string
@@ -1964,7 +1719,9 @@ class Epochs(_BaseEpochs):
(will yield equivalent results but be slower).
add_eeg_ref : bool
If True, an EEG average reference will be added (unless one
- already exists).
+ already exists). The default value of True in 0.13 will change to
+ False in 0.14, and the parameter will be removed in 0.15. Use
+ :func:`mne.set_eeg_reference` instead.
on_missing : str
What to do if one or several event ids are not found in the recording.
Valid keys are 'error' | 'warning' | 'ignore'
@@ -2008,48 +1765,24 @@ class Epochs(_BaseEpochs):
verbose : bool, str, int, or None
See above.
+ See Also
+ --------
+ mne.epochs.combine_event_ids
+ mne.Epochs.equalize_event_counts
+
Notes
-----
When accessing data, Epochs are detrended, baseline-corrected, and
decimated, then projectors are (optionally) applied.
- For indexing and slicing:
-
- epochs[idx] : Epochs
- Return Epochs object with a subset of epochs (supports single
- index and python-style slicing)
-
- For subset selection using categorial labels:
-
- epochs['name'] : Epochs
- Return Epochs object with a subset of epochs corresponding to an
- experimental condition as specified by 'name'.
-
- If conditions are tagged by names separated by '/' (e.g. 'audio/left',
- 'audio/right'), and 'name' is not in itself an event key, this selects
- every event whose condition contains the 'name' tag (e.g., 'left'
- matches 'audio/left' and 'visual/left'; but not 'audio_left'). Note
- that tags like 'auditory/left' and 'left/auditory' will be treated the
- same way when accessed using tags.
-
- epochs[['name_1', 'name_2', ... ]] : Epochs
- Return Epochs object with a subset of epochs corresponding to multiple
- experimental conditions as specified by 'name_1', 'name_2', ... .
-
- If conditions are separated by '/', selects every item containing every
- list tag (e.g. ['audio', 'left'] selects 'audio/left' and
- 'audio/center/left', but not 'audio/right').
-
- See Also
- --------
- mne.epochs.combine_event_ids
- mne.Epochs.equalize_event_counts
+ For indexing and slicing using ``epochs[...]``, see
+ :func:`mne.Epochs.__getitem__`.
"""
@verbose
def __init__(self, raw, events, event_id=None, tmin=-0.2, tmax=0.5,
baseline=(None, 0), picks=None, name='Unknown', preload=False,
reject=None, flat=None, proj=True, decim=1, reject_tmin=None,
- reject_tmax=None, detrend=None, add_eeg_ref=True,
+ reject_tmax=None, detrend=None, add_eeg_ref=None,
on_missing='error', reject_by_annotation=True, verbose=None):
if not isinstance(raw, _BaseRaw):
raise ValueError('The first argument to `Epochs` must be an '
@@ -2087,7 +1820,8 @@ class Epochs(_BaseEpochs):
event_samp = self.events[idx, 0]
# Read a data segment
first_samp = self._raw.first_samp
- start = int(round(event_samp + self.tmin * sfreq)) - first_samp
+ start = int(round(event_samp + self._raw_times[0] * sfreq))
+ start -= first_samp
stop = start + len(self._raw_times)
data = self._raw._check_bad_segment(start, stop, self.picks,
self.reject_by_annotation)
@@ -2142,14 +1876,15 @@ class EpochsArray(_BaseEpochs):
reject_tmax : scalar | None
End of the time window used to reject epochs (with the default None,
the window will end with tmax).
- baseline : None or tuple of length 2 (default: None)
- The time interval to apply baseline correction.
- If None do not apply it. If baseline is (a, b)
- the interval is between "a (s)" and "b (s)".
- If a is None the beginning of the data is used
- and if b is None then b is set to the end of the interval.
- If baseline is equal to (None, None) all the time
- interval is used.
+ baseline : None or tuple of length 2 (default None)
+ The time interval to apply baseline correction. If None do not apply
+ it. If baseline is (a, b) the interval is between "a (s)" and "b (s)".
+ If a is None the beginning of the data is used and if b is None then b
+ is set to the end of the interval. If baseline is equal to (None, None)
+ all the time interval is used. Correction is applied by computing mean
+ of the baseline period and subtracting it from the data. The baseline
+ (a, b) includes both endpoints, i.e. all timepoints t such that
+ a <= t <= b.
proj : bool | 'delayed'
Apply SSP projection vectors. See :class:`mne.Epochs` for details.
verbose : bool, str, int, or None
@@ -2189,9 +1924,9 @@ class EpochsArray(_BaseEpochs):
tmax, baseline, reject=reject,
flat=flat, reject_tmin=reject_tmin,
reject_tmax=reject_tmax, decim=1,
- add_eeg_ref=False, proj=proj)
- if len(events) != in1d(self.events[:, 2],
- list(self.event_id.values())).sum():
+ proj=proj)
+ if len(events) != np.in1d(self.events[:, 2],
+ list(self.event_id.values())).sum():
raise ValueError('The events must only contain event numbers from '
'event_id')
for ii, e in enumerate(self._data):
@@ -2509,14 +2244,14 @@ def _read_one_epoch_file(f, tree, fname, preload):
if selection is None:
selection = np.arange(len(events))
if drop_log is None:
- drop_log = [[] for _ in range(len(epochs))] # noqa, analysis:ignore
+ drop_log = [[] for _ in range(len(events))]
return (info, data, data_tag, events, event_id, tmin, tmax, baseline, name,
selection, drop_log, epoch_shape, cals)
@verbose
-def read_epochs(fname, proj=True, add_eeg_ref=False, preload=True,
+def read_epochs(fname, proj=True, add_eeg_ref=None, preload=True,
verbose=None):
"""Read epochs from a fif file
@@ -2537,7 +2272,8 @@ def read_epochs(fname, proj=True, add_eeg_ref=False, preload=True,
recommended value if SSPs are not used for cleaning the data.
add_eeg_ref : bool
If True, an EEG average reference will be added (unless one
- already exists).
+ already exists). This parameter is deprecated and will be
+ removed in 0.14, use :func:`mne.set_eeg_reference` instead.
preload : bool
If True, read all epochs from disk immediately. If False, epochs will
be read on demand.
@@ -2550,6 +2286,7 @@ def read_epochs(fname, proj=True, add_eeg_ref=False, preload=True,
epochs : instance of Epochs
The epochs
"""
+ add_eeg_ref = _dep_eeg_ref(add_eeg_ref, False)
return EpochsFIF(fname, proj, add_eeg_ref, preload, verbose)
@@ -2586,7 +2323,9 @@ class EpochsFIF(_BaseEpochs):
recommended value if SSPs are not used for cleaning the data.
add_eeg_ref : bool
If True, an EEG average reference will be added (unless one
- already exists).
+ already exists). The default value of True in 0.13 will change to
+ False in 0.14, and the parameter will be removed in 0.15. Use
+ :func:`mne.set_eeg_reference` instead.
preload : bool
If True, read all epochs from disk immediately. If False, epochs will
be read on demand.
@@ -2601,10 +2340,9 @@ class EpochsFIF(_BaseEpochs):
mne.Epochs.equalize_event_counts
"""
@verbose
- def __init__(self, fname, proj=True, add_eeg_ref=True, preload=True,
+ def __init__(self, fname, proj=True, add_eeg_ref=None, preload=True,
verbose=None):
check_fname(fname, 'epochs', ('-epo.fif', '-epo.fif.gz'))
-
fnames = [fname]
ep_list = list()
raw = list()
@@ -2620,7 +2358,7 @@ class EpochsFIF(_BaseEpochs):
epoch = _BaseEpochs(
info, data, events, event_id, tmin, tmax, baseline,
on_missing='ignore', selection=selection, drop_log=drop_log,
- add_eeg_ref=False, proj=False, verbose=False)
+ proj=False, verbose=False)
ep_list.append(epoch)
if not preload:
# store everything we need to index back to the original data
@@ -2731,7 +2469,7 @@ def _check_merge_epochs(epochs_list):
@verbose
-def add_channels_epochs(epochs_list, name='Unknown', add_eeg_ref=True,
+def add_channels_epochs(epochs_list, name='Unknown', add_eeg_ref=None,
verbose=None):
"""Concatenate channels, info and data from two Epochs objects
@@ -2742,8 +2480,10 @@ def add_channels_epochs(epochs_list, name='Unknown', add_eeg_ref=True,
name : str
Comment that describes the Epochs data created.
add_eeg_ref : bool
- If True, an EEG average reference will be added (unless there is no
- EEG in the data).
+ If True, an EEG average reference will be added (unless there is
+ no EEG in the data). The default value of True in 0.13 will change to
+ False in 0.14, and the parameter will be removed in 0.15. Use
+ :func:`mne.set_eeg_reference` instead.
verbose : bool, str, int, or None
If not None, override default verbose level (see mne.verbose).
Defaults to True if any of the input epochs have verbose=True.
@@ -2753,6 +2493,7 @@ def add_channels_epochs(epochs_list, name='Unknown', add_eeg_ref=True,
epochs : instance of Epochs
Concatenated epochs.
"""
+ add_eeg_ref = _dep_eeg_ref(add_eeg_ref)
if not all(e.preload for e in epochs_list):
raise ValueError('All epochs must be preloaded.')
@@ -2811,10 +2552,29 @@ def _compare_epochs_infos(info1, info2, ind):
if any(not _proj_equal(p1, p2) for p1, p2 in
zip(info2['projs'], info1['projs'])):
raise ValueError('SSP projectors in epochs files must be the same')
+ if (info1['dev_head_t'] is None) != (info2['dev_head_t'] is None) or \
+ (info1['dev_head_t'] is not None and not
+ np.allclose(info1['dev_head_t']['trans'],
+ info2['dev_head_t']['trans'], rtol=1e-6)):
+ raise ValueError('epochs[%d][\'info\'][\'dev_head_t\'] must match. '
+ 'The epochs probably come from different runs, and '
+ 'are therefore associated with different head '
+ 'positions. Manually change info[\'dev_head_t\'] to '
+ 'avoid this message but beware that this means the '
+ 'MEG sensors will not be properly spatially aligned. '
+ 'See mne.preprocessing.maxwell_filter to realign the '
+ 'runs to a common head position.' % ind)
def _concatenate_epochs(epochs_list, with_data=True):
"""Auxiliary function for concatenating epochs."""
+ if not isinstance(epochs_list, (list, tuple)):
+ raise TypeError('epochs_list must be a list or tuple, got %s'
+ % (type(epochs_list),))
+ for ei, epochs in enumerate(epochs_list):
+ if not isinstance(epochs, _BaseEpochs):
+ raise TypeError('epochs_list[%d] must be an instance of Epochs, '
+ 'got %s' % (ei, type(epochs)))
out = epochs_list[0]
data = [out.get_data()] if with_data else None
events = [out.events]
@@ -2826,7 +2586,7 @@ def _concatenate_epochs(epochs_list, with_data=True):
selection = out.selection
for ii, epochs in enumerate(epochs_list[1:]):
_compare_epochs_infos(epochs.info, info, ii)
- if not np.array_equal(epochs.times, epochs_list[0].times):
+ if not np.allclose(epochs.times, epochs_list[0].times):
raise ValueError('Epochs must have same times')
if epochs.baseline != baseline:
@@ -2850,10 +2610,10 @@ def _finish_concat(info, data, events, event_id, tmin, tmax, baseline,
"""Helper to finish concatenation for epochs not read from disk"""
events[:, 0] = np.arange(len(events)) # arbitrary after concat
selection = np.where([len(d) == 0 for d in drop_log])[0]
- out = _BaseEpochs(info, data, events, event_id, tmin, tmax,
- baseline=baseline, add_eeg_ref=False,
- selection=selection, drop_log=drop_log,
- proj=False, on_missing='ignore', verbose=verbose)
+ out = _BaseEpochs(
+ info, data, events, event_id, tmin, tmax, baseline=baseline,
+ selection=selection, drop_log=drop_log, proj=False,
+ on_missing='ignore', verbose=verbose)
out.drop_bad()
return out
@@ -2882,7 +2642,7 @@ def concatenate_epochs(epochs_list):
def average_movements(epochs, head_pos=None, orig_sfreq=None, picks=None,
origin='auto', weight_all=True, int_order=8, ext_order=3,
destination=None, ignore_ref=False, return_mapping=False,
- pos=None, verbose=None):
+ mag_scale=100., verbose=None):
"""Average data using Maxwell filtering, transforming using head positions
Parameters
@@ -2899,7 +2659,7 @@ def average_movements(epochs, head_pos=None, orig_sfreq=None, picks=None,
event sample numbers in ``epochs.events``). Can be ``None``
if data have not been decimated or resampled.
picks : array-like of int | None
- If None only MEG, EEG, SEEG, and ECoG channels are kept
+ If None only MEG, EEG, SEEG, ECoG, and fNIRS channels are kept
otherwise the channels indices in picks are kept.
origin : array-like, shape (3,) | str
Origin of internal and external multipolar moment space in head
@@ -2934,6 +2694,16 @@ def average_movements(epochs, head_pos=None, orig_sfreq=None, picks=None,
with reference channels is not currently supported.
return_mapping : bool
If True, return the mapping matrix.
+ mag_scale : float | str
+ The magenetometer scale-factor used to bring the magnetometers
+ to approximately the same order of magnitude as the gradiometers
+ (default 100.), as they have different units (T vs T/m).
+ Can be ``'auto'`` to use the reciprocal of the physical distance
+ between the gradiometer pickup loops (e.g., 0.0168 m yields
+ 59.5 for VectorView).
+
+ .. versionadded:: 0.13
+
verbose : bool, str, int, or None
If not None, override default verbose level (see mne.verbose).
@@ -2974,11 +2744,7 @@ def average_movements(epochs, head_pos=None, orig_sfreq=None, picks=None,
_check_usable, _col_norm_pinv,
_get_n_moments, _get_mf_picks,
_prep_mf_coils, _check_destination,
- _remove_meg_projs)
- if pos is not None:
- head_pos = pos
- warn('pos has been replaced by head_pos and will be removed in 0.13',
- DeprecationWarning)
+ _remove_meg_projs, _get_coil_scale)
if head_pos is None:
raise TypeError('head_pos must be provided and cannot be None')
from .chpi import head_pos_to_trans_rot_t
@@ -2999,8 +2765,10 @@ def average_movements(epochs, head_pos=None, orig_sfreq=None, picks=None,
% (len(epochs.events)))
if not np.array_equal(epochs.events[:, 0], np.unique(epochs.events[:, 0])):
raise RuntimeError('Epochs must have monotonically increasing events')
- meg_picks, _, _, good_picks, coil_scale, _ = \
+ meg_picks, mag_picks, grad_picks, good_picks, _ = \
_get_mf_picks(epochs.info, int_order, ext_order, ignore_ref)
+ coil_scale, mag_scale = _get_coil_scale(
+ meg_picks, mag_picks, grad_picks, mag_scale, epochs.info)
n_channels, n_times = len(epochs.ch_names), len(epochs.times)
other_picks = np.setdiff1d(np.arange(n_channels), meg_picks)
data = np.zeros((n_channels, n_times))
@@ -3085,3 +2853,29 @@ def average_movements(epochs, head_pos=None, orig_sfreq=None, picks=None,
_remove_meg_projs(evoked) # remove MEG projectors, they won't apply now
logger.info('Created Evoked dataset from %s epochs' % (count,))
return (evoked, mapping) if return_mapping else evoked
+
+
+ at verbose
+def _segment_raw(raw, segment_length=1., verbose=None, **kwargs):
+ """Divide continuous raw data into equal-sized
+ consecutive epochs.
+
+ Parameters
+ ----------
+ raw : instance of Raw
+ Raw data to divide into segments.
+ segment_length : float
+ Length of each segment in seconds. Defaults to 1.
+ verbose: bool
+ Whether to report what is being done by printing text.
+ **kwargs
+ Any additional keyword arguments are passed to ``Epochs`` constructor.
+
+ Returns
+ -------
+ epochs : instance of ``Epochs``
+ Segmented data.
+ """
+ events = make_fixed_length_events(raw, 1, duration=segment_length)
+ return Epochs(raw, events, event_id=[1], tmin=0., tmax=segment_length,
+ verbose=verbose, baseline=None, add_eeg_ref=False, **kwargs)
diff --git a/mne/event.py b/mne/event.py
index 6683e15..e2084d5 100644
--- a/mne/event.py
+++ b/mne/event.py
@@ -11,6 +11,7 @@
import numpy as np
from os.path import splitext
+
from .utils import check_fname, logger, verbose, _get_stim_channel, warn
from .io.constants import FIFF
from .io.tree import dir_tree_find
@@ -188,7 +189,8 @@ def _read_events_fif(fid, tree):
return event_list, mappings
-def read_events(filename, include=None, exclude=None, mask=0):
+def read_events(filename, include=None, exclude=None, mask=None,
+ mask_type=None):
"""Reads events from fif or text file
Parameters
@@ -209,7 +211,12 @@ def read_events(filename, include=None, exclude=None, mask=0):
the exclude parameter is ignored.
mask : int | None
The value of the digital mask to apply to the stim channel values.
- The default value is 0. ``None`` skips masking.
+ If None (default), no masking is performed.
+ mask_type: 'and' | 'not_and'
+ The type of operation between the mask and the trigger.
+ Choose 'and' for MNE-C masking behavior.
+
+ .. versionadded:: 0.13
Returns
-------
@@ -225,9 +232,8 @@ def read_events(filename, include=None, exclude=None, mask=0):
This function will discard the offset line (i.e., first line with zero
event number) if it is present in a text file.
- Working with downsampled data: Events that were computed before the data
- was decimated are no longer valid. Please recompute your events after
- decimation.
+ For more information on ``mask`` and ``mask_type``, see
+ :func:`mne.find_events`.
"""
check_fname(filename, 'events', ('.eve', '-eve.fif', '-eve.fif.gz',
'-eve.lst', '-eve.txt'))
@@ -265,7 +271,7 @@ def read_events(filename, include=None, exclude=None, mask=0):
event_list = pick_events(event_list, include, exclude)
unmasked_len = event_list.shape[0]
if mask is not None:
- event_list = _mask_trigs(event_list, mask)
+ event_list = _mask_trigs(event_list, mask, mask_type)
masked_len = event_list.shape[0]
if masked_len < unmasked_len:
warn('{0} of {1} events masked'.format(unmasked_len - masked_len,
@@ -382,7 +388,7 @@ def find_stim_steps(raw, pad_start=None, pad_stop=None, merge=0,
affected by the trigger. If None, the config variables
'MNE_STIM_CHANNEL', 'MNE_STIM_CHANNEL_1', 'MNE_STIM_CHANNEL_2',
etc. are read. If these are not found, it will default to
- 'STI 014'.
+ 'STI101' or 'STI 014', whichever is present.
Returns
-------
@@ -415,7 +421,7 @@ def find_stim_steps(raw, pad_start=None, pad_stop=None, merge=0,
def _find_events(data, first_samp, verbose=None, output='onset',
consecutive='increasing', min_samples=0, mask=0,
- uint_cast=False):
+ uint_cast=False, mask_type=None):
"""Helper function for find events"""
if min_samples > 0:
merge = int(min_samples // 1)
@@ -435,7 +441,7 @@ def _find_events(data, first_samp, verbose=None, output='onset',
data = np.abs(data) # make sure trig channel is positive
events = _find_stim_steps(data, first_samp, pad_stop=0, merge=merge)
- events = _mask_trigs(events, mask)
+ events = _mask_trigs(events, mask, mask_type)
# Determine event onsets and offsets
if consecutive == 'increasing':
@@ -487,7 +493,8 @@ def _find_events(data, first_samp, verbose=None, output='onset',
@verbose
def find_events(raw, stim_channel=None, output='onset',
consecutive='increasing', min_duration=0,
- shortest_event=2, mask=0, uint_cast=False, verbose=None):
+ shortest_event=2, mask=None, uint_cast=False,
+ mask_type=None, verbose=None):
"""Find events from raw file
Parameters
@@ -516,9 +523,9 @@ def find_events(raw, stim_channel=None, output='onset',
shortest_event : int
Minimum number of samples an event must last (default is 2). If the
duration is less than this an exception will be raised.
- mask : int
+ mask : int | None
The value of the digital mask to apply to the stim channel values.
- The default value is 0.
+ If None (default), no masking is performed.
uint_cast : bool
If True (default False), do a cast to ``uint16`` on the channel
data. This can be used to fix a bug with STI101 and STI014 in
@@ -528,6 +535,12 @@ def find_events(raw, stim_channel=None, output='onset',
.. versionadded:: 0.12
+ mask_type: 'and' | 'not_and'
+ The type of operation between the mask and the trigger.
+ Choose 'and' for MNE-C masking behavior.
+
+ .. versionadded:: 0.13
+
verbose : bool, str, int, or None
If not None, override default verbose level (see mne.verbose).
@@ -541,9 +554,24 @@ def find_events(raw, stim_channel=None, output='onset',
the second column contains the value of the stim channel after the
event offset.
+ See Also
+ --------
+ find_stim_steps : Find all the steps in the stim channel.
+ read_events : Read events from disk.
+ write_events : Write events to disk.
+
+ Notes
+ -----
+ .. warning:: If you are working with downsampled data, events computed
+ before decimation are no longer valid. Please recompute
+ your events after decimation, but note this reduces the
+ precision of event timing.
+
Examples
--------
- Consider data with a stim channel that looks like: [0, 32, 32, 33, 32, 0]
+ Consider data with a stim channel that looks like::
+
+ [0, 32, 32, 33, 32, 0]
By default, find_events returns all samples at which the value of the
stim channel increases::
@@ -593,18 +621,24 @@ def find_events(raw, stim_channel=None, output='onset',
... min_duration=0.002))
[[ 1 0 32]]
- For the digital mask, it will take the binary representation of the
- digital mask, e.g. 5 -> '00000101', and will block the values
- where mask is one, e.g.::
+ For the digital mask, if mask_type is set to 'and' it will take the
+ binary representation of the digital mask, e.g. 5 -> '00000101', and will
+ allow the values to pass where mask is one, e.g.::
+
+ 7 '0000111' <- trigger value
+ 37 '0100101' <- mask
+ ----------------
+ 5 '0000101'
+
+ For the digital mask, if mask_type is set to 'not_and' it will take the
+ binary representation of the digital mask, e.g. 5 -> '00000101', and will
+ block the values where mask is one, e.g.::
7 '0000111' <- trigger value
37 '0100101' <- mask
----------------
2 '0000010'
- See Also
- --------
- find_stim_steps : Find all the steps in the stim channel.
"""
min_samples = min_duration * raw.info['sfreq']
@@ -618,7 +652,7 @@ def find_events(raw, stim_channel=None, output='onset',
events = _find_events(data, raw.first_samp, verbose=verbose, output=output,
consecutive=consecutive, min_samples=min_samples,
- mask=mask, uint_cast=uint_cast)
+ mask=mask, uint_cast=uint_cast, mask_type=mask_type)
# add safety check for spurious events (for ex. from neuromag syst.) by
# checking the number of low sample events
@@ -633,17 +667,28 @@ def find_events(raw, stim_channel=None, output='onset',
return events
-def _mask_trigs(events, mask):
+def _mask_trigs(events, mask, mask_type):
"""Helper function for masking digital trigger values"""
- if not isinstance(mask, int):
- raise TypeError('You provided a(n) %s. Mask must be an int.'
- % type(mask))
+ if mask is not None:
+ if not isinstance(mask, int):
+ raise TypeError('You provided a(n) %s.' % type(mask) +
+ 'Mask must be an int or None.')
n_events = len(events)
if n_events == 0:
return events.copy()
- mask = np.bitwise_not(mask)
- events[:, 1:] = np.bitwise_and(events[:, 1:], mask)
+ if mask is not None:
+ if mask_type is None:
+ warn("The default setting for mask_type will change from "
+ "'not and' to 'and' in v0.14.", DeprecationWarning)
+ mask_type = 'not_and'
+ if mask_type == 'not_and':
+ mask = np.bitwise_not(mask)
+ elif mask_type != 'and':
+ if mask_type is not None:
+ raise ValueError("'mask_type' should be either 'and'"
+ " or 'not_and', instead of '%s'" % mask_type)
+ events[:, 1:] = np.bitwise_and(events[:, 1:], mask)
events = events[events[:, 1] != events[:, 2]]
return events
@@ -765,6 +810,15 @@ def make_fixed_length_events(raw, id, start=0, stop=None, duration=1.,
new_events : array
The new events.
"""
+ from .io.base import _BaseRaw
+ if not isinstance(raw, _BaseRaw):
+ raise ValueError('Input data must be an instance of Raw, got'
+ ' %s instead.' % (type(raw)))
+ if not isinstance(id, int):
+ raise ValueError('id must be an integer')
+ if not isinstance(duration, (int, float)):
+ raise ValueError('duration must be an integer of a float, '
+ 'got %s instead.' % (type(duration)))
start = raw.time_as_index(start)[0]
if stop is not None:
stop = raw.time_as_index(stop)[0]
@@ -775,8 +829,6 @@ def make_fixed_length_events(raw, id, start=0, stop=None, duration=1.,
stop = min([stop + raw.first_samp, raw.last_samp + 1])
else:
stop = min([stop, len(raw.times)])
- if not isinstance(id, int):
- raise ValueError('id must be an integer')
# Make sure we don't go out the end of the file:
stop -= int(np.ceil(raw.info['sfreq'] * duration))
# This should be inclusive due to how we generally use start and stop...
@@ -834,3 +886,457 @@ def concatenate_events(events, first_samps, last_samps):
events_out = np.concatenate((events_out, e2), axis=0)
return events_out
+
+
+class AcqParserFIF(object):
+ """ Parser for Elekta data acquisition settings.
+
+ This class parses parameters (e.g. events and averaging categories) that
+ are defined in the Elekta TRIUX/VectorView data acquisition software (DACQ)
+ and stored in ``info['acq_pars']``. It can be used to reaverage raw data
+ according to DACQ settings and modify original averaging settings if
+ necessary.
+
+ Parameters
+ ----------
+ info : Info
+ An instance of Info where the DACQ parameters will be taken from.
+
+ Attributes
+ ----------
+ categories : list
+ List of averaging categories marked active in DACQ.
+ events : list
+ List of events that are in use (referenced by some averaging category).
+ reject : dict
+ Rejection criteria from DACQ that can be used with mne.Epochs.
+ Note that mne does not support all DACQ rejection criteria
+ (e.g. spike, slope).
+ flat : dict
+ Flatness rejection criteria from DACQ that can be used with mne.Epochs.
+ acq_dict : dict
+ All DACQ parameters.
+
+ Notes
+ -----
+ Any averaging category (also non-active ones) can be accessed by indexing
+ as ``acqparserfif['category_name']``.
+ """
+
+ # DACQ variables always start with one of these
+ _acq_var_magic = ['ERF', 'DEF', 'ACQ', 'TCP']
+
+ # averager related DACQ variable names (without preceding 'ERF')
+ # old versions (DACQ < 3.4)
+ _dacq_vars_compat = ('megMax', 'megMin', 'megNoise', 'megSlope',
+ 'megSpike', 'eegMax', 'eegMin', 'eegNoise',
+ 'eegSlope', 'eegSpike', 'eogMax', 'ecgMax', 'ncateg',
+ 'nevent', 'stimSource', 'triggerMap', 'update',
+ 'artefIgnore', 'averUpdate')
+
+ _event_vars_compat = ('Comment', 'Delay')
+
+ _cat_vars = ('Comment', 'Display', 'Start', 'State', 'End', 'Event',
+ 'Nave', 'ReqEvent', 'ReqWhen', 'ReqWithin', 'SubAve')
+
+ # new versions only (DACQ >= 3.4)
+ _dacq_vars = _dacq_vars_compat + ('magMax', 'magMin', 'magNoise',
+ 'magSlope', 'magSpike', 'version')
+
+ _event_vars = _event_vars_compat + ('Name', 'Channel', 'NewBits',
+ 'OldBits', 'NewMask', 'OldMask')
+
+ def __init__(self, info):
+ acq_pars = info['acq_pars']
+ if not acq_pars:
+ raise ValueError('No acquisition parameters')
+ self.acq_dict = self._acqpars_dict(acq_pars)
+ if 'ERFversion' in self.acq_dict:
+ self.compat = False # DACQ ver >= 3.4
+ elif 'ERFncateg' in self.acq_dict: # probably DACQ < 3.4
+ self.compat = True
+ else:
+ raise ValueError('Cannot parse acquisition parameters')
+ dacq_vars = self._dacq_vars_compat if self.compat else self._dacq_vars
+ # set instance variables
+ for var in dacq_vars:
+ val = self.acq_dict['ERF' + var]
+ if var[:3] in ['mag', 'meg', 'eeg', 'eog', 'ecg']:
+ val = float(val)
+ elif var in ['ncateg', 'nevent']:
+ val = int(val)
+ setattr(self, var.lower(), val)
+ self.stimsource = (
+ 'Internal' if self.stimsource == '1' else 'External')
+ # collect all events and categories
+ self._events = self._events_from_acq_pars()
+ self._categories = self._categories_from_acq_pars()
+ # mark events that are used by a category
+ for cat in self._categories.values():
+ if cat['event']:
+ self._events[cat['event']]['in_use'] = True
+ if cat['reqevent']:
+ self._events[cat['reqevent']]['in_use'] = True
+ # make mne rejection dicts based on the averager parameters
+ self.reject = {'grad': self.megmax, 'eeg': self.eegmax,
+ 'eog': self.eogmax, 'ecg': self.ecgmax}
+ if not self.compat:
+ self.reject['mag'] = self.magmax
+ self.reject = {k: float(v) for k, v in self.reject.items()
+ if float(v) > 0}
+ self.flat = {'grad': self.megmin, 'eeg': self.eegmin}
+ if not self.compat:
+ self.flat['mag'] = self.magmin
+ self.flat = {k: float(v) for k, v in self.flat.items()
+ if float(v) > 0}
+
+ def __repr__(self):
+ s = '<AcqParserFIF | '
+ s += 'categories: %d ' % self.ncateg
+ cats_in_use = len(self._categories_in_use)
+ s += '(%d in use), ' % cats_in_use
+ s += 'events: %d ' % self.nevent
+ evs_in_use = len(self._events_in_use)
+ s += '(%d in use)' % evs_in_use
+ if self.categories:
+ s += '\nAveraging categories:'
+ for cat in self.categories:
+ s += '\n%d: "%s"' % (cat['index'], cat['comment'])
+ s += '>'
+ return s
+
+ def __getitem__(self, item):
+ """ Return an averaging category, or list of categories.
+
+ Parameters
+ ----------
+ item : str or list of str
+ Name of the category (comment field in DACQ).
+
+ Returns
+ -------
+ conds : dict or list of dict, each with following keys:
+ comment: str
+ The comment field in DACQ.
+ state : bool
+ Whether the category was marked enabled in DACQ.
+ index : int
+ The index of the category in DACQ. Indices start from 1.
+ event : int
+ DACQ index of the reference event (trigger event, zero time for
+ the corresponding epochs). Note that the event indices start
+ from 1.
+ start : float
+ Start time of epoch relative to the reference event.
+ end : float
+ End time of epoch relative to the reference event.
+ reqevent : int
+ Index of the required (conditional) event.
+ reqwhen : int
+ Whether the required event is required before (1) or after (2)
+ the reference event.
+ reqwithin : float
+ The time range within which the required event must occur,
+ before or after the reference event.
+ display : bool
+ Whether the category was displayed online in DACQ.
+ nave : int
+ Desired number of averages. DACQ stops collecting averages once
+ this number is reached.
+ subave : int
+ Whether to compute normal and alternating subaverages, and
+ how many epochs to include. See the Elekta data acquisition
+ manual for details. Currently the class does not offer any
+ facility for computing subaverages, but it can be done manually
+ by the user after collecting the epochs.
+
+ """
+ if isinstance(item, str):
+ item = [item]
+ elif not isinstance(item, list):
+ raise ValueError('Keys must be category names')
+ cats = list()
+ for it in item:
+ if it in self._categories:
+ cats.append(self._categories[it])
+ else:
+ raise KeyError('No such category')
+ return cats[0] if len(cats) == 1 else cats
+
+ def __len__(self):
+ """ Return number of averaging categories marked active in DACQ. """
+ return len(self.categories)
+
+ def _events_from_acq_pars(self):
+ """ Collect DACQ events into a dict.
+
+ Events are keyed by number starting from 1 (DACQ index of event).
+ Each event is itself represented by a dict containing the event
+ parameters. """
+ # lookup table for event number -> bits for old DACQ versions
+ _compat_event_lookup = {1: 1, 2: 2, 3: 4, 4: 8, 5: 16, 6: 32, 7: 3,
+ 8: 5, 9: 6, 10: 7, 11: 9, 12: 10, 13: 11,
+ 14: 12, 15: 13, 16: 14, 17: 15}
+ events = dict()
+ for evnum in range(1, self.nevent + 1):
+ evnum_s = str(evnum).zfill(2) # '01', '02' etc.
+ evdi = dict()
+ event_vars = (self._event_vars_compat if self.compat
+ else self._event_vars)
+ for var in event_vars:
+ # name of DACQ variable, e.g. 'ERFeventNewBits01'
+ acq_key = 'ERFevent' + var + evnum_s
+ # corresponding dict key, e.g. 'newbits'
+ dict_key = var.lower()
+ val = self.acq_dict[acq_key]
+ # type convert numeric values
+ if dict_key in ['newbits', 'oldbits', 'newmask', 'oldmask']:
+ val = int(val)
+ elif dict_key in ['delay']:
+ val = float(val)
+ evdi[dict_key] = val
+ evdi['in_use'] = False # __init__() will set this
+ evdi['index'] = evnum
+ if self.compat:
+ evdi['name'] = str(evnum)
+ evdi['oldmask'] = 63
+ evdi['newmask'] = 63
+ evdi['oldbits'] = 0
+ evdi['newbits'] = _compat_event_lookup[evnum]
+ events[evnum] = evdi
+ return events
+
+ def _acqpars_dict(self, acq_pars):
+ """ Parse `` info['acq_pars']`` into a dict. """
+ return dict(self._acqpars_gen(acq_pars))
+
+ def _acqpars_gen(self, acq_pars):
+ """ Yields key/value pairs from ``info['acq_pars'])``. """
+ # DACQ variable names always start with one of these
+ key, val = '', ''
+ for line in acq_pars.split():
+ if any([line.startswith(x) for x in self._acq_var_magic]):
+ key = line
+ val = ''
+ else:
+ if not key:
+ raise ValueError('Cannot parse acquisition parameters')
+ # DACQ splits items with spaces into multiple lines
+ val += ' ' + line if val else line
+ yield key, val
+
+ def _categories_from_acq_pars(self):
+ """ Collect DACQ averaging categories into a dict.
+
+ Categories are keyed by the comment field in DACQ. Each category is
+ itself represented a dict containing the category parameters. """
+ cats = dict()
+ for catnum in [str(x).zfill(2) for x in range(1, self.nevent + 1)]:
+ catdi = dict()
+ # read all category variables
+ for var in self._cat_vars:
+ acq_key = 'ERFcat' + var + catnum
+ class_key = var.lower()
+ val = self.acq_dict[acq_key]
+ catdi[class_key] = val
+ # some type conversions
+ catdi['display'] = (catdi['display'] == '1')
+ catdi['state'] = (catdi['state'] == '1')
+ for key in ['start', 'end', 'reqwithin']:
+ catdi[key] = float(catdi[key])
+ for key in ['nave', 'event', 'reqevent', 'reqwhen', 'subave']:
+ catdi[key] = int(catdi[key])
+ # some convenient extra (non-DACQ) vars
+ catdi['index'] = int(catnum) # index of category in DACQ list
+ cats[catdi['comment']] = catdi
+ return cats
+
+ def _events_mne_to_dacq(self, mne_events):
+ """ Creates list of DACQ events based on mne trigger transitions list.
+
+ mne_events is typically given by mne.find_events (use consecutive=True
+ to get all transitions). Output consists of rows in the form
+ [t, 0, event_codes] where t is time in samples and event_codes is all
+ DACQ events compatible with the transition, bitwise ORed together:
+ e.g. [t1, 0, 5] means that events 1 and 3 occurred at time t1,
+ as 2**(1 - 1) + 2**(3 - 1) = 5. """
+ events_ = mne_events.copy()
+ events_[:, 1:3] = 0
+ for n, ev in self._events.items():
+ if ev['in_use']:
+ pre_ok = (
+ np.bitwise_and(ev['oldmask'],
+ mne_events[:, 1]) == ev['oldbits'])
+ post_ok = (
+ np.bitwise_and(ev['newmask'],
+ mne_events[:, 2]) == ev['newbits'])
+ ok_ind = np.where(pre_ok & post_ok)
+ events_[ok_ind, 2] |= 1 << (n - 1)
+ return events_
+
+ def _mne_events_to_category_t0(self, cat, mne_events, sfreq):
+ """ Translate mne_events to epoch zero times (t0).
+
+ First mne events (trigger transitions) are converted into DACQ events.
+ Then the zero times for the epochs are obtained by considering the
+ reference and conditional (required) events and the delay to stimulus.
+ """
+
+ cat_ev = cat['event']
+ cat_reqev = cat['reqevent']
+ # first convert mne events to dacq event list
+ events = self._events_mne_to_dacq(mne_events)
+ # next, take req. events and delays into account
+ times = events[:, 0]
+ # indices of times where ref. event occurs
+ refEvents_inds = np.where(events[:, 2] & (1 << cat_ev - 1))[0]
+ refEvents_t = times[refEvents_inds]
+ if cat_reqev:
+ # indices of times where req. event occurs
+ reqEvents_inds = np.where(events[:, 2] & (
+ 1 << cat_reqev - 1))[0]
+ reqEvents_t = times[reqEvents_inds]
+ # relative (to refevent) time window where req. event
+ # must occur (e.g. [0 .2])
+ twin = [0, (-1)**(cat['reqwhen']) * cat['reqwithin']]
+ win = np.round(np.array(sorted(twin)) * sfreq) # to samples
+ refEvents_wins = refEvents_t[:, None] + win
+ req_acc = np.zeros(refEvents_inds.shape, dtype=bool)
+ for t in reqEvents_t:
+ # mark time windows where req. condition is satisfied
+ reqEvent_in_win = np.logical_and(
+ t >= refEvents_wins[:, 0], t <= refEvents_wins[:, 1])
+ req_acc |= reqEvent_in_win
+ # drop ref. events where req. event condition is not satisfied
+ refEvents_inds = refEvents_inds[np.where(req_acc)]
+ refEvents_t = times[refEvents_inds]
+ # adjust for trigger-stimulus delay by delaying the ref. event
+ refEvents_t += int(np.round(self._events[cat_ev]['delay'] * sfreq))
+ return refEvents_t
+
+ @property
+ def categories(self):
+ """ Return list of averaging categories ordered by DACQ index.
+
+ Only returns categories marked active in DACQ.
+ """
+ cats = sorted(self._categories_in_use.values(),
+ key=lambda cat: cat['index'])
+ return cats
+
+ @property
+ def events(self):
+ """ Return events ordered by DACQ index.
+
+ Only returns events that are in use (referred to by a category).
+ """
+ evs = sorted(self._events_in_use.values(), key=lambda ev: ev['index'])
+ return evs
+
+ @property
+ def _categories_in_use(self):
+ return {k: v for k, v in self._categories.items() if v['state']}
+
+ @property
+ def _events_in_use(self):
+ return {k: v for k, v in self._events.items() if v['in_use']}
+
+ def get_condition(self, raw, condition=None, stim_channel=None, mask=None,
+ uint_cast=None, mask_type='and', delayed_lookup=True):
+ """ Get averaging parameters for a condition (averaging category).
+
+ Output is designed to be used with the Epochs class to extract the
+ corresponding epochs.
+
+ Parameters
+ ----------
+ raw : Raw object
+ An instance of Raw.
+ condition : None | str | dict | list of dict
+ Condition or a list of conditions. Conditions can be strings
+ (DACQ comment field, e.g. 'Auditory left') or category dicts
+ (e.g. acqp['Auditory left'], where acqp is an instance of
+ AcqParserFIF). If None, get all conditions marked active in
+ DACQ.
+ stim_channel : None | string | list of string
+ Name of the stim channel or all the stim channels
+ affected by the trigger. If None, the config variables
+ 'MNE_STIM_CHANNEL', 'MNE_STIM_CHANNEL_1', 'MNE_STIM_CHANNEL_2',
+ etc. are read. If these are not found, it will fall back to
+ 'STI101' or 'STI 014' if present, then fall back to the first
+ channel of type 'stim', if present.
+ mask : int | None
+ The value of the digital mask to apply to the stim channel values.
+ If None (default), no masking is performed.
+ uint_cast : bool
+ If True (default False), do a cast to ``uint16`` on the channel
+ data. This can be used to fix a bug with STI101 and STI014 in
+ Neuromag acquisition setups that use channel STI016 (channel 16
+ turns data into e.g. -32768), similar to ``mne_fix_stim14 --32``
+ in MNE-C.
+ mask_type: 'and' | 'not_and'
+ The type of operation between the mask and the trigger.
+ Choose 'and' for MNE-C masking behavior.
+ delayed_lookup: bool
+ If True, use the 'delayed lookup' procedure implemented in Elekta
+ software. When a trigger transition occurs, the lookup of
+ the new trigger value will not happen immediately at the following
+ sample, but with a 1-sample delay. This allows a slight
+ asynchrony between trigger onsets, when they are intended to be
+ synchronous. If you have accurate hardware and want to detect
+ transitions with a resolution of one sample, use
+ delayed_lookup=False.
+
+
+ Returns
+ -------
+ conds_data : dict or list of dict, each with following keys:
+ events : array, shape (n_epochs_out, 3)
+ List of zero time points (t0) for the epochs matching the
+ condition. Use as the ``events`` parameter to Epochs. Note
+ that these are not (necessarily) actual events.
+ event_id : dict
+ Name of condition and index compatible with ``events``.
+ Should be passed as the ``event_id`` parameter to Epochs.
+ tmin : float
+ Epoch starting time relative to t0. Use as the ``tmin``
+ parameter to Epochs.
+ tmax : float
+ Epoch ending time relative to t0. Use as the ``tmax``
+ parameter to Epochs.
+
+ """
+ if condition is None:
+ condition = self.categories # get all
+ if not isinstance(condition, list):
+ condition = [condition] # single cond -> listify
+ conds_data = list()
+ for cat in condition:
+ if isinstance(cat, str):
+ cat = self[cat]
+ mne_events = find_events(raw, stim_channel=stim_channel, mask=mask,
+ mask_type=mask_type, output='step',
+ uint_cast=uint_cast, consecutive=True,
+ verbose=False, shortest_event=1)
+ if delayed_lookup:
+ ind = np.where(np.diff(mne_events[:, 0]) == 1)[0]
+ if 1 in np.diff(ind):
+ raise ValueError('There are several subsequent '
+ 'transitions on the trigger channel. '
+ 'This will not work well with '
+ 'delayed_lookup=True. You may want to '
+ 'check your trigger data and '
+ 'set delayed_lookup=False.')
+ mne_events[ind, 2] = mne_events[ind + 1, 2]
+ mne_events = np.delete(mne_events, ind + 1, axis=0)
+ sfreq = raw.info['sfreq']
+ cat_t0_ = self._mne_events_to_category_t0(cat, mne_events, sfreq)
+ # make it compatible with the usual events array
+ cat_t0 = np.c_[cat_t0_, np.zeros(cat_t0_.shape),
+ cat['index'] * np.ones(cat_t0_.shape)
+ ].astype(np.uint32)
+ cat_id = {cat['comment']: cat['index']}
+ tmin, tmax = cat['start'], cat['end']
+ conds_data.append(dict(events=cat_t0, event_id=cat_id,
+ tmin=tmin, tmax=tmax))
+ return conds_data[0] if len(conds_data) == 1 else conds_data
diff --git a/mne/evoked.py b/mne/evoked.py
index 001c2b4..dd486e3 100644
--- a/mne/evoked.py
+++ b/mne/evoked.py
@@ -1,3 +1,4 @@
+# -*- coding: utf-8 -*-
# Authors: Alexandre Gramfort <alexandre.gramfort at telecom-paristech.fr>
# Matti Hamalainen <msh at nmr.mgh.harvard.edu>
# Denis Engemann <denis.engemann at gmail.com>
@@ -14,9 +15,8 @@ from .channels.channels import (ContainsMixin, UpdateChannelsMixin,
SetChannelsMixin, InterpolationMixin,
equalize_channels)
from .filter import resample, detrend, FilterMixin
-from .fixes import in1d
-from .utils import (check_fname, logger, verbose, object_hash, _time_mask,
- warn, _check_copy_dep)
+from .utils import (check_fname, logger, verbose, _time_mask, warn, sizeof_fmt,
+ deprecated, SizeMixin, copy_function_doc_to_method_doc)
from .viz import (plot_evoked, plot_evoked_topomap, plot_evoked_field,
plot_evoked_image, plot_evoked_topo)
from .viz.evoked import (_plot_evoked_white, plot_evoked_joint,
@@ -34,7 +34,7 @@ from .io.proj import ProjMixin
from .io.write import (start_file, start_block, end_file, end_block,
write_int, write_string, write_float_matrix,
write_id)
-from .io.base import ToDataFrameMixin, TimeMixin
+from .io.base import ToDataFrameMixin, TimeMixin, _check_maxshield
_aspect_dict = {'average': FIFF.FIFFV_ASPECT_AVERAGE,
'standard_error': FIFF.FIFFV_ASPECT_STD_ERR}
@@ -44,7 +44,7 @@ _aspect_rev = {str(FIFF.FIFFV_ASPECT_AVERAGE): 'average',
class Evoked(ProjMixin, ContainsMixin, UpdateChannelsMixin,
SetChannelsMixin, InterpolationMixin, FilterMixin,
- ToDataFrameMixin, TimeMixin):
+ ToDataFrameMixin, TimeMixin, SizeMixin):
"""Evoked data
Parameters
@@ -56,18 +56,26 @@ class Evoked(ProjMixin, ContainsMixin, UpdateChannelsMixin,
Dataset ID number (int) or comment/name (str). Optional if there is
only one data set in file.
baseline : tuple or list of length 2, or None
+ This parameter has been deprecated and will be removed in 0.14
+ Use inst.apply_baseline(baseline) instead.
The time interval to apply rescaling / baseline correction.
If None do not apply it. If baseline is (a, b)
the interval is between "a (s)" and "b (s)".
If a is None the beginning of the data is used
and if b is None then b is set to the end of the interval.
- If baseline is equal ot (None, None) all the time
+ If baseline is equal to (None, None) all the time
interval is used. If None, no correction is applied.
proj : bool, optional
Apply SSP projection vectors
kind : str
Either 'average' or 'standard_error'. The type of data to read.
Only used if 'condition' is a str.
+ allow_maxshield : bool | str (default False)
+ If True, allow loading of data that has been recorded with internal
+ active compensation (MaxShield). Data recorded with MaxShield should
+ generally not be loaded directly, but should first be processed using
+ SSS/tSSS to remove the compensation signals that may also affect brain
+ activity. Can also be "yes" to load without eliciting a warning.
verbose : bool, str, int, or None
If not None, override default verbose level (see mne.verbose).
@@ -101,19 +109,51 @@ class Evoked(ProjMixin, ContainsMixin, UpdateChannelsMixin,
"""
@verbose
def __init__(self, fname, condition=None, baseline=None, proj=True,
- kind='average', verbose=None):
+ kind='average', allow_maxshield=False, verbose=None):
if not isinstance(proj, bool):
raise ValueError(r"'proj' must be 'True' or 'False'")
# Read the requested data
self.info, self.nave, self._aspect_kind, self.first, self.last, \
self.comment, self.times, self.data = _read_evoked(
- fname, condition, kind)
+ fname, condition, kind, allow_maxshield)
self.kind = _aspect_rev.get(str(self._aspect_kind), 'Unknown')
self.verbose = verbose
# project and baseline correct
if proj:
self.apply_proj()
+ self.apply_baseline(baseline, self.verbose)
+
+ @verbose
+ def apply_baseline(self, baseline=(None, 0), verbose=None):
+ """Baseline correct evoked data
+
+ Parameters
+ ----------
+ baseline : tuple of length 2
+ The time interval to apply baseline correction. If None do not
+ apply it. If baseline is (a, b) the interval is between "a (s)" and
+ "b (s)". If a is None the beginning of the data is used and if b is
+ None then b is set to the end of the interval. If baseline is equal
+ to (None, None) all the time interval is used. Correction is
+ applied by computing mean of the baseline period and subtracting it
+ from the data. The baseline (a, b) includes both endpoints, i.e.
+ all timepoints t such that a <= t <= b.
+ verbose : bool, str, int, or None
+ If not None, override default verbose level (see mne.verbose).
+
+ Returns
+ -------
+ evoked : instance of Evoked
+ The baseline-corrected Evoked object.
+
+ Notes
+ -----
+ Baseline correction can be done multiple times.
+
+ .. versionadded:: 0.13.0
+ """
self.data = rescale(self.data, self.times, baseline, copy=False)
+ return self
def save(self, fname):
"""Save dataset to file.
@@ -136,6 +176,7 @@ class Evoked(ProjMixin, ContainsMixin, UpdateChannelsMixin,
s += ", time : [%f, %f]" % (self.times[0], self.times[-1])
s += ", n_epochs : %d" % self.nave
s += ", n_channels x n_times : %s x %s" % self.data.shape
+ s += ", ~%s" % (sizeof_fmt(self._size),)
return "<Evoked | %s>" % s
@property
@@ -143,7 +184,7 @@ class Evoked(ProjMixin, ContainsMixin, UpdateChannelsMixin,
"""Channel names"""
return self.info['ch_names']
- def crop(self, tmin=None, tmax=None, copy=None):
+ def crop(self, tmin=None, tmax=None):
"""Crop data to a given time interval
Parameters
@@ -152,18 +193,66 @@ class Evoked(ProjMixin, ContainsMixin, UpdateChannelsMixin,
Start time of selection in seconds.
tmax : float | None
End time of selection in seconds.
- copy : bool
- This parameter has been deprecated and will be removed in 0.13.
- Use inst.copy() instead.
- Whether to return a new instance or modify in place.
+
+ Returns
+ -------
+ evoked : instance of Evoked
+ The cropped Evoked object.
+
+ Notes
+ -----
+ Unlike Python slices, MNE time intervals include both their end points;
+ crop(tmin, tmax) returns the interval tmin <= t <= tmax.
"""
- inst = _check_copy_dep(self, copy)
- mask = _time_mask(inst.times, tmin, tmax, sfreq=self.info['sfreq'])
- inst.times = inst.times[mask]
- inst.first = int(inst.times[0] * inst.info['sfreq'])
- inst.last = len(inst.times) + inst.first - 1
- inst.data = inst.data[:, mask]
- return inst
+ mask = _time_mask(self.times, tmin, tmax, sfreq=self.info['sfreq'])
+ self.times = self.times[mask]
+ self.first = int(self.times[0] * self.info['sfreq'])
+ self.last = len(self.times) + self.first - 1
+ self.data = self.data[:, mask]
+ return self
+
+ def decimate(self, decim, offset=0):
+ """Decimate the evoked data
+
+ .. note:: No filtering is performed. To avoid aliasing, ensure
+ your data are properly lowpassed.
+
+ Parameters
+ ----------
+ decim : int
+ The amount to decimate data.
+ offset : int
+ Apply an offset to where the decimation starts relative to the
+ sample corresponding to t=0. The offset is in samples at the
+ current sampling rate.
+
+ Returns
+ -------
+ evoked : instance of Evoked
+ The decimated Evoked object.
+
+ See Also
+ --------
+ Epochs.decimate
+ Epochs.resample
+ Raw.resample
+
+ Notes
+ -----
+ Decimation can be done multiple times. For example,
+ ``evoked.decimate(2).decimate(2)`` will be the same as
+ ``evoked.decimate(4)``.
+
+ .. versionadded:: 0.13.0
+ """
+ decim, offset, new_sfreq = _check_decim(self.info, decim, offset)
+ start_idx = int(round(self.times[0] * (self.info['sfreq'] * decim)))
+ i_start = start_idx % decim + offset
+ decim_slice = slice(i_start, None, decim)
+ self.info['sfreq'] = new_sfreq
+ self.data = self.data[:, decim_slice].copy()
+ self.times = self.times[decim_slice].copy()
+ return self
def shift_time(self, tshift, relative=True):
"""Shift time scale in evoked data
@@ -193,190 +282,34 @@ class Evoked(ProjMixin, ContainsMixin, UpdateChannelsMixin,
self.times = np.arange(self.first, self.last + 1,
dtype=np.float) / sfreq
+ @copy_function_doc_to_method_doc(plot_evoked)
def plot(self, picks=None, exclude='bads', unit=True, show=True, ylim=None,
xlim='tight', proj=False, hline=None, units=None, scalings=None,
titles=None, axes=None, gfp=False, window_title=None,
- spatial_colors=False):
- """Plot evoked data using butterfly plots
-
- Left click to a line shows the channel name. Selecting an area by
- clicking and holding left mouse button plots a topographic map of the
- painted area.
-
- Note: If bad channels are not excluded they are shown in red.
-
- Parameters
- ----------
- picks : array-like of int | None
- The indices of channels to plot. If None show all.
- exclude : list of str | 'bads'
- Channels names to exclude from being shown. If 'bads', the
- bad channels are excluded.
- unit : bool
- Scale plot with channel (SI) unit.
- show : bool
- Call pyplot.show() at the end or not.
- ylim : dict | None
- ylim for plots (after scaling has been applied). The value
- determines the upper and lower subplot limits. e.g.
- ylim = dict(eeg=[-20, 20]). Valid keys are eeg, mag, grad. If None,
- the ylim parameter for each channel is determined by the maximum
- absolute peak.
- xlim : 'tight' | tuple | None
- xlim for plots.
- proj : bool | 'interactive'
- If true SSP projections are applied before display. If
- 'interactive', a check box for reversible selection of SSP
- projection vectors will be shown.
- hline : list of floats | None
- The values at which show an horizontal line.
- units : dict | None
- The units of the channel types used for axes lables. If None,
- defaults to `dict(eeg='uV', grad='fT/cm', mag='fT')`.
- scalings : dict | None
- The scalings of the channel types to be applied for plotting.
- If None, defaults to `dict(eeg=1e6, grad=1e13, mag=1e15)`.
- titles : dict | None
- The titles associated with the channels. If None, defaults to
- `dict(eeg='EEG', grad='Gradiometers', mag='Magnetometers')`.
- axes : instance of Axes | list | None
- The axes to plot to. If list, the list must be a list of Axes of
- the same length as the number of channel types. If instance of
- Axes, there must be only one channel type plotted.
- gfp : bool | 'only'
- Plot GFP in green if True or "only". If "only", then the individual
- channel traces will not be shown.
- window_title : str | None
- The title to put at the top of the figure window.
- spatial_colors : bool
- If True, the lines are color coded by mapping physical sensor
- coordinates into color values. Spatially similar channels will have
- similar colors. Bad channels will be dotted. If False, the good
- channels are plotted black and bad channels red. Defaults to False.
-
- Returns
- -------
- fig : instance of matplotlib.figure.Figure
- Figure containing the butterfly plots.
- """
- return plot_evoked(self, picks=picks, exclude=exclude, unit=unit,
- show=show, ylim=ylim, proj=proj, xlim=xlim,
- hline=hline, units=units, scalings=scalings,
- titles=titles, axes=axes, gfp=gfp,
- window_title=window_title,
- spatial_colors=spatial_colors)
-
+ spatial_colors=False, zorder='unsorted', selectable=True):
+ return plot_evoked(
+ self, picks=picks, exclude=exclude, unit=unit, show=show,
+ ylim=ylim, proj=proj, xlim=xlim, hline=hline, units=units,
+ scalings=scalings, titles=titles, axes=axes, gfp=gfp,
+ window_title=window_title, spatial_colors=spatial_colors,
+ zorder=zorder, selectable=selectable)
+
+ @copy_function_doc_to_method_doc(plot_evoked_image)
def plot_image(self, picks=None, exclude='bads', unit=True, show=True,
clim=None, xlim='tight', proj=False, units=None,
scalings=None, titles=None, axes=None, cmap='RdBu_r'):
- """Plot evoked data as images
-
- Parameters
- ----------
- picks : array-like of int | None
- The indices of channels to plot. If None show all.
- exclude : list of str | 'bads'
- Channels names to exclude from being shown. If 'bads', the
- bad channels are excluded.
- unit : bool
- Scale plot with channel (SI) unit.
- show : bool
- Call pyplot.show() at the end or not.
- clim : dict
- clim for images. e.g. clim = dict(eeg=[-200e-6, 200e6])
- Valid keys are eeg, mag, grad
- xlim : 'tight' | tuple | None
- xlim for plots.
- proj : bool | 'interactive'
- If true SSP projections are applied before display. If
- 'interactive', a check box for reversible selection of SSP
- projection vectors will be shown.
- units : dict | None
- The units of the channel types used for axes lables. If None,
- defaults to `dict(eeg='uV', grad='fT/cm', mag='fT')`.
- scalings : dict | None
- The scalings of the channel types to be applied for plotting.
- If None, defaults to `dict(eeg=1e6, grad=1e13, mag=1e15)`.
- titles : dict | None
- The titles associated with the channels. If None, defaults to
- `dict(eeg='EEG', grad='Gradiometers', mag='Magnetometers')`.
- axes : instance of Axes | list | None
- The axes to plot to. If list, the list must be a list of Axes of
- the same length as the number of channel types. If instance of
- Axes, there must be only one channel type plotted.
- cmap : matplotlib colormap
- Colormap.
-
- Returns
- -------
- fig : instance of matplotlib.figure.Figure
- Figure containing the images.
- """
return plot_evoked_image(self, picks=picks, exclude=exclude, unit=unit,
show=show, clim=clim, proj=proj, xlim=xlim,
units=units, scalings=scalings,
titles=titles, axes=axes, cmap=cmap)
+ @copy_function_doc_to_method_doc(plot_evoked_topo)
def plot_topo(self, layout=None, layout_scale=0.945, color=None,
border='none', ylim=None, scalings=None, title=None,
proj=False, vline=[0.0], fig_facecolor='k',
fig_background=None, axis_facecolor='k', font_color='w',
merge_grads=False, show=True):
- """Plot 2D topography of evoked responses.
-
- Clicking on the plot of an individual sensor opens a new figure showing
- the evoked response for the selected sensor.
-
- Parameters
- ----------
- layout : instance of Layout | None
- Layout instance specifying sensor positions (does not need to
- be specified for Neuromag data). If possible, the correct layout is
- inferred from the data.
- layout_scale: float
- Scaling factor for adjusting the relative size of the layout
- on the canvas
- color : list of color objects | color object | None
- Everything matplotlib accepts to specify colors. If not list-like,
- the color specified will be repeated. If None, colors are
- automatically drawn.
- border : str
- matplotlib borders style to be used for each sensor plot.
- ylim : dict | None
- ylim for plots. The value determines the upper and lower subplot
- limits. e.g. ylim = dict(eeg=[-20, 20]). Valid keys are eeg,
- mag, grad, misc. If None, the ylim parameter for each channel is
- determined by the maximum absolute peak.
- scalings : dict | None
- The scalings of the channel types to be applied for plotting. If
- None, defaults to `dict(eeg=1e6, grad=1e13, mag=1e15)`.
- title : str
- Title of the figure.
- proj : bool | 'interactive'
- If true SSP projections are applied before display. If
- 'interactive', a check box for reversible selection of SSP
- projection vectors will be shown.
- vline : list of floats | None
- The values at which to show a vertical line.
- fig_facecolor : str | obj
- The figure face color. Defaults to black.
- fig_background : None | numpy ndarray
- A background image for the figure. This must work with a call to
- plt.imshow. Defaults to None.
- axis_facecolor : str | obj
- The face color to be used for each sensor plot. Defaults to black.
- font_color : str | obj
- The color of text in the colorbar and title. Defaults to white.
- merge_grads : bool
- Whether to use RMS value of gradiometer pairs. Only works for
- Neuromag data. Defaults to False.
- show : bool
- Show figure if True.
-
- Returns
- -------
- fig : Instance of matplotlib.figure.Figure
- Images of evoked responses at sensor locations
+ """
Notes
-----
@@ -391,6 +324,7 @@ class Evoked(ProjMixin, ContainsMixin, UpdateChannelsMixin,
font_color=font_color, merge_grads=merge_grads,
show=show)
+ @copy_function_doc_to_method_doc(plot_evoked_topomap)
def plot_topomap(self, times="auto", ch_type=None, layout=None, vmin=None,
vmax=None, cmap=None, sensors=True, colorbar=True,
scale=None, scale_time=1e3, unit=None, res=64, size=1,
@@ -399,123 +333,6 @@ class Evoked(ProjMixin, ContainsMixin, UpdateChannelsMixin,
mask_params=None, outlines='head', contours=6,
image_interp='bilinear', average=None, head_pos=None,
axes=None):
- """Plot topographic maps of specific time points
-
- Parameters
- ----------
- times : float | array of floats | "auto" | "peaks".
- The time point(s) to plot. If "auto", the number of ``axes``
- determines the amount of time point(s). If ``axes`` is also None,
- 10 topographies will be shown with a regular time spacing between
- the first and last time instant. If "peaks", finds time points
- automatically by checking for local maxima in Global Field Power.
- ch_type : 'mag' | 'grad' | 'planar1' | 'planar2' | 'eeg' | None
- The channel type to plot. For 'grad', the gradiometers are collec-
- ted in pairs and the RMS for each pair is plotted.
- If None, then first available channel type from order given
- above is used. Defaults to None.
- layout : None | Layout
- Layout instance specifying sensor positions (does not need to
- be specified for Neuromag data). If possible, the correct
- layout file is inferred from the data; if no appropriate layout
- file was found, the layout is automatically generated from the
- sensor locations.
- vmin : float | callable
- The value specfying the lower bound of the color range.
- If None, and vmax is None, -vmax is used. Else np.min(data).
- If callable, the output equals vmin(data).
- vmax : float | callable
- The value specfying the upper bound of the color range.
- If None, the maximum absolute value is used. If vmin is None,
- but vmax is not, defaults to np.max(data).
- If callable, the output equals vmax(data).
- cmap : matplotlib colormap | None
- Colormap to use. If None, 'Reds' is used for all positive data,
- otherwise defaults to 'RdBu_r'.
- sensors : bool | str
- Add markers for sensor locations to the plot. Accepts matplotlib
- plot format string (e.g., 'r+' for red plusses). If True, a circle
- will be used (via .add_artist). Defaults to True.
- colorbar : bool
- Plot a colorbar.
- scale : dict | float | None
- Scale the data for plotting. If None, defaults to 1e6 for eeg, 1e13
- for grad and 1e15 for mag.
- scale_time : float | None
- Scale the time labels. Defaults to 1e3 (ms).
- unit : dict | str | None
- The unit of the channel type used for colorbar label. If
- scale is None the unit is automatically determined.
- res : int
- The resolution of the topomap image (n pixels along each side).
- size : scalar
- Side length of the topomaps in inches (only applies when plotting
- multiple topomaps at a time).
- cbar_fmt : str
- String format for colorbar values.
- time_format : str
- String format for topomap values. Defaults to ``"%01d ms"``.
- proj : bool | 'interactive'
- If true SSP projections are applied before display. If
- 'interactive', a check box for reversible selection of SSP
- projection vectors will be shown.
- show : bool
- Call pyplot.show() at the end.
- show_names : bool | callable
- If True, show channel names on top of the map. If a callable is
- passed, channel names will be formatted using the callable; e.g.,
- to delete the prefix 'MEG ' from all channel names, pass the
- function
- lambda x: x.replace('MEG ', ''). If `mask` is not None, only
- significant sensors will be shown.
- title : str | None
- Title. If None (default), no title is displayed.
- mask : ndarray of bool, shape (n_channels, n_times) | None
- The channels to be marked as significant at a given time point.
- Indices set to `True` will be considered. Defaults to None.
- mask_params : dict | None
- Additional plotting parameters for plotting significant sensors.
- Default (None) equals:
- ``dict(marker='o', markerfacecolor='w', markeredgecolor='k',
- linewidth=0, markersize=4)``.
- outlines : 'head' | 'skirt' | dict | None
- The outlines to be drawn. If 'head', the default head scheme will
- be drawn. If 'skirt' the head scheme will be drawn, but sensors are
- allowed to be plotted outside of the head circle. If dict, each key
- refers to a tuple of x and y positions, the values in 'mask_pos'
- will serve as image mask, and the 'autoshrink' (bool) field will
- trigger automated shrinking of the positions due to points outside
- the outline. Alternatively, a matplotlib patch object can be passed
- for advanced masking options, either directly or as a function that
- returns patches (required for multi-axis plots). If None, nothing
- will be drawn. Defaults to 'head'.
- contours : int | False | None
- The number of contour lines to draw. If 0, no contours will be
- drawn.
- image_interp : str
- The image interpolation to be used. All matplotlib options are
- accepted.
- average : float | None
- The time window around a given time to be used for averaging
- (seconds). For example, 0.01 would translate into window that
- starts 5 ms before and ends 5 ms after a given time point.
- Defaults to None, which means no averaging.
- head_pos : dict | None
- If None (default), the sensors are positioned such that they span
- the head circle. If dict, can have entries 'center' (tuple) and
- 'scale' (tuple) for what the center and scale of the head should be
- relative to the electrode locations.
- axes : instance of Axes | list | None
- The axes to plot to. If list, the list must be a list of Axes of
- the same length as ``times`` (unless ``times`` is None). If
- instance of Axes, ``times`` must be a float or a list of one float.
- Defaults to None.
-
- Returns
- -------
- fig : instance of matplotlib.figure.Figure
- Images of evoked responses at sensor locations
- """
return plot_evoked_topomap(self, times=times, ch_type=ch_type,
layout=layout, vmin=vmin, vmax=vmax,
cmap=cmap, sensors=sensors,
@@ -529,27 +346,9 @@ class Evoked(ProjMixin, ContainsMixin, UpdateChannelsMixin,
image_interp=image_interp, average=average,
head_pos=head_pos, axes=axes)
+ @copy_function_doc_to_method_doc(plot_evoked_field)
def plot_field(self, surf_maps, time=None, time_label='t = %0.0f ms',
n_jobs=1):
- """Plot MEG/EEG fields on head surface and helmet in 3D
-
- Parameters
- ----------
- surf_maps : list
- The surface mapping information obtained with make_field_map.
- time : float | None
- The time point at which the field map shall be displayed. If None,
- the average peak latency (across sensor types) is used.
- time_label : str
- How to print info about the time instant visualized.
- n_jobs : int
- Number of jobs to run in parallel.
-
- Returns
- -------
- fig : instance of mlab.Figure
- The mayavi figure.
- """
return plot_evoked_field(self, surf_maps, time=time,
time_label=time_label, n_jobs=n_jobs)
@@ -593,53 +392,10 @@ class Evoked(ProjMixin, ContainsMixin, UpdateChannelsMixin,
return _plot_evoked_white(self, noise_cov=noise_cov, scalings=None,
rank=None, show=show)
+ @copy_function_doc_to_method_doc(plot_evoked_joint)
def plot_joint(self, times="peaks", title='', picks=None,
exclude='bads', show=True, ts_args=None,
topomap_args=None):
- """Plot evoked data as butterfly plots and add topomaps for selected
- time points.
-
- Parameters
- ----------
- times : float | array of floats | "auto" | "peaks"
- The time point(s) to plot. If "auto", 5 evenly spaced topographies
- between the first and last time instant will be shown. If "peaks",
- finds time points automatically by checking for 3 local
- maxima in Global Field Power. Defaults to "peaks".
- title : str
- The title. If ``None``, suppress printing channel type. Defaults to
- an empty string.
- picks : array-like of int | None
- The indices of channels to plot. If ``None``, show all. Defaults
- to None.
- exclude : list of str | 'bads'
- Channels names to exclude from being shown. If 'bads', the
- bad channels are excluded. Defaults to 'bads'.
- show : bool
- Show figure if True. Defaults to True.
- ts_args : None | dict
- A dict of `kwargs` that are forwarded to `evoked.plot` to
- style the butterfly plot. `axes` and `show` are ignored.
- If `spatial_colors` is not in this dict, `spatial_colors=True`
- will be passed. Beyond that, if `None`, no customizable arguments
- will be passed.
- topomap_args : None | dict
- A dict of `kwargs` that are forwarded to `evoked.plot_topomap`
- to style the topomaps. `axes` and `show` are ignored. If `times`
- is not in this dict, automatic peak detection is used. Beyond
- that, if `None`, no customizable arguments will be passed.
-
- Returns
- -------
- fig : instance of matplotlib.figure.Figure | list
- The figure object containing the plot. If `evoked` has multiple
- channel types, a list of figures, one for each channel type, is
- returned.
-
- Notes
- -----
- .. versionadded:: 0.12.0
- """
return plot_evoked_joint(self, times=times, title=title, picks=picks,
exclude=exclude, show=show, ts_args=ts_args,
topomap_args=topomap_args)
@@ -719,7 +475,7 @@ class Evoked(ProjMixin, ContainsMixin, UpdateChannelsMixin,
from .forward import _as_meg_type_evoked
return _as_meg_type_evoked(self, ch_type=ch_type, mode=mode)
- def resample(self, sfreq, npad=None, window='boxcar'):
+ def resample(self, sfreq, npad='auto', window='boxcar'):
"""Resample data
This function operates in-place.
@@ -734,12 +490,12 @@ class Evoked(ProjMixin, ContainsMixin, UpdateChannelsMixin,
a power-of-two size (can be much faster).
window : string or tuple
Window to use in resampling. See scipy.signal.resample.
+
+ Returns
+ -------
+ evoked : instance of mne.Evoked
+ The resampled evoked object.
"""
- if npad is None:
- npad = 100
- warn('npad is currently taken to be 100, but will be changed to '
- '"auto" in 0.13. Please set the value explicitly.',
- DeprecationWarning)
sfreq = float(sfreq)
o_sfreq = self.info['sfreq']
self.data = resample(self.data, sfreq, o_sfreq, npad, -1, window)
@@ -749,6 +505,7 @@ class Evoked(ProjMixin, ContainsMixin, UpdateChannelsMixin,
self.times[0])
self.first = int(self.times[0] * self.info['sfreq'])
self.last = len(self.times) + self.first - 1
+ return self
def detrend(self, order=1, picks=None):
"""Detrend data
@@ -761,12 +518,12 @@ class Evoked(ProjMixin, ContainsMixin, UpdateChannelsMixin,
Either 0 or 1, the order of the detrending. 0 is a constant
(DC) detrend, 1 is a linear detrend.
picks : array-like of int | None
- If None only MEG, EEG, SEEG, and ECoG channels are detrended.
+ If None only MEG, EEG, SEEG, ECoG and fNIRS channels are detrended.
Returns
-------
evoked : instance of Evoked
- The evoked instance.
+ The detrended evoked object.
"""
if picks is None:
picks = _pick_data_channels(self.info)
@@ -783,6 +540,23 @@ class Evoked(ProjMixin, ContainsMixin, UpdateChannelsMixin,
evoked = deepcopy(self)
return evoked
+ def __neg__(self):
+ """Negate channel responses
+
+ Returns
+ -------
+ evoked_neg : instance of Evoked
+ The Evoked instance with channel data negated and '-'
+ prepended to the comment.
+ """
+ out = self.copy()
+ out.data *= -1
+ out.comment = '-' + (out.comment or 'unknown')
+ return out
+
+ @deprecated('ev1 + ev2 weighted summation has been deprecated and will be '
+ 'removed in 0.14, use combine_evoked([ev1, ev2],'
+ 'weights="nave") instead')
def __add__(self, evoked):
"""Add evoked taking into account number of epochs
@@ -793,10 +567,13 @@ class Evoked(ProjMixin, ContainsMixin, UpdateChannelsMixin,
--------
mne.combine_evoked
"""
- out = combine_evoked([self, evoked])
+ out = combine_evoked([self, evoked], weights='nave')
out.comment = self.comment + " + " + evoked.comment
return out
+ @deprecated('ev1 - ev2 weighted subtraction has been deprecated and will '
+ 'be removed in 0.14, use combine_evoked([ev1, -ev2], '
+ 'weights="nave") instead')
def __sub__(self, evoked):
"""Subtract evoked taking into account number of epochs
@@ -807,26 +584,21 @@ class Evoked(ProjMixin, ContainsMixin, UpdateChannelsMixin,
--------
mne.combine_evoked
"""
- this_evoked = deepcopy(evoked)
- this_evoked.data *= -1.
- out = combine_evoked([self, this_evoked])
- if self.comment is None or this_evoked.comment is None:
+ out = combine_evoked([self, -evoked], weights='nave')
+ if self.comment is None or evoked.comment is None:
warn('evoked.comment expects a string but is None')
out.comment = 'unknown'
else:
- out.comment = self.comment + " - " + this_evoked.comment
+ out.comment = self.comment + " - " + evoked.comment
return out
- def __hash__(self):
- return object_hash(dict(info=self.info, data=self.data))
-
def get_peak(self, ch_type=None, tmin=None, tmax=None, mode='abs',
time_as_index=False):
"""Get location and latency of peak amplitude
Parameters
----------
- ch_type : {'mag', 'grad', 'eeg', 'seeg', 'ecog', 'misc', None}
+ ch_type : 'mag', 'grad', 'eeg', 'seeg', 'ecog', 'hbo', hbr', 'misc', None # noqa
The channel type to use. Defaults to None. If more than one sensor
Type is present in the data the channel type has to be explicitly
set.
@@ -852,7 +624,8 @@ class Evoked(ProjMixin, ContainsMixin, UpdateChannelsMixin,
The time point of the maximum response, either latency in seconds
or index.
"""
- supported = ('mag', 'grad', 'eeg', 'seeg', 'ecog', 'misc', 'None')
+ supported = ('mag', 'grad', 'eeg', 'seeg', 'ecog', 'misc', 'hbo',
+ 'hbr', 'None')
data_picks = _pick_data_channels(self.info, with_ref_meg=False)
types_used = set([channel_type(self.info, idx) for idx in data_picks])
@@ -871,11 +644,9 @@ class Evoked(ProjMixin, ContainsMixin, UpdateChannelsMixin,
'must not be `None`, pass a sensor type '
'value instead')
- meg = eeg = misc = seeg = ecog = False
+ meg = eeg = misc = seeg = ecog = fnirs = False
picks = None
- if ch_type == 'mag':
- meg = ch_type
- elif ch_type == 'grad':
+ if ch_type in ('mag', 'grad'):
meg = ch_type
elif ch_type == 'eeg':
eeg = True
@@ -885,10 +656,13 @@ class Evoked(ProjMixin, ContainsMixin, UpdateChannelsMixin,
seeg = True
elif ch_type == 'ecog':
ecog = True
+ elif ch_type in ('hbo', 'hbr'):
+ fnirs = ch_type
if ch_type is not None:
picks = pick_types(self.info, meg=meg, eeg=eeg, misc=misc,
- seeg=seeg, ecog=ecog, ref_meg=False)
+ seeg=seeg, ecog=ecog, ref_meg=False,
+ fnirs=fnirs)
data = self.data
ch_names = self.ch_names
@@ -902,6 +676,30 @@ class Evoked(ProjMixin, ContainsMixin, UpdateChannelsMixin,
time_idx if time_as_index else self.times[time_idx])
+def _check_decim(info, decim, offset):
+ """Helper to check decimation parameters"""
+ if decim < 1 or decim != int(decim):
+ raise ValueError('decim must be an integer > 0')
+ decim = int(decim)
+ new_sfreq = info['sfreq'] / float(decim)
+ lowpass = info['lowpass']
+ if decim > 1 and lowpass is None:
+ warn('The measurement information indicates data is not low-pass '
+ 'filtered. The decim=%i parameter will result in a sampling '
+ 'frequency of %g Hz, which can cause aliasing artifacts.'
+ % (decim, new_sfreq))
+ elif decim > 1 and new_sfreq < 2.5 * lowpass:
+ warn('The measurement information indicates a low-pass frequency '
+ 'of %g Hz. The decim=%i parameter will result in a sampling '
+ 'frequency of %g Hz, which can cause aliasing artifacts.'
+ % (lowpass, decim, new_sfreq)) # > 50% nyquist lim
+ offset = int(offset)
+ if not 0 <= offset < decim:
+ raise ValueError('decim must be at least 0 and less than %s, got '
+ '%s' % (decim, offset))
+ return decim, offset, new_sfreq
+
+
class EvokedArray(Evoked):
"""Evoked object from numpy array
@@ -967,7 +765,7 @@ class EvokedArray(Evoked):
self._aspect_kind = _aspect_dict[self.kind]
-def _get_entries(fid, evoked_node):
+def _get_entries(fid, evoked_node, allow_maxshield=False):
"""Helper to get all evoked entries"""
comments = list()
aspect_kinds = list()
@@ -978,7 +776,7 @@ def _get_entries(fid, evoked_node):
if my_kind == FIFF.FIFF_COMMENT:
tag = read_tag(fid, pos)
comments.append(tag.data)
- my_aspect = dir_tree_find(ev, FIFF.FIFFB_ASPECT)[0]
+ my_aspect = _get_aspect(ev, allow_maxshield)[0]
for k in range(my_aspect['nent']):
my_kind = my_aspect['directory'][k].kind
pos = my_aspect['directory'][k].pos
@@ -997,6 +795,19 @@ def _get_entries(fid, evoked_node):
return comments, aspect_kinds, t
+def _get_aspect(evoked, allow_maxshield):
+ """Get Evoked data aspect."""
+ is_maxshield = False
+ aspect = dir_tree_find(evoked, FIFF.FIFFB_ASPECT)
+ if len(aspect) == 0:
+ _check_maxshield(allow_maxshield)
+ aspect = dir_tree_find(evoked, FIFF.FIFFB_SMSH_ASPECT)
+ is_maxshield = True
+ if len(aspect) > 1:
+ logger.info('Multiple data aspects found. Taking first one.')
+ return aspect[0], is_maxshield
+
+
def _get_evoked_node(fname):
"""Helper to get info in evoked file"""
f, tree, _ = fiff_open(fname)
@@ -1057,8 +868,8 @@ def grand_average(all_evoked, interpolate_bads=True):
return grand_average
-def combine_evoked(all_evoked, weights='nave'):
- """Merge evoked data by weighted addition
+def combine_evoked(all_evoked, weights=None):
+ """Merge evoked data by weighted addition or subtraction
Data should have the same channels and the same time instants.
Subtraction can be performed by passing negative weights (e.g., [1, -1]).
@@ -1081,6 +892,11 @@ def combine_evoked(all_evoked, weights='nave'):
-----
.. versionadded:: 0.9.0
"""
+ if weights is None:
+ weights = 'nave'
+ warn('In 0.13 the default is weights="nave", but in 0.14 the default '
+ 'will be removed and it will have to be explicitly set',
+ DeprecationWarning)
evoked = all_evoked[0].copy()
if isinstance(weights, string_types):
if weights not in ('nave', 'equal'):
@@ -1110,14 +926,30 @@ def combine_evoked(all_evoked, weights='nave'):
evoked.info['bads'] = bads
evoked.data = sum(w * e.data for w, e in zip(weights, all_evoked))
- evoked.nave = max(int(1. / sum(w ** 2 / e.nave
- for w, e in zip(weights, all_evoked))), 1)
+ # We should set nave based on how variances change when summing Gaussian
+ # random variables. From:
+ #
+ # https://en.wikipedia.org/wiki/Weighted_arithmetic_mean
+ #
+ # We know that the variance of a weighted sample mean is:
+ #
+ # σ^2 = w_1^2 σ_1^2 + w_2^2 σ_2^2 + ... + w_n^2 σ_n^2
+ #
+ # We estimate the variance of each evoked instance as 1 / nave to get:
+ #
+ # σ^2 = w_1^2 / nave_1 + w_2^2 / nave_2 + ... + w_n^2 / nave_n
+ #
+ # And our resulting nave is the reciprocal of this:
+ evoked.nave = max(int(round(
+ 1. / sum(w ** 2 / e.nave for w, e in zip(weights, all_evoked)))), 1)
+ evoked.comment = ' + '.join('%0.3f * %s' % (w, e.comment or 'unknown')
+ for w, e in zip(weights, all_evoked))
return evoked
@verbose
def read_evokeds(fname, condition=None, baseline=None, kind='average',
- proj=True, verbose=None):
+ proj=True, allow_maxshield=False, verbose=None):
"""Read evoked dataset(s)
Parameters
@@ -1131,13 +963,22 @@ def read_evokeds(fname, condition=None, baseline=None, kind='average',
baseline : None (default) or tuple of length 2
The time interval to apply baseline correction. If None do not apply
it. If baseline is (a, b) the interval is between "a (s)" and "b (s)".
- If a is None the beginning of the data is used and if b is None then
- b is set to the end of the interval. If baseline is equal to
- (None, None) all the time interval is used.
+ If a is None the beginning of the data is used and if b is None then b
+ is set to the end of the interval. If baseline is equal to (None, None)
+ all the time interval is used. Correction is applied by computing mean
+ of the baseline period and subtracting it from the data. The baseline
+ (a, b) includes both endpoints, i.e. all timepoints t such that
+ a <= t <= b.
kind : str
Either 'average' or 'standard_error', the type of data to read.
proj : bool
If False, available projectors won't be applied to the data.
+ allow_maxshield : bool | str (default False)
+ If True, allow loading of data that has been recorded with internal
+ active compensation (MaxShield). Data recorded with MaxShield should
+ generally not be loaded directly, but should first be processed using
+ SSS/tSSS to remove the compensation signals that may also affect brain
+ activity. Can also be "yes" to load without eliciting a warning.
verbose : bool, str, int, or None
If not None, override default verbose level (see mne.verbose).
@@ -1161,13 +1002,15 @@ def read_evokeds(fname, condition=None, baseline=None, kind='average',
condition = [condition]
return_list = False
- out = [Evoked(fname, c, baseline=baseline, kind=kind, proj=proj,
- verbose=verbose) for c in condition]
+ out = [Evoked(fname, c, kind=kind, proj=proj,
+ allow_maxshield=allow_maxshield,
+ verbose=verbose).apply_baseline(baseline)
+ for c in condition]
return out if return_list else out[0]
-def _read_evoked(fname, condition=None, kind='average'):
+def _read_evoked(fname, condition=None, kind='average', allow_maxshield=False):
"""Read evoked data from a FIF file"""
if fname is None:
raise ValueError('No evoked filename specified')
@@ -1193,10 +1036,10 @@ def _read_evoked(fname, condition=None, kind='average'):
raise ValueError('kind must be "average" or '
'"standard_error"')
- comments, aspect_kinds, t = _get_entries(fid, evoked_node)
- goods = np.logical_and(in1d(comments, [condition]),
- in1d(aspect_kinds,
- [_aspect_dict[kind]]))
+ comments, aspect_kinds, t = _get_entries(fid, evoked_node,
+ allow_maxshield)
+ goods = (np.in1d(comments, [condition]) &
+ np.in1d(aspect_kinds, [_aspect_dict[kind]]))
found_cond = np.where(goods)[0]
if len(found_cond) != 1:
raise ValueError('condition "%s" (%s) not found, out of '
@@ -1205,7 +1048,8 @@ def _read_evoked(fname, condition=None, kind='average'):
condition = found_cond[0]
elif condition is None:
if len(evoked_node) > 1:
- _, _, conditions = _get_entries(fid, evoked_node)
+ _, _, conditions = _get_entries(fid, evoked_node,
+ allow_maxshield)
raise TypeError("Evoked file has more than one "
"conditions, the condition parameters "
"must be specified from:\n%s" % conditions)
@@ -1218,10 +1062,7 @@ def _read_evoked(fname, condition=None, kind='average'):
my_evoked = evoked_node[condition]
# Identify the aspects
- aspects = dir_tree_find(my_evoked, FIFF.FIFFB_ASPECT)
- if len(aspects) > 1:
- logger.info('Multiple aspects found. Taking first one.')
- my_aspect = aspects[0]
+ my_aspect, info['maxshield'] = _get_aspect(my_evoked, allow_maxshield)
# Now find the data in the evoked block
nchan = 0
@@ -1395,7 +1236,11 @@ def _write_evokeds(fname, evoked, check=True):
write_int(fid, FIFF.FIFF_LAST_SAMPLE, e.last)
# The epoch itself
- start_block(fid, FIFF.FIFFB_ASPECT)
+ if e.info.get('maxshield'):
+ aspect = FIFF.FIFFB_SMSH_ASPECT
+ else:
+ aspect = FIFF.FIFFB_ASPECT
+ start_block(fid, aspect)
write_int(fid, FIFF.FIFF_ASPECT_KIND, e._aspect_kind)
write_int(fid, FIFF.FIFF_NAVE, e.nave)
@@ -1406,7 +1251,7 @@ def _write_evokeds(fname, evoked, check=True):
e.info['chs'][k].get('scale', 1.0))
write_float_matrix(fid, FIFF.FIFF_EPOCH, decal * e.data)
- end_block(fid, FIFF.FIFFB_ASPECT)
+ end_block(fid, aspect)
end_block(fid, FIFF.FIFFB_EVOKED)
end_block(fid, FIFF.FIFFB_PROCESSED_DATA)
diff --git a/mne/externals/h5io/_h5io.py b/mne/externals/h5io/_h5io.py
index 36dd9f7..0130dff 100644
--- a/mne/externals/h5io/_h5io.py
+++ b/mne/externals/h5io/_h5io.py
@@ -19,6 +19,8 @@ PY3 = sys.version_info[0] == 3
text_type = str if PY3 else unicode # noqa
string_types = str if PY3 else basestring # noqa
+special_chars = {'{FWDSLASH}': '/'}
+
##############################################################################
# WRITING
@@ -47,8 +49,16 @@ def _create_titled_dataset(root, key, title, data, comp_kw=None):
return out
+def _create_pandas_dataset(fname, root, key, title, data):
+ h5py = _check_h5py()
+ rootpath = '/'.join([root, key])
+ data.to_hdf(fname, rootpath)
+ with h5py.File(fname, mode='a') as fid:
+ fid[rootpath].attrs['TITLE'] = 'pd_dataframe'
+
+
def write_hdf5(fname, data, overwrite=False, compression=4,
- title='h5io'):
+ title='h5io', slash='error'):
"""Write python object to HDF5 format using h5py
Parameters
@@ -58,42 +68,80 @@ def write_hdf5(fname, data, overwrite=False, compression=4,
data : object
Object to write. Can be of any of these types:
{ndarray, dict, list, tuple, int, float, str}
- Note that dict objects must only have ``str`` keys.
- overwrite : bool
- If True, overwrite file (if it exists).
+ Note that dict objects must only have ``str`` keys. It is recommended
+ to use ndarrays where possible, as it is handled most efficiently.
+ overwrite : True | False | 'update'
+ If True, overwrite file (if it exists). If 'update', appends the title
+ to the file (or replace value if title exists).
compression : int
Compression level to use (0-9) to compress data using gzip.
title : str
The top-level directory name to use. Typically it is useful to make
this your package name, e.g. ``'mnepython'``.
+ slash : 'error' | 'replace'
+ Whether to replace forward-slashes ('/') in any key found nested within
+ keys in data. This does not apply to the top level name (title).
+ If 'error', '/' is not allowed in any lower-level keys.
"""
h5py = _check_h5py()
- if op.isfile(fname) and not overwrite:
- raise IOError('file "%s" exists, use overwrite=True to overwrite'
- % fname)
+ mode = 'w'
+ if op.isfile(fname):
+ if isinstance(overwrite, string_types):
+ if overwrite != 'update':
+ raise ValueError('overwrite must be "update" or a bool')
+ mode = 'a'
+ elif not overwrite:
+ raise IOError('file "%s" exists, use overwrite=True to overwrite'
+ % fname)
if not isinstance(title, string_types):
raise ValueError('title must be a string')
comp_kw = dict()
if compression > 0:
comp_kw = dict(compression='gzip', compression_opts=compression)
- with h5py.File(fname, mode='w') as fid:
- _triage_write(title, data, fid, comp_kw, str(type(data)))
-
+ with h5py.File(fname, mode=mode) as fid:
+ if title in fid:
+ del fid[title]
+ cleanup_data = []
+ _triage_write(title, data, fid, comp_kw, str(type(data)),
+ cleanup_data=cleanup_data, slash=slash, title=title)
+
+ # Will not be empty if any extra data to be written
+ for data in cleanup_data:
+ # In case different extra I/O needs different inputs
+ title = list(data.keys())[0]
+ if title in ['pd_dataframe', 'pd_series']:
+ rootname, key, value = data[title]
+ _create_pandas_dataset(fname, rootname, key, title, value)
+
+
+def _triage_write(key, value, root, comp_kw, where,
+ cleanup_data=[], slash='error', title=None):
+ if key != title and '/' in key:
+ if slash == 'error':
+ raise ValueError('Found a key with "/", '
+ 'this is not allowed if slash == error')
+ elif slash == 'replace':
+ # Auto-replace keys with proper values
+ for key_spec, val_spec in special_chars.items():
+ key = key.replace(val_spec, key_spec)
+ else:
+ raise ValueError("slash must be one of ['error', 'replace'")
-def _triage_write(key, value, root, comp_kw, where):
if isinstance(value, dict):
sub_root = _create_titled_group(root, key, 'dict')
for key, sub_value in value.items():
if not isinstance(key, string_types):
raise TypeError('All dict keys must be strings')
- _triage_write('key_{0}'.format(key), sub_value, sub_root, comp_kw,
- where + '["%s"]' % key)
+ _triage_write(
+ 'key_{0}'.format(key), sub_value, sub_root, comp_kw,
+ where + '["%s"]' % key, cleanup_data=cleanup_data, slash=slash)
elif isinstance(value, (list, tuple)):
title = 'list' if isinstance(value, list) else 'tuple'
sub_root = _create_titled_group(root, key, title)
for vi, sub_value in enumerate(value):
- _triage_write('idx_{0}'.format(vi), sub_value, sub_root, comp_kw,
- where + '[%s]' % vi)
+ _triage_write(
+ 'idx_{0}'.format(vi), sub_value, sub_root, comp_kw,
+ where + '[%s]' % vi, cleanup_data=cleanup_data, slash=slash)
elif isinstance(value, type(None)):
_create_titled_dataset(root, key, 'None', [False])
elif isinstance(value, (int, float)):
@@ -102,6 +150,8 @@ def _triage_write(key, value, root, comp_kw, where):
else: # isinstance(value, float):
title = 'float'
_create_titled_dataset(root, key, title, np.atleast_1d(value))
+ elif isinstance(value, np.bool_):
+ _create_titled_dataset(root, key, 'np_bool_', np.atleast_1d(value))
elif isinstance(value, string_types):
if isinstance(value, text_type): # unicode
value = np.fromstring(value.encode('utf-8'), np.uint8)
@@ -115,19 +165,51 @@ def _triage_write(key, value, root, comp_kw, where):
elif sparse is not None and isinstance(value, sparse.csc_matrix):
sub_root = _create_titled_group(root, key, 'csc_matrix')
_triage_write('data', value.data, sub_root, comp_kw,
- where + '.csc_matrix_data')
+ where + '.csc_matrix_data', cleanup_data=cleanup_data,
+ slash=slash)
_triage_write('indices', value.indices, sub_root, comp_kw,
- where + '.csc_matrix_indices')
+ where + '.csc_matrix_indices', cleanup_data=cleanup_data,
+ slash=slash)
_triage_write('indptr', value.indptr, sub_root, comp_kw,
- where + '.csc_matrix_indptr')
+ where + '.csc_matrix_indptr', cleanup_data=cleanup_data,
+ slash=slash)
+ elif sparse is not None and isinstance(value, sparse.csr_matrix):
+ sub_root = _create_titled_group(root, key, 'csr_matrix')
+ _triage_write('data', value.data, sub_root, comp_kw,
+ where + '.csr_matrix_data', cleanup_data=cleanup_data,
+ slash=slash)
+ _triage_write('indices', value.indices, sub_root, comp_kw,
+ where + '.csr_matrix_indices', cleanup_data=cleanup_data,
+ slash=slash)
+ _triage_write('indptr', value.indptr, sub_root, comp_kw,
+ where + '.csr_matrix_indptr', cleanup_data=cleanup_data,
+ slash=slash)
+ _triage_write('shape', value.shape, sub_root, comp_kw,
+ where + '.csr_matrix_shape', cleanup_data=cleanup_data,
+ slash=slash)
else:
- raise TypeError('unsupported type %s (in %s)' % (type(value), where))
-
+ try:
+ from pandas import DataFrame, Series
+ except ImportError:
+ pass
+ else:
+ if isinstance(value, (DataFrame, Series)):
+ if isinstance(value, DataFrame):
+ title = 'pd_dataframe'
+ else:
+ title = 'pd_series'
+ rootname = root.name
+ cleanup_data.append({title: (rootname, key, value)})
+ return
+
+ err_str = 'unsupported type %s (in %s)' % (type(value), where)
+ raise TypeError(err_str)
##############################################################################
# READING
-def read_hdf5(fname, title='h5io'):
+
+def read_hdf5(fname, title='h5io', slash='ignore'):
"""Read python object from HDF5 format using h5py
Parameters
@@ -137,6 +219,10 @@ def read_hdf5(fname, title='h5io'):
title : str
The top-level directory name to use. Typically it is useful to make
this your package name, e.g. ``'mnepython'``.
+ slash : 'ignore' | 'replace'
+ Whether to replace the string {FWDSLASH} with the value /. This does
+ not apply to the top level name (title). If 'ignore', nothing will be
+ replaced.
Returns
-------
@@ -149,13 +235,18 @@ def read_hdf5(fname, title='h5io'):
if not isinstance(title, string_types):
raise ValueError('title must be a string')
with h5py.File(fname, mode='r') as fid:
- if title not in fid.keys():
+ if title not in fid:
raise ValueError('no "%s" data found' % title)
- data = _triage_read(fid[title])
+ if isinstance(fid[title], h5py.Group):
+ if 'TITLE' not in fid[title].attrs:
+ raise ValueError('no "%s" data found' % title)
+ data = _triage_read(fid[title], slash=slash)
return data
-def _triage_read(node):
+def _triage_read(node, slash='ignore'):
+ if slash not in ['ignore', 'replace']:
+ raise ValueError("slash must be one of 'replace', 'ignore'")
h5py = _check_h5py()
type_str = node.attrs['TITLE']
if isinstance(type_str, bytes):
@@ -164,7 +255,10 @@ def _triage_read(node):
if type_str == 'dict':
data = dict()
for key, subnode in node.items():
- data[key[4:]] = _triage_read(subnode)
+ if slash == 'replace':
+ for key_spec, val_spec in special_chars.items():
+ key = key.replace(key_spec, val_spec)
+ data[key[4:]] = _triage_read(subnode, slash=slash)
elif type_str in ['list', 'tuple']:
data = list()
ii = 0
@@ -172,7 +266,7 @@ def _triage_read(node):
subnode = node.get('idx_{0}'.format(ii), None)
if subnode is None:
break
- data.append(_triage_read(subnode))
+ data.append(_triage_read(subnode, slash=slash))
ii += 1
assert len(data) == ii
data = tuple(data) if type_str == 'tuple' else data
@@ -180,9 +274,25 @@ def _triage_read(node):
elif type_str == 'csc_matrix':
if sparse is None:
raise RuntimeError('scipy must be installed to read this data')
- data = sparse.csc_matrix((_triage_read(node['data']),
- _triage_read(node['indices']),
- _triage_read(node['indptr'])))
+ data = sparse.csc_matrix((_triage_read(node['data'], slash=slash),
+ _triage_read(node['indices'],
+ slash=slash),
+ _triage_read(node['indptr'],
+ slash=slash)))
+ elif type_str == 'csr_matrix':
+ if sparse is None:
+ raise RuntimeError('scipy must be installed to read this data')
+ data = sparse.csr_matrix((_triage_read(node['data'], slash=slash),
+ _triage_read(node['indices'],
+ slash=slash),
+ _triage_read(node['indptr'],
+ slash=slash)),
+ shape=_triage_read(node['shape']))
+ elif type_str in ['pd_dataframe', 'pd_series']:
+ from pandas import read_hdf
+ rootname = node.name
+ filename = node.file.filename
+ data = read_hdf(filename, rootname, mode='r')
else:
raise NotImplementedError('Unknown group type: {0}'
''.format(type_str))
@@ -191,6 +301,8 @@ def _triage_read(node):
elif type_str in ('int', 'float'):
cast = int if type_str == 'int' else float
data = cast(np.array(node)[0])
+ elif type_str == 'np_bool_':
+ data = np.bool_(np.array(node)[0])
elif type_str in ('unicode', 'ascii', 'str'): # 'str' for backward compat
decoder = 'utf-8' if type_str == 'unicode' else 'ASCII'
cast = text_type if type_str == 'unicode' else str
@@ -231,6 +343,12 @@ def object_diff(a, b, pre=''):
diffs : str
A string representation of the differences.
"""
+
+ try:
+ from pandas import DataFrame, Series
+ except ImportError:
+ DataFrame = Series = type(None)
+
out = ''
if type(a) != type(b):
out += pre + ' type mismatch (%s, %s)\n' % (type(a), type(b))
@@ -270,6 +388,16 @@ def object_diff(a, b, pre=''):
if c.nnz > 0:
out += pre + (' sparse matrix a and b differ on %s '
'elements' % c.nnz)
+ elif isinstance(a, (DataFrame, Series)):
+ if b.shape != a.shape:
+ out += pre + (' pandas values a and b shape mismatch'
+ '(%s vs %s)' % (a.shape, b.shape))
+ else:
+ c = a.values - b.values
+ nzeros = np.sum(c != 0)
+ if nzeros > 0:
+ out += pre + (' pandas values a and b differ on %s '
+ 'elements' % nzeros)
else:
raise RuntimeError(pre + ': unsupported type %s (%s)' % (type(a), a))
return out
diff --git a/mne/externals/tempita/_looper.py b/mne/externals/tempita/_looper.py
index 4413a5b..4b480b4 100644
--- a/mne/externals/tempita/_looper.py
+++ b/mne/externals/tempita/_looper.py
@@ -7,9 +7,9 @@ These can be awkward to manage in a normal Python loop, but using the
looper you can get a better sense of the context. Use like::
>>> for loop, item in looper(['a', 'b', 'c']):
- ... print loop.number, item
+ ... print("%d %s" % (loop.number, item))
... if not loop.last:
- ... print '---'
+ ... print('---')
1 a
---
2 b
diff --git a/mne/filter.py b/mne/filter.py
index 2401f85..cbd0717 100644
--- a/mne/filter.py
+++ b/mne/filter.py
@@ -1,19 +1,24 @@
"""IIR and FIR filtering functions"""
from copy import deepcopy
+from functools import partial
import numpy as np
-from scipy.fftpack import fft, ifftshift, fftfreq
+from scipy.fftpack import fft, ifftshift, fftfreq, ifft
from .cuda import (setup_cuda_fft_multiply_repeated, fft_multiply_repeated,
setup_cuda_fft_resample, fft_resample, _smart_pad)
from .externals.six import string_types, integer_types
-from .fixes import get_firwin2, get_filtfilt
+from .fixes import get_sosfiltfilt
from .parallel import parallel_func, check_n_jobs
from .time_frequency.multitaper import dpss_windows, _mt_spectra
from .utils import logger, verbose, sum_squared, check_version, warn
+# These values are *double* what is given in Ifeachor and Jervis.
+_length_factors = dict(hann=6.2, hamming=6.6, blackman=11.0)
+
+
def is_power2(num):
"""Test if number is a power of 2
@@ -38,121 +43,177 @@ def is_power2(num):
return num != 0 and ((num & (num - 1)) == 0)
-def _overlap_add_filter(x, h, n_fft=None, zero_phase=True, picks=None,
- n_jobs=1):
- """ Filter using overlap-add FFTs.
+def next_fast_len(target):
+ """
+ Find the next fast size of input data to `fft`, for zero-padding, etc.
- Filters the signal x using a filter with the impulse response h.
- If zero_phase==True, the the filter is applied twice, once in the forward
- direction and once backward , resulting in a zero-phase filter.
+ SciPy's FFTPACK has efficient functions for radix {2, 3, 4, 5}, so this
+ returns the next composite of the prime factors 2, 3, and 5 which is
+ greater than or equal to `target`. (These are also known as 5-smooth
+ numbers, regular numbers, or Hamming numbers.)
+
+ Parameters
+ ----------
+ target : int
+ Length to start searching from. Must be a positive integer.
+
+ Returns
+ -------
+ out : int
+ The first 5-smooth number greater than or equal to `target`.
+
+ Notes
+ -----
+ Copied from SciPy with minor modifications.
+ """
+ from bisect import bisect_left
+ hams = (8, 9, 10, 12, 15, 16, 18, 20, 24, 25, 27, 30, 32, 36, 40, 45, 48,
+ 50, 54, 60, 64, 72, 75, 80, 81, 90, 96, 100, 108, 120, 125, 128,
+ 135, 144, 150, 160, 162, 180, 192, 200, 216, 225, 240, 243, 250,
+ 256, 270, 288, 300, 320, 324, 360, 375, 384, 400, 405, 432, 450,
+ 480, 486, 500, 512, 540, 576, 600, 625, 640, 648, 675, 720, 729,
+ 750, 768, 800, 810, 864, 900, 960, 972, 1000, 1024, 1080, 1125,
+ 1152, 1200, 1215, 1250, 1280, 1296, 1350, 1440, 1458, 1500, 1536,
+ 1600, 1620, 1728, 1800, 1875, 1920, 1944, 2000, 2025, 2048, 2160,
+ 2187, 2250, 2304, 2400, 2430, 2500, 2560, 2592, 2700, 2880, 2916,
+ 3000, 3072, 3125, 3200, 3240, 3375, 3456, 3600, 3645, 3750, 3840,
+ 3888, 4000, 4050, 4096, 4320, 4374, 4500, 4608, 4800, 4860, 5000,
+ 5120, 5184, 5400, 5625, 5760, 5832, 6000, 6075, 6144, 6250, 6400,
+ 6480, 6561, 6750, 6912, 7200, 7290, 7500, 7680, 7776, 8000, 8100,
+ 8192, 8640, 8748, 9000, 9216, 9375, 9600, 9720, 10000)
+
+ if target <= 6:
+ return target
+
+ # Quickly check if it's already a power of 2
+ if not (target & (target - 1)):
+ return target
+
+ # Get result quickly for small sizes, since FFT itself is similarly fast.
+ if target <= hams[-1]:
+ return hams[bisect_left(hams, target)]
+
+ match = float('inf') # Anything found will be smaller
+ p5 = 1
+ while p5 < target:
+ p35 = p5
+ while p35 < target:
+ # Ceiling integer division, avoiding conversion to float
+ # (quotient = ceil(target / p35))
+ quotient = -(-target // p35)
+
+ p2 = 2 ** (quotient - 1).bit_length()
+
+ N = p2 * p35
+ if N == target:
+ return N
+ elif N < match:
+ match = N
+ p35 *= 3
+ if p35 == target:
+ return p35
+ if p35 < match:
+ match = p35
+ p5 *= 5
+ if p5 == target:
+ return p5
+ if p5 < match:
+ match = p5
+ return match
+
+
+def _overlap_add_filter(x, h, n_fft=None, phase='zero', picks=None,
+ n_jobs=1):
+ """Filter the signal x using h with overlap-add FFTs.
.. warning:: This operates on the data in-place.
Parameters
----------
- x : 2d array
- Signal to filter.
+ x : array, shape (n_signals, n_times)
+ Signals to filter.
h : 1d array
- Filter impulse response (FIR filter coefficients).
+ Filter impulse response (FIR filter coefficients). Must be odd length
+ if phase == 'linear'.
n_fft : int
Length of the FFT. If None, the best size is determined automatically.
- zero_phase : bool
- If True: the filter is applied in forward and backward direction,
- resulting in a zero-phase filter.
+ phase : str
+ If 'zero', the delay for the filter is compensated (and it must be
+ an odd-length symmetric filter). If 'linear', the response is
+ uncompensated. If 'zero-double', the filter is applied in the
+ forward and reverse directions.
picks : array-like of int | None
- Indices to filter. If None all indices will be filtered.
+ Indices of channels to filter. If None all channels will be
+ filtered. Only supported for 2D (n_channels, n_times) and 3D
+ (n_epochs, n_channels, n_times) data.
n_jobs : int | str
Number of jobs to run in parallel. Can be 'cuda' if scikits.cuda
is installed properly and CUDA is initialized.
Returns
-------
- xf : 2d array
+ xf : array, shape (n_signals, n_times)
x filtered.
"""
- if picks is None:
- picks = np.arange(x.shape[0])
-
# Extend the signal by mirroring the edges to reduce transient filter
# response
- n_h = len(h)
- if n_h == 1:
- return x * h ** 2 if zero_phase else x * h
- if x.shape[1] < len(h):
- raise ValueError('Overlap add should only be used for signals '
- 'longer than the requested filter')
- n_edge = max(min(n_h, x.shape[1]) - 1, 0)
-
+ _check_zero_phase_length(len(h), phase)
+ if len(h) == 1:
+ return x * h ** 2 if phase == 'zero-double' else x * h
+ n_edge = max(min(len(h), x.shape[1]) - 1, 0)
+ logger.debug('Smart-padding with: %s samples on each edge' % n_edge)
n_x = x.shape[1] + 2 * n_edge
+ if phase == 'zero-double':
+ h = np.convolve(h, h[::-1])
+
# Determine FFT length to use
+ min_fft = 2 * len(h) - 1
if n_fft is None:
- min_fft = 2 * n_h - 1
max_fft = n_x
if max_fft >= min_fft:
- n_tot = 2 * n_x if zero_phase else n_x
-
# cost function based on number of multiplications
N = 2 ** np.arange(np.ceil(np.log2(min_fft)),
np.ceil(np.log2(max_fft)) + 1, dtype=int)
- # if doing zero-phase, h needs to be thought of as ~ twice as long
- n_h_cost = 2 * n_h - 1 if zero_phase else n_h
- cost = (np.ceil(n_tot / (N - n_h_cost + 1).astype(np.float)) *
+ cost = (np.ceil(n_x / (N - len(h) + 1).astype(np.float)) *
N * (np.log2(N) + 1))
# add a heuristic term to prevent too-long FFT's which are slow
# (not predicted by mult. cost alone, 4e-5 exp. determined)
- cost += 4e-5 * N * n_tot
+ cost += 4e-5 * N * n_x
n_fft = N[np.argmin(cost)]
else:
# Use only a single block
- n_fft = 2 ** int(np.ceil(np.log2(n_x + n_h - 1)))
-
- if zero_phase and n_fft <= 2 * n_h - 1:
- raise ValueError("n_fft is too short, has to be at least "
- "2 * len(h) - 1 if zero_phase == True")
- elif not zero_phase and n_fft <= n_h:
- raise ValueError("n_fft is too short, has to be at least "
- "len(h) if zero_phase == False")
-
- if not is_power2(n_fft):
- warn("FFT length is not a power of 2. Can be slower.")
+ n_fft = next_fast_len(min_fft)
+ logger.debug('FFT block length: %s' % n_fft)
+ if n_fft < min_fft:
+ raise ValueError('n_fft is too short, has to be at least '
+ '2 * len(h) - 1 (%s), got %s' % (min_fft, n_fft))
# Filter in frequency domain
- h_fft = fft(np.concatenate([h, np.zeros(n_fft - n_h, dtype=h.dtype)]))
- assert(len(h_fft) == n_fft)
-
- if zero_phase:
- """Zero-phase filtering is now done in one pass by taking the squared
- magnitude of h_fft. This gives equivalent results to the old two-pass
- method but theoretically doubles the speed for long fft lengths. To
- compensate for this, overlapping must be done both before and after
- each segment. When zero_phase == False it only needs to be done after.
- """
- h_fft = (h_fft * h_fft.conj()).real
- # equivalent to convolving h(t) and h(-t) in the time domain
+ h_fft = fft(np.concatenate([h, np.zeros(n_fft - len(h), dtype=h.dtype)]))
# Figure out if we should use CUDA
n_jobs, cuda_dict, h_fft = setup_cuda_fft_multiply_repeated(n_jobs, h_fft)
# Process each row separately
+ picks = np.arange(len(x)) if picks is None else picks
if n_jobs == 1:
for p in picks:
- x[p] = _1d_overlap_filter(x[p], h_fft, n_h, n_edge, zero_phase,
+ x[p] = _1d_overlap_filter(x[p], h_fft, len(h), n_edge, phase,
cuda_dict)
else:
parallel, p_fun, _ = parallel_func(_1d_overlap_filter, n_jobs)
- data_new = parallel(p_fun(x[p], h_fft, n_h, n_edge, zero_phase,
- cuda_dict)
- for p in picks)
+ data_new = parallel(p_fun(x[p], h_fft, len(h), n_edge, phase,
+ cuda_dict) for p in picks)
for pp, p in enumerate(picks):
x[p] = data_new[pp]
return x
-def _1d_overlap_filter(x, h_fft, n_h, n_edge, zero_phase, cuda_dict):
+def _1d_overlap_filter(x, h_fft, n_h, n_edge, phase, cuda_dict):
"""Do one-dimensional overlap-add FFT FIR filtering"""
# pad to reduce ringing
if cuda_dict['use_cuda']:
@@ -163,21 +224,9 @@ def _1d_overlap_filter(x, h_fft, n_h, n_edge, zero_phase, cuda_dict):
n_x = len(x_ext)
x_filtered = np.zeros_like(x_ext)
- if zero_phase:
- # Segment length for signal x (convolving twice)
- n_seg = n_fft - 2 * (n_h - 1) - 1
-
- # Number of segments (including fractional segments)
- n_segments = int(np.ceil(n_x / float(n_seg)))
-
- # padding parameters to ensure filtering is done properly
- pre_pad = n_h - 1
- post_pad = n_fft - (n_h - 1)
- else:
- n_seg = n_fft - n_h + 1
- n_segments = int(np.ceil(n_x / float(n_seg)))
- pre_pad = 0
- post_pad = n_fft
+ n_seg = n_fft - n_h + 1
+ n_segments = int(np.ceil(n_x / float(n_seg)))
+ shift = ((n_h - 1) // 2 if phase.startswith('zero') else 0) + n_edge
# Now the actual filtering step is identical for zero-phase (filtfilt-like)
# or single-pass
@@ -185,21 +234,18 @@ def _1d_overlap_filter(x, h_fft, n_h, n_edge, zero_phase, cuda_dict):
start = seg_idx * n_seg
stop = (seg_idx + 1) * n_seg
seg = x_ext[start:stop]
- seg = np.concatenate([np.zeros(pre_pad), seg,
- np.zeros(post_pad - len(seg))])
+ seg = np.concatenate([seg, np.zeros(n_fft - len(seg))])
prod = fft_multiply_repeated(h_fft, seg, cuda_dict)
- start_filt = max(0, start - pre_pad)
- stop_filt = min(start - pre_pad + n_fft, n_x)
- start_prod = max(0, pre_pad - start)
+ start_filt = max(0, start - shift)
+ stop_filt = min(start - shift + n_fft, n_x)
+ start_prod = max(0, shift - start)
stop_prod = start_prod + stop_filt - start_filt
x_filtered[start_filt:stop_filt] += prod[start_prod:stop_prod]
- # Remove mirrored edges that we added and cast
- if n_edge > 0:
- x_filtered = x_filtered[n_edge:-n_edge]
- x_filtered = x_filtered.astype(x.dtype)
+ # Remove mirrored edges that we added and cast (n_edge can be zero)
+ x_filtered = x_filtered[:n_x - 2 * n_edge].astype(x.dtype)
return x_filtered
@@ -208,7 +254,6 @@ def _filter_attenuation(h, freq, gain):
from scipy.signal import freqz
_, filt_resp = freqz(h.ravel(), worN=np.pi * freq)
filt_resp = np.abs(filt_resp) # use amplitude response
- filt_resp /= np.max(filt_resp)
filt_resp[np.where(gain == 1)] = 0
idx = np.argmax(filt_resp)
att_db = -20 * np.log10(filt_resp[idx])
@@ -216,23 +261,6 @@ def _filter_attenuation(h, freq, gain):
return att_db, att_freq
-def _1d_fftmult_ext(x, B, extend_x, cuda_dict):
- """Helper to parallelize FFT FIR, with extension if necessary"""
- # extend, if necessary
- if extend_x is True:
- x = np.r_[x, x[-1]]
-
- # do Fourier transforms
- xf = fft_multiply_repeated(B, x, cuda_dict)
-
- # put back to original size and type
- if extend_x is True:
- xf = xf[:-1]
-
- xf = xf.astype(x.dtype)
- return xf
-
-
def _prep_for_filtering(x, copy, picks=None):
"""Set up array as 2D for filtering ease"""
if x.dtype != np.float64:
@@ -253,17 +281,21 @@ def _prep_for_filtering(x, copy, picks=None):
elif len(orig_shape) > 3:
raise ValueError('picks argument is not supported for data with more'
' than three dimensions')
+ picks = np.array(picks, int).ravel()
+ if not all(0 <= pick < x.shape[0] for pick in picks) or \
+ len(set(picks)) != len(picks):
+ raise ValueError('bad argument for "picks": %s' % (picks,))
return x, orig_shape, picks
-def _filter(x, Fs, freq, gain, filter_length='10s', picks=None, n_jobs=1,
- copy=True):
+def _fir_filter(x, Fs, freq, gain, filter_length, picks=None, n_jobs=1,
+ copy=True, phase='zero', fir_window='hamming'):
"""Filter signal using gain control points in the frequency domain.
- The filter impulse response is constructed from a Hamming window (window
+ The filter impulse response is constructed from a Hann window (window
used in "firwin2" function) to avoid ripples in the frequency response
- (windowing is a smoothing in frequency domain). The filter is zero-phase.
+ (windowing is a smoothing in frequency domain).
If x is multi-dimensional, this operates along the last dimension.
@@ -277,28 +309,33 @@ def _filter(x, Fs, freq, gain, filter_length='10s', picks=None, n_jobs=1,
Frequency sampling points in Hz.
gain : 1d array
Filter gain at frequency sampling points.
- filter_length : str (Default: '10s') | int | None
- Length of the filter to use. If None or "len(x) < filter_length",
- the filter length used is len(x). Otherwise, if int, overlap-add
- filtering with a filter of the specified length in samples) is
- used (faster for long signals). If str, a human-readable time in
- units of "s" or "ms" (e.g., "10s" or "5500ms") will be converted
- to the shortest power-of-two length at least that duration.
+ filter_length : int
+ Length of the filter to use. Must be odd length if phase == "zero".
picks : array-like of int | None
- Indices to filter. If None all indices will be filtered.
+ Indices of channels to filter. If None all channels will be
+ filtered. Only supported for 2D (n_channels, n_times) and 3D
+ (n_epochs, n_channels, n_times) data.
n_jobs : int | str
Number of jobs to run in parallel. Can be 'cuda' if scikits.cuda
is installed properly and CUDA is initialized.
copy : bool
If True, a copy of x, filtered, is returned. Otherwise, it operates
on x in place.
+ phase : str
+ If 'zero', the delay for the filter is compensated (and it must be
+ an odd-length symmetric filter). If 'linear', the response is
+ uncompensated. If 'zero-double', the filter is applied in the
+ forward and reverse directions.
+ fir_window : str
+ The window to use in FIR design, can be "hamming" (default),
+ "hann", or "blackman".
Returns
-------
xf : array
x filtered.
"""
- firwin2 = get_firwin2()
+ from scipy.signal import firwin2
# set up array for filtering, reshape to 2D, operate on last axis
x, orig_shape, picks = _prep_for_filtering(x, copy, picks)
@@ -307,118 +344,134 @@ def _filter(x, Fs, freq, gain, filter_length='10s', picks=None, n_jobs=1,
# normalize frequencies
freq = np.array(freq) / (Fs / 2.)
+ if freq[0] != 0 or freq[-1] != 1:
+ raise ValueError('freq must start at 0 and end an Nyquist (%s), got %s'
+ % (Fs / 2., freq))
gain = np.array(gain)
- filter_length = _get_filter_length(filter_length, Fs, len_x=x.shape[1])
n_jobs = check_n_jobs(n_jobs, allow_cuda=True)
- if filter_length is None or x.shape[1] <= filter_length:
- # Use direct FFT filtering for short signals
-
- Norig = x.shape[1]
-
- extend_x = False
- if (gain[-1] == 0.0 and Norig % 2 == 1) \
- or (gain[-1] == 1.0 and Norig % 2 != 1):
- # Gain at Nyquist freq: 1: make x EVEN, 0: make x ODD
- extend_x = True
-
- N = x.shape[1] + (extend_x is True)
-
- h = firwin2(N, freq, gain)[np.newaxis, :]
-
- att_db, att_freq = _filter_attenuation(h, freq, gain)
- if att_db < min_att_db:
- att_freq *= Fs / 2
- warn('Attenuation at stop frequency %0.1fHz is only %0.1fdB.'
- % (att_freq, att_db))
-
- # Make zero-phase filter function
- B = np.abs(fft(h)).ravel()
-
- # Figure out if we should use CUDA
- n_jobs, cuda_dict, B = setup_cuda_fft_multiply_repeated(n_jobs, B)
+ # Use overlap-add filter with a fixed length
+ N = _check_zero_phase_length(filter_length, phase, gain[-1])
+ # construct symmetric (linear phase) filter
+ h = firwin2(N, freq, gain, window=fir_window)
+ att_db, att_freq = _filter_attenuation(h, freq, gain)
+ if phase == 'zero-double':
+ att_db += 6
+ if att_db < min_att_db:
+ att_freq *= Fs / 2.
+ warn('Attenuation at stop frequency %0.1fHz is only %0.1fdB. '
+ 'Increase filter_length for higher attenuation.'
+ % (att_freq, att_db))
+ x = _overlap_add_filter(x, h, phase=phase, picks=picks, n_jobs=n_jobs)
+ x.shape = orig_shape
+ return x
- if n_jobs == 1:
- for p in picks:
- x[p] = _1d_fftmult_ext(x[p], B, extend_x, cuda_dict)
- else:
- parallel, p_fun, _ = parallel_func(_1d_fftmult_ext, n_jobs)
- data_new = parallel(p_fun(x[p], B, extend_x, cuda_dict)
- for p in picks)
- for pp, p in enumerate(picks):
- x[p] = data_new[pp]
- else:
- # Use overlap-add filter with a fixed length
- N = filter_length
- if (gain[-1] == 0.0 and N % 2 == 1) \
- or (gain[-1] == 1.0 and N % 2 != 1):
- # Gain at Nyquist freq: 1: make N EVEN, 0: make N ODD
+def _check_zero_phase_length(N, phase, gain_nyq=0):
+ N = int(N)
+ if N % 2 == 0:
+ if phase == 'zero':
+ raise RuntimeError('filter_length must be odd if phase="zero", '
+ 'got %s' % N)
+ elif phase == 'zero-double' and gain_nyq == 1:
N += 1
+ return N
- # construct filter with gain resulting from forward-backward filtering
- h = firwin2(N, freq, gain, window='hann')
-
- att_db, att_freq = _filter_attenuation(h, freq, gain)
- att_db += 6 # the filter is applied twice (zero phase)
- if att_db < min_att_db:
- att_freq *= Fs / 2
- warn('Attenuation at stop frequency %0.1fHz is only %0.1fdB. '
- 'Increase filter_length for higher attenuation.'
- % (att_freq, att_db))
- # reconstruct filter, this time with appropriate gain for fwd-bkwd
- gain = np.sqrt(gain)
- h = firwin2(N, freq, gain, window='hann')
- x = _overlap_add_filter(x, h, zero_phase=True, picks=picks,
- n_jobs=n_jobs)
-
- x.shape = orig_shape
- return x
-
-
-def _check_coefficients(b, a):
+def _check_coefficients(system):
"""Check for filter stability"""
- from scipy.signal import tf2zpk
- z, p, k = tf2zpk(b, a)
+ if isinstance(system, tuple):
+ from scipy.signal import tf2zpk
+ z, p, k = tf2zpk(*system)
+ else: # sos
+ from scipy.signal import sos2zpk
+ z, p, k = sos2zpk(system)
if np.any(np.abs(p) > 1.0):
raise RuntimeError('Filter poles outside unit circle, filter will be '
'unstable. Consider using different filter '
'coefficients.')
-def _filtfilt(x, b, a, padlen, picks, n_jobs, copy):
+def _filtfilt(x, iir_params, picks, n_jobs, copy):
"""Helper to more easily call filtfilt"""
# set up array for filtering, reshape to 2D, operate on last axis
- filtfilt = get_filtfilt()
+ from scipy.signal import filtfilt
+ padlen = min(iir_params['padlen'], len(x))
n_jobs = check_n_jobs(n_jobs)
x, orig_shape, picks = _prep_for_filtering(x, copy, picks)
- _check_coefficients(b, a)
+ if 'sos' in iir_params:
+ sosfiltfilt = get_sosfiltfilt()
+ fun = partial(sosfiltfilt, sos=iir_params['sos'], padlen=padlen)
+ _check_coefficients(iir_params['sos'])
+ else:
+ fun = partial(filtfilt, b=iir_params['b'], a=iir_params['a'],
+ padlen=padlen)
+ _check_coefficients((iir_params['b'], iir_params['a']))
if n_jobs == 1:
for p in picks:
- x[p] = filtfilt(b, a, x[p], padlen=padlen)
+ x[p] = fun(x=x[p])
else:
- parallel, p_fun, _ = parallel_func(filtfilt, n_jobs)
- data_new = parallel(p_fun(b, a, x[p], padlen=padlen)
- for p in picks)
+ parallel, p_fun, _ = parallel_func(fun, n_jobs)
+ data_new = parallel(p_fun(x=x[p]) for p in picks)
for pp, p in enumerate(picks):
x[p] = data_new[pp]
x.shape = orig_shape
return x
-def _estimate_ringing_samples(b, a):
- """Helper function for determining IIR padding"""
- from scipy.signal import lfilter
- x = np.zeros(1000)
+def estimate_ringing_samples(system, max_try=100000):
+ """Estimate filter ringing
+
+ Parameters
+ ----------
+ system : tuple | ndarray
+ A tuple of (b, a) or ndarray of second-order sections coefficients.
+ max_try : int
+ Approximate maximum number of samples to try.
+ This will be changed to a multple of 1000.
+
+ Returns
+ -------
+ n : int
+ The approximate ringing.
+ """
+ from scipy import signal
+ if isinstance(system, tuple): # TF
+ kind = 'ba'
+ b, a = system
+ zi = [0.] * (len(a) - 1)
+ else:
+ kind = 'sos'
+ sos = system
+ zi = [[0.] * 2] * len(sos)
+ n_per_chunk = 1000
+ n_chunks_max = int(np.ceil(max_try / float(n_per_chunk)))
+ x = np.zeros(n_per_chunk)
x[0] = 1
- h = lfilter(b, a, x)
- return np.where(np.abs(h) > 0.001 * np.max(np.abs(h)))[0][-1]
+ last_good = n_per_chunk
+ thresh_val = 0
+ for ii in range(n_chunks_max):
+ if kind == 'ba':
+ h, zi = signal.lfilter(b, a, x, zi=zi)
+ else:
+ h, zi = signal.sosfilt(sos, x, zi=zi)
+ x[0] = 0 # for subsequent iterations we want zero input
+ h = np.abs(h)
+ thresh_val = max(0.001 * np.max(h), thresh_val)
+ idx = np.where(np.abs(h) > thresh_val)[0]
+ if len(idx) > 0:
+ last_good = idx[-1]
+ else: # this iteration had no sufficiently lange values
+ idx = (ii - 1) * n_per_chunk + last_good
+ break
+ else:
+ warn('Could not properly estimate ringing for the filter')
+ idx = n_per_chunk * n_chunks_max
+ return idx
-def construct_iir_filter(iir_params=dict(b=[1, 0], a=[1, 0], padlen=0),
- f_pass=None, f_stop=None, sfreq=None, btype=None,
- return_copy=True):
+def construct_iir_filter(iir_params, f_pass=None, f_stop=None, sfreq=None,
+ btype=None, return_copy=True):
"""Use IIR parameters to get filtering coefficients
This function works like a wrapper for iirdesign and iirfilter in
@@ -428,19 +481,39 @@ def construct_iir_filter(iir_params=dict(b=[1, 0], a=[1, 0], padlen=0),
function) with the filter coefficients ('b' and 'a') and an estimate
of the padding necessary ('padlen') so IIR filtering can be performed.
+ .. note:: As of 0.14, second-order sections will be used in filter
+ design by default (replacing ``output='ba'`` by
+ ``output='sos'``) to help ensure filter stability and
+ reduce numerical error. Second-order sections filtering
+ requires SciPy >= 16.0.
+
+
Parameters
----------
iir_params : dict
Dictionary of parameters to use for IIR filtering.
- If iir_params['b'] and iir_params['a'] exist, these will be used
- as coefficients to perform IIR filtering. Otherwise, if
- iir_params['order'] and iir_params['ftype'] exist, these will be
- used with scipy.signal.iirfilter to make a filter. Otherwise, if
- iir_params['gpass'] and iir_params['gstop'] exist, these will be
- used with scipy.signal.iirdesign to design a filter.
- iir_params['padlen'] defines the number of samples to pad (and
- an estimate will be calculated if it is not given). See Notes for
- more details.
+
+ * If ``iir_params['sos']`` exists, it will be used as
+ second-order sections to perform IIR filtering.
+
+ .. versionadded:: 0.13
+
+ * Otherwise, if ``iir_params['b']`` and ``iir_params['a']``
+ exist, these will be used as coefficients to perform IIR
+ filtering.
+ * Otherwise, if ``iir_params['order']`` and
+ ``iir_params['ftype']`` exist, these will be used with
+ `scipy.signal.iirfilter` to make a filter.
+ * Otherwise, if ``iir_params['gpass']`` and
+ ``iir_params['gstop']`` exist, these will be used with
+ `scipy.signal.iirdesign` to design a filter.
+ * ``iir_params['padlen']`` defines the number of samples to pad
+ (and an estimate will be calculated if it is not given).
+ See Notes for more details.
+ * ``iir_params['output']`` defines the system output kind when
+ designing filters, either "sos" or "ba". For 0.13 the
+ default is 'ba' but will change to 'sos' in 0.14.
+
f_pass : float or list of float
Frequency for the pass-band. Low-pass and high-pass filters should
be a float, band-pass should be a 2-element list of float.
@@ -451,23 +524,31 @@ def construct_iir_filter(iir_params=dict(b=[1, 0], a=[1, 0], padlen=0),
The sample rate.
btype : str
Type of filter. Should be 'lowpass', 'highpass', or 'bandpass'
- (or analogous string representations known to scipy.signal).
+ (or analogous string representations known to
+ :func:`scipy.signal.iirfilter`).
return_copy : bool
- If False, the 'b', 'a', and 'padlen' entries in iir_params will be
- set inplace (if they weren't already). Otherwise, a new iir_params
- instance will be created and returned with these entries.
+ If False, the 'sos', 'b', 'a', and 'padlen' entries in
+ ``iir_params`` will be set inplace (if they weren't already).
+ Otherwise, a new ``iir_params`` instance will be created and
+ returned with these entries.
Returns
-------
iir_params : dict
Updated iir_params dict, with the entries (set only if they didn't
- exist before) for 'b', 'a', and 'padlen' for IIR filtering.
+ exist before) for 'sos' (or 'b', 'a'), and 'padlen' for
+ IIR filtering.
+
+ See Also
+ --------
+ mne.filter.filter_data
+ mne.io.Raw.filter
Notes
-----
- This function triages calls to scipy.signal.iirfilter and iirdesign
- based on the input arguments (see descriptions of these functions
- and scipy's scipy.signal.filter_design documentation for details).
+ This function triages calls to :func:`scipy.signal.iirfilter` and
+ :func:`scipy.signal.iirdesign` based on the input arguments (see
+ linked functions for more details).
Examples
--------
@@ -478,20 +559,20 @@ def construct_iir_filter(iir_params=dict(b=[1, 0], a=[1, 0], padlen=0),
filter 'N' and the type of filtering 'ftype' are specified. To get
coefficients for a 4th-order Butterworth filter, this would be:
- >>> iir_params = dict(order=4, ftype='butter')
- >>> iir_params = construct_iir_filter(iir_params, 40, None, 1000, 'low', return_copy=False)
- >>> print((len(iir_params['b']), len(iir_params['a']), iir_params['padlen']))
- (5, 5, 82)
+ >>> iir_params = dict(order=4, ftype='butter', output='sos') # doctest:+SKIP
+ >>> iir_params = construct_iir_filter(iir_params, 40, None, 1000, 'low', return_copy=False) # doctest:+SKIP
+ >>> print((2 * len(iir_params['sos']), iir_params['padlen'])) # doctest:+SKIP
+ (4, 82)
Filters can also be constructed using filter design methods. To get a
40 Hz Chebyshev type 1 lowpass with specific gain characteristics in the
pass and stop bands (assuming the desired stop band is at 45 Hz), this
would be a filter with much longer ringing:
- >>> iir_params = dict(ftype='cheby1', gpass=3, gstop=20)
- >>> iir_params = construct_iir_filter(iir_params, 40, 50, 1000, 'low')
- >>> print((len(iir_params['b']), len(iir_params['a']), iir_params['padlen']))
- (6, 6, 439)
+ >>> iir_params = dict(ftype='cheby1', gpass=3, gstop=20, output='sos') # doctest:+SKIP
+ >>> iir_params = construct_iir_filter(iir_params, 40, 50, 1000, 'low') # doctest:+SKIP
+ >>> print((2 * len(iir_params['sos']), iir_params['padlen'])) # doctest:+SKIP
+ (6, 439)
Padding and/or filter coefficients can also be manually specified. For
a 10-sample moving window with no padding during filtering, for example,
@@ -502,17 +583,32 @@ def construct_iir_filter(iir_params=dict(b=[1, 0], a=[1, 0], padlen=0),
>>> print((iir_params['b'], iir_params['a'], iir_params['padlen']))
(array([ 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]), [1, 0], 0)
+ For more information, see the tutorials :ref:`tut_background_filtering`
+ and :ref:`tut_artifacts_filter`.
""" # noqa
from scipy.signal import iirfilter, iirdesign
known_filters = ('bessel', 'butter', 'butterworth', 'cauer', 'cheby1',
'cheby2', 'chebyshev1', 'chebyshev2', 'chebyshevi',
'chebyshevii', 'ellip', 'elliptic')
- a = None
- b = None
+ if not isinstance(iir_params, dict):
+ raise TypeError('iir_params must be a dict, got %s' % type(iir_params))
+ system = None
# if the filter has been designed, we're good to go
- if 'a' in iir_params and 'b' in iir_params:
- [b, a] = [iir_params['b'], iir_params['a']]
+ if 'sos' in iir_params:
+ system = iir_params['sos']
+ output = 'sos'
+ elif 'a' in iir_params and 'b' in iir_params:
+ system = (iir_params['b'], iir_params['a'])
+ output = 'ba'
else:
+ output = iir_params.get('output', None)
+ if output is None:
+ warn('The default output type is "ba" in 0.13 but will change '
+ 'to "sos" in 0.14')
+ output = 'ba'
+ if not isinstance(output, string_types) or output not in ('ba', 'sos'):
+ raise ValueError('Output must be "ba" or "sos", got %s'
+ % (output,))
# ensure we have a valid ftype
if 'ftype' not in iir_params:
raise RuntimeError('ftype must be an entry in iir_params if ''b'' '
@@ -526,57 +622,235 @@ def construct_iir_filter(iir_params=dict(b=[1, 0], a=[1, 0], padlen=0),
# use order-based design
Wp = np.asanyarray(f_pass) / (float(sfreq) / 2)
if 'order' in iir_params:
- [b, a] = iirfilter(iir_params['order'], Wp, btype=btype,
- ftype=ftype)
+ system = iirfilter(iir_params['order'], Wp, btype=btype,
+ ftype=ftype, output=output)
else:
# use gpass / gstop design
Ws = np.asanyarray(f_stop) / (float(sfreq) / 2)
if 'gpass' not in iir_params or 'gstop' not in iir_params:
raise ValueError('iir_params must have at least ''gstop'' and'
' ''gpass'' (or ''N'') entries')
- [b, a] = iirdesign(Wp, Ws, iir_params['gpass'],
- iir_params['gstop'], ftype=ftype)
+ system = iirdesign(Wp, Ws, iir_params['gpass'],
+ iir_params['gstop'], ftype=ftype, output=output)
- if a is None or b is None:
+ if system is None:
raise RuntimeError('coefficients could not be created from iir_params')
+ # do some sanity checks
+ _check_coefficients(system)
# now deal with padding
if 'padlen' not in iir_params:
- padlen = _estimate_ringing_samples(b, a)
+ padlen = estimate_ringing_samples(system)
else:
padlen = iir_params['padlen']
if return_copy:
iir_params = deepcopy(iir_params)
- iir_params.update(dict(b=b, a=a, padlen=padlen))
+ iir_params.update(dict(padlen=padlen))
+ if output == 'sos':
+ iir_params.update(sos=system)
+ else:
+ iir_params.update(b=system[0], a=system[1])
return iir_params
def _check_method(method, iir_params, extra_types):
"""Helper to parse method arguments"""
- allowed_types = ['iir', 'fft'] + extra_types
+ allowed_types = ['iir', 'fir', 'fft'] + extra_types
if not isinstance(method, string_types):
raise TypeError('method must be a string')
if method not in allowed_types:
raise ValueError('method must be one of %s, not "%s"'
% (allowed_types, method))
+ if method == 'fft':
+ method = 'fir' # use the better name
if method == 'iir':
if iir_params is None:
- iir_params = dict(order=4, ftype='butter')
- if not isinstance(iir_params, dict):
- raise ValueError('iir_params must be a dict')
+ iir_params = dict()
+ if len(iir_params) == 0 or (len(iir_params) == 1 and
+ 'output' in iir_params):
+ # XXX update this after deprecation of ba
+ iir_params = dict(order=4, ftype='butter',
+ output=iir_params.get('output', 'ba'))
elif iir_params is not None:
raise ValueError('iir_params must be None if method != "iir"')
- method = method.lower()
- return iir_params
+ return iir_params, method
@verbose
-def band_pass_filter(x, Fs, Fp1, Fp2, filter_length='10s',
- l_trans_bandwidth=0.5, h_trans_bandwidth=0.5,
- method='fft', iir_params=None,
- picks=None, n_jobs=1, copy=True, verbose=None):
+def filter_data(data, sfreq, l_freq, h_freq, picks=None, filter_length='auto',
+ l_trans_bandwidth='auto', h_trans_bandwidth='auto', n_jobs=1,
+ method='fir', iir_params=None, copy=True, phase='zero',
+ fir_window='hamming', verbose=None):
+ """Filter a subset of channels.
+
+ Applies a zero-phase low-pass, high-pass, band-pass, or band-stop
+ filter to the channels selected by ``picks``.
+
+ ``l_freq`` and ``h_freq`` are the frequencies below which and above
+ which, respectively, to filter out of the data. Thus the uses are:
+
+ * ``l_freq < h_freq``: band-pass filter
+ * ``l_freq > h_freq``: band-stop filter
+ * ``l_freq is not None and h_freq is None``: high-pass filter
+ * ``l_freq is None and h_freq is not None``: low-pass filter
+
+ .. note:: If n_jobs > 1, more memory is required as
+ ``len(picks) * n_times`` additional time points need to
+ be temporaily stored in memory.
+
+ Parameters
+ ----------
+ data : ndarray, shape (..., n_times)
+ The data to filter.
+ sfreq : float
+ The sample frequency in Hz.
+ l_freq : float | None
+ Low cut-off frequency in Hz. If None the data are only low-passed.
+ h_freq : float | None
+ High cut-off frequency in Hz. If None the data are only
+ high-passed.
+ picks : array-like of int | None
+ Indices of channels to filter. If None all channels will be
+ filtered. Currently this is only supported for
+ 2D (n_channels, n_times) and 3D (n_epochs, n_channels, n_times)
+ arrays.
+ filter_length : str | int
+ Length of the FIR filter to use (if applicable):
+
+ * int: specified length in samples.
+ * 'auto' (default in 0.14): the filter length is chosen based
+ on the size of the transition regions (6.6 times the reciprocal
+ of the shortest transition band for fir_window='hamming').
+ * str: (default in 0.13 is "10s") a human-readable time in
+ units of "s" or "ms" (e.g., "10s" or "5500ms") will be
+ converted to that number of samples if ``phase="zero"``, or
+ the shortest power-of-two length at least that duration for
+ ``phase="zero-double"``.
+
+ l_trans_bandwidth : float | str
+ Width of the transition band at the low cut-off frequency in Hz
+ (high pass or cutoff 1 in bandpass). Can be "auto"
+ (default in 0.14) to use a multiple of ``l_freq``::
+
+ min(max(l_freq * 0.25, 2), l_freq)
+
+ Only used for ``method='fir'``.
+ h_trans_bandwidth : float | str
+ Width of the transition band at the high cut-off frequency in Hz
+ (low pass or cutoff 2 in bandpass). Can be "auto"
+ (default in 0.14) to use a multiple of ``h_freq``::
+
+ min(max(h_freq * 0.25, 2.), info['sfreq'] / 2. - h_freq)
+
+ Only used for ``method='fir'``.
+ n_jobs : int | str
+ Number of jobs to run in parallel. Can be 'cuda' if scikits.cuda
+ is installed properly, CUDA is initialized, and method='fir'.
+ method : str
+ 'fir' will use overlap-add FIR filtering, 'iir' will use IIR
+ forward-backward filtering (via filtfilt).
+ iir_params : dict | None
+ Dictionary of parameters to use for IIR filtering.
+ See mne.filter.construct_iir_filter for details. If iir_params
+ is None and method="iir", 4th order Butterworth will be used.
+ copy : bool
+ If True, a copy of x, filtered, is returned. Otherwise, it operates
+ on x in place.
+ phase : str
+ Phase of the filter, only used if ``method='fir'``.
+ By default, a symmetric linear-phase FIR filter is constructed.
+ If ``phase='zero'`` (default), the delay of this filter
+ is compensated for. If ``phase=='zero-double'``, then this filter
+ is applied twice, once forward, and once backward.
+
+ .. versionadded:: 0.13
+
+ fir_window : str
+ The window to use in FIR design, can be "hamming" (default),
+ "hann" (default in 0.13), or "blackman".
+
+ .. versionadded:: 0.13
+
+ verbose : bool, str, int, or None
+ If not None, override default verbose level (see mne.verbose).
+ Defaults to self.verbose.
+
+ Returns
+ -------
+ data : ndarray, shape (..., n_times)
+ The filtered data.
+
+ See Also
+ --------
+ mne.filter.construct_iir_filter
+ mne.io.Raw.filter
+ band_pass_filter
+ band_stop_filter
+ high_pass_filter
+ low_pass_filter
+ notch_filter
+ resample
+
+ Notes
+ -----
+ For more information, see the tutorials :ref:`tut_background_filtering`
+ and :ref:`tut_artifacts_filter`.
+ """
+ if not isinstance(data, np.ndarray):
+ raise ValueError('data must be an array')
+ sfreq = float(sfreq)
+ if sfreq < 0:
+ raise ValueError('sfreq must be positive')
+ if h_freq is not None:
+ h_freq = float(h_freq)
+ if h_freq > (sfreq / 2.):
+ raise ValueError('h_freq (%s) must be less than the Nyquist '
+ 'frequency %s' % (h_freq, sfreq / 2.))
+ if l_freq is not None:
+ l_freq = float(l_freq)
+ if l_freq == 0:
+ l_freq = None
+ if l_freq is None and h_freq is not None:
+ data = low_pass_filter(
+ data, sfreq, h_freq, filter_length=filter_length,
+ trans_bandwidth=h_trans_bandwidth, method=method,
+ iir_params=iir_params, picks=picks, n_jobs=n_jobs,
+ copy=copy, phase=phase, fir_window=fir_window)
+ if l_freq is not None and h_freq is None:
+ data = high_pass_filter(
+ data, sfreq, l_freq, filter_length=filter_length,
+ trans_bandwidth=l_trans_bandwidth, method=method,
+ iir_params=iir_params, picks=picks, n_jobs=n_jobs, copy=copy,
+ phase=phase, fir_window=fir_window)
+ if l_freq is not None and h_freq is not None:
+ if l_freq < h_freq:
+ data = band_pass_filter(
+ data, sfreq, l_freq, h_freq,
+ filter_length=filter_length,
+ l_trans_bandwidth=l_trans_bandwidth,
+ h_trans_bandwidth=h_trans_bandwidth,
+ method=method, iir_params=iir_params, picks=picks,
+ n_jobs=n_jobs, copy=copy, phase=phase, fir_window=fir_window)
+ else:
+ logger.info('Band-stop filtering from %0.2g - %0.2g Hz'
+ % (h_freq, l_freq))
+ data = band_stop_filter(
+ data, sfreq, h_freq, l_freq,
+ filter_length=filter_length,
+ l_trans_bandwidth=h_trans_bandwidth,
+ h_trans_bandwidth=l_trans_bandwidth, method=method,
+ iir_params=iir_params, picks=picks, n_jobs=n_jobs,
+ copy=copy, phase=phase, fir_window=fir_window)
+ return data
+
+
+ at verbose
+def band_pass_filter(x, Fs, Fp1, Fp2, filter_length='',
+ l_trans_bandwidth=None, h_trans_bandwidth=None,
+ method='fir', iir_params=None, picks=None, n_jobs=1,
+ copy=True, phase='', fir_window='', verbose=None):
"""Bandpass filter for the signal x.
Applies a zero-phase bandpass filter to the signal x, operating on the
@@ -592,35 +866,66 @@ def band_pass_filter(x, Fs, Fp1, Fp2, filter_length='10s',
Low cut-off frequency in Hz.
Fp2 : float
High cut-off frequency in Hz.
- filter_length : str (Default: '10s') | int | None
- Length of the filter to use. If None or "len(x) < filter_length",
- the filter length used is len(x). Otherwise, if int, overlap-add
- filtering with a filter of the specified length in samples) is
- used (faster for long signals). If str, a human-readable time in
- units of "s" or "ms" (e.g., "10s" or "5500ms") will be converted
- to the shortest power-of-two length at least that duration.
- Not used for 'iir' filters.
- l_trans_bandwidth : float
- Width of the transition band at the low cut-off frequency in Hz.
- Not used if 'order' is specified in iir_params.
- h_trans_bandwidth : float
- Width of the transition band at the high cut-off frequency in Hz.
- Not used if 'order' is specified in iir_params.
+ filter_length : str | int
+ Length of the FIR filter to use (if applicable):
+
+ * int: specified length in samples.
+ * 'auto' (default in 0.14): the filter length is chosen based
+ on the size of the transition regions (6.6 times the reciprocal
+ of the shortest transition band for fir_window='hamming').
+ * str: (default in 0.13 is "10s") a human-readable time in
+ units of "s" or "ms" (e.g., "10s" or "5500ms") will be
+ converted to that number of samples if ``phase="zero"``, or
+ the shortest power-of-two length at least that duration for
+ ``phase="zero-double"``.
+
+ l_trans_bandwidth : float | str
+ Width of the transition band at the low cut-off frequency in Hz
+ Can be "auto" (default in 0.14) to use a multiple of ``l_freq``::
+
+ min(max(l_freq * 0.25, 2), l_freq)
+
+ Only used for ``method='fir'``.
+ h_trans_bandwidth : float | str
+ Width of the transition band at the high cut-off frequency in Hz
+ Can be "auto" (default in 0.14) to use a multiple of ``h_freq``::
+
+ min(max(h_freq * 0.25, 2.), info['sfreq'] / 2. - h_freq)
+
+ Only used for ``method='fir'``.
method : str
- 'fft' will use overlap-add FIR filtering, 'iir' will use IIR
+ 'fir' will use overlap-add FIR filtering, 'iir' will use IIR
forward-backward filtering (via filtfilt).
iir_params : dict | None
Dictionary of parameters to use for IIR filtering.
See mne.filter.construct_iir_filter for details. If iir_params
is None and method="iir", 4th order Butterworth will be used.
picks : array-like of int | None
- Indices to filter. If None all indices will be filtered.
+ Indices of channels to filter. If None all channels will be
+ filtered. Only supported for 2D (n_channels, n_times) and 3D
+ (n_epochs, n_channels, n_times) data.
n_jobs : int | str
Number of jobs to run in parallel. Can be 'cuda' if scikits.cuda
- is installed properly, CUDA is initialized, and method='fft'.
+ is installed properly, CUDA is initialized, and method='fir'.
copy : bool
If True, a copy of x, filtered, is returned. Otherwise, it operates
on x in place.
+ phase : str
+ Phase of the filter, only used if ``method='fir'``.
+ By default, a symmetric linear-phase FIR filter is constructed.
+ If ``phase='zero'`` (default in 0.14), the delay of this filter
+ is compensated for. If ``phase=='zero-double'`` (default in 0.13
+ and before), then this filter is applied twice, once forward, and
+ once backward.
+
+ .. versionadded:: 0.13
+
+ fir_window : str
+ The window to use in FIR design, can be "hamming" (default in
+ 0.14), "hann" (default in 0.13), or "blackman".
+
+ .. versionadded:: 0.13
+
verbose : bool, str, int, or None
If not None, override default verbose level (see mne.verbose).
@@ -631,61 +936,62 @@ def band_pass_filter(x, Fs, Fp1, Fp2, filter_length='10s',
See Also
--------
- low_pass_filter, high_pass_filter
+ filter_data
+ band_stop_filter
+ high_pass_filter
+ low_pass_filter
+ notch_filter
+ resample
Notes
-----
The frequency response is (approximately) given by::
- ----------
- /| | \
- / | | \
- / | | \
- / | | \
- ---------- | | -----------------
- | |
- Fs1 Fp1 Fp2 Fs2
+ 1-| ----------
+ | /| | \
+ |H| | / | | \
+ | / | | \
+ | / | | \
+ 0-|---------- | | --------------
+ | | | | | |
+ 0 Fs1 Fp1 Fp2 Fs2 Nyq
Where:
- Fs1 = Fp1 - l_trans_bandwidth in Hz
- Fs2 = Fp2 + h_trans_bandwidth in Hz
- """
- iir_params = _check_method(method, iir_params, [])
-
- Fs = float(Fs)
- Fp1 = float(Fp1)
- Fp2 = float(Fp2)
- Fs1 = Fp1 - l_trans_bandwidth if method == 'fft' else Fp1
- Fs2 = Fp2 + h_trans_bandwidth if method == 'fft' else Fp2
- if Fs2 > Fs / 2:
- raise ValueError('Effective band-stop frequency (%s) is too high '
- '(maximum based on Nyquist is %s)' % (Fs2, Fs / 2.))
-
- if Fs1 <= 0:
- raise ValueError('Filter specification invalid: Lower stop frequency '
- 'too low (%0.1fHz). Increase Fp1 or reduce '
- 'transition bandwidth (l_trans_bandwidth)' % Fs1)
+ * Fs1 = Fp1 - l_trans_bandwidth in Hz
+ * Fs2 = Fp2 + h_trans_bandwidth in Hz
- if method == 'fft':
- freq = [0, Fs1, Fp1, Fp2, Fs2, Fs / 2]
- gain = [0, 0, 1, 1, 0, 0]
- xf = _filter(x, Fs, freq, gain, filter_length, picks, n_jobs, copy)
+ """
+ iir_params, method = _check_method(method, iir_params, [])
+ logger.info('Band-pass filtering from %0.2g - %0.2g Hz' % (Fp1, Fp2))
+ x, Fs, Fp1, Fp2, Fs1, Fs2, filter_length, phase, fir_window = \
+ _triage_filter_params(
+ x, Fs, Fp1, Fp2, l_trans_bandwidth, h_trans_bandwidth,
+ filter_length, method, phase, fir_window)
+ if method == 'fir':
+ freq = [Fs1, Fp1, Fp2, Fs2]
+ gain = [0, 1, 1, 0]
+ if Fs1 != 0:
+ freq = [0.] + freq
+ gain = [0.] + gain
+ if Fs2 != Fs / 2:
+ freq += [Fs / 2.]
+ gain += [0.]
+ xf = _fir_filter(x, Fs, freq, gain, filter_length, picks, n_jobs, copy,
+ phase, fir_window)
else:
iir_params = construct_iir_filter(iir_params, [Fp1, Fp2],
[Fs1, Fs2], Fs, 'bandpass')
- padlen = min(iir_params['padlen'], len(x))
- xf = _filtfilt(x, iir_params['b'], iir_params['a'], padlen,
- picks, n_jobs, copy)
+ xf = _filtfilt(x, iir_params, picks, n_jobs, copy)
return xf
@verbose
-def band_stop_filter(x, Fs, Fp1, Fp2, filter_length='10s',
- l_trans_bandwidth=0.5, h_trans_bandwidth=0.5,
- method='fft', iir_params=None,
- picks=None, n_jobs=1, copy=True, verbose=None):
+def band_stop_filter(x, Fs, Fp1, Fp2, filter_length='',
+ l_trans_bandwidth=None, h_trans_bandwidth=None,
+ method='fir', iir_params=None, picks=None, n_jobs=1,
+ copy=True, phase='', fir_window='', verbose=None):
"""Bandstop filter for the signal x.
Applies a zero-phase bandstop filter to the signal x, operating on the
@@ -701,35 +1007,66 @@ def band_stop_filter(x, Fs, Fp1, Fp2, filter_length='10s',
Low cut-off frequency in Hz.
Fp2 : float | array of float
High cut-off frequency in Hz.
- filter_length : str (Default: '10s') | int | None
- Length of the filter to use. If None or "len(x) < filter_length",
- the filter length used is len(x). Otherwise, if int, overlap-add
- filtering with a filter of the specified length in samples) is
- used (faster for long signals). If str, a human-readable time in
- units of "s" or "ms" (e.g., "10s" or "5500ms") will be converted
- to the shortest power-of-two length at least that duration.
- Not used for 'iir' filters.
- l_trans_bandwidth : float
- Width of the transition band at the low cut-off frequency in Hz.
- Not used if 'order' is specified in iir_params.
- h_trans_bandwidth : float
- Width of the transition band at the high cut-off frequency in Hz.
- Not used if 'order' is specified in iir_params.
+ filter_length : str | int
+ Length of the FIR filter to use (if applicable):
+
+ * int: specified length in samples.
+ * 'auto' (default in 0.14): the filter length is chosen based
+ on the size of the transition regions (6.6 times the reciprocal
+ of the shortest transition band for fir_window='hamming').
+ * str: (default in 0.13 is "10s") a human-readable time in
+ units of "s" or "ms" (e.g., "10s" or "5500ms") will be
+ converted to that number of samples if ``phase="zero"``, or
+ the shortest power-of-two length at least that duration for
+ ``phase="zero-double"``.
+
+ l_trans_bandwidth : float | str
+ Width of the transition band at the low cut-off frequency in Hz
+ Can be "auto" (default in 0.14) to use a multiple of ``l_freq``::
+
+ min(max(l_freq * 0.25, 2), l_freq)
+
+ Only used for ``method='fir'``.
+ h_trans_bandwidth : float | str
+ Width of the transition band at the high cut-off frequency in Hz
+ Can be "auto" (default in 0.14) to use a multiple of ``h_freq``::
+
+ min(max(h_freq * 0.25, 2.), info['sfreq'] / 2. - h_freq)
+
+ Only used for ``method='fir'``.
method : str
- 'fft' will use overlap-add FIR filtering, 'iir' will use IIR
+ 'fir' will use overlap-add FIR filtering, 'iir' will use IIR
forward-backward filtering (via filtfilt).
iir_params : dict | None
Dictionary of parameters to use for IIR filtering.
See mne.filter.construct_iir_filter for details. If iir_params
is None and method="iir", 4th order Butterworth will be used.
picks : array-like of int | None
- Indices to filter. If None all indices will be filtered.
+ Indices of channels to filter. If None all channels will be
+ filtered. Only supported for 2D (n_channels, n_times) and 3D
+ (n_epochs, n_channels, n_times) data.
n_jobs : int | str
Number of jobs to run in parallel. Can be 'cuda' if scikits.cuda
- is installed properly, CUDA is initialized, and method='fft'.
+ is installed properly, CUDA is initialized, and method='fir'.
copy : bool
If True, a copy of x, filtered, is returned. Otherwise, it operates
on x in place.
+ phase : str
+ Phase of the filter, only used if ``method='fir'``.
+ By default, a symmetric linear-phase FIR filter is constructed.
+ If ``phase='zero'`` (default in 0.14), the delay of this filter
+ is compensated for. If ``phase=='zero-double'`` (default in 0.13
+ and before), then this filter is applied twice, once forward, and
+ once backward.
+
+ .. versionadded:: 0.13
+
+ fir_window : str
+ The window to use in FIR design, can be "hamming" (default in
+ 0.14), "hann" (default in 0.13), or "blackman".
+
+ .. versionadded:: 0.13
+
verbose : bool, str, int, or None
If not None, override default verbose level (see mne.verbose).
@@ -738,69 +1075,74 @@ def band_stop_filter(x, Fs, Fp1, Fp2, filter_length='10s',
xf : array
x filtered.
+ See Also
+ --------
+ filter_data
+ band_pass_filter
+ high_pass_filter
+ low_pass_filter
+ notch_filter
+ resample
+
Notes
-----
The frequency response is (approximately) given by::
- ---------- ----------
- |\ /|
- | \ / |
- | \ / |
- | \ / |
- | ----------- |
- | | | |
- Fp1 Fs1 Fs2 Fp2
+ 1-|--------- ----------
+ | \ /
+ |H| | \ /
+ | \ /
+ | \ /
+ 0-| -----------
+ | | | | | |
+ 0 Fp1 Fs1 Fs2 Fp2 Nyq
- Where:
-
- Fs1 = Fp1 + l_trans_bandwidth in Hz
- Fs2 = Fp2 - h_trans_bandwidth in Hz
+ Where ``Fs1 = Fp1 + l_trans_bandwidth`` and
+ ``Fs2 = Fp2 - h_trans_bandwidth``.
- Note that multiple stop bands can be specified using arrays.
+ Multiple stop bands can be specified using arrays.
"""
- iir_params = _check_method(method, iir_params, [])
-
- Fp1 = np.atleast_1d(Fp1)
- Fp2 = np.atleast_1d(Fp2)
- if not len(Fp1) == len(Fp2):
+ iir_params, method = _check_method(method, iir_params, [])
+ Fp1 = np.array(Fp1, float).ravel()
+ Fp2 = np.array(Fp2, float).ravel()
+ if len(Fp1) != len(Fp2):
raise ValueError('Fp1 and Fp2 must be the same length')
-
- Fs = float(Fs)
- Fp1 = Fp1.astype(float)
- Fp2 = Fp2.astype(float)
- Fs1 = Fp1 + l_trans_bandwidth if method == 'fft' else Fp1
- Fs2 = Fp2 - h_trans_bandwidth if method == 'fft' else Fp2
-
- if np.any(Fs1 <= 0):
- raise ValueError('Filter specification invalid: Lower stop frequency '
- 'too low (%0.1fHz). Increase Fp1 or reduce '
- 'transition bandwidth (l_trans_bandwidth)' % Fs1)
-
- if method == 'fft':
- freq = np.r_[0, Fp1, Fs1, Fs2, Fp2, Fs / 2]
- gain = np.r_[1, np.ones_like(Fp1), np.zeros_like(Fs1),
- np.zeros_like(Fs2), np.ones_like(Fp2), 1]
+ # Note: order of outputs is intentionally switched here!
+ x, Fs, Fs1, Fs2, Fp1, Fp2, filter_length, phase, fir_window = \
+ _triage_filter_params(
+ x, Fs, Fp1, Fp2, l_trans_bandwidth, h_trans_bandwidth,
+ filter_length, method, phase, fir_window,
+ bands='arr', reverse=True)
+ if method == 'fir':
+ freq = np.r_[Fp1, Fs1, Fs2, Fp2]
+ gain = np.r_[np.ones_like(Fp1), np.zeros_like(Fs1),
+ np.zeros_like(Fs2), np.ones_like(Fp2)]
order = np.argsort(freq)
freq = freq[order]
gain = gain[order]
+ if freq[0] != 0:
+ freq = np.r_[[0.], freq]
+ gain = np.r_[[1.], gain]
+ if freq[-1] != Fs / 2.:
+ freq = np.r_[freq, [Fs / 2.]]
+ gain = np.r_[gain, [1.]]
if np.any(np.abs(np.diff(gain, 2)) > 1):
raise ValueError('Stop bands are not sufficiently separated.')
- xf = _filter(x, Fs, freq, gain, filter_length, picks, n_jobs, copy)
+ xf = _fir_filter(x, Fs, freq, gain, filter_length, picks, n_jobs, copy,
+ phase, fir_window)
else:
for fp_1, fp_2, fs_1, fs_2 in zip(Fp1, Fp2, Fs1, Fs2):
iir_params_new = construct_iir_filter(iir_params, [fp_1, fp_2],
[fs_1, fs_2], Fs, 'bandstop')
- padlen = min(iir_params_new['padlen'], len(x))
- xf = _filtfilt(x, iir_params_new['b'], iir_params_new['a'], padlen,
- picks, n_jobs, copy)
+ xf = _filtfilt(x, iir_params_new, picks, n_jobs, copy)
return xf
@verbose
-def low_pass_filter(x, Fs, Fp, filter_length='10s', trans_bandwidth=0.5,
- method='fft', iir_params=None,
- picks=None, n_jobs=1, copy=True, verbose=None):
+def low_pass_filter(x, Fs, Fp, filter_length='', trans_bandwidth=None,
+ method='fir', iir_params=None, picks=None, n_jobs=1,
+ copy=True, phase='', fir_window='', verbose=None):
"""Lowpass filter for the signal x.
Applies a zero-phase lowpass filter to the signal x, operating on the
@@ -814,32 +1156,59 @@ def low_pass_filter(x, Fs, Fp, filter_length='10s', trans_bandwidth=0.5,
Sampling rate in Hz.
Fp : float
Cut-off frequency in Hz.
- filter_length : str (Default: '10s') | int | None
- Length of the filter to use. If None or "len(x) < filter_length",
- the filter length used is len(x). Otherwise, if int, overlap-add
- filtering with a filter of the specified length in samples) is
- used (faster for long signals). If str, a human-readable time in
- units of "s" or "ms" (e.g., "10s" or "5500ms") will be converted
- to the shortest power-of-two length at least that duration.
- Not used for 'iir' filters.
- trans_bandwidth : float
- Width of the transition band in Hz. Not used if 'order' is specified
- in iir_params.
+ filter_length : str | int
+ Length of the FIR filter to use (if applicable):
+
+ * int: specified length in samples.
+ * 'auto' (default in 0.14): the filter length is chosen based
+ on the size of the transition regions (6.6 times the reciprocal
+ of the shortest transition band for fir_window='hamming').
+ * str: (default in 0.13 is "10s") a human-readable time in
+ units of "s" or "ms" (e.g., "10s" or "5500ms") will be
+ converted to that number of samples if ``phase="zero"``, or
+ the shortest power-of-two length at least that duration for
+ ``phase="zero-double"``.
+
+ trans_bandwidth : float | str
+ Width of the transition band in Hz. Can be "auto"
+ (default in 0.14) to use a multiple of ``l_freq``::
+
+ min(max(l_freq * 0.25, 2), l_freq)
+
+ Only used for ``method='fir'``.
method : str
- 'fft' will use overlap-add FIR filtering, 'iir' will use IIR
+ 'fir' will use overlap-add FIR filtering, 'iir' will use IIR
forward-backward filtering (via filtfilt).
iir_params : dict | None
Dictionary of parameters to use for IIR filtering.
See mne.filter.construct_iir_filter for details. If iir_params
is None and method="iir", 4th order Butterworth will be used.
picks : array-like of int | None
- Indices to filter. If None all indices will be filtered.
+ Indices of channels to filter. If None all channels will be
+ filtered. Only supported for 2D (n_channels, n_times) and 3D
+ (n_epochs, n_channels, n_times) data.
n_jobs : int | str
Number of jobs to run in parallel. Can be 'cuda' if scikits.cuda
- is installed properly, CUDA is initialized, and method='fft'.
+ is installed properly, CUDA is initialized, and method='fir'.
copy : bool
If True, a copy of x, filtered, is returned. Otherwise, it operates
on x in place.
+ phase : str
+ Phase of the filter, only used if ``method='fir'``.
+ By default, a symmetric linear-phase FIR filter is constructed.
+ If ``phase='zero'`` (default in 0.14), the delay of this filter
+ is compensated for. If ``phase=='zero-double'`` (default in 0.13
+ and before), then this filter is applied twice, once forward, and
+ once backward.
+
+ .. versionadded:: 0.13
+
+ fir_window : str
+ The window to use in FIR design, can be "hamming" (default in
+ 0.14), "hann" (default in 0.13), or "blackman".
+
+ .. versionadded:: 0.13
+
verbose : bool, str, int, or None
If not None, override default verbose level (see mne.verbose).
@@ -850,48 +1219,53 @@ def low_pass_filter(x, Fs, Fp, filter_length='10s', trans_bandwidth=0.5,
See Also
--------
+ filter_data
+ band_pass_filter
+ band_stop_filter
+ high_pass_filter
+ notch_filter
resample
- band_pass_filter, high_pass_filter
Notes
-----
The frequency response is (approximately) given by::
- -------------------------
- | \
- | \
- | \
- | \
- | -----------------
- |
- Fp Fp+trans_bandwidth
+ 1-|------------------------
+ | \
+ |H| | \
+ | \
+ | \
+ 0-| ----------------
+ | | | |
+ 0 Fp Fstop Nyq
+ Where ``Fstop = Fp + trans_bandwidth``.
"""
- iir_params = _check_method(method, iir_params, [])
- Fs = float(Fs)
- Fp = float(Fp)
- Fstop = Fp + trans_bandwidth if method == 'fft' else Fp
- if Fstop > Fs / 2.:
- raise ValueError('Effective stop frequency (%s) is too high '
- '(maximum based on Nyquist is %s)' % (Fstop, Fs / 2.))
-
- if method == 'fft':
- freq = [0, Fp, Fstop, Fs / 2]
- gain = [1, 1, 0, 0]
- xf = _filter(x, Fs, freq, gain, filter_length, picks, n_jobs, copy)
+ iir_params, method = _check_method(method, iir_params, [])
+ logger.info('Low-pass filtering at %0.2g Hz' % (Fp,))
+ x, Fs, _, Fp, _, Fstop, filter_length, phase, fir_window = \
+ _triage_filter_params(
+ x, Fs, None, Fp, None, trans_bandwidth, filter_length, method,
+ phase, fir_window)
+ if method == 'fir':
+ freq = [0, Fp, Fstop]
+ gain = [1, 1, 0]
+ if Fstop != Fs / 2.:
+ freq += [Fs / 2.]
+ gain += [0]
+ xf = _fir_filter(x, Fs, freq, gain, filter_length, picks, n_jobs, copy,
+ phase, fir_window)
else:
iir_params = construct_iir_filter(iir_params, Fp, Fstop, Fs, 'low')
- padlen = min(iir_params['padlen'], len(x))
- xf = _filtfilt(x, iir_params['b'], iir_params['a'], padlen,
- picks, n_jobs, copy)
+ xf = _filtfilt(x, iir_params, picks, n_jobs, copy)
return xf
@verbose
-def high_pass_filter(x, Fs, Fp, filter_length='10s', trans_bandwidth=0.5,
- method='fft', iir_params=None,
- picks=None, n_jobs=1, copy=True, verbose=None):
+def high_pass_filter(x, Fs, Fp, filter_length='', trans_bandwidth=None,
+ method='fir', iir_params=None, picks=None, n_jobs=1,
+ copy=True, phase='', fir_window='', verbose=None):
"""Highpass filter for the signal x.
Applies a zero-phase highpass filter to the signal x, operating on the
@@ -905,32 +1279,59 @@ def high_pass_filter(x, Fs, Fp, filter_length='10s', trans_bandwidth=0.5,
Sampling rate in Hz.
Fp : float
Cut-off frequency in Hz.
- filter_length : str (Default: '10s') | int | None
- Length of the filter to use. If None or "len(x) < filter_length",
- the filter length used is len(x). Otherwise, if int, overlap-add
- filtering with a filter of the specified length in samples) is
- used (faster for long signals). If str, a human-readable time in
- units of "s" or "ms" (e.g., "10s" or "5500ms") will be converted
- to the shortest power-of-two length at least that duration.
- Not used for 'iir' filters.
- trans_bandwidth : float
- Width of the transition band in Hz. Not used if 'order' is
- specified in iir_params.
+ filter_length : str | int
+ Length of the FIR filter to use (if applicable):
+
+ * int: specified length in samples.
+ * 'auto' (default in 0.14): the filter length is chosen based
+ on the size of the transition regions (6.6 times the reciprocal
+ of the shortest transition band for fir_window='hamming').
+ * str: (default in 0.13 is "10s") a human-readable time in
+ units of "s" or "ms" (e.g., "10s" or "5500ms") will be
+ converted to that number of samples if ``phase="zero"``, or
+ the shortest power-of-two length at least that duration for
+ ``phase="zero-double"``.
+
+ trans_bandwidth : float | str
+ Width of the transition band in Hz. Can be "auto"
+ (default in 0.14) to use a multiple of ``h_freq``::
+
+ min(max(h_freq * 0.25, 2.), info['sfreq'] / 2. - h_freq)
+
+ Only used for ``method='fir'``.
method : str
- 'fft' will use overlap-add FIR filtering, 'iir' will use IIR
+ 'fir' will use overlap-add FIR filtering, 'iir' will use IIR
forward-backward filtering (via filtfilt).
iir_params : dict | None
Dictionary of parameters to use for IIR filtering.
See mne.filter.construct_iir_filter for details. If iir_params
is None and method="iir", 4th order Butterworth will be used.
picks : array-like of int | None
- Indices to filter. If None all indices will be filtered.
+ Indices of channels to filter. If None all channels will be
+ filtered. Only supported for 2D (n_channels, n_times) and 3D
+ (n_epochs, n_channels, n_times) data.
n_jobs : int | str
Number of jobs to run in parallel. Can be 'cuda' if scikits.cuda
- is installed properly, CUDA is initialized, and method='fft'.
+ is installed properly, CUDA is initialized, and method='fir'.
copy : bool
If True, a copy of x, filtered, is returned. Otherwise, it operates
on x in place.
+ phase : str
+ Phase of the filter, only used if ``method='fir'``.
+ By default, a symmetric linear-phase FIR filter is constructed.
+ If ``phase='zero'`` (default in 0.14), the delay of this filter
+ is compensated for. If ``phase=='zero-double'`` (default in 0.13
+ and before), then this filter is applied twice, once forward, and
+ once backward.
+
+ .. versionadded:: 0.13
+
+ fir_window : str
+ The window to use in FIR design, can be "hamming" (default in
+ 0.14), "hann" (default in 0.13), or "blackman".
+
+ .. versionadded:: 0.13
+
verbose : bool, str, int, or None
If not None, override default verbose level (see mne.verbose).
@@ -941,51 +1342,53 @@ def high_pass_filter(x, Fs, Fp, filter_length='10s', trans_bandwidth=0.5,
See Also
--------
- low_pass_filter, band_pass_filter
+ filter_data
+ band_pass_filter
+ band_stop_filter
+ low_pass_filter
+ notch_filter
+ resample
Notes
-----
The frequency response is (approximately) given by::
- -----------------------
- /|
- / |
- / |
- / |
- ---------- |
- |
- Fstop Fp
+ 1-| -----------------------
+ | /
+ |H| | /
+ | /
+ | /
+ 0-|---------
+ | | | |
+ 0 Fstop Fp Nyq
- Where Fstop = Fp - trans_bandwidth.
+ Where ``Fstop = Fp - trans_bandwidth``.
"""
- iir_params = _check_method(method, iir_params, [])
- Fs = float(Fs)
- Fp = float(Fp)
-
- Fstop = Fp - trans_bandwidth if method == 'fft' else Fp
- if Fstop <= 0:
- raise ValueError('Filter specification invalid: Stop frequency too low'
- '(%0.1fHz). Increase Fp or reduce transition '
- 'bandwidth (trans_bandwidth)' % Fstop)
-
- if method == 'fft':
- freq = [0, Fstop, Fp, Fs / 2]
- gain = [0, 0, 1, 1]
- xf = _filter(x, Fs, freq, gain, filter_length, picks, n_jobs, copy)
+ iir_params, method = _check_method(method, iir_params, [])
+ logger.info('High-pass filtering at %0.2g Hz' % (Fp,))
+ x, Fs, Fp, _, Fstop, _, filter_length, phase, fir_window = \
+ _triage_filter_params(
+ x, Fs, Fp, None, trans_bandwidth, None, filter_length, method,
+ phase, fir_window)
+ if method == 'fir':
+ freq = [Fstop, Fp, Fs / 2.]
+ gain = [0, 1, 1]
+ if Fstop != 0:
+ freq = [0] + freq
+ gain = [0] + gain
+ xf = _fir_filter(x, Fs, freq, gain, filter_length, picks, n_jobs, copy,
+ phase, fir_window)
else:
iir_params = construct_iir_filter(iir_params, Fp, Fstop, Fs, 'high')
- padlen = min(iir_params['padlen'], len(x))
- xf = _filtfilt(x, iir_params['b'], iir_params['a'], padlen,
- picks, n_jobs, copy)
-
+ xf = _filtfilt(x, iir_params, picks, n_jobs, copy)
return xf
@verbose
-def notch_filter(x, Fs, freqs, filter_length='10s', notch_widths=None,
- trans_bandwidth=1, method='fft',
- iir_params=None, mt_bandwidth=None,
- p_value=0.05, picks=None, n_jobs=1, copy=True, verbose=None):
+def notch_filter(x, Fs, freqs, filter_length='', notch_widths=None,
+ trans_bandwidth=1, method='fir', iir_params=None,
+ mt_bandwidth=None, p_value=0.05, picks=None, n_jobs=1,
+ copy=True, phase='', fir_window='', verbose=None):
"""Notch filter for the signal x.
Applies a zero-phase notch filter to the signal x, operating on the last
@@ -1001,22 +1404,27 @@ def notch_filter(x, Fs, freqs, filter_length='10s', notch_widths=None,
Frequencies to notch filter in Hz, e.g. np.arange(60, 241, 60).
None can only be used with the mode 'spectrum_fit', where an F
test is used to find sinusoidal components.
- filter_length : str (Default: '10s') | int | None
- Length of the filter to use. If None or "len(x) < filter_length",
- the filter length used is len(x). Otherwise, if int, overlap-add
- filtering with a filter of the specified length in samples) is
- used (faster for long signals). If str, a human-readable time in
- units of "s" or "ms" (e.g., "10s" or "5500ms") will be converted
- to the shortest power-of-two length at least that duration.
- Not used for 'iir' filters.
+ filter_length : str | int
+ Length of the FIR filter to use (if applicable):
+
+ * int: specified length in samples.
+ * 'auto' (default in 0.14): the filter length is chosen based
+ on the size of the transition regions (6.6 times the reciprocal
+ of the shortest transition band for fir_window='hamming').
+ * str: (default in 0.13 is "10s") a human-readable time in
+ units of "s" or "ms" (e.g., "10s" or "5500ms") will be
+ converted to that number of samples if ``phase="zero"``, or
+ the shortest power-of-two length at least that duration for
+ ``phase="zero-double"``.
+
notch_widths : float | array of float | None
Width of the stop band (centred at each freq in freqs) in Hz.
If None, freqs / 200 is used.
trans_bandwidth : float
- Width of the transition band in Hz. Not used if 'order' is
- specified in iir_params.
+ Width of the transition band in Hz.
+ Only used for ``method='fir'``.
method : str
- 'fft' will use overlap-add FIR filtering, 'iir' will use IIR
+ 'fir' will use overlap-add FIR filtering, 'iir' will use IIR
forward-backward filtering (via filtfilt). 'spectrum_fit' will
use multi-taper estimation of sinusoidal components. If freqs=None
and method='spectrum_fit', significant sinusoidal components
@@ -1034,13 +1442,31 @@ def notch_filter(x, Fs, freqs, filter_length='10s', notch_widths=None,
freqs=None. Note that this will be Bonferroni corrected for the
number of frequencies, so large p-values may be justified.
picks : array-like of int | None
- Indices to filter. If None all indices will be filtered.
+ Indices of channels to filter. If None all channels will be
+ filtered. Only supported for 2D (n_channels, n_times) and 3D
+ (n_epochs, n_channels, n_times) data.
n_jobs : int | str
Number of jobs to run in parallel. Can be 'cuda' if scikits.cuda
- is installed properly, CUDA is initialized, and method='fft'.
+ is installed properly, CUDA is initialized, and method='fir'.
copy : bool
If True, a copy of x, filtered, is returned. Otherwise, it operates
on x in place.
+ phase : str
+ Phase of the filter, only used if ``method='fir'``.
+ By default, a symmetric linear-phase FIR filter is constructed.
+ If ``phase='zero'`` (default in 0.14), the delay of this filter
+ is compensated for. If ``phase=='zero-double'`` (default in 0.13
+ and before), then this filter is applied twice, once forward, and
+ once backward.
+
+ .. versionadded:: 0.13
+
+ fir_window : str
+ The window to use in FIR design, can be "hamming" (default in
+ 0.14), "hann" (default in 0.13), or "blackman".
+
+ .. versionadded:: 0.13
+
verbose : bool, str, int, or None
If not None, override default verbose level (see mne.verbose).
@@ -1049,23 +1475,30 @@ def notch_filter(x, Fs, freqs, filter_length='10s', notch_widths=None,
xf : array
x filtered.
+ See Also
+ --------
+ filter_data
+ band_pass_filter
+ band_stop_filter
+ high_pass_filter
+ low_pass_filter
+ resample
+
Notes
-----
The frequency response is (approximately) given by::
- ---------- -----------
- |\ /|
- | \ / |
- | \ / |
- | \ / |
- | - |
- | | |
- Fp1 freq Fp2
-
- For each freq in freqs, where:
+ 1-|---------- -----------
+ | \ /
+ |H| | \ /
+ | \ /
+ | \ /
+ 0-| -
+ | | | | |
+ 0 Fp1 freq Fp2 Nyq
- Fp1 = freq - trans_bandwidth / 2 in Hz
- Fs2 = freq + trans_bandwidth / 2 in Hz
+ For each freq in freqs, where ``Fp1 = freq - trans_bandwidth / 2`` and
+ ``Fs2 = freq + trans_bandwidth / 2``.
References
----------
@@ -1074,7 +1507,7 @@ def notch_filter(x, Fs, freqs, filter_length='10s', notch_widths=None,
& Hemant Bokil, Oxford University Press, New York, 2008. Please
cite this in publications if method 'spectrum_fit' is used.
"""
- iir_params = _check_method(method, iir_params, ['spectrum_fit'])
+ iir_params, method = _check_method(method, iir_params, ['spectrum_fit'])
if freqs is not None:
freqs = np.atleast_1d(freqs)
@@ -1096,7 +1529,7 @@ def notch_filter(x, Fs, freqs, filter_length='10s', notch_widths=None,
raise ValueError('notch_widths must be None, scalar, or the '
'same length as freqs')
- if method in ['fft', 'iir']:
+ if method in ['fir', 'iir']:
# Speed this up by computing the fourier coefficients once
tb_2 = trans_bandwidth / 2.0
lows = [freq - nw / 2.0 - tb_2
@@ -1104,7 +1537,8 @@ def notch_filter(x, Fs, freqs, filter_length='10s', notch_widths=None,
highs = [freq + nw / 2.0 + tb_2
for freq, nw in zip(freqs, notch_widths)]
xf = band_stop_filter(x, Fs, lows, highs, filter_length, tb_2, tb_2,
- method, iir_params, picks, n_jobs, copy)
+ method, iir_params, picks, n_jobs, copy,
+ phase=phase, fir_window=fir_window)
elif method == 'spectrum_fit':
xf = _mt_spectrum_proc(x, Fs, freqs, notch_widths, mt_bandwidth,
p_value, picks, n_jobs, copy)
@@ -1271,7 +1705,7 @@ def resample(x, up, down, npad=100, axis=-1, window='boxcar', n_jobs=1,
axis : int
Axis along which to resample (default is the last axis).
window : string or tuple
- See scipy.signal.resample for description.
+ See :func:`scipy.signal.resample` for description.
n_jobs : int | str
Number of jobs to run in parallel. Can be 'cuda' if scikits.cuda
is installed properly and CUDA is initialized.
@@ -1387,8 +1821,7 @@ def _resample_stim_channels(stim_data, up, down):
Parameters
----------
- stim_data : 1D array, shape (n_samples,) |
- 2D array, shape (n_stim_channels, n_samples)
+ stim_data : array, shape (n_samples,) or (n_stim_channels, n_samples)
Stim channels to resample.
up : float
Factor to upsample by.
@@ -1397,8 +1830,8 @@ def _resample_stim_channels(stim_data, up, down):
Returns
-------
- stim_resampled : 2D array, shape (n_stim_channels, n_samples_resampled)
- The resampled stim channels
+ stim_resampled : array, shape (n_stim_channels, n_samples_resampled)
+ The resampled stim channels.
Note
----
@@ -1479,45 +1912,170 @@ def detrend(x, order=1, axis=-1):
return y
-def _get_filter_length(filter_length, sfreq, min_length=128, len_x=np.inf):
- """Helper to determine a reasonable filter length"""
- if not isinstance(min_length, int):
- raise ValueError('min_length must be an int')
+def _triage_filter_params(x, sfreq, l_freq, h_freq,
+ l_trans_bandwidth, h_trans_bandwidth,
+ filter_length, method, phase, fir_window,
+ bands='scalar', reverse=False):
+ """Helper to validate and automate filter parameter selection"""
+ dep = list()
+ if not isinstance(phase, string_types) or phase not in \
+ ('linear', 'zero', 'zero-double', ''):
+ raise ValueError('phase must be "linear", "zero", or "zero-double", '
+ 'got "%s"' % (phase,))
+ if not isinstance(fir_window, string_types) or fir_window not in \
+ ('hann', 'hamming', 'blackman', ''):
+ raise ValueError('fir_window must be "hamming", "hann", or "blackman",'
+ 'got "%s"' % (fir_window,))
+ if phase == '':
+ if method == 'fir':
+ dep += ['phase in 0.13 is "zero-double" but will change to '
+ '"zero" in 0.14']
+ phase = 'zero-double'
+ if fir_window == '':
+ if method == 'fir':
+ dep += ['fir_window in 0.13 is "hann" but will change to '
+ '"hamming" in 0.14']
+ fir_window = 'hann'
+
+ def float_array(c):
+ return np.array(c, float).ravel()
+
+ if bands == 'arr':
+ cast = float_array
+ else:
+ cast = float
+ x = np.asanyarray(x)
+ len_x = x.shape[-1]
+ sfreq = float(sfreq)
+ if l_freq is not None:
+ l_freq = cast(l_freq)
+ if np.any(l_freq <= 0):
+ raise ValueError('highpass frequency %s must be greater than zero'
+ % (l_freq,))
+ if h_freq is not None:
+ h_freq = cast(h_freq)
+ if np.any(h_freq >= sfreq / 2.):
+ raise ValueError('lowpass frequency %s must be less than Nyquist '
+ '(%s)' % (h_freq, sfreq / 2.))
+ if method == 'iir':
+ # Ignore these parameters, effectively
+ l_stop, h_stop = l_freq, h_freq
+ else: # method == 'fir'
+ l_stop = h_stop = None
+ if l_freq is not None: # high-pass component
+ if isinstance(l_trans_bandwidth, string_types):
+ if l_trans_bandwidth != 'auto':
+ raise ValueError('l_trans_bandwidth must be "auto" if '
+ 'string, got "%s"' % l_trans_bandwidth)
+ l_trans_bandwidth = np.minimum(np.maximum(0.25 * l_freq, 2.),
+ l_freq)
+ logger.info('l_trans_bandwidth chosen to be %0.1f Hz'
+ % (l_trans_bandwidth,))
+ elif l_trans_bandwidth is None:
+ dep += ['lower transition bandwidth in 0.13 is 0.5 Hz but '
+ 'will change to "auto" in 0.14']
+ l_trans_bandwidth = 0.5
+ l_trans_bandwidth = cast(l_trans_bandwidth)
+ if np.any(l_trans_bandwidth <= 0):
+ raise ValueError('l_trans_bandwidth must be positive, got %s'
+ % (l_trans_bandwidth,))
+ l_stop = l_freq - l_trans_bandwidth
+ if reverse: # band-stop style
+ l_stop += l_trans_bandwidth
+ l_freq += l_trans_bandwidth
+ if np.any(l_stop < 0):
+ raise ValueError('Filter specification invalid: Lower stop '
+ 'frequency negative (%0.1fHz). Increase pass '
+ 'frequency or reduce the transition '
+ 'bandwidth (l_trans_bandwidth)' % l_stop)
+ if h_freq is not None: # low-pass component
+ if isinstance(h_trans_bandwidth, string_types):
+ if h_trans_bandwidth != 'auto':
+ raise ValueError('h_trans_bandwidth must be "auto" if '
+ 'string, got "%s"' % h_trans_bandwidth)
+ h_trans_bandwidth = np.minimum(np.maximum(0.25 * h_freq, 2.),
+ sfreq / 2. - h_freq)
+ logger.info('h_trans_bandwidth chosen to be %0.1f Hz'
+ % (h_trans_bandwidth))
+ elif h_trans_bandwidth is None:
+ dep += ['upper transition bandwidth in 0.13 is 0.5 Hz but '
+ 'will change to "auto" in 0.14']
+ h_trans_bandwidth = 0.5
+ h_trans_bandwidth = cast(h_trans_bandwidth)
+ if np.any(h_trans_bandwidth <= 0):
+ raise ValueError('h_trans_bandwidth must be positive, got %s'
+ % (h_trans_bandwidth,))
+ h_stop = h_freq + h_trans_bandwidth
+ if reverse: # band-stop style
+ h_stop -= h_trans_bandwidth
+ h_freq -= h_trans_bandwidth
+ if np.any(h_stop > sfreq / 2):
+ raise ValueError('Effective band-stop frequency (%s) is too '
+ 'high (maximum based on Nyquist is %s)'
+ % (h_stop, sfreq / 2.))
if isinstance(filter_length, string_types):
- # parse time values
- if filter_length[-2:].lower() == 'ms':
- mult_fact = 1e-3
- filter_length = filter_length[:-2]
- elif filter_length[-1].lower() == 's':
- mult_fact = 1
- filter_length = filter_length[:-1]
+ filter_length = filter_length.lower()
+ if filter_length == '':
+ if method == 'fir':
+ dep += ['The default filter length in 0.13 is "10s" but will '
+ 'change to "auto" in 0.14']
+ filter_length = '10s'
+ if filter_length == 'auto':
+ h_check = h_trans_bandwidth if h_freq is not None else np.inf
+ l_check = l_trans_bandwidth if l_freq is not None else np.inf
+ filter_length = max(int(round(
+ _length_factors[fir_window] * sfreq /
+ float(min(h_check, l_check)))), 1)
+ logger.info('Filter length of %s samples (%0.3f sec) selected'
+ % (filter_length, filter_length / sfreq))
else:
- raise ValueError('filter_length, if a string, must be a '
- 'human-readable time (e.g., "10s"), not '
- '"%s"' % filter_length)
- # now get the number
- try:
- filter_length = float(filter_length)
- except ValueError:
- raise ValueError('filter_length, if a string, must be a '
- 'human-readable time (e.g., "10s"), not '
- '"%s"' % filter_length)
- filter_length = 2 ** int(np.ceil(np.log2(filter_length *
- mult_fact * sfreq)))
- # shouldn't make filter longer than length of x
- if filter_length >= len_x:
- filter_length = len_x
- # only need to check min_length if the filter is shorter than len_x
- elif filter_length < min_length:
- filter_length = min_length
- warn('filter_length was too short, using filter of length %d '
- 'samples ("%0.1fs")'
- % (filter_length, filter_length / float(sfreq)))
-
- if filter_length is not None:
- if not isinstance(filter_length, integer_types):
- raise ValueError('filter_length must be str, int, or None')
- return filter_length
+ err_msg = ('filter_length, if a string, must be a human-readable '
+ 'time, e.g. "10s", or "auto", not "%s"' % filter_length)
+ if filter_length.lower().endswith('ms'):
+ mult_fact = 1e-3
+ filter_length = filter_length[:-2]
+ elif filter_length[-1].lower() == 's':
+ mult_fact = 1
+ filter_length = filter_length[:-1]
+ else:
+ raise ValueError(err_msg)
+ # now get the number
+ try:
+ filter_length = float(filter_length)
+ except ValueError:
+ raise ValueError(err_msg)
+ if phase == 'zero-double': # old mode
+ filter_length = 2 ** int(np.ceil(np.log2(
+ filter_length * mult_fact * sfreq)))
+ else:
+ filter_length = max(int(np.ceil(filter_length * mult_fact *
+ sfreq)), 1)
+ elif filter_length is None:
+ filter_length = len_x
+ if phase == 'zero':
+ filter_length -= (filter_length % 2 == 0)
+ dep += ['filter_length=None has been deprecated, set the filter '
+ 'length using an integer or string']
+ elif not isinstance(filter_length, integer_types):
+ raise ValueError('filter_length must be a str, int, or None, got %s'
+ % (type(filter_length),))
+ if phase == 'zero':
+ filter_length += (filter_length % 2 == 0)
+ if method != 'fir':
+ filter_length = len_x
+ if filter_length <= 0:
+ raise ValueError('filter_length must be positive, got %s'
+ % (filter_length,))
+ if filter_length > len_x:
+ warn('filter_length (%s) is longer than the signal (%s), '
+ 'distortion is likely. Reduce filter length or filter a '
+ 'longer signal.' % (filter_length, len_x))
+ logger.debug('Using filter length: %s' % filter_length)
+ if len(dep) > 0:
+ warn(('Multiple deprecated filter parameters were used:\n'
+ if len(dep) > 1 else '') + '\n'.join(dep), DeprecationWarning)
+ return (x, sfreq, l_freq, h_freq, l_stop, h_stop, filter_length, phase,
+ fir_window)
class FilterMixin(object):
@@ -1596,3 +2154,81 @@ class FilterMixin(object):
data[...] = savgol_filter(data, axis=axis, polyorder=5,
window_length=window_length)
return inst
+
+
+ at verbose
+def design_mne_c_filter(sfreq, l_freq=None, h_freq=40.,
+ l_trans_bandwidth=None, h_trans_bandwidth=5.,
+ verbose=None):
+ """Create a FIR filter like that used by MNE-C
+
+ Parameters
+ ----------
+ sfreq : float
+ The sample frequency.
+ l_freq : float | None
+ The low filter frequency in Hz, default None.
+ Can be None to avoid high-passing.
+ h_freq : float
+ The high filter frequency in Hz, default 40.
+ Can be None to avoid low-passing.
+ l_trans_bandwidth : float | None
+ Low transition bandwidthin Hz. Can be None (default) to use 3 samples.
+ h_trans_bandwidth : float
+ High transition bandwidth in Hz.
+ verbose : bool, str, int, or None
+ If not None, override default verbose level (see mne.verbose).
+ Defaults to self.verbose.
+
+ Returns
+ -------
+ h : ndarray, shape (8193,)
+ The linear-phase (symmetric) FIR filter coefficients.
+
+ Notes
+ -----
+ This function is provided mostly for reference purposes.
+
+ MNE-C uses a frequency-domain filter design technique by creating a
+ linear-phase filter of length 8193. In the frequency domain, the
+ 4197 frequencies are directly constructed, with zeroes in the stop-band
+ and ones in the pass-band, with squared cosine ramps in between.
+ """
+ n_freqs = (4096 + 2 * 2048) // 2 + 1
+ freq_resp = np.ones(n_freqs)
+ l_freq = 0 if l_freq is None else float(l_freq)
+ if l_trans_bandwidth is None:
+ l_width = 3
+ else:
+ l_width = (int(((n_freqs - 1) * l_trans_bandwidth) /
+ (0.5 * sfreq)) + 1) // 2
+ l_start = int(((n_freqs - 1) * l_freq) / (0.5 * sfreq))
+ h_freq = sfreq / 2. if h_freq is None else float(h_freq)
+ h_width = (int(((n_freqs - 1) * h_trans_bandwidth) /
+ (0.5 * sfreq)) + 1) // 2
+ h_start = int(((n_freqs - 1) * h_freq) / (0.5 * sfreq))
+ logger.info('filter : %7.3f ... %6.1f Hz bins : %d ... %d of %d '
+ 'hpw : %d lpw : %d' % (l_freq, h_freq, l_start, h_start,
+ n_freqs, l_width, h_width))
+ if l_freq > 0:
+ start = l_start - l_width + 1
+ stop = start + 2 * l_width - 1
+ if start < 0 or stop >= n_freqs:
+ raise RuntimeError('l_freq too low or l_trans_bandwidth too large')
+ freq_resp[:start] = 0.
+ k = np.arange(-l_width + 1, l_width) / float(l_width) + 3.
+ freq_resp[start:stop] = np.cos(np.pi / 4. * k) ** 2
+
+ if h_freq < sfreq / 2.:
+ start = h_start - h_width + 1
+ stop = start + 2 * h_width - 1
+ if start < 0 or stop >= n_freqs:
+ raise RuntimeError('h_freq too high or h_trans_bandwidth too '
+ 'large')
+ k = np.arange(-h_width + 1, h_width) / float(h_width) + 1.
+ freq_resp[start:stop] *= np.cos(np.pi / 4. * k) ** 2
+ freq_resp[stop:] = 0.0
+ # Get the time-domain version of this signal
+ h = ifft(np.concatenate((freq_resp, freq_resp[::-1][:-1]))).real
+ h = np.roll(h, n_freqs - 1) # center the impulse like a linear-phase filt
+ return h
diff --git a/mne/fixes.py b/mne/fixes.py
index 399715e..7f98674 100644
--- a/mne/fixes.py
+++ b/mne/fixes.py
@@ -1,7 +1,7 @@
"""Compatibility fixes for older version of python, numpy and scipy
If you add content to this file, please give the version of the package
-at which the fixe is no longer needed.
+at which the fix is no longer needed.
# XXX : originally copied from scikit-learn
@@ -13,23 +13,13 @@ at which the fixe is no longer needed.
# License: BSD
from __future__ import division
-import collections
-from distutils.version import LooseVersion
-from functools import partial
-from gzip import GzipFile
+
import inspect
-from math import ceil, log
-from operator import itemgetter
import re
import warnings
import numpy as np
-from numpy.fft import irfft
-import scipy
-from scipy import linalg, sparse
-
-from .externals import six
-from .externals.six.moves import copyreg, xrange
+from scipy import linalg
###############################################################################
@@ -58,706 +48,239 @@ else:
return out[0]
-class gzip_open(GzipFile): # python2.6 doesn't have context managing
-
- def __enter__(self):
- if hasattr(GzipFile, '__enter__'):
- return GzipFile.__enter__(self)
- else:
- return self
-
- def __exit__(self, exc_type, exc_value, traceback):
- if hasattr(GzipFile, '__exit__'):
- return GzipFile.__exit__(self, exc_type, exc_value, traceback)
- else:
- return self.close()
-
-
-class _Counter(collections.defaultdict):
- """Partial replacement for Python 2.7 collections.Counter."""
- def __init__(self, iterable=(), **kwargs):
- super(_Counter, self).__init__(int, **kwargs)
- self.update(iterable)
-
- def most_common(self):
- return sorted(six.iteritems(self), key=itemgetter(1), reverse=True)
-
- def update(self, other):
- """Adds counts for elements in other"""
- if isinstance(other, self.__class__):
- for x, n in six.iteritems(other):
- self[x] += n
- else:
- for x in other:
- self[x] += 1
-
-try:
- Counter = collections.Counter
-except AttributeError:
- Counter = _Counter
-
-
-def _unique(ar, return_index=False, return_inverse=False):
- """A replacement for the np.unique that appeared in numpy 1.4.
-
- While np.unique existed long before, keyword return_inverse was
- only added in 1.4.
- """
+def _safe_svd(A, **kwargs):
+ """Wrapper to get around the SVD did not converge error of death"""
+ # Intel has a bug with their GESVD driver:
+ # https://software.intel.com/en-us/forums/intel-distribution-for-python/topic/628049 # noqa
+ # For SciPy 0.18 and up, we can work around it by using
+ # lapack_driver='gesvd' instead.
+ if kwargs.get('overwrite_a', False):
+ raise ValueError('Cannot set overwrite_a=True with this function')
try:
- ar = ar.flatten()
- except AttributeError:
- if not return_inverse and not return_index:
- items = sorted(set(ar))
- return np.asarray(items)
+ return linalg.svd(A, **kwargs)
+ except np.linalg.LinAlgError as exp:
+ from .utils import warn
+ if 'lapack_driver' in _get_args(linalg.svd):
+ warn('SVD error (%s), attempting to use GESVD instead of GESDD'
+ % (exp,))
+ return linalg.svd(A, lapack_driver='gesvd', **kwargs)
else:
- ar = np.asarray(ar).flatten()
+ raise
- if ar.size == 0:
- if return_inverse and return_index:
- return ar, np.empty(0, np.bool), np.empty(0, np.bool)
- elif return_inverse or return_index:
- return ar, np.empty(0, np.bool)
- else:
- return ar
-
- if return_inverse or return_index:
- perm = ar.argsort()
- aux = ar[perm]
- flag = np.concatenate(([True], aux[1:] != aux[:-1]))
- if return_inverse:
- iflag = np.cumsum(flag) - 1
- iperm = perm.argsort()
- if return_index:
- return aux[flag], perm[flag], iflag[iperm]
- else:
- return aux[flag], iflag[iperm]
- else:
- return aux[flag], perm[flag]
+###############################################################################
+# Back porting scipy.signal.sosfilt (0.17) and sosfiltfilt (0.18)
+
+
+def _sosfiltfilt(sos, x, axis=-1, padtype='odd', padlen=None):
+ """copy of SciPy sosfiltfilt"""
+ sos, n_sections = _validate_sos(sos)
+
+ # `method` is "pad"...
+ ntaps = 2 * n_sections + 1
+ ntaps -= min((sos[:, 2] == 0).sum(), (sos[:, 5] == 0).sum())
+ edge, ext = _validate_pad(padtype, padlen, x, axis,
+ ntaps=ntaps)
+
+ # These steps follow the same form as filtfilt with modifications
+ zi = sosfilt_zi(sos) # shape (n_sections, 2) --> (n_sections, ..., 2, ...)
+ zi_shape = [1] * x.ndim
+ zi_shape[axis] = 2
+ zi.shape = [n_sections] + zi_shape
+ x_0 = axis_slice(ext, stop=1, axis=axis)
+ (y, zf) = sosfilt(sos, ext, axis=axis, zi=zi * x_0)
+ y_0 = axis_slice(y, start=-1, axis=axis)
+ (y, zf) = sosfilt(sos, axis_reverse(y, axis=axis), axis=axis, zi=zi * y_0)
+ y = axis_reverse(y, axis=axis)
+ if edge > 0:
+ y = axis_slice(y, start=edge, stop=-edge, axis=axis)
+ return y
+
+
+def axis_slice(a, start=None, stop=None, step=None, axis=-1):
+ """Take a slice along axis 'axis' from 'a'"""
+ a_slice = [slice(None)] * a.ndim
+ a_slice[axis] = slice(start, stop, step)
+ b = a[a_slice]
+ return b
+
+
+def axis_reverse(a, axis=-1):
+ """Reverse the 1-d slices of `a` along axis `axis`."""
+ return axis_slice(a, step=-1, axis=axis)
+
+
+def _validate_pad(padtype, padlen, x, axis, ntaps):
+ """Helper to validate padding for filtfilt"""
+ if padtype not in ['even', 'odd', 'constant', None]:
+ raise ValueError(("Unknown value '%s' given to padtype. padtype "
+ "must be 'even', 'odd', 'constant', or None.") %
+ padtype)
+
+ if padtype is None:
+ padlen = 0
+
+ if padlen is None:
+ # Original padding; preserved for backwards compatibility.
+ edge = ntaps * 3
else:
- ar.sort()
- flag = np.concatenate(([True], ar[1:] != ar[:-1]))
- return ar[flag]
-
-if LooseVersion(np.__version__) < LooseVersion('1.5'):
- unique = _unique
-else:
- unique = np.unique
-
-
-def _bincount(X, weights=None, minlength=None):
- """Replacing np.bincount in numpy < 1.6 to provide minlength."""
- result = np.bincount(X, weights)
- if minlength is None or len(result) >= minlength:
- return result
- out = np.zeros(minlength, np.int)
- out[:len(result)] = result
- return out
-
-if LooseVersion(np.__version__) < LooseVersion('1.6'):
- bincount = _bincount
-else:
- bincount = np.bincount
-
-
-def _copysign(x1, x2):
- """Slow replacement for np.copysign, which was introduced in numpy 1.4"""
- return np.abs(x1) * np.sign(x2)
-
-if not hasattr(np, 'copysign'):
- copysign = _copysign
-else:
- copysign = np.copysign
-
-
-def _in1d(ar1, ar2, assume_unique=False, invert=False):
- """Replacement for in1d that is provided for numpy >= 1.4"""
- # Ravel both arrays, behavior for the first array could be different
- ar1 = np.asarray(ar1).ravel()
- ar2 = np.asarray(ar2).ravel()
-
- # This code is significantly faster when the condition is satisfied.
- if len(ar2) < 10 * len(ar1) ** 0.145:
- if invert:
- mask = np.ones(len(ar1), dtype=np.bool)
- for a in ar2:
- mask &= (ar1 != a)
+ edge = padlen
+
+ # x's 'axis' dimension must be bigger than edge.
+ if x.shape[axis] <= edge:
+ raise ValueError("The length of the input vector x must be at least "
+ "padlen, which is %d." % edge)
+
+ if padtype is not None and edge > 0:
+ # Make an extension of length `edge` at each
+ # end of the input array.
+ if padtype == 'even':
+ ext = even_ext(x, edge, axis=axis)
+ elif padtype == 'odd':
+ ext = odd_ext(x, edge, axis=axis)
else:
- mask = np.zeros(len(ar1), dtype=np.bool)
- for a in ar2:
- mask |= (ar1 == a)
- return mask
-
- # Otherwise use sorting
- if not assume_unique:
- ar1, rev_idx = unique(ar1, return_inverse=True)
- ar2 = np.unique(ar2)
-
- ar = np.concatenate((ar1, ar2))
- # We need this to be a stable sort, so always use 'mergesort'
- # here. The values from the first array should always come before
- # the values from the second array.
- order = ar.argsort(kind='mergesort')
- sar = ar[order]
- if invert:
- bool_ar = (sar[1:] != sar[:-1])
+ ext = const_ext(x, edge, axis=axis)
else:
- bool_ar = (sar[1:] == sar[:-1])
- flag = np.concatenate((bool_ar, [invert]))
- indx = order.argsort(kind='mergesort')[:len(ar1)]
-
- if assume_unique:
- return flag[indx]
- else:
- return flag[indx][rev_idx]
-
-
-if not hasattr(np, 'in1d') or LooseVersion(np.__version__) < '1.8':
- in1d = _in1d
-else:
- in1d = np.in1d
-
-
-def _digitize(x, bins, right=False):
- """Replacement for digitize with right kwarg (numpy < 1.7).
-
- Notes
- -----
- This fix is only meant for integer arrays. If ``right==True`` but either
- ``x`` or ``bins`` are of a different type, a NotImplementedError will be
- raised.
- """
- if right:
- x = np.asarray(x)
- bins = np.asarray(bins)
- if (x.dtype.kind not in 'ui') or (bins.dtype.kind not in 'ui'):
- raise NotImplementedError("Only implemented for integer input")
- return np.digitize(x - 1e-5, bins)
- else:
- return np.digitize(x, bins)
-
-if LooseVersion(np.__version__) < LooseVersion('1.7'):
- digitize = _digitize
-else:
- digitize = np.digitize
-
-
-def _tril_indices(n, k=0):
- """Replacement for tril_indices that is provided for numpy >= 1.4"""
- mask = np.greater_equal(np.subtract.outer(np.arange(n), np.arange(n)), -k)
- indices = np.where(mask)
-
- return indices
-
-if not hasattr(np, 'tril_indices'):
- tril_indices = _tril_indices
-else:
- tril_indices = np.tril_indices
-
-
-def _unravel_index(indices, dims):
- """Add support for multiple indices in unravel_index that is provided
- for numpy >= 1.4"""
- indices_arr = np.asarray(indices)
- if indices_arr.size == 1:
- return np.unravel_index(indices, dims)
- else:
- if indices_arr.ndim != 1:
- raise ValueError('indices should be one dimensional')
-
- ndims = len(dims)
- unraveled_coords = np.empty((indices_arr.size, ndims), dtype=np.int)
- for coord, idx in zip(unraveled_coords, indices_arr):
- coord[:] = np.unravel_index(idx, dims)
- return tuple(unraveled_coords.T)
-
-
-if LooseVersion(np.__version__) < LooseVersion('1.4'):
- unravel_index = _unravel_index
-else:
- unravel_index = np.unravel_index
-
-
-def _qr_economic_old(A, **kwargs):
- """
- Compat function for the QR-decomposition in economic mode
- Scipy 0.9 changed the keyword econ=True to mode='economic'
- """
- with warnings.catch_warnings(record=True):
- return linalg.qr(A, econ=True, **kwargs)
-
-
-def _qr_economic_new(A, **kwargs):
- return linalg.qr(A, mode='economic', **kwargs)
-
-
-if LooseVersion(scipy.__version__) < LooseVersion('0.9'):
- qr_economic = _qr_economic_old
-else:
- qr_economic = _qr_economic_new
-
-
-def savemat(file_name, mdict, oned_as="column", **kwargs):
- """MATLAB-format output routine that is compatible with SciPy 0.7's.
-
- 0.7.2 (or .1?) added the oned_as keyword arg with 'column' as the default
- value. It issues a warning if this is not provided, stating that "This will
- change to 'row' in future versions."
- """
- import scipy.io
- try:
- return scipy.io.savemat(file_name, mdict, oned_as=oned_as, **kwargs)
- except TypeError:
- return scipy.io.savemat(file_name, mdict, **kwargs)
-
-if hasattr(np, 'count_nonzero'):
- from numpy import count_nonzero
-else:
- def count_nonzero(X):
- return len(np.flatnonzero(X))
-
-# little dance to see if np.copy has an 'order' keyword argument
-if 'order' in _get_args(np.copy):
- def safe_copy(X):
- # Copy, but keep the order
- return np.copy(X, order='K')
-else:
- # Before an 'order' argument was introduced, numpy wouldn't muck with
- # the ordering
- safe_copy = np.copy
-
-
-def _meshgrid(*xi, **kwargs):
- """
- Return coordinate matrices from coordinate vectors.
- Make N-D coordinate arrays for vectorized evaluations of
- N-D scalar/vector fields over N-D grids, given
- one-dimensional coordinate arrays x1, x2,..., xn.
- .. versionchanged:: 1.9
- 1-D and 0-D cases are allowed.
- Parameters
- ----------
- x1, x2,..., xn : array_like
- 1-D arrays representing the coordinates of a grid.
- indexing : {'xy', 'ij'}, optional
- Cartesian ('xy', default) or matrix ('ij') indexing of output.
- See Notes for more details.
- .. versionadded:: 1.7.0
- sparse : bool, optional
- If True a sparse grid is returned in order to conserve memory.
- Default is False.
- .. versionadded:: 1.7.0
- copy : bool, optional
- If False, a view into the original arrays are returned in order to
- conserve memory. Default is True. Please note that
- ``sparse=False, copy=False`` will likely return non-contiguous
- arrays. Furthermore, more than one element of a broadcast array
- may refer to a single memory location. If you need to write to the
- arrays, make copies first.
- .. versionadded:: 1.7.0
- Returns
- -------
- X1, X2,..., XN : ndarray
- For vectors `x1`, `x2`,..., 'xn' with lengths ``Ni=len(xi)`` ,
- return ``(N1, N2, N3,...Nn)`` shaped arrays if indexing='ij'
- or ``(N2, N1, N3,...Nn)`` shaped arrays if indexing='xy'
- with the elements of `xi` repeated to fill the matrix along
- the first dimension for `x1`, the second for `x2` and so on.
- """
- ndim = len(xi)
-
- copy_ = kwargs.pop('copy', True)
- sparse = kwargs.pop('sparse', False)
- indexing = kwargs.pop('indexing', 'xy')
-
- if kwargs:
- raise TypeError("meshgrid() got an unexpected keyword argument '%s'"
- % (list(kwargs)[0],))
-
- if indexing not in ['xy', 'ij']:
- raise ValueError(
- "Valid values for `indexing` are 'xy' and 'ij'.")
-
- s0 = (1,) * ndim
- output = [np.asanyarray(x).reshape(s0[:i] + (-1,) + s0[i + 1::])
- for i, x in enumerate(xi)]
-
- shape = [x.size for x in output]
-
- if indexing == 'xy' and ndim > 1:
- # switch first and second axis
- output[0].shape = (1, -1) + (1,) * (ndim - 2)
- output[1].shape = (-1, 1) + (1,) * (ndim - 2)
- shape[0], shape[1] = shape[1], shape[0]
-
- if sparse:
- if copy_:
- return [x.copy() for x in output]
+ ext = x
+ return edge, ext
+
+
+def _validate_sos(sos):
+ """Helper to validate a SOS input"""
+ sos = np.atleast_2d(sos)
+ if sos.ndim != 2:
+ raise ValueError('sos array must be 2D')
+ n_sections, m = sos.shape
+ if m != 6:
+ raise ValueError('sos array must be shape (n_sections, 6)')
+ if not (sos[:, 3] == 1).all():
+ raise ValueError('sos[:, 3] should be all ones')
+ return sos, n_sections
+
+
+def odd_ext(x, n, axis=-1):
+ """Generate a new ndarray by making an odd extension of x along an axis."""
+ if n < 1:
+ return x
+ if n > x.shape[axis] - 1:
+ raise ValueError(("The extension length n (%d) is too big. " +
+ "It must not exceed x.shape[axis]-1, which is %d.")
+ % (n, x.shape[axis] - 1))
+ left_end = axis_slice(x, start=0, stop=1, axis=axis)
+ left_ext = axis_slice(x, start=n, stop=0, step=-1, axis=axis)
+ right_end = axis_slice(x, start=-1, axis=axis)
+ right_ext = axis_slice(x, start=-2, stop=-(n + 2), step=-1, axis=axis)
+ ext = np.concatenate((2 * left_end - left_ext,
+ x,
+ 2 * right_end - right_ext),
+ axis=axis)
+ return ext
+
+
+def even_ext(x, n, axis=-1):
+ """Create an ndarray that is an even extension of x along an axis."""
+ if n < 1:
+ return x
+ if n > x.shape[axis] - 1:
+ raise ValueError(("The extension length n (%d) is too big. " +
+ "It must not exceed x.shape[axis]-1, which is %d.")
+ % (n, x.shape[axis] - 1))
+ left_ext = axis_slice(x, start=n, stop=0, step=-1, axis=axis)
+ right_ext = axis_slice(x, start=-2, stop=-(n + 2), step=-1, axis=axis)
+ ext = np.concatenate((left_ext,
+ x,
+ right_ext),
+ axis=axis)
+ return ext
+
+
+def const_ext(x, n, axis=-1):
+ """Create an ndarray that is a constant extension of x along an axis"""
+ if n < 1:
+ return x
+ left_end = axis_slice(x, start=0, stop=1, axis=axis)
+ ones_shape = [1] * x.ndim
+ ones_shape[axis] = n
+ ones = np.ones(ones_shape, dtype=x.dtype)
+ left_ext = ones * left_end
+ right_end = axis_slice(x, start=-1, axis=axis)
+ right_ext = ones * right_end
+ ext = np.concatenate((left_ext,
+ x,
+ right_ext),
+ axis=axis)
+ return ext
+
+
+def sosfilt_zi(sos):
+ """Compute an initial state `zi` for the sosfilt function"""
+ from scipy.signal import lfilter_zi
+ sos = np.asarray(sos)
+ if sos.ndim != 2 or sos.shape[1] != 6:
+ raise ValueError('sos must be shape (n_sections, 6)')
+
+ n_sections = sos.shape[0]
+ zi = np.empty((n_sections, 2))
+ scale = 1.0
+ for section in range(n_sections):
+ b = sos[section, :3]
+ a = sos[section, 3:]
+ zi[section] = scale * lfilter_zi(b, a)
+ # If H(z) = B(z)/A(z) is this section's transfer function, then
+ # b.sum()/a.sum() is H(1), the gain at omega=0. That's the steady
+ # state value of this section's step response.
+ scale *= b.sum() / a.sum()
+
+ return zi
+
+
+def sosfilt(sos, x, axis=-1, zi=None):
+ """Filter data along one dimension using cascaded second-order sections"""
+ from scipy.signal import lfilter
+ x = np.asarray(x)
+
+ sos = np.atleast_2d(sos)
+ if sos.ndim != 2:
+ raise ValueError('sos array must be 2D')
+
+ n_sections, m = sos.shape
+ if m != 6:
+ raise ValueError('sos array must be shape (n_sections, 6)')
+
+ use_zi = zi is not None
+ if use_zi:
+ zi = np.asarray(zi)
+ x_zi_shape = list(x.shape)
+ x_zi_shape[axis] = 2
+ x_zi_shape = tuple([n_sections] + x_zi_shape)
+ if zi.shape != x_zi_shape:
+ raise ValueError('Invalid zi shape. With axis=%r, an input with '
+ 'shape %r, and an sos array with %d sections, zi '
+ 'must have shape %r.' %
+ (axis, x.shape, n_sections, x_zi_shape))
+ zf = np.zeros_like(zi)
+
+ for section in range(n_sections):
+ if use_zi:
+ x, zf[section] = lfilter(sos[section, :3], sos[section, 3:],
+ x, axis, zi=zi[section])
else:
- return output
- else:
- # Return the full N-D matrix (not only the 1-D vector)
- if copy_:
- mult_fact = np.ones(shape, dtype=int)
- return [x * mult_fact for x in output]
- else:
- return np.broadcast_arrays(*output)
-
-if LooseVersion(np.__version__) < LooseVersion('1.7'):
- meshgrid = _meshgrid
-else:
- meshgrid = np.meshgrid
-
-
-###############################################################################
-# Back porting firwin2 for older scipy
-
-# Original version of firwin2 from scipy ticket #457, submitted by "tash".
-#
-# Rewritten by Warren Weckesser, 2010.
-
-
-def _firwin2(numtaps, freq, gain, nfreqs=None, window='hamming', nyq=1.0):
- """FIR filter design using the window method.
-
- From the given frequencies `freq` and corresponding gains `gain`,
- this function constructs an FIR filter with linear phase and
- (approximately) the given frequency response.
-
- Parameters
- ----------
- numtaps : int
- The number of taps in the FIR filter. `numtaps` must be less than
- `nfreqs`. If the gain at the Nyquist rate, `gain[-1]`, is not 0,
- then `numtaps` must be odd.
-
- freq : array-like, 1D
- The frequency sampling points. Typically 0.0 to 1.0 with 1.0 being
- Nyquist. The Nyquist frequency can be redefined with the argument
- `nyq`.
-
- The values in `freq` must be nondecreasing. A value can be repeated
- once to implement a discontinuity. The first value in `freq` must
- be 0, and the last value must be `nyq`.
-
- gain : array-like
- The filter gains at the frequency sampling points.
-
- nfreqs : int, optional
- The size of the interpolation mesh used to construct the filter.
- For most efficient behavior, this should be a power of 2 plus 1
- (e.g, 129, 257, etc). The default is one more than the smallest
- power of 2 that is not less than `numtaps`. `nfreqs` must be greater
- than `numtaps`.
-
- window : string or (string, float) or float, or None, optional
- Window function to use. Default is "hamming". See
- `scipy.signal.get_window` for the complete list of possible values.
- If None, no window function is applied.
-
- nyq : float
- Nyquist frequency. Each frequency in `freq` must be between 0 and
- `nyq` (inclusive).
-
- Returns
- -------
- taps : numpy 1D array of length `numtaps`
- The filter coefficients of the FIR filter.
-
- Examples
- --------
- A lowpass FIR filter with a response that is 1 on [0.0, 0.5], and
- that decreases linearly on [0.5, 1.0] from 1 to 0:
-
- >>> taps = firwin2(150, [0.0, 0.5, 1.0], [1.0, 1.0, 0.0]) # doctest: +SKIP
- >>> print(taps[72:78]) # doctest: +SKIP
- [-0.02286961 -0.06362756 0.57310236 0.57310236 -0.06362756 -0.02286961]
-
- See also
- --------
- scipy.signal.firwin
-
- Notes
- -----
-
- From the given set of frequencies and gains, the desired response is
- constructed in the frequency domain. The inverse FFT is applied to the
- desired response to create the associated convolution kernel, and the
- first `numtaps` coefficients of this kernel, scaled by `window`, are
- returned.
-
- The FIR filter will have linear phase. The filter is Type I if `numtaps`
- is odd and Type II if `numtaps` is even. Because Type II filters always
- have a zero at the Nyquist frequency, `numtaps` must be odd if `gain[-1]`
- is not zero.
-
- .. versionadded:: 0.9.0
-
- References
- ----------
- .. [1] Oppenheim, A. V. and Schafer, R. W., "Discrete-Time Signal
- Processing", Prentice-Hall, Englewood Cliffs, New Jersey (1989).
- (See, for example, Section 7.4.)
-
- .. [2] Smith, Steven W., "The Scientist and Engineer's Guide to Digital
- Signal Processing", Ch. 17. http://www.dspguide.com/ch17/1.htm
-
- """
-
- if len(freq) != len(gain):
- raise ValueError('freq and gain must be of same length.')
-
- if nfreqs is not None and numtaps >= nfreqs:
- raise ValueError('ntaps must be less than nfreqs, but firwin2 was '
- 'called with ntaps=%d and nfreqs=%s'
- % (numtaps, nfreqs))
-
- if freq[0] != 0 or freq[-1] != nyq:
- raise ValueError('freq must start with 0 and end with `nyq`.')
- d = np.diff(freq)
- if (d < 0).any():
- raise ValueError('The values in freq must be nondecreasing.')
- d2 = d[:-1] + d[1:]
- if (d2 == 0).any():
- raise ValueError('A value in freq must not occur more than twice.')
-
- if numtaps % 2 == 0 and gain[-1] != 0.0:
- raise ValueError("A filter with an even number of coefficients must "
- "have zero gain at the Nyquist rate.")
-
- if nfreqs is None:
- nfreqs = 1 + 2 ** int(ceil(log(numtaps, 2)))
-
- # Tweak any repeated values in freq so that interp works.
- eps = np.finfo(float).eps
- for k in range(len(freq)):
- if k < len(freq) - 1 and freq[k] == freq[k + 1]:
- freq[k] = freq[k] - eps
- freq[k + 1] = freq[k + 1] + eps
-
- # Linearly interpolate the desired response on a uniform mesh `x`.
- x = np.linspace(0.0, nyq, nfreqs)
- fx = np.interp(x, freq, gain)
-
- # Adjust the phases of the coefficients so that the first `ntaps` of the
- # inverse FFT are the desired filter coefficients.
- shift = np.exp(-(numtaps - 1) / 2. * 1.j * np.pi * x / nyq)
- fx2 = fx * shift
-
- # Use irfft to compute the inverse FFT.
- out_full = irfft(fx2)
-
- if window is not None:
- # Create the window to apply to the filter coefficients.
- from scipy.signal.signaltools import get_window
- wind = get_window(window, numtaps, fftbins=False)
- else:
- wind = 1
-
- # Keep only the first `numtaps` coefficients in `out`, and multiply by
- # the window.
- out = out_full[:numtaps] * wind
-
+ x = lfilter(sos[section, :3], sos[section, 3:], x, axis)
+ out = (x, zf) if use_zi else x
return out
-def get_firwin2():
- """Helper to get firwin2"""
- try:
- from scipy.signal import firwin2
- except ImportError:
- firwin2 = _firwin2
- return firwin2
-
-
-def _filtfilt(*args, **kwargs):
- """wrap filtfilt, excluding padding arguments"""
- from scipy.signal import filtfilt
- # cut out filter args
- if len(args) > 4:
- args = args[:4]
- if 'padlen' in kwargs:
- del kwargs['padlen']
- return filtfilt(*args, **kwargs)
-
-
-def get_filtfilt():
- """Helper to get filtfilt from scipy"""
- from scipy.signal import filtfilt
-
- if 'padlen' in _get_args(filtfilt):
- return filtfilt
-
- return _filtfilt
-
-
-def _get_argrelmax():
+def get_sosfiltfilt():
+ """Helper to get sosfiltfilt from scipy"""
try:
- from scipy.signal import argrelmax
+ from scipy.signal import sosfiltfilt
except ImportError:
- argrelmax = _argrelmax
- return argrelmax
-
-
-def _argrelmax(data, axis=0, order=1, mode='clip'):
- """Calculate the relative maxima of `data`.
-
- Parameters
- ----------
- data : ndarray
- Array in which to find the relative maxima.
- axis : int, optional
- Axis over which to select from `data`. Default is 0.
- order : int, optional
- How many points on each side to use for the comparison
- to consider ``comparator(n, n+x)`` to be True.
- mode : str, optional
- How the edges of the vector are treated.
- Available options are 'wrap' (wrap around) or 'clip' (treat overflow
- as the same as the last (or first) element).
- Default 'clip'. See `numpy.take`.
-
- Returns
- -------
- extrema : tuple of ndarrays
- Indices of the maxima in arrays of integers. ``extrema[k]`` is
- the array of indices of axis `k` of `data`. Note that the
- return value is a tuple even when `data` is one-dimensional.
- """
- comparator = np.greater
- if((int(order) != order) or (order < 1)):
- raise ValueError('Order must be an int >= 1')
- datalen = data.shape[axis]
- locs = np.arange(0, datalen)
- results = np.ones(data.shape, dtype=bool)
- main = data.take(locs, axis=axis, mode=mode)
- for shift in xrange(1, order + 1):
- plus = data.take(locs + shift, axis=axis, mode=mode)
- minus = data.take(locs - shift, axis=axis, mode=mode)
- results &= comparator(main, plus)
- results &= comparator(main, minus)
- if(~results.any()):
- return results
- return np.where(results)
+ sosfiltfilt = _sosfiltfilt
+ return sosfiltfilt
###############################################################################
-# Back porting matrix_rank for numpy < 1.7
-
-
-def _matrix_rank(M, tol=None):
- """ Return matrix rank of array using SVD method
-
- Rank of the array is the number of SVD singular values of the array that
- are greater than `tol`.
-
- Parameters
- ----------
- M : {(M,), (M, N)} array_like
- array of <=2 dimensions
- tol : {None, float}, optional
- threshold below which SVD values are considered zero. If `tol` is
- None, and ``S`` is an array with singular values for `M`, and
- ``eps`` is the epsilon value for datatype of ``S``, then `tol` is
- set to ``S.max() * max(M.shape) * eps``.
-
- Notes
- -----
- The default threshold to detect rank deficiency is a test on the magnitude
- of the singular values of `M`. By default, we identify singular values less
- than ``S.max() * max(M.shape) * eps`` as indicating rank deficiency (with
- the symbols defined above). This is the algorithm MATLAB uses [1]. It also
- appears in *Numerical recipes* in the discussion of SVD solutions for
- linear least squares [2].
-
- This default threshold is designed to detect rank deficiency accounting
- for the numerical errors of the SVD computation. Imagine that there is a
- column in `M` that is an exact (in floating point) linear combination of
- other columns in `M`. Computing the SVD on `M` will not produce a
- singular value exactly equal to 0 in general: any difference of the
- smallest SVD value from 0 will be caused by numerical imprecision in the
- calculation of the SVD. Our threshold for small SVD values takes this
- numerical imprecision into account, and the default threshold will detect
- such numerical rank deficiency. The threshold may declare a matrix `M`
- rank deficient even if the linear combination of some columns of `M` is
- not exactly equal to another column of `M` but only numerically very
- close to another column of `M`.
-
- We chose our default threshold because it is in wide use. Other
- thresholds are possible. For example, elsewhere in the 2007 edition of
- *Numerical recipes* there is an alternative threshold of ``S.max() *
- np.finfo(M.dtype).eps / 2. * np.sqrt(m + n + 1.)``. The authors describe
- this threshold as being based on "expected roundoff error" (p 71).
-
- The thresholds above deal with floating point roundoff error in the
- calculation of the SVD. However, you may have more information about the
- sources of error in `M` that would make you consider other tolerance
- values to detect *effective* rank deficiency. The most useful measure of
- the tolerance depends on the operations you intend to use on your matrix.
- For example, if your data come from uncertain measurements with
- uncertainties greater than floating point epsilon, choosing a tolerance
- near that uncertainty may be preferable. The tolerance may be absolute if
- the uncertainties are absolute rather than relative.
-
- References
- ----------
- .. [1] MATLAB reference documention, "Rank"
- http://www.mathworks.com/help/techdoc/ref/rank.html
- .. [2] W. H. Press, S. A. Teukolsky, W. T. Vetterling and B. P. Flannery,
- "Numerical Recipes (3rd edition)", Cambridge University Press, 2007,
- page 795.
-
- Examples
- --------
- >>> from numpy.linalg import matrix_rank
- >>> matrix_rank(np.eye(4)) # Full rank matrix
- 4
- >>> I=np.eye(4); I[-1,-1] = 0. # rank deficient matrix
- >>> matrix_rank(I)
- 3
- >>> matrix_rank(np.ones((4,))) # 1 dimension - rank 1 unless all 0
- 1
- >>> matrix_rank(np.zeros((4,)))
- 0
- """
- M = np.asarray(M)
- if M.ndim > 2:
- raise TypeError('array should have 2 or fewer dimensions')
- if M.ndim < 2:
- return np.int(not all(M == 0))
- S = np.linalg.svd(M, compute_uv=False)
- if tol is None:
- tol = S.max() * np.max(M.shape) * np.finfo(S.dtype).eps
- return np.sum(S > tol)
-
-if LooseVersion(np.__version__) > '1.7.1':
- from numpy.linalg import matrix_rank
-else:
- matrix_rank = _matrix_rank
-
-
-def _reconstruct_partial(func, args, kwargs):
- """Helper to pickle partial functions"""
- return partial(func, *args, **(kwargs or {}))
-
-
-def _reduce_partial(p):
- """Helper to pickle partial functions"""
- return _reconstruct_partial, (p.func, p.args, p.keywords)
-
-# This adds pickling functionality to older Python 2.6
-# Please always import partial from here.
-copyreg.pickle(partial, _reduce_partial)
-
-
-def normalize_colors(vmin, vmax, clip=False):
- """Helper to handle matplotlib API"""
- import matplotlib.pyplot as plt
- try:
- return plt.Normalize(vmin, vmax, clip=clip)
- except AttributeError:
- return plt.normalize(vmin, vmax, clip=clip)
-
+# Misc utilities
def assert_true(expr, msg='False is not True'):
"""Fake assert_true without message"""
@@ -852,139 +375,59 @@ def assert_raises_regex(exception_class, expected_regexp,
callable_obj, *args, **kwargs)
-def _sparse_block_diag(mats, format=None, dtype=None):
- """An implementation of scipy.sparse.block_diag since old versions of
- scipy don't have it. Forms a sparse matrix by stacking matrices in block
- diagonal form.
-
- Parameters
- ----------
- mats : list of matrices
- Input matrices.
- format : str, optional
- The sparse format of the result (e.g. "csr"). If not given, the
- matrix is returned in "coo" format.
- dtype : dtype specifier, optional
- The data-type of the output matrix. If not given, the dtype is
- determined from that of blocks.
-
- Returns
- -------
- res : sparse matrix
- """
- nmat = len(mats)
- rows = []
- for ia, a in enumerate(mats):
- row = [None] * nmat
- row[ia] = a
- rows.append(row)
- return sparse.bmat(rows, format=format, dtype=dtype)
-
-try:
- from scipy.sparse import block_diag as sparse_block_diag
-except Exception:
- sparse_block_diag = _sparse_block_diag
-
-
-def _isclose(a, b, rtol=1.e-5, atol=1.e-8, equal_nan=False):
- """
- Returns a boolean array where two arrays are element-wise equal within a
- tolerance.
-
- The tolerance values are positive, typically very small numbers. The
- relative difference (`rtol` * abs(`b`)) and the absolute difference
- `atol` are added together to compare against the absolute difference
- between `a` and `b`.
-
- Parameters
- ----------
- a, b : array_like
- Input arrays to compare.
- rtol : float
- The relative tolerance parameter (see Notes).
- atol : float
- The absolute tolerance parameter (see Notes).
- equal_nan : bool
- Whether to compare NaN's as equal. If True, NaN's in `a` will be
- considered equal to NaN's in `b` in the output array.
-
- Returns
- -------
- y : array_like
- Returns a boolean array of where `a` and `b` are equal within the
- given tolerance. If both `a` and `b` are scalars, returns a single
- boolean value.
-
- See Also
- --------
- allclose
-
- Notes
- -----
- .. versionadded:: 1.7.0
-
- For finite values, isclose uses the following equation to test whether
- two floating point values are equivalent.
-
- absolute(`a` - `b`) <= (`atol` + `rtol` * absolute(`b`))
-
- The above equation is not symmetric in `a` and `b`, so that
- `isclose(a, b)` might be different from `isclose(b, a)` in
- some rare cases.
-
- Examples
- --------
- >>> isclose([1e10,1e-7], [1.00001e10,1e-8])
- array([ True, False], dtype=bool)
- >>> isclose([1e10,1e-8], [1.00001e10,1e-9])
- array([ True, True], dtype=bool)
- >>> isclose([1e10,1e-8], [1.0001e10,1e-9])
- array([False, True], dtype=bool)
- >>> isclose([1.0, np.nan], [1.0, np.nan])
- array([ True, False], dtype=bool)
- >>> isclose([1.0, np.nan], [1.0, np.nan], equal_nan=True)
- array([ True, True], dtype=bool)
+def _read_volume_info(fobj):
+ """An implementation of nibabel.freesurfer.io._read_volume_info, since old
+ versions of nibabel (<=2.1.0) don't have it.
"""
- def within_tol(x, y, atol, rtol):
- with np.errstate(invalid='ignore'):
- result = np.less_equal(abs(x - y), atol + rtol * abs(y))
- if np.isscalar(a) and np.isscalar(b):
- result = bool(result)
- return result
-
- x = np.array(a, copy=False, subok=True, ndmin=1)
- y = np.array(b, copy=False, subok=True, ndmin=1)
-
- # Make sure y is an inexact type to avoid bad behavior on abs(MIN_INT).
- # This will cause casting of x later. Also, make sure to allow subclasses
- # (e.g., for numpy.ma).
- dt = np.core.multiarray.result_type(y, 1.)
- y = np.array(y, dtype=dt, copy=False, subok=True)
-
- xfin = np.isfinite(x)
- yfin = np.isfinite(y)
- if np.all(xfin) and np.all(yfin):
- return within_tol(x, y, atol, rtol)
- else:
- finite = xfin & yfin
- cond = np.zeros_like(finite, subok=True)
- # Because we're using boolean indexing, x & y must be the same shape.
- # Ideally, we'd just do x, y = broadcast_arrays(x, y). It's in
- # lib.stride_tricks, though, so we can't import it here.
- x = x * np.ones_like(cond)
- y = y * np.ones_like(cond)
- # Avoid subtraction with infinite/nan values...
- cond[finite] = within_tol(x[finite], y[finite], atol, rtol)
- # Check for equality of infinite values...
- cond[~finite] = (x[~finite] == y[~finite])
- if equal_nan:
- # Make NaN == NaN
- both_nan = np.isnan(x) & np.isnan(y)
- cond[both_nan] = both_nan[both_nan]
- return cond
-
-
-if LooseVersion(np.__version__) < LooseVersion('1.7'):
- isclose = _isclose
-else:
- isclose = np.isclose
+ volume_info = dict()
+ head = np.fromfile(fobj, '>i4', 1)
+ if not np.array_equal(head, [20]): # Read two bytes more
+ head = np.concatenate([head, np.fromfile(fobj, '>i4', 2)])
+ if not np.array_equal(head, [2, 0, 20]):
+ warnings.warn("Unknown extension code.")
+ return volume_info
+
+ volume_info['head'] = head
+ for key in ['valid', 'filename', 'volume', 'voxelsize', 'xras', 'yras',
+ 'zras', 'cras']:
+ pair = fobj.readline().decode('utf-8').split('=')
+ if pair[0].strip() != key or len(pair) != 2:
+ raise IOError('Error parsing volume info.')
+ if key in ('valid', 'filename'):
+ volume_info[key] = pair[1].strip()
+ elif key == 'volume':
+ volume_info[key] = np.array(pair[1].split()).astype(int)
+ else:
+ volume_info[key] = np.array(pair[1].split()).astype(float)
+ # Ignore the rest
+ return volume_info
+
+
+def _serialize_volume_info(volume_info):
+ """An implementation of nibabel.freesurfer.io._serialize_volume_info, since
+ old versions of nibabel (<=2.1.0) don't have it."""
+ keys = ['head', 'valid', 'filename', 'volume', 'voxelsize', 'xras', 'yras',
+ 'zras', 'cras']
+ diff = set(volume_info.keys()).difference(keys)
+ if len(diff) > 0:
+ raise ValueError('Invalid volume info: %s.' % diff.pop())
+
+ strings = list()
+ for key in keys:
+ if key == 'head':
+ if not (np.array_equal(volume_info[key], [20]) or np.array_equal(
+ volume_info[key], [2, 0, 20])):
+ warnings.warn("Unknown extension code.")
+ strings.append(np.array(volume_info[key], dtype='>i4').tostring())
+ elif key in ('valid', 'filename'):
+ val = volume_info[key]
+ strings.append('{0} = {1}\n'.format(key, val).encode('utf-8'))
+ elif key == 'volume':
+ val = volume_info[key]
+ strings.append('{0} = {1} {2} {3}\n'.format(
+ key, val[0], val[1], val[2]).encode('utf-8'))
+ else:
+ val = volume_info[key]
+ strings.append('{0} = {1:0.10g} {2:0.10g} {3:0.10g}\n'.format(
+ key.ljust(6), val[0], val[1], val[2]).encode('utf-8'))
+ return b''.join(strings)
diff --git a/mne/forward/__init__.py b/mne/forward/__init__.py
index 1f8b21c..c413e39 100644
--- a/mne/forward/__init__.py
+++ b/mne/forward/__init__.py
@@ -4,7 +4,7 @@ from .forward import (Forward, read_forward_solution, write_forward_solution,
compute_orient_prior, compute_depth_prior,
apply_forward, apply_forward_raw,
restrict_forward_to_stc, restrict_forward_to_label,
- do_forward_solution, average_forward_solutions,
+ average_forward_solutions,
_restrict_gain_matrix, _stc_src_sel,
_fill_measurement_info, _apply_forward,
_subject_from_forward, convert_forward_solution,
diff --git a/mne/forward/_compute_forward.py b/mne/forward/_compute_forward.py
index f6bdab9..843e79c 100644
--- a/mne/forward/_compute_forward.py
+++ b/mne/forward/_compute_forward.py
@@ -733,7 +733,7 @@ def _prep_field_computation(rr, bem, fwd_data, n_jobs, verbose=None):
# Compute solution for EEG sensor
solution = _bem_specify_els(bem, coils, mults)
else:
- solution = bem
+ solution = csolution = bem
if coil_type == 'eeg':
logger.info('Using the equivalent source approach in the '
'homogeneous sphere for EEG')
diff --git a/mne/forward/_field_interpolation.py b/mne/forward/_field_interpolation.py
index 21fff15..71e868f 100644
--- a/mne/forward/_field_interpolation.py
+++ b/mne/forward/_field_interpolation.py
@@ -1,8 +1,10 @@
# -*- coding: utf-8 -*-
+from copy import deepcopy
+from functools import partial
+
import numpy as np
from scipy import linalg
-from copy import deepcopy
from ..bem import _check_origin
from ..io.constants import FIFF
@@ -17,7 +19,6 @@ from ._lead_dots import (_do_self_dots, _do_surface_dots, _get_legen_table,
_do_cross_dots)
from ..parallel import check_n_jobs
from ..utils import logger, verbose
-from ..fixes import partial
def _is_axial_coil(coil):
diff --git a/mne/forward/_make_forward.py b/mne/forward/_make_forward.py
index 145d06a..b4521e3 100644
--- a/mne/forward/_make_forward.py
+++ b/mne/forward/_make_forward.py
@@ -26,7 +26,6 @@ from ..externals.six import string_types
from .forward import (Forward, write_forward_solution, _merge_meg_eeg_fwds,
convert_forward_solution)
from ._compute_forward import _compute_forwards
-from ..fixes import in1d
_accuracy_dict = dict(normal=FIFF.FWD_COIL_ACCURACY_NORMAL,
@@ -718,7 +717,7 @@ def make_forward_dipole(dipole, bem, info, trans=None, n_jobs=1, verbose=None):
data = np.zeros((len(amplitude), len(timepoints))) # (n_d, n_t)
row = 0
for tpind, tp in enumerate(timepoints):
- amp = amplitude[in1d(times, tp)]
+ amp = amplitude[np.in1d(times, tp)]
data[row:row + len(amp), tpind] = amp
row += len(amp)
diff --git a/mne/forward/forward.py b/mne/forward/forward.py
index c2f2fd1..cba55f5 100644
--- a/mne/forward/forward.py
+++ b/mne/forward/forward.py
@@ -17,7 +17,6 @@ from os import path as op
import tempfile
from ..externals.six import string_types
-from ..fixes import sparse_block_diag
from ..io import RawArray, Info
from ..io.constants import FIFF
from ..io.open import fiff_open
@@ -41,8 +40,7 @@ from ..source_estimate import VolSourceEstimate
from ..transforms import (transform_surface_to, invert_transform,
write_trans)
from ..utils import (_check_fname, get_subjects_dir, has_mne_c, warn,
- run_subprocess, check_fname, logger, verbose,
- deprecated)
+ run_subprocess, check_fname, logger, verbose, deprecated)
from ..label import Label
@@ -58,9 +56,9 @@ class Forward(dict):
entr = '<Forward'
- nchan = len(pick_types(self['info'], meg=True, eeg=False))
+ nchan = len(pick_types(self['info'], meg=True, eeg=False, exclude=[]))
entr += ' | ' + 'MEG channels: %d' % nchan
- nchan = len(pick_types(self['info'], meg=False, eeg=True))
+ nchan = len(pick_types(self['info'], meg=False, eeg=True, exclude=[]))
entr += ' | ' + 'EEG channels: %d' % nchan
src_types = np.array([src['type'] for src in self['src']])
@@ -98,6 +96,8 @@ class Forward(dict):
return entr
+ at deprecated("it will be removed in mne 0.14; use mne.make_bem_solution() "
+ "instead.")
def prepare_bem_model(bem, sol_fname=None, method='linear'):
"""Wrapper for the mne_prepare_bem_model command line utility
@@ -576,7 +576,7 @@ def convert_forward_solution(fwd, surf_ori=False, force_fixed=False,
Parameters
----------
- fwd : dict
+ fwd : Forward
The forward solution to modify.
surf_ori : bool, optional (default False)
Use surface-based source coordinate system? Note that force_fixed=True
@@ -590,7 +590,7 @@ def convert_forward_solution(fwd, surf_ori=False, force_fixed=False,
Returns
-------
- fwd : dict
+ fwd : Forward
The modified forward solution.
"""
fwd = fwd.copy() if copy else fwd
@@ -620,7 +620,7 @@ def convert_forward_solution(fwd, surf_ori=False, force_fixed=False,
fwd['source_ori'] = FIFF.FIFFV_MNE_FIXED_ORI
if fwd['sol_grad'] is not None:
- x = sparse_block_diag([fix_rot] * 3)
+ x = sparse.block_diag([fix_rot] * 3)
fwd['sol_grad']['data'] = fwd['_orig_sol_grad'] * x # dot prod
fwd['sol_grad']['ncol'] = 3 * fwd['nsource']
logger.info(' [done]')
@@ -663,7 +663,7 @@ def convert_forward_solution(fwd, surf_ori=False, force_fixed=False,
fwd['sol']['data'] = fwd['_orig_sol'] * surf_rot
fwd['sol']['ncol'] = 3 * fwd['nsource']
if fwd['sol_grad'] is not None:
- x = sparse_block_diag([surf_rot] * 3)
+ x = sparse.block_diag([surf_rot] * 3)
fwd['sol_grad']['data'] = fwd['_orig_sol_grad'] * x # dot prod
fwd['sol_grad']['ncol'] = 3 * fwd['nsource']
logger.info('[done]')
@@ -693,7 +693,7 @@ def write_forward_solution(fname, fwd, overwrite=False, verbose=None):
fname : str
File name to save the forward solution to. It should end with -fwd.fif
or -fwd.fif.gz.
- fwd : dict
+ fwd : Forward
Forward solution.
overwrite : bool
If True, overwrite destination file (if it exists).
@@ -773,7 +773,7 @@ def write_forward_solution(fname, fwd, overwrite=False, verbose=None):
inv_rot = _inv_block_diag(fwd['source_nn'].T, 3)
sol = sol * inv_rot
if sol_grad is not None:
- sol_grad = sol_grad * sparse_block_diag([inv_rot] * 3) # dot prod
+ sol_grad = sol_grad * sparse.block_diag([inv_rot] * 3) # dot prod
#
# MEG forward solution
@@ -1129,7 +1129,7 @@ def apply_forward(fwd, stc, info, start=None, stop=None,
Parameters
----------
- fwd : dict
+ fwd : Forward
Forward operator to use. Has to be fixed-orientation.
stc : SourceEstimate
The source estimate from which the sensor space data is computed.
@@ -1189,7 +1189,7 @@ def apply_forward_raw(fwd, stc, info, start=None, stop=None,
Parameters
----------
- fwd : dict
+ fwd : Forward
Forward operator to use. Has to be fixed-orientation.
stc : SourceEstimate
The source estimate from which the sensor space data is computed.
@@ -1239,7 +1239,7 @@ def restrict_forward_to_stc(fwd, stc):
Parameters
----------
- fwd : dict
+ fwd : Forward
Forward operator.
stc : SourceEstimate
Source estimate.
@@ -1286,7 +1286,7 @@ def restrict_forward_to_label(fwd, labels):
Parameters
----------
- fwd : dict
+ fwd : Forward
Forward operator.
labels : label object | list
Label object or list of label objects.
@@ -1361,86 +1361,6 @@ def restrict_forward_to_label(fwd, labels):
return fwd_out
- at deprecated('do_forward_solution is deprecated and will be removed in 0.13, '
- 'use make_forward_solution instead')
- at verbose
-def do_forward_solution(subject, meas, fname=None, src=None, spacing=None,
- mindist=None, bem=None, mri=None, trans=None,
- eeg=True, meg=True, fixed=False, grad=False,
- mricoord=False, overwrite=False, subjects_dir=None,
- verbose=None):
- """Calculate a forward solution for a subject using MNE-C routines
-
- This function wraps to mne_do_forward_solution, so the mne
- command-line tools must be installed and accessible from Python.
-
- Parameters
- ----------
- subject : str
- Name of the subject.
- meas : Raw | Epochs | Evoked | str
- If Raw or Epochs, a temporary evoked file will be created and
- saved to a temporary directory. If str, then it should be a
- filename to a file with measurement information the mne
- command-line tools can understand (i.e., raw or evoked).
- fname : str | None
- Destination forward solution filename. If None, the solution
- will be created in a temporary directory, loaded, and deleted.
- src : str | None
- Source space name. If None, the MNE default is used.
- spacing : str
- The spacing to use. Can be ``'#'`` for spacing in mm, ``'ico#'`` for a
- recursively subdivided icosahedron, or ``'oct#'`` for a recursively
- subdivided octahedron (e.g., ``spacing='ico4'``). Default is 7 mm.
- mindist : float | str | None
- Minimum distance of sources from inner skull surface (in mm).
- If None, the MNE default value is used. If string, 'all'
- indicates to include all points.
- bem : str | None
- Name of the BEM to use (e.g., "sample-5120-5120-5120"). If None
- (Default), the MNE default will be used.
- mri : str | None
- The name of the trans file in FIF format.
- If None, trans must not be None.
- trans : dict | str | None
- File name of the trans file in text format.
- If None, mri must not be None.
- eeg : bool
- If True (Default), include EEG computations.
- meg : bool
- If True (Default), include MEG computations.
- fixed : bool
- If True, make a fixed-orientation forward solution (Default:
- False). Note that fixed-orientation inverses can still be
- created from free-orientation forward solutions.
- grad : bool
- If True, compute the gradient of the field with respect to the
- dipole coordinates as well (Default: False).
- mricoord : bool
- If True, calculate in MRI coordinates (Default: False).
- overwrite : bool
- If True, the destination file (if it exists) will be overwritten.
- If False (default), an error will be raised if the file exists.
- subjects_dir : None | str
- Override the SUBJECTS_DIR environment variable.
- verbose : bool, str, int, or None
- If not None, override default verbose level (see mne.verbose).
-
- See Also
- --------
- forward.make_forward_solution
-
- Returns
- -------
- fwd : dict
- The generated forward solution.
- """
- return _do_forward_solution(subject, meas, fname, src, spacing,
- mindist, bem, mri, trans, eeg, meg, fixed,
- grad, mricoord, overwrite, subjects_dir,
- verbose)
-
-
def _do_forward_solution(subject, meas, fname=None, src=None, spacing=None,
mindist=None, bem=None, mri=None, trans=None,
eeg=True, meg=True, fixed=False, grad=False,
@@ -1511,7 +1431,7 @@ def _do_forward_solution(subject, meas, fname=None, src=None, spacing=None,
Returns
-------
- fwd : dict
+ fwd : Forward
The generated forward solution.
"""
if not has_mne_c():
@@ -1657,7 +1577,7 @@ def average_forward_solutions(fwds, weights=None):
Parameters
----------
- fwds : list of dict
+ fwds : list of Forward
Forward solutions to average. Each entry (dict) should be a
forward solution.
weights : array | None
@@ -1667,7 +1587,7 @@ def average_forward_solutions(fwds, weights=None):
Returns
-------
- fwd : dict
+ fwd : Forward
The averaged forward solution.
"""
# check for fwds being a list
diff --git a/mne/forward/tests/test_field_interpolation.py b/mne/forward/tests/test_field_interpolation.py
index a8160cd..0393ac7 100644
--- a/mne/forward/tests/test_field_interpolation.py
+++ b/mne/forward/tests/test_field_interpolation.py
@@ -1,3 +1,4 @@
+from functools import partial
from os import path as op
import numpy as np
@@ -17,7 +18,6 @@ from mne.forward._field_interpolation import _setup_dots
from mne.surface import get_meg_helmet_surf, get_head_surf
from mne.datasets import testing
from mne import read_evokeds, pick_types
-from mne.fixes import partial
from mne.externals.six.moves import zip
from mne.utils import run_tests_if_main, slow_test
@@ -32,8 +32,7 @@ subjects_dir = op.join(data_path, 'subjects')
def test_legendre_val():
- """Test Legendre polynomial (derivative) equivalence
- """
+ """Test Legendre polynomial (derivative) equivalence"""
rng = np.random.RandomState(0)
# check table equiv
xs = np.linspace(-1., 1., 1000)
@@ -83,8 +82,7 @@ def test_legendre_val():
def test_legendre_table():
- """Test Legendre table calculation
- """
+ """Test Legendre table calculation"""
# double-check our table generation
n = 10
for ch_type in ['eeg', 'meg']:
@@ -98,8 +96,7 @@ def test_legendre_table():
@testing.requires_testing_data
def test_make_field_map_eeg():
- """Test interpolation of EEG field onto head
- """
+ """Test interpolation of EEG field onto head"""
evoked = read_evokeds(evoked_fname, condition='Left Auditory')
evoked.info['bads'] = ['MEG 2443', 'EEG 053'] # add some bads
surf = get_head_surf('sample', subjects_dir=subjects_dir)
@@ -124,8 +121,7 @@ def test_make_field_map_eeg():
@testing.requires_testing_data
@slow_test
def test_make_field_map_meg():
- """Test interpolation of MEG field onto helmet | head
- """
+ """Test interpolation of MEG field onto helmet | head"""
evoked = read_evokeds(evoked_fname, condition='Left Auditory')
info = evoked.info
surf = get_meg_helmet_surf(info)
diff --git a/mne/forward/tests/test_forward.py b/mne/forward/tests/test_forward.py
index 2e4b770..3f1f330 100644
--- a/mne/forward/tests/test_forward.py
+++ b/mne/forward/tests/test_forward.py
@@ -57,7 +57,10 @@ def test_convert_forward():
"""Test converting forward solution between different representations
"""
fwd = read_forward_solution(fname_meeg_grad)
- assert_true(repr(fwd))
+ fwd_repr = repr(fwd)
+ assert_true('306' in fwd_repr)
+ assert_true('60' in fwd_repr)
+ assert_true(fwd_repr)
assert_true(isinstance(fwd, Forward))
# look at surface orientation
fwd_surf = convert_forward_solution(fwd, surf_ori=True)
diff --git a/mne/forward/tests/test_make_forward.py b/mne/forward/tests/test_make_forward.py
index 1389db1..ca6b816 100644
--- a/mne/forward/tests/test_make_forward.py
+++ b/mne/forward/tests/test_make_forward.py
@@ -10,7 +10,7 @@ import numpy as np
from numpy.testing import (assert_equal, assert_allclose)
from mne.datasets import testing
-from mne.io import Raw, read_raw_kit, read_raw_bti, read_info
+from mne.io import read_raw_fif, read_raw_kit, read_raw_bti, read_info
from mne.io.constants import FIFF
from mne import (read_forward_solution, make_forward_solution,
convert_forward_solution, setup_volume_source_space,
@@ -190,7 +190,8 @@ def test_make_forward_solution_kit():
_compare_forwards(fwd, fwd_py, 274, n_src)
# CTF with compensation changed in python
- ctf_raw = Raw(fname_ctf_raw, compensation=2)
+ ctf_raw = read_raw_fif(fname_ctf_raw, add_eeg_ref=False)
+ ctf_raw.apply_gradient_compensation(2)
fwd_py = make_forward_solution(ctf_raw.info, fname_trans, src,
fname_bem_meg, eeg=False, meg=True)
diff --git a/mne/gui/__init__.py b/mne/gui/__init__.py
index acdd7b8..ac5231f 100644
--- a/mne/gui/__init__.py
+++ b/mne/gui/__init__.py
@@ -23,7 +23,7 @@ def combine_kit_markers():
return gui
-def coregistration(tabbed=False, split=True, scene_width=0o1, inst=None,
+def coregistration(tabbed=False, split=True, scene_width=500, inst=None,
subject=None, subjects_dir=None):
"""Coregister an MRI with a subject's head shape
diff --git a/mne/gui/_coreg_gui.py b/mne/gui/_coreg_gui.py
index a805c14..cb9d522 100644
--- a/mne/gui/_coreg_gui.py
+++ b/mne/gui/_coreg_gui.py
@@ -8,6 +8,7 @@ import os
from ..externals.six.moves import queue
import re
from threading import Thread
+import traceback
import warnings
import numpy as np
@@ -17,8 +18,8 @@ from scipy.spatial.distance import cdist
try:
from mayavi.core.ui.mayavi_scene import MayaviScene
from mayavi.tools.mlab_scene_model import MlabSceneModel
- from pyface.api import (error, confirm, warning, OK, YES, information,
- FileDialog, GUI)
+ from pyface.api import (error, confirm, OK, YES, NO, CANCEL,
+ information, FileDialog, GUI)
from traits.api import (Bool, Button, cached_property, DelegatesTo,
Directory, Enum, Float, HasTraits,
HasPrivateTraits, Instance, Int, on_trait_change,
@@ -37,16 +38,15 @@ except Exception:
NoButtons = SceneEditor = trait_wraith
+from ..bem import make_bem_solution, write_bem_solution
from ..coreg import bem_fname, trans_fname
-from ..forward import prepare_bem_model
from ..transforms import (write_trans, read_trans, apply_trans, rotation,
translation, scaling, rotation_angles, Transform)
from ..coreg import (fit_matched_points, fit_point_cloud, scale_mri,
- _point_cloud_error)
+ _find_fiducials_files, _point_cloud_error)
from ..utils import get_subjects_dir, logger
from ._fiducials_gui import MRIHeadWithFiducialsModel, FiducialsPanel
-from ._file_traits import (set_mne_root, trans_wildcard, InstSource,
- SubjectSelectorPanel)
+from ._file_traits import trans_wildcard, InstSource, SubjectSelectorPanel
from ._viewer import (defaults, HeadViewController, PointObject, SurfaceObject,
_testing_mode)
@@ -476,49 +476,22 @@ class CoregModel(HasPrivateTraits):
self.rot_x, self.rot_y, self.rot_z = est[:3]
- def get_scaling_job(self, subject_to):
- desc = 'Scaling %s' % subject_to
- func = scale_mri
- args = (self.mri.subject, subject_to, self.scale)
- kwargs = dict(overwrite=True, subjects_dir=self.mri.subjects_dir)
- return (desc, func, args, kwargs)
-
- def get_prepare_bem_model_job(self, subject_to):
+ def get_scaling_job(self, subject_to, skip_fiducials, do_bem_sol):
+ "Find all arguments needed for the scaling worker"
subjects_dir = self.mri.subjects_dir
subject_from = self.mri.subject
-
- bem_name = 'inner_skull-bem'
- bem_file = bem_fname.format(subjects_dir=subjects_dir,
- subject=subject_from, name=bem_name)
- if not os.path.exists(bem_file):
+ bem_names = []
+ if do_bem_sol:
pattern = bem_fname.format(subjects_dir=subjects_dir,
- subject=subject_to, name='(.+-bem)')
- bem_dir, bem_file = os.path.split(pattern)
- m = None
- bem_file_pattern = re.compile(bem_file)
- for name in os.listdir(bem_dir):
- m = bem_file_pattern.match(name)
- if m is not None:
- break
-
- if m is None:
- pattern = bem_fname.format(subjects_dir=subjects_dir,
- subject=subject_to, name='*-bem')
- err = ("No bem file found; looking for files matching "
- "%s" % pattern)
- error(None, err)
-
- bem_name = m.group(1)
-
- bem_file = bem_fname.format(subjects_dir=subjects_dir,
- subject=subject_to, name=bem_name)
-
- # job
- desc = 'mne_prepare_bem_model for %s' % subject_to
- func = prepare_bem_model
- args = (bem_file,)
- kwargs = {}
- return (desc, func, args, kwargs)
+ subject=subject_from, name='(.+-bem)')
+ bem_dir, pattern = os.path.split(pattern)
+ for filename in os.listdir(bem_dir):
+ match = re.match(pattern, filename)
+ if match:
+ bem_names.append(match.group(1))
+
+ return (subjects_dir, subject_from, subject_to, self.scale,
+ skip_fiducials, bem_names)
def load_trans(self, fname):
"""Load the head-mri transform from a fif file
@@ -597,7 +570,7 @@ class CoregPanel(HasPrivateTraits):
reset_params = Button(label='Reset')
grow_hair = DelegatesTo('model')
n_scale_params = DelegatesTo('model')
- scale_step = Float(1.01)
+ scale_step = Float(0.01)
scale_x = DelegatesTo('model')
scale_x_dec = Button('-')
scale_x_inc = Button('+')
@@ -655,7 +628,6 @@ class CoregPanel(HasPrivateTraits):
queue_current = Str('')
queue_len = Int(0)
queue_len_str = Property(Str, depends_on=['queue_len'])
- error = Str('')
view = View(VGroup(Item('grow_hair', show_label=True),
Item('n_scale_params', label='MRI Scaling',
@@ -785,26 +757,48 @@ class CoregPanel(HasPrivateTraits):
def __init__(self, *args, **kwargs):
super(CoregPanel, self).__init__(*args, **kwargs)
- # setup save worker
+ # Setup scaling worker
def worker():
while True:
- desc, cmd, args, kwargs = self.queue.get()
-
+ (subjects_dir, subject_from, subject_to, scale, skip_fiducials,
+ bem_names) = self.queue.get()
self.queue_len -= 1
- self.queue_current = 'Processing: %s' % desc
- # task
+ # Scale MRI files
+ self.queue_current = 'Scaling %s...' % subject_to
try:
- cmd(*args, **kwargs)
- except Exception as err:
- self.error = str(err)
- res = "Error in %s"
+ scale_mri(subject_from, subject_to, scale, True,
+ subjects_dir, skip_fiducials)
+ except:
+ logger.error('Error scaling %s:\n' % subject_to +
+ traceback.format_exc())
+ self.queue_feedback = ('Error scaling %s (see Terminal)' %
+ subject_to)
+ bem_names = () # skip bem solutions
else:
- res = "Done: %s"
-
- # finalize
+ self.queue_feedback = 'Done scaling %s.' % subject_to
+
+ # Precompute BEM solutions
+ for bem_name in bem_names:
+ self.queue_current = ('Computing %s solution...' %
+ bem_name)
+ try:
+ bem_file = bem_fname.format(subjects_dir=subjects_dir,
+ subject=subject_to,
+ name=bem_name)
+ bemsol = make_bem_solution(bem_file)
+ write_bem_solution(bem_file[:-4] + '-sol.fif', bemsol)
+ except:
+ logger.error('Error computing %s solution:\n' %
+ bem_name + traceback.format_exc())
+ self.queue_feedback = ('Error computing %s solution '
+ '(see Terminal)' % bem_name)
+ else:
+ self.queue_feedback = ('Done computing %s solution.' %
+ bem_name)
+
+ # Finalize
self.queue_current = ''
- self.queue_feedback = res % desc
self.queue.task_done()
t = Thread(target=worker)
@@ -882,18 +876,6 @@ class CoregPanel(HasPrivateTraits):
self.model.fit_scale_hsp_points()
GUI.set_busy(False)
- def _n_scale_params_changed(self, new):
- if not new:
- return
-
- # Make sure that MNE_ROOT environment variable is set
- if not set_mne_root(True):
- err = ("MNE_ROOT environment variable could not be set. "
- "You will be able to scale MRIs, but the "
- "mne_prepare_bem_model tool will fail. Please install "
- "MNE.")
- warning(None, err, "MNE_ROOT Not Set")
-
def _reset_params_fired(self):
self.model.reset()
@@ -929,28 +911,43 @@ class CoregPanel(HasPrivateTraits):
self.model.load_trans(trans_file)
def _save_fired(self):
- if self.n_scale_params:
- subjects_dir = self.model.mri.subjects_dir
- subject_from = self.model.mri.subject
- subject_to = self.model.raw_subject or self.model.mri.subject
- else:
- subject_to = self.model.mri.subject
+ subjects_dir = self.model.mri.subjects_dir
+ subject_from = self.model.mri.subject
+
+ # check that fiducials are saved
+ skip_fiducials = False
+ if self.n_scale_params and not _find_fiducials_files(subject_from,
+ subjects_dir):
+ msg = ("No fiducials file has been found for {src}. If fiducials "
+ "are not saved, they will not be available in the scaled "
+ "MRI. Should the current fiducials be saved now? "
+ "Select Yes to save the fiducials at "
+ "{src}/bem/{src}-fiducials.fif. "
+ "Select No to proceed scaling the MRI without fiducials.".
+ format(src=subject_from))
+ title = "Save Fiducials for %s?" % subject_from
+ rc = confirm(None, msg, title, cancel=True, default=CANCEL)
+ if rc == CANCEL:
+ return
+ elif rc == YES:
+ self.model.mri.save(self.model.mri.default_fid_fname)
+ elif rc == NO:
+ skip_fiducials = True
+ else:
+ raise RuntimeError("rc=%s" % repr(rc))
- # ask for target subject
+ # find target subject
if self.n_scale_params:
+ subject_to = self.model.raw_subject or subject_from
mridlg = NewMriDialog(subjects_dir=subjects_dir,
subject_from=subject_from,
subject_to=subject_to)
ui = mridlg.edit_traits(kind='modal')
- if ui.result != True: # noqa
+ if not ui.result: # i.e., user pressed cancel
return
subject_to = mridlg.subject_to
-
- # find bem file to run mne_prepare_bem_model
- if self.can_prepare_bem_model and self.prepare_bem_model:
- bem_job = self.model.get_prepare_bem_model_job(subject_to)
else:
- bem_job = None
+ subject_to = subject_from
# find trans file destination
raw_dir = os.path.dirname(self.model.hsp.file)
@@ -962,7 +959,7 @@ class CoregPanel(HasPrivateTraits):
return
trans_file = dlg.path
if not trans_file.endswith('.fif'):
- trans_file = trans_file + '.fif'
+ trans_file += '.fif'
if os.path.exists(trans_file):
answer = confirm(None, "The file %r already exists. Should it "
"be replaced?", "Overwrite File?")
@@ -973,25 +970,23 @@ class CoregPanel(HasPrivateTraits):
try:
self.model.save_trans(trans_file)
except Exception as e:
- error(None, str(e), "Error Saving Trans File")
- return
+ error(None, "Error saving -trans.fif file: %s (See terminal for "
+ "details)" % str(e), "Error Saving Trans File")
+ raise
# save the scaled MRI
if self.n_scale_params:
- job = self.model.get_scaling_job(subject_to)
+ do_bem_sol = self.can_prepare_bem_model and self.prepare_bem_model
+ job = self.model.get_scaling_job(subject_to, skip_fiducials,
+ do_bem_sol)
self.queue.put(job)
self.queue_len += 1
- if bem_job is not None:
- self.queue.put(bem_job)
- self.queue_len += 1
-
def _scale_x_dec_fired(self):
- step = 1. / self.scale_step
- self.scale_x *= step
+ self.scale_x -= self.scale_step
def _scale_x_inc_fired(self):
- self.scale_x *= self.scale_step
+ self.scale_x += self.scale_step
def _scale_x_changed(self, old, new):
if self.n_scale_params == 1:
@@ -1074,7 +1069,10 @@ class NewMriDialog(HasPrivateTraits):
@on_trait_change('subject_to_dir,overwrite')
def update_dialog(self):
- if not self.subject_to:
+ if not self.subject_from:
+ # weird trait state that occurs even when subject_from is set
+ return
+ elif not self.subject_to:
self.feedback = "No subject specified..."
self.can_save = False
self.can_overwrite = False
@@ -1097,7 +1095,7 @@ class NewMriDialog(HasPrivateTraits):
self.can_overwrite = False
-def _make_view(tabbed=False, split=False, scene_width=-1):
+def _make_view(tabbed=False, split=False, scene_width=500):
"""Create a view for the CoregFrame
Parameters
@@ -1121,7 +1119,7 @@ def _make_view(tabbed=False, split=False, scene_width=-1):
scene = VGroup(Item('scene', show_label=False,
editor=SceneEditor(scene_class=MayaviScene),
- dock='vertical', width=500),
+ dock='vertical', width=scene_width),
view_options)
data_panel = VGroup(VGroup(Item('subject_panel', style='custom'),
@@ -1270,7 +1268,7 @@ class CoregFrame(HasTraits):
color = defaults['mri_color']
self.mri_obj = SurfaceObject(points=self.model.transformed_mri_points,
color=color, tri=self.model.mri.tris,
- scene=self.scene)
+ scene=self.scene, name="MRI Scalp")
# on_trait_change was unreliable, so link it another way:
self.model.mri.on_trait_change(self._on_mri_src_change, 'tris')
self.model.sync_trait('transformed_mri_points', self.mri_obj, 'points',
@@ -1280,18 +1278,18 @@ class CoregFrame(HasTraits):
# MRI Fiducials
point_scale = defaults['mri_fid_scale']
self.lpa_obj = PointObject(scene=self.scene, color=lpa_color,
- point_scale=point_scale)
+ point_scale=point_scale, name='LPA')
self.model.mri.sync_trait('lpa', self.lpa_obj, 'points', mutual=False)
self.model.sync_trait('scale', self.lpa_obj, 'trans', mutual=False)
self.nasion_obj = PointObject(scene=self.scene, color=nasion_color,
- point_scale=point_scale)
+ point_scale=point_scale, name='Nasion')
self.model.mri.sync_trait('nasion', self.nasion_obj, 'points',
mutual=False)
self.model.sync_trait('scale', self.nasion_obj, 'trans', mutual=False)
self.rpa_obj = PointObject(scene=self.scene, color=rpa_color,
- point_scale=point_scale)
+ point_scale=point_scale, name='RPA')
self.model.mri.sync_trait('rpa', self.rpa_obj, 'points', mutual=False)
self.model.sync_trait('scale', self.rpa_obj, 'trans', mutual=False)
@@ -1299,7 +1297,7 @@ class CoregFrame(HasTraits):
color = defaults['hsp_point_color']
point_scale = defaults['hsp_points_scale']
p = PointObject(view='cloud', scene=self.scene, color=color,
- point_scale=point_scale, resolution=5)
+ point_scale=point_scale, resolution=5, name='HSP')
self.hsp_obj = p
self.model.hsp.sync_trait('points', p, mutual=False)
self.model.sync_trait('head_mri_trans', p, 'trans', mutual=False)
@@ -1309,21 +1307,21 @@ class CoregFrame(HasTraits):
point_scale = defaults['hsp_fid_scale']
opacity = defaults['hsp_fid_opacity']
p = PointObject(scene=self.scene, color=lpa_color, opacity=opacity,
- point_scale=point_scale)
+ point_scale=point_scale, name='HSP-LPA')
self.hsp_lpa_obj = p
self.model.hsp.sync_trait('lpa', p, 'points', mutual=False)
self.model.sync_trait('head_mri_trans', p, 'trans', mutual=False)
self.sync_trait('hsp_visible', p, 'visible', mutual=False)
p = PointObject(scene=self.scene, color=nasion_color, opacity=opacity,
- point_scale=point_scale)
+ point_scale=point_scale, name='HSP-Nasion')
self.hsp_nasion_obj = p
self.model.hsp.sync_trait('nasion', p, 'points', mutual=False)
self.model.sync_trait('head_mri_trans', p, 'trans', mutual=False)
self.sync_trait('hsp_visible', p, 'visible', mutual=False)
p = PointObject(scene=self.scene, color=rpa_color, opacity=opacity,
- point_scale=point_scale)
+ point_scale=point_scale, name='HSP-RPA')
self.hsp_rpa_obj = p
self.model.hsp.sync_trait('rpa', p, 'points', mutual=False)
self.model.sync_trait('head_mri_trans', p, 'trans', mutual=False)
diff --git a/mne/gui/_fiducials_gui.py b/mne/gui/_fiducials_gui.py
index 4a9973b..635fbb3 100644
--- a/mne/gui/_fiducials_gui.py
+++ b/mne/gui/_fiducials_gui.py
@@ -4,7 +4,6 @@
#
# License: BSD (3-clause)
-from glob import glob
import os
from ..externals.six.moves import map
@@ -13,7 +12,7 @@ try:
from mayavi.core.ui.mayavi_scene import MayaviScene
from mayavi.tools.mlab_scene_model import MlabSceneModel
import numpy as np
- from pyface.api import confirm, FileDialog, OK, YES
+ from pyface.api import confirm, error, FileDialog, OK, YES
from traits.api import (HasTraits, HasPrivateTraits, on_trait_change,
cached_property, DelegatesTo, Event, Instance,
Property, Array, Bool, Button, Enum)
@@ -26,13 +25,14 @@ except Exception:
cached_property = on_trait_change = MayaviScene = MlabSceneModel = \
Array = Bool = Button = DelegatesTo = Enum = Event = Instance = \
Property = View = Item = HGroup = VGroup = SceneEditor = \
- NoButtons = trait_wraith
+ NoButtons = error = trait_wraith
-from ..coreg import fid_fname, fid_fname_general, head_bem_fname
+from ..coreg import (fid_fname, head_bem_fname, _find_fiducials_files,
+ _find_high_res_head)
from ..io import write_fiducials
from ..io.constants import FIFF
from ..utils import get_subjects_dir, logger
-from ._file_traits import (BemSource, fid_wildcard, FiducialsSource,
+from ._file_traits import (SurfaceSource, fid_wildcard, FiducialsSource,
MRISubjectSource, SubjectSelectorPanel)
from ._viewer import (defaults, HeadViewController, PointObject, SurfaceObject,
headview_borders)
@@ -55,7 +55,7 @@ class MRIHeadWithFiducialsModel(HasPrivateTraits):
Right peri-auricular point coordinates.
"""
subject_source = Instance(MRISubjectSource, ())
- bem = Instance(BemSource, ())
+ bem = Instance(SurfaceSource, ())
fid = Instance(FiducialsSource, ())
fid_file = DelegatesTo('fid', 'file')
@@ -64,6 +64,7 @@ class MRIHeadWithFiducialsModel(HasPrivateTraits):
subjects_dir = DelegatesTo('subject_source')
subject = DelegatesTo('subject_source')
subject_has_bem = DelegatesTo('subject_source')
+ use_high_res_head = DelegatesTo('subject_source')
points = DelegatesTo('bem')
norms = DelegatesTo('bem')
tris = DelegatesTo('bem')
@@ -160,34 +161,38 @@ class MRIHeadWithFiducialsModel(HasPrivateTraits):
# if subject changed because of a change of subjects_dir this was not
# triggered
- @on_trait_change('subjects_dir,subject')
+ @on_trait_change('subjects_dir,subject,use_high_res_head')
def _subject_changed(self):
subject = self.subject
subjects_dir = self.subjects_dir
if not subjects_dir or not subject:
return
- # update bem head
- path = head_bem_fname.format(subjects_dir=subjects_dir,
- subject=subject)
+ path = None
+ if self.use_high_res_head:
+ path = _find_high_res_head(subjects_dir=subjects_dir,
+ subject=subject)
+ if not path:
+ error(None, "No high resolution head model was found for "
+ "subject {0}, using standard head instead. In order to "
+ "generate a high resolution head model, run:\n\n"
+ " $ mne make_scalp_surfaces -s {0}"
+ "\n\n".format(subject), "No High Resolution Head")
+
+ if not path:
+ path = head_bem_fname.format(subjects_dir=subjects_dir,
+ subject=subject)
self.bem.file = path
# find fiducials file
- path = fid_fname.format(subjects_dir=subjects_dir, subject=subject)
- if os.path.exists(path):
- self.fid_file = path
- self.lock_fiducials = True
+ fid_files = _find_fiducials_files(subject, subjects_dir)
+ if len(fid_files) == 0:
+ self.fid.reset_traits(['file'])
+ self.lock_fiducials = False
else:
- path = fid_fname_general.format(subjects_dir=subjects_dir,
- subject=subject, head='*')
- fnames = glob(path)
- if fnames:
- path = fnames[0]
- self.fid.file = path
- self.lock_fiducials = True
- else:
- self.fid.reset_traits(['file'])
- self.lock_fiducials = False
+ self.fid_file = fid_files[0].format(subjects_dir=subjects_dir,
+ subject=subject)
+ self.lock_fiducials = True
# does not seem to happen by itself ... so hard code it:
self.reset_fiducials()
diff --git a/mne/gui/_file_traits.py b/mne/gui/_file_traits.py
index 777cd79..4dc0714 100644
--- a/mne/gui/_file_traits.py
+++ b/mne/gui/_file_traits.py
@@ -15,19 +15,19 @@ try:
on_trait_change, Array, Bool, Button, DelegatesTo,
Directory, Enum, Event, File, Instance, Int, List,
Property, Str)
- from traitsui.api import View, Item, VGroup
+ from traitsui.api import View, Item, VGroup, HGroup
from pyface.api import (DirectoryDialog, OK, ProgressDialog, error,
information)
except Exception:
from ..utils import trait_wraith
HasTraits = HasPrivateTraits = object
cached_property = on_trait_change = Any = Array = Bool = Button = \
- DelegatesTo = Directory = Enum = Event = File = Instance = \
+ DelegatesTo = Directory = Enum = Event = File = Instance = HGroup = \
Int = List = Property = Str = View = Item = VGroup = trait_wraith
from ..io.constants import FIFF
from ..io import read_info, read_fiducials
-from ..surface import read_bem_surfaces
+from ..surface import read_bem_surfaces, read_surface
from ..coreg import (_is_mri_subject, _mri_subject_has_bem,
create_default_subject)
from ..utils import get_config, set_config
@@ -92,7 +92,7 @@ def _get_root_home(cfg, name, check_fun):
root = dlg.path
problem = check_fun(root)
if problem is None:
- set_config(cfg, root)
+ set_config(cfg, root, set_env=False)
else:
return None
return root
@@ -185,27 +185,27 @@ def _mne_root_problem(mne_root):
"installation, consider reinstalling." % mne_root)
-class BemSource(HasTraits):
- """Expose points and tris of a given BEM file
+class SurfaceSource(HasTraits):
+ """Expose points and tris of a file storing a surface
Parameters
----------
file : File
- Path to the BEM file (*.fif).
+ Path to a *-bem.fif file or a surface containing a Freesurfer surface.
Attributes
----------
pts : Array, shape = (n_pts, 3)
- BEM file points.
- tri : Array, shape = (n_tri, 3)
- BEM file triangles.
+ Point coordinates.
+ tris : Array, shape = (n_tri, 3)
+ Triangles.
Notes
-----
tri is always updated after pts, so in case downstream objects depend on
- both, they should sync to a change in tri.
+ both, they should sync to a change in tris.
"""
- file = File(exists=True, filter=['*.fif'])
+ file = File(exists=True, filter=['*.fif', '*.*'])
points = Array(shape=(None, 3), value=np.empty((0, 3)))
norms = Array
tris = Array(shape=(None, 3), value=np.empty((0, 3)))
@@ -213,10 +213,24 @@ class BemSource(HasTraits):
@on_trait_change('file')
def read_file(self):
if os.path.exists(self.file):
- bem = read_bem_surfaces(self.file)[0]
- self.points = bem['rr']
- self.norms = bem['nn']
- self.tris = bem['tris']
+ if self.file.endswith('.fif'):
+ bem = read_bem_surfaces(self.file)[0]
+ self.points = bem['rr']
+ self.norms = bem['nn']
+ self.tris = bem['tris']
+ else:
+ try:
+ points, tris = read_surface(self.file)
+ points /= 1e3
+ self.points = points
+ self.norms = []
+ self.tris = tris
+ except Exception:
+ error(message="Error loading surface from %s (see "
+ "Terminal for details).",
+ title="Error Loading Surface")
+ self.reset_traits(['file'])
+ raise
else:
self.points = np.empty((0, 3))
self.norms = np.empty((0, 3))
@@ -250,17 +264,24 @@ class FiducialsSource(HasTraits):
if not os.path.exists(self.file):
return None
- points = np.zeros((3, 3))
- fids, _ = read_fiducials(self.file)
- for fid in fids:
- ident = fid['ident']
- if ident == FIFF.FIFFV_POINT_LPA:
- points[0] = fid['r']
- elif ident == FIFF.FIFFV_POINT_NASION:
- points[1] = fid['r']
- elif ident == FIFF.FIFFV_POINT_RPA:
- points[2] = fid['r']
- return points
+ try:
+ points = np.zeros((3, 3))
+ fids, _ = read_fiducials(self.file)
+ for fid in fids:
+ ident = fid['ident']
+ if ident == FIFF.FIFFV_POINT_LPA:
+ points[0] = fid['r']
+ elif ident == FIFF.FIFFV_POINT_NASION:
+ points[1] = fid['r']
+ elif ident == FIFF.FIFFV_POINT_RPA:
+ points[2] = fid['r']
+ return points
+ except Exception as err:
+ error(None, "Error reading fiducials from %s: %s (See terminal "
+ "for more information)" % (self.fname, str(err)),
+ "Error Reading Fiducials")
+ self.reset_traits(['file'])
+ raise
class InstSource(HasPrivateTraits):
@@ -318,7 +339,14 @@ class InstSource(HasPrivateTraits):
@cached_property
def _get_inst(self):
if self.file:
- return read_info(self.file)
+ info = read_info(self.file)
+ if info['dig'] is None:
+ error(None, "The selected FIFF file does not contain "
+ "digitizer information. Please select a different "
+ "file.", "Error Reading FIFF File")
+ self.reset_traits(['file'])
+ else:
+ return info
@cached_property
def _get_inst_dir(self):
@@ -405,6 +433,7 @@ class MRISubjectSource(HasPrivateTraits):
subjects_dir = Directory(exists=True)
subjects = Property(List(Str), depends_on=['subjects_dir', 'refresh'])
subject = Enum(values='subjects')
+ use_high_res_head = Bool(True)
# info
can_create_fsaverage = Property(Bool, depends_on=['subjects_dir',
@@ -481,6 +510,7 @@ class SubjectSelectorPanel(HasPrivateTraits):
subjects_dir = DelegatesTo('model')
subject = DelegatesTo('model')
subjects = DelegatesTo('model')
+ use_high_res_head = DelegatesTo('model')
create_fsaverage = Button("Copy FsAverage to Subjects Folder",
desc="Copy the files for the fsaverage subject "
@@ -488,6 +518,8 @@ class SubjectSelectorPanel(HasPrivateTraits):
view = View(VGroup(Item('subjects_dir', label='subjects_dir'),
'subject',
+ HGroup(Item('use_high_res_head',
+ label='High Resolution Head')),
Item('create_fsaverage', show_label=False,
enabled_when='can_create_fsaverage')))
@@ -507,3 +539,12 @@ class SubjectSelectorPanel(HasPrivateTraits):
raise
finally:
prog.close()
+
+ def _subjects_dir_changed(self, old, new):
+ if new and self.subjects == ['']:
+ information(None, "The directory selected as subjects-directory "
+ "(%s) does not contain any valid MRI subjects. MRI "
+ "subjects need to contain head surface models which "
+ "can be created by running:\n\n $ mne "
+ "make_scalp_surfaces" % self.subjects_dir,
+ "No Subjects Found")
diff --git a/mne/gui/_kit2fiff_gui.py b/mne/gui/_kit2fiff_gui.py
index 3ce49ad..d77b6b0 100644
--- a/mne/gui/_kit2fiff_gui.py
+++ b/mne/gui/_kit2fiff_gui.py
@@ -4,7 +4,9 @@
#
# License: BSD (3-clause)
+from collections import Counter
import os
+
import numpy as np
from scipy.linalg import inv
from threading import Thread
@@ -18,7 +20,8 @@ from ..utils import logger
try:
from mayavi.core.ui.mayavi_scene import MayaviScene
from mayavi.tools.mlab_scene_model import MlabSceneModel
- from pyface.api import confirm, error, FileDialog, OK, YES, information
+ from pyface.api import (confirm, error, FileDialog, OK, YES, information,
+ ProgressDialog)
from traits.api import (HasTraits, HasPrivateTraits, cached_property,
Instance, Property, Bool, Button, Enum, File,
Float, Int, List, Str, Array, DelegatesTo)
@@ -34,10 +37,12 @@ except Exception:
Str = Array = spring = View = Item = HGroup = VGroup = EnumEditor = \
NoButtons = CheckListEditor = SceneEditor = TextEditor = trait_wraith
-from ..io.kit.kit import RawKIT, KIT
-from ..transforms import (apply_trans, als_ras_trans, als_ras_trans_mm,
+from ..io.constants import FIFF
+from ..io.kit.kit import RawKIT, KIT, _make_stim_channel, _default_stim_chs
+from ..transforms import (apply_trans, als_ras_trans,
get_ras_to_neuromag_trans, Transform)
from ..coreg import _decimate_points, fit_matched_points
+from ..event import _find_events
from ._marker_gui import CombineMarkersPanel, CombineMarkersModel
from ._help import read_tooltips
from ._viewer import (HeadViewController, headview_item, PointObject,
@@ -48,12 +53,12 @@ use_editor = CheckListEditor(cols=5, values=[(i, str(i)) for i in range(5)])
backend_is_wx = False # is there a way to determine this?
if backend_is_wx:
# wx backend allows labels for wildcards
- hsp_points_wildcard = ['Head Shape Points (*.txt)|*.txt']
- hsp_fid_wildcard = ['Head Shape Fiducials (*.txt)|*.txt']
+ hsp_wildcard = ['Head Shape Points (*.hsp;*.txt)|*.hsp;*.txt']
+ elp_wildcard = ['Head Shape Fiducials (*.elp;*.txt)|*.elp;*.txt']
kit_con_wildcard = ['Continuous KIT Files (*.sqd;*.con)|*.sqd;*.con']
else:
- hsp_points_wildcard = ['*.txt']
- hsp_fid_wildcard = ['*.txt']
+ hsp_wildcard = ['*.hsp;*.txt']
+ elp_wildcard = ['*.elp;*.txt']
kit_con_wildcard = ['*.sqd;*.con']
@@ -71,13 +76,11 @@ class Kit2FiffModel(HasPrivateTraits):
# Input Traits
markers = Instance(CombineMarkersModel, ())
sqd_file = File(exists=True, filter=kit_con_wildcard)
- hsp_file = File(exists=True, filter=hsp_points_wildcard, desc="Digitizer "
- "head shape")
- fid_file = File(exists=True, filter=hsp_fid_wildcard, desc="Digitizer "
- "fiducials")
+ hsp_file = File(exists=True, filter=hsp_wildcard)
+ fid_file = File(exists=True, filter=elp_wildcard)
stim_coding = Enum(">", "<", "channel")
stim_chs = Str("")
- stim_chs_array = Property(depends_on='stim_chs')
+ stim_chs_array = Property(depends_on=['raw', 'stim_chs', 'stim_coding'])
stim_chs_ok = Property(depends_on='stim_chs_array')
stim_chs_comment = Property(depends_on='stim_chs_array')
stim_slope = Enum("-", "+")
@@ -104,17 +107,27 @@ class Kit2FiffModel(HasPrivateTraits):
dev_head_trans = Property(depends_on=['elp', 'mrk', 'use_mrk'])
head_dev_trans = Property(depends_on=['dev_head_trans'])
+ # event preview
+ raw = Property(depends_on='sqd_file')
+ misc_chs = Property(List, depends_on='raw')
+ misc_chs_desc = Property(Str, depends_on='misc_chs')
+ misc_data = Property(Array, depends_on='raw')
+ can_test_stim = Property(Bool, depends_on='raw')
+
# info
sqd_fname = Property(Str, depends_on='sqd_file')
hsp_fname = Property(Str, depends_on='hsp_file')
fid_fname = Property(Str, depends_on='fid_file')
- can_save = Property(Bool, depends_on=['stim_chs_ok', 'sqd_file', 'fid',
+ can_save = Property(Bool, depends_on=['stim_chs_ok', 'fid',
'elp', 'hsp', 'dev_head_trans'])
+ # Show GUI feedback (like error messages and progress bar)
+ show_gui = Bool(False)
+
@cached_property
def _get_can_save(self):
"Only allow saving when either all or no head shape elements are set."
- if not self.stim_chs_ok or not self.sqd_file:
+ if not self.stim_chs_ok:
return False
has_all_hsp = (np.any(self.dev_head_trans) and np.any(self.hsp) and
@@ -126,6 +139,10 @@ class Kit2FiffModel(HasPrivateTraits):
return not has_any_hsp
@cached_property
+ def _get_can_test_stim(self):
+ return self.raw is not None
+
+ @cached_property
def _get_dev_head_trans(self):
if (self.mrk is None) or not np.any(self.fid):
return np.eye(4)
@@ -135,9 +152,10 @@ class Kit2FiffModel(HasPrivateTraits):
n_use = len(self.use_mrk)
if n_use < 3:
- error(None, "Estimating the device head transform requires at "
- "least 3 marker points. Please adjust the markers used.",
- "Not Enough Marker Points")
+ if self.show_gui:
+ error(None, "Estimating the device head transform requires at "
+ "least 3 marker points. Please adjust the markers used.",
+ "Not Enough Marker Points")
return
elif n_use < 5:
src_pts = src_pts[self.use_mrk]
@@ -164,9 +182,10 @@ class Kit2FiffModel(HasPrivateTraits):
if len(pts) < 8:
raise ValueError("File contains %i points, need 8" % len(pts))
except Exception as err:
- error(None, str(err), "Error Reading Fiducials")
+ if self.show_gui:
+ error(None, str(err), "Error Reading Fiducials")
self.reset_traits(['fid_file'])
- raise
+ raise err
else:
return pts
@@ -218,19 +237,62 @@ class Kit2FiffModel(HasPrivateTraits):
"which is more than the recommended maximum ({n_rec}). "
"The file will be automatically downsampled, which "
"might take a while. A better way to downsample is "
- "using FastScan.")
- msg = msg.format(n_in=n_pts, n_rec=KIT.DIG_POINTS)
- information(None, msg, "Too Many Head Shape Points")
+ "using FastScan.".
+ format(n_in=n_pts, n_rec=KIT.DIG_POINTS))
+ if self.show_gui:
+ information(None, msg, "Too Many Head Shape Points")
pts = _decimate_points(pts, 5)
except Exception as err:
- error(None, str(err), "Error Reading Head Shape")
+ if self.show_gui:
+ error(None, str(err), "Error Reading Head Shape")
self.reset_traits(['hsp_file'])
raise
else:
return pts
@cached_property
+ def _get_misc_chs(self):
+ if not self.raw:
+ return
+ return [i for i, ch in enumerate(self.raw.info['chs']) if
+ ch['kind'] == FIFF.FIFFV_MISC_CH]
+
+ @cached_property
+ def _get_misc_chs_desc(self):
+ if self.misc_chs is None:
+ return "No SQD file selected..."
+ elif np.all(np.diff(self.misc_chs) == 1):
+ return "%i:%i" % (self.misc_chs[0], self.misc_chs[-1] + 1)
+ else:
+ return "%i... (discontinuous)" % self.misc_chs[0]
+
+ @cached_property
+ def _get_misc_data(self):
+ if not self.raw:
+ return
+ if self.show_gui:
+ # progress dialog with indefinite progress bar
+ prog = ProgressDialog(title="Loading SQD data...",
+ message="Loading stim channel data from SQD "
+ "file ...")
+ prog.open()
+ prog.update(0)
+ else:
+ prog = None
+
+ try:
+ data, times = self.raw[self.misc_chs]
+ except Exception as err:
+ if self.show_gui:
+ error(None, str(err), "Error Creating FsAverage")
+ raise err
+ finally:
+ if self.show_gui:
+ prog.close()
+ return data
+
+ @cached_property
def _get_mrk(self):
return apply_trans(als_ras_trans, self.markers.mrk3.points)
@@ -238,11 +300,23 @@ class Kit2FiffModel(HasPrivateTraits):
def _get_polhemus_neuromag_trans(self):
if self.elp_raw is None:
return
- pts = apply_trans(als_ras_trans_mm, self.elp_raw[:3])
- nasion, lpa, rpa = pts
+ nasion, lpa, rpa = apply_trans(als_ras_trans, self.elp_raw[:3])
trans = get_ras_to_neuromag_trans(nasion, lpa, rpa)
- trans = np.dot(trans, als_ras_trans_mm)
- return trans
+ return np.dot(trans, als_ras_trans)
+
+ @cached_property
+ def _get_raw(self):
+ if not self.sqd_file:
+ return
+ try:
+ return RawKIT(self.sqd_file, stim=None)
+ except Exception as err:
+ self.reset_traits(['sqd_file'])
+ if self.show_gui:
+ error(None, "Error reading SQD data file: %s (Check the "
+ "terminal output for details)" % str(err),
+ "Error Reading SQD file")
+ raise err
@cached_property
def _get_sqd_fname(self):
@@ -253,25 +327,33 @@ class Kit2FiffModel(HasPrivateTraits):
@cached_property
def _get_stim_chs_array(self):
- if not self.stim_chs.strip():
- return True
- try:
- out = eval("r_[%s]" % self.stim_chs, vars(np))
- if out.dtype.kind != 'i':
- raise TypeError("Need array of int")
- except:
- return None
+ if self.raw is None:
+ return
+ elif not self.stim_chs.strip():
+ picks = _default_stim_chs(self.raw.info)
else:
- return out
+ try:
+ picks = eval("r_[%s]" % self.stim_chs, vars(np))
+ if picks.dtype.kind != 'i':
+ raise TypeError("Need array of int")
+ except:
+ return None
+
+ if self.stim_coding == '<': # Big-endian
+ return picks[::-1]
+ else:
+ return picks
@cached_property
def _get_stim_chs_comment(self):
- if self.stim_chs_array is None:
+ if self.raw is None:
+ return ""
+ elif not self.stim_chs_ok:
return "Invalid!"
- elif self.stim_chs_array is True:
- return "Ok: Default channels"
+ elif not self.stim_chs.strip():
+ return "Default: The first 8 MISC channels"
else:
- return "Ok: %i channels" % len(self.stim_chs_array)
+ return "Ok: %i channels" % len(self.stim_chs_array)
@cached_property
def _get_stim_chs_ok(self):
@@ -283,19 +365,27 @@ class Kit2FiffModel(HasPrivateTraits):
self.reset_traits(['sqd_file', 'hsp_file', 'fid_file', 'use_mrk'])
def get_event_info(self):
- """
- Return a string with the number of events found for each trigger value
- """
- if len(self.events) == 0:
- return "No events found."
-
- count = ["Events found:"]
- events = np.array(self.events)
- for i in np.unique(events):
- n = np.sum(events == i)
- count.append('%3i: %i' % (i, n))
+ """Count events with current stim channel settings
- return os.linesep.join(count)
+ Returns
+ -------
+ event_count : Counter
+ Counter mapping event ID to number of occurrences.
+ """
+ if self.misc_data is None:
+ return
+ idx = [self.misc_chs.index(ch) for ch in self.stim_chs_array]
+ data = self.misc_data[idx]
+ if self.stim_coding == 'channel':
+ coding = 'channel'
+ else:
+ coding = 'binary'
+ stim_ch = _make_stim_channel(data, self.stim_slope,
+ self.stim_threshold, coding,
+ self.stim_chs_array)
+ events = _find_events(stim_ch, self.raw.first_samp, consecutive=True,
+ min_samples=3)
+ return Counter(events[:, 2])
def get_raw(self, preload=False):
"""Create a raw object based on the current model settings
@@ -304,30 +394,17 @@ class Kit2FiffModel(HasPrivateTraits):
raise ValueError("Not all necessary parameters are set")
# stim channels and coding
- if self.stim_chs_array is True:
- if self.stim_coding == 'channel':
- stim_code = 'channel'
- raise NotImplementedError("Finding default event channels")
- else:
- stim = self.stim_coding
- stim_code = 'binary'
+ if self.stim_coding == 'channel':
+ stim_code = 'channel'
+ elif self.stim_coding in '<>':
+ stim_code = 'binary'
else:
- stim = self.stim_chs_array
- if self.stim_coding == 'channel':
- stim_code = 'channel'
- elif self.stim_coding == '<':
- stim_code = 'binary'
- elif self.stim_coding == '>':
- # if stim is
- stim = stim[::-1]
- stim_code = 'binary'
- else:
- raise RuntimeError("stim_coding=%r" % self.stim_coding)
+ raise RuntimeError("stim_coding=%r" % self.stim_coding)
logger.info("Creating raw with stim=%r, slope=%r, stim_code=%r, "
- "stimthresh=%r", stim, self.stim_slope, stim_code,
- self.stim_threshold)
- raw = RawKIT(self.sqd_file, preload=preload, stim=stim,
+ "stimthresh=%r", self.stim_chs_array, self.stim_slope,
+ stim_code, self.stim_threshold)
+ raw = RawKIT(self.sqd_file, preload=preload, stim=self.stim_chs_array,
slope=self.stim_slope, stim_code=stim_code,
stimthresh=self.stim_threshold)
@@ -375,6 +452,10 @@ class Kit2FiffPanel(HasPrivateTraits):
sqd_fname = DelegatesTo('model')
hsp_fname = DelegatesTo('model')
fid_fname = DelegatesTo('model')
+ misc_chs_desc = DelegatesTo('model')
+ can_test_stim = DelegatesTo('model')
+ test_stim = Button(label="Find Events")
+ plot_raw = Button(label="Plot Raw")
# Source Files
reset_dig = Button
@@ -399,15 +480,20 @@ class Kit2FiffPanel(HasPrivateTraits):
VGroup(VGroup(Item('sqd_file', label="Data",
tooltip=tooltips['sqd_file']),
Item('sqd_fname', show_label=False, style='readonly'),
- Item('hsp_file', label='Dig Head Shape'),
+ Item('hsp_file', label='Digitizer\nHead Shape',
+ tooltip=tooltips['hsp_file']),
Item('hsp_fname', show_label=False, style='readonly'),
- Item('fid_file', label='Dig Points'),
+ Item('fid_file', label='Digitizer\nFiducials',
+ tooltip=tooltips['fid_file']),
Item('fid_fname', show_label=False, style='readonly'),
Item('reset_dig', label='Clear Digitizer Files',
show_label=False),
- Item('use_mrk', editor=use_editor, style='custom'),
+ Item('use_mrk', editor=use_editor, style='custom',
+ tooltip=tooltips['use_mrk']),
label="Sources", show_border=True),
- VGroup(Item('stim_slope', label="Event Onset", style='custom',
+ VGroup(Item('misc_chs_desc', label='MISC Channels',
+ style='readonly'),
+ Item('stim_slope', label="Event Onset", style='custom',
tooltip=tooltips['stim_slope'],
editor=EnumEditor(
values={'+': '2:Peak (0 to 5 V)',
@@ -423,9 +509,15 @@ class Kit2FiffPanel(HasPrivateTraits):
tooltip=tooltips["stim_chs"],
editor=TextEditor(evaluate_name='stim_chs_ok',
auto_set=True)),
- Item('stim_chs_comment', label='>', style='readonly'),
+ Item('stim_chs_comment', label='Evaluation',
+ style='readonly', show_label=False),
Item('stim_threshold', label='Threshold',
tooltip=tooltips['stim_threshold']),
+ HGroup(Item('test_stim', enabled_when='can_test_stim',
+ show_label=False),
+ Item('plot_raw', enabled_when='can_test_stim',
+ show_label=False),
+ show_labels=False),
label='Events', show_border=True),
HGroup(Item('save_as', enabled_when='can_save'), spring,
'clear_all', show_labels=False),
@@ -467,11 +559,11 @@ class Kit2FiffPanel(HasPrivateTraits):
# setup mayavi visualization
m = self.model
self.fid_obj = PointObject(scene=self.scene, color=(25, 225, 25),
- point_scale=5e-3)
+ point_scale=5e-3, name='Fiducials')
self.elp_obj = PointObject(scene=self.scene, color=(50, 50, 220),
- point_scale=1e-2, opacity=.2)
+ point_scale=1e-2, opacity=.2, name='ELP')
self.hsp_obj = PointObject(scene=self.scene, color=(200, 200, 200),
- point_scale=2e-3)
+ point_scale=2e-3, name='HSP')
if not _testing_mode():
for name, obj in zip(['fid', 'elp', 'hsp'],
[self.fid_obj, self.elp_obj, self.hsp_obj]):
@@ -490,6 +582,9 @@ class Kit2FiffPanel(HasPrivateTraits):
else:
return ''
+ def _plot_raw_fired(self):
+ self.model.raw.plot()
+
def _reset_dig_fired(self):
self.reset_traits(['hsp_file', 'fid_file'])
@@ -527,10 +622,28 @@ class Kit2FiffPanel(HasPrivateTraits):
self.queue.put((raw, fname))
self.queue_len += 1
+ def _test_stim_fired(self):
+ try:
+ events = self.model.get_event_info()
+ except Exception as err:
+ error(None, "Error reading events from SQD data file: %s (Check "
+ "the terminal output for details)" % str(err),
+ "Error Reading events from SQD file")
+ raise err
+
+ if len(events) == 0:
+ information(None, "No events were found with the current "
+ "settings.", "No Events Found")
+ else:
+ lines = ["Events found (ID: n events):"]
+ for id_ in sorted(events):
+ lines.append("%3i: \t%i" % (id_, events[id_]))
+ information(None, '\n'.join(lines), "Events in SQD File")
+
class Kit2FiffFrame(HasTraits):
"""GUI for interpolating between two KIT marker files"""
- model = Instance(Kit2FiffModel, ())
+ model = Instance(Kit2FiffModel, kw={'show_gui': True})
scene = Instance(MlabSceneModel, ())
headview = Instance(HeadViewController)
marker_panel = Instance(CombineMarkersPanel)
diff --git a/mne/gui/help/kit2fiff.json b/mne/gui/help/kit2fiff.json
index 47cea8b..37a5d3f 100644
--- a/mne/gui/help/kit2fiff.json
+++ b/mne/gui/help/kit2fiff.json
@@ -1,7 +1,10 @@
{
+ "hsp_file": "*.hsp or *.txt file containing digitized head shape points",
+ "fid_file": "*.elp or *.txt file containing 8 digitized fiducial points",
"stim_chs": "Define the channels that are used to generate events. If the field is empty, the default channels are used (for NYU systems only). Channels can be defined as comma separated channel numbers (1, 2, 3, 4, 5, 6), ranges (1:7) and combinations of the two (1:4, 7, 10:13).",
"stim_coding": "Specifies how stim-channel events are translated into trigger values. Little- and big-endian assume binary coding. In little-endian order, the first channel is assigned the smallest value (1) and the last channel is assigned the highest value (with 8 channels this would be 128). Channel# implies a different method of coding in which an event in a given channel is assigned the channel number as value.",
"sqd_file": "*.sqd or *.con file containing recorded MEG data",
"stim_slope": "How events are marked in stim channels. Trough: normally the signal is high, events are marked by transitory signal decrease. Peak: normally signal is low, events are marked by an increase.",
- "stim_threshold": "Threshold voltage to detect events in stim channels."
+ "stim_threshold": "Threshold voltage to detect events in stim channels",
+ "use_mrk": "Determine which marker points are used for the MEG-head shape coregistration"
}
diff --git a/mne/gui/tests/test_coreg_gui.py b/mne/gui/tests/test_coreg_gui.py
index a82e09d..4bdd276 100644
--- a/mne/gui/tests/test_coreg_gui.py
+++ b/mne/gui/tests/test_coreg_gui.py
@@ -3,6 +3,7 @@
# License: BSD (3-clause)
import os
+import re
import numpy as np
from numpy.testing import assert_allclose
@@ -14,9 +15,17 @@ import mne
from mne.datasets import testing
from mne.io.kit.tests import data_dir as kit_data_dir
from mne.utils import (_TempDir, requires_traits, requires_mne,
- requires_freesurfer, run_tests_if_main)
+ requires_freesurfer, run_tests_if_main, requires_mayavi)
from mne.externals.six import string_types
+# backend needs to be set early
+try:
+ from traits.etsconfig.api import ETSConfig
+except ImportError:
+ pass
+else:
+ ETSConfig.toolkit = 'qt4'
+
data_path = testing.data_path(download=False)
raw_path = os.path.join(data_path, 'MEG', 'sample',
@@ -39,6 +48,8 @@ def test_coreg_model():
model = CoregModel()
assert_raises(RuntimeError, model.save_trans, 'blah.fif')
+ model.mri.use_high_res_head = False
+
model.mri.subjects_dir = subjects_dir
model.mri.subject = 'sample'
@@ -101,7 +112,26 @@ def test_coreg_model():
assert_true(isinstance(model.fid_eval_str, string_types))
assert_true(isinstance(model.points_eval_str, string_types))
- model.get_prepare_bem_model_job('sample')
+ # scaling job
+ sdir, sfrom, sto, scale, skip_fiducials, bemsol = \
+ model.get_scaling_job('sample2', False, True)
+ assert_equal(sdir, subjects_dir)
+ assert_equal(sfrom, 'sample')
+ assert_equal(sto, 'sample2')
+ assert_equal(scale, model.scale)
+ assert_equal(skip_fiducials, False)
+ # find BEM files
+ bems = set()
+ for fname in os.listdir(os.path.join(subjects_dir, 'sample', 'bem')):
+ match = re.match('sample-(.+-bem)\.fif', fname)
+ if match:
+ bems.add(match.group(1))
+ assert_equal(set(bemsol), bems)
+ sdir, sfrom, sto, scale, skip_fiducials, bemsol = \
+ model.get_scaling_job('sample2', True, False)
+ assert_equal(bemsol, [])
+ assert_true(skip_fiducials)
+
model.load_trans(fname_trans)
from mne.gui._coreg_gui import CoregFrame
@@ -120,13 +150,14 @@ def test_coreg_model():
@requires_mne
@requires_freesurfer
def test_coreg_model_with_fsaverage():
- """Test CoregModel"""
+ """Test CoregModel with the fsaverage brain data"""
tempdir = _TempDir()
from mne.gui._coreg_gui import CoregModel
mne.create_default_subject(subjects_dir=tempdir)
model = CoregModel()
+ model.mri.use_high_res_head = False
model.mri.subjects_dir = tempdir
model.mri.subject = 'fsaverage'
assert_true(model.mri.fid_ok)
@@ -165,12 +196,17 @@ def test_coreg_model_with_fsaverage():
avg_point_distance_1param = np.mean(model.point_distance)
assert_true(avg_point_distance_1param < avg_point_distance)
- desc, func, args, kwargs = model.get_scaling_job('test')
- assert_true(isinstance(desc, string_types))
- assert_equal(args[0], 'fsaverage')
- assert_equal(args[1], 'test')
- assert_allclose(args[2], model.scale)
- assert_equal(kwargs['subjects_dir'], tempdir)
+ # scaling job
+ sdir, sfrom, sto, scale, skip_fiducials, bemsol = \
+ model.get_scaling_job('scaled', False, True)
+ assert_equal(sdir, tempdir)
+ assert_equal(sfrom, 'fsaverage')
+ assert_equal(sto, 'scaled')
+ assert_equal(scale, model.scale)
+ assert_equal(set(bemsol), set(('inner_skull-bem',)))
+ sdir, sfrom, sto, scale, skip_fiducials, bemsol = \
+ model.get_scaling_job('scaled', False, False)
+ assert_equal(bemsol, [])
# scale with 3 parameters
model.n_scale_params = 3
@@ -184,4 +220,23 @@ def test_coreg_model_with_fsaverage():
assert_equal(model.hsp.n_omitted, 0)
+ at testing.requires_testing_data
+ at requires_mayavi
+def test_coreg_gui():
+ """Test Coregistration GUI"""
+ from mne.gui._coreg_gui import CoregFrame
+
+ frame = CoregFrame()
+ frame.edit_traits()
+
+ frame.model.mri.subjects_dir = subjects_dir
+ frame.model.mri.subject = 'sample'
+
+ assert_false(frame.model.mri.fid_ok)
+ frame.model.mri.lpa = [[-0.06, 0, 0]]
+ frame.model.mri.nasion = [[0, 0.05, 0]]
+ frame.model.mri.rpa = [[0.08, 0, 0]]
+ assert_true(frame.model.mri.fid_ok)
+
+
run_tests_if_main()
diff --git a/mne/gui/tests/test_fiducials_gui.py b/mne/gui/tests/test_fiducials_gui.py
index 4eea1f7..5cabefb 100644
--- a/mne/gui/tests/test_fiducials_gui.py
+++ b/mne/gui/tests/test_fiducials_gui.py
@@ -35,7 +35,7 @@ def test_mri_model():
bem_fname = os.path.basename(model.bem.file)
assert_false(model.can_reset)
- assert_equal(bem_fname, 'sample-head.fif')
+ assert_equal(bem_fname, 'lh.seghead')
model.save(tgt_fname)
assert_equal(model.fid_file, tgt_fname)
diff --git a/mne/gui/tests/test_file_traits.py b/mne/gui/tests/test_file_traits.py
index ea90fb1..038fe65 100644
--- a/mne/gui/tests/test_file_traits.py
+++ b/mne/gui/tests/test_file_traits.py
@@ -24,10 +24,10 @@ fid_path = os.path.join(fiff_data_dir, 'fsaverage-fiducials.fif')
@testing.requires_testing_data
@requires_traits
def test_bem_source():
- """Test BemSource"""
- from mne.gui._file_traits import BemSource
+ """Test SurfaceSource"""
+ from mne.gui._file_traits import SurfaceSource
- bem = BemSource()
+ bem = SurfaceSource()
assert_equal(bem.points.shape, (0, 3))
assert_equal(bem.tris.shape, (0, 3))
diff --git a/mne/gui/tests/test_kit2fiff_gui.py b/mne/gui/tests/test_kit2fiff_gui.py
index f6d5f59..6e8f688 100644
--- a/mne/gui/tests/test_kit2fiff_gui.py
+++ b/mne/gui/tests/test_kit2fiff_gui.py
@@ -11,7 +11,7 @@ from nose.tools import assert_true, assert_false, assert_equal
import mne
from mne.io.kit.tests import data_dir as kit_data_dir
-from mne.io import Raw
+from mne.io import read_raw_fif
from mne.utils import _TempDir, requires_traits, run_tests_if_main
mrk_pre_path = os.path.join(kit_data_dir, 'test_mrk_pre.sqd')
@@ -26,25 +26,35 @@ warnings.simplefilter('always')
@requires_traits
def test_kit2fiff_model():
- """Test CombineMarkersModel Traits Model"""
+ """Test CombineMarkersModel Traits Model."""
from mne.gui._kit2fiff_gui import Kit2FiffModel, Kit2FiffPanel
tempdir = _TempDir()
tgt_fname = os.path.join(tempdir, 'test-raw.fif')
model = Kit2FiffModel()
assert_false(model.can_save)
+ assert_equal(model.misc_chs_desc, "No SQD file selected...")
+ assert_equal(model.stim_chs_comment, "")
model.markers.mrk1.file = mrk_pre_path
model.markers.mrk2.file = mrk_post_path
model.sqd_file = sqd_path
+ assert_equal(model.misc_chs_desc, "160:192")
model.hsp_file = hsp_path
assert_false(model.can_save)
model.fid_file = fid_path
assert_true(model.can_save)
+ # events
+ model.stim_slope = '+'
+ assert_equal(model.get_event_info(), {1: 2})
+ model.stim_slope = '-'
+ assert_equal(model.get_event_info(), {254: 2, 255: 2})
+
# stim channels
model.stim_chs = "181:184, 186"
assert_array_equal(model.stim_chs_array, [181, 182, 183, 186])
assert_true(model.stim_chs_ok)
+ assert_equal(model.get_event_info(), {})
model.stim_chs = "181:184, bad"
assert_false(model.stim_chs_ok)
assert_false(model.can_save)
@@ -54,10 +64,10 @@ def test_kit2fiff_model():
# export raw
raw_out = model.get_raw()
raw_out.save(tgt_fname)
- raw = Raw(tgt_fname)
+ raw = read_raw_fif(tgt_fname, add_eeg_ref=False)
# Compare exported raw with the original binary conversion
- raw_bin = Raw(fif_path)
+ raw_bin = read_raw_fif(fif_path, add_eeg_ref=False)
trans_bin = raw.info['dev_head_t']['trans']
want_keys = list(raw_bin.info.keys())
assert_equal(sorted(want_keys), sorted(list(raw.info.keys())))
diff --git a/mne/inverse_sparse/mxne_inverse.py b/mne/inverse_sparse/mxne_inverse.py
index 72e5c75..4182ec2 100644
--- a/mne/inverse_sparse/mxne_inverse.py
+++ b/mne/inverse_sparse/mxne_inverse.py
@@ -374,7 +374,7 @@ def tf_mixed_norm(evoked, forward, noise_cov, alpha_space, alpha_time,
verbose=None):
"""Time-Frequency Mixed-norm estimate (TF-MxNE)
- Compute L1/L2 + L1 mixed-norm solution on time frequency
+ Compute L1/L2 + L1 mixed-norm solution on time-frequency
dictionary. Works with evoked data.
References:
diff --git a/mne/io/__init__.py b/mne/io/__init__.py
index 734869e..6083377 100644
--- a/mne/io/__init__.py
+++ b/mne/io/__init__.py
@@ -7,7 +7,8 @@
from .open import fiff_open, show_fiff, _fiff_get_fid
from .meas_info import (read_fiducials, write_fiducials, read_info, write_info,
- _empty_info, _merge_info, _force_update_info, Info)
+ _empty_info, _merge_info, _force_update_info, Info,
+ anonymize_info)
from .proj import make_eeg_average_ref_proj, Projection
from .tag import _loc_to_coil_trans, _coil_trans_to_loc, _loc_to_eeg_loc
diff --git a/mne/io/array/array.py b/mne/io/array/array.py
index 25834fb..c9214f0 100644
--- a/mne/io/array/array.py
+++ b/mne/io/array/array.py
@@ -55,7 +55,7 @@ class RawArray(_BaseRaw):
info['buffer_size_sec'] = 1. # reasonable default
super(RawArray, self).__init__(info, data,
first_samps=(int(first_samp),),
- verbose=verbose)
+ dtype=dtype, verbose=verbose)
logger.info(' Range : %d ... %d = %9.3f ... %9.3f secs' % (
self.first_samp, self.last_samp,
float(self.first_samp) / info['sfreq'],
diff --git a/mne/io/array/tests/test_array.py b/mne/io/array/tests/test_array.py
index f49bf0d..0a50388 100644
--- a/mne/io/array/tests/test_array.py
+++ b/mne/io/array/tests/test_array.py
@@ -1,5 +1,3 @@
-from __future__ import print_function
-
# Author: Eric Larson <larson.eric.d at gmail.com>
#
# License: BSD (3-clause)
@@ -8,10 +6,12 @@ import os.path as op
import warnings
import matplotlib
+import numpy as np
from numpy.testing import assert_array_almost_equal, assert_allclose
from nose.tools import assert_equal, assert_raises, assert_true
+
from mne import find_events, Epochs, pick_types
-from mne.io import Raw
+from mne.io import read_raw_fif
from mne.io.array import RawArray
from mne.io.tests.test_raw import _test_raw_reader
from mne.io.meas_info import create_info, _kind_dict
@@ -32,15 +32,16 @@ def test_array_raw():
"""
import matplotlib.pyplot as plt
# creating
- raw = Raw(fif_fname).crop(2, 5, copy=False)
+ raw = read_raw_fif(fif_fname, add_eeg_ref=False).crop(2, 5)
data, times = raw[:, :]
sfreq = raw.info['sfreq']
ch_names = [(ch[4:] if 'STI' not in ch else ch)
for ch in raw.info['ch_names']] # change them, why not
# del raw
types = list()
- for ci in range(102):
+ for ci in range(101):
types.extend(('grad', 'grad', 'mag'))
+ types.extend(['ecog', 'seeg', 'hbo']) # really 3 meg channels
types.extend(['stim'] * 9)
types.extend(['eeg'] * 60)
# wrong length
@@ -66,23 +67,27 @@ def test_array_raw():
picks = pick_types(raw2.info, misc=True, exclude='bads')[:4]
assert_equal(len(picks), 4)
raw_lp = raw2.copy()
- with warnings.catch_warnings(record=True):
- raw_lp.filter(0., 4.0 - 0.25, picks=picks, n_jobs=2)
+ raw_lp.filter(None, 4.0, h_trans_bandwidth=4.,
+ filter_length='auto', picks=picks, n_jobs=2, phase='zero',
+ fir_window='hamming')
raw_hp = raw2.copy()
- with warnings.catch_warnings(record=True):
- raw_hp.filter(8.0 + 0.25, None, picks=picks, n_jobs=2)
+ raw_hp.filter(16.0, None, l_trans_bandwidth=4.,
+ filter_length='auto', picks=picks, n_jobs=2, phase='zero',
+ fir_window='hamming')
raw_bp = raw2.copy()
- with warnings.catch_warnings(record=True):
- raw_bp.filter(4.0 + 0.25, 8.0 - 0.25, picks=picks)
+ raw_bp.filter(8.0, 12.0, l_trans_bandwidth=4.,
+ h_trans_bandwidth=4., filter_length='auto', picks=picks,
+ phase='zero', fir_window='hamming')
raw_bs = raw2.copy()
- with warnings.catch_warnings(record=True):
- raw_bs.filter(8.0 + 0.25, 4.0 - 0.25, picks=picks, n_jobs=2)
+ raw_bs.filter(16.0, 4.0, l_trans_bandwidth=4., h_trans_bandwidth=4.,
+ filter_length='auto', picks=picks, n_jobs=2, phase='zero',
+ fir_window='hamming')
data, _ = raw2[picks, :]
lp_data, _ = raw_lp[picks, :]
hp_data, _ = raw_hp[picks, :]
bp_data, _ = raw_bp[picks, :]
bs_data, _ = raw_bs[picks, :]
- sig_dec = 11
+ sig_dec = 15
assert_array_almost_equal(data, lp_data + bp_data + hp_data, sig_dec)
assert_array_almost_equal(data, bp_data + bs_data, sig_dec)
@@ -95,7 +100,8 @@ def test_array_raw():
events = find_events(raw2, stim_channel='STI 014')
events[:, 2] = 1
assert_true(len(events) > 2)
- epochs = Epochs(raw2, events, 1, -0.2, 0.4, preload=True)
+ epochs = Epochs(raw2, events, 1, -0.2, 0.4, preload=True,
+ add_eeg_ref=False)
epochs.plot_drop_log()
epochs.plot()
evoked = epochs.average()
@@ -103,4 +109,10 @@ def test_array_raw():
assert_equal(evoked.nave, len(events) - 1)
plt.close('all')
+ # complex data
+ rng = np.random.RandomState(0)
+ data = rng.randn(1, 100) + 1j * rng.randn(1, 100)
+ raw = RawArray(data, create_info(1, 1000., 'eeg'))
+ assert_allclose(raw._data, data)
+
run_tests_if_main()
diff --git a/mne/io/base.py b/mne/io/base.py
index 82e01c3..87746db 100644
--- a/mne/io/base.py
+++ b/mne/io/base.py
@@ -13,7 +13,6 @@ import os
import os.path as op
import numpy as np
-from scipy import linalg
from .constants import FIFF
from .pick import pick_types, channel_type, pick_channels, pick_info
@@ -23,26 +22,25 @@ from .proj import setup_proj, activate_proj, _proj_equal, ProjMixin
from ..channels.channels import (ContainsMixin, UpdateChannelsMixin,
SetChannelsMixin, InterpolationMixin)
from ..channels.montage import read_montage, _set_montage, Montage
-from .compensator import set_current_comp
+from .compensator import set_current_comp, make_compensator
from .write import (start_file, end_file, start_block, end_block,
write_dau_pack16, write_float, write_double,
write_complex64, write_complex128, write_int,
write_id, write_string, write_name_list, _get_split_size)
-from ..filter import (low_pass_filter, high_pass_filter, band_pass_filter,
- notch_filter, band_stop_filter, resample,
+from ..filter import (filter_data, notch_filter, resample, next_fast_len,
_resample_stim_channels)
-from ..fixes import in1d
from ..parallel import parallel_func
-from ..utils import (_check_fname, _check_pandas_installed,
+from ..utils import (_check_fname, _check_pandas_installed, sizeof_fmt,
_check_pandas_index_arguments, _check_copy_dep,
- check_fname, _get_stim_channel, object_hash,
- logger, verbose, _time_mask, warn, deprecated)
+ check_fname, _get_stim_channel,
+ logger, verbose, _time_mask, warn, SizeMixin,
+ copy_function_doc_to_method_doc)
from ..viz import plot_raw, plot_raw_psd, plot_raw_psd_topo
from ..defaults import _handle_default
from ..externals.six import string_types
from ..event import find_events, concatenate_events
-from ..annotations import _combine_annotations, _onset_to_seconds
+from ..annotations import Annotations, _combine_annotations, _onset_to_seconds
class ToDataFrameMixin(object):
@@ -51,7 +49,7 @@ class ToDataFrameMixin(object):
if picks is None:
picks = list(range(self.info['nchan']))
else:
- if not in1d(picks, np.arange(len(picks_check))).all():
+ if not np.in1d(picks, np.arange(len(picks_check))).all():
raise ValueError('At least one picked channel is not present '
'in this object instance.')
return picks
@@ -245,7 +243,7 @@ def _check_fun(fun, d, *args, **kwargs):
class _BaseRaw(ProjMixin, ContainsMixin, UpdateChannelsMixin,
SetChannelsMixin, InterpolationMixin, ToDataFrameMixin,
- TimeMixin):
+ TimeMixin, SizeMixin):
"""Base class for Raw data
Subclasses must provide the following methods:
@@ -261,8 +259,7 @@ class _BaseRaw(ProjMixin, ContainsMixin, UpdateChannelsMixin,
def __init__(self, info, preload=False,
first_samps=(0,), last_samps=None,
filenames=(None,), raw_extras=(None,),
- comp=None, orig_comp_grade=None, orig_format='double',
- dtype=np.float64, verbose=None):
+ orig_format='double', dtype=np.float64, verbose=None):
# wait until the end to preload data, but triage here
if isinstance(preload, np.ndarray):
# some functions (e.g., filtering) only work w/64-bit data
@@ -299,8 +296,13 @@ class _BaseRaw(ProjMixin, ContainsMixin, UpdateChannelsMixin,
self.verbose = verbose
self._cals = cals
self._raw_extras = list(raw_extras)
- self.comp = comp
- self._orig_comp_grade = orig_comp_grade
+ # deal with compensation (only relevant for CTF data, either CTF
+ # reader or MNE-C converted CTF->FIF files)
+ self._read_comp_grade = self.compensation_grade # read property
+ if self._read_comp_grade is not None:
+ logger.info('Current compensation grade : %d'
+ % self._read_comp_grade)
+ self._comp = None
self._filenames = list(filenames)
self.orig_format = orig_format
self._projectors = list()
@@ -312,6 +314,54 @@ class _BaseRaw(ProjMixin, ContainsMixin, UpdateChannelsMixin,
if load_from_disk:
self._preload_data(preload)
+ @verbose
+ def apply_gradient_compensation(self, grade, verbose=None):
+ """Apply CTF gradient compensation
+
+ .. warning:: The compensation matrices are stored with single
+ precision, so repeatedly switching between different
+ of compensation (e.g., 0->1->3->2) can increase
+ numerical noise, especially if data are saved to
+ disk in between changing grades. It is thus best to
+ only use a single gradient compensation level in
+ final analyses.
+
+ Parameters
+ ----------
+ grade : int
+ CTF gradient compensation level.
+ verbose : bool, str, int, or None
+ If not None, override default verbose level (see mne.verbose).
+
+ Returns
+ -------
+ raw : instance of Raw
+ The modified Raw instance. Works in-place.
+ """
+ grade = int(grade)
+ current_comp = self.compensation_grade
+ if current_comp != grade:
+ if self.proj:
+ raise RuntimeError('Cannot change compensation on data where '
+ 'projectors have been applied')
+ # Figure out what operator to use (varies depending on preload)
+ from_comp = current_comp if self.preload else self._read_comp_grade
+ comp = make_compensator(self.info, from_comp, grade)
+ logger.info('Compensator constructed to change %d -> %d'
+ % (current_comp, grade))
+ set_current_comp(self.info, grade)
+ # We might need to apply it to our data now
+ if self.preload:
+ logger.info('Applying compensator to loaded data')
+ lims = np.concatenate([np.arange(0, len(self.times), 10000),
+ [len(self.times)]])
+ for start, stop in zip(lims[:-1], lims[1:]):
+ self._data[:, start:stop] = np.dot(
+ comp, self._data[:, start:stop])
+ else:
+ self._comp = comp # store it for later use
+ return self
+
@property
def _dtype(self):
"""dtype for loading data (property so subclasses can override)"""
@@ -383,12 +433,12 @@ class _BaseRaw(ProjMixin, ContainsMixin, UpdateChannelsMixin,
# set up cals and mult (cals, compensation, and projector)
cals = self._cals.ravel()[np.newaxis, :]
- if self.comp is not None:
+ if self._comp is not None:
if projector is not None:
- mult = self.comp * cals
+ mult = self._comp * cals
mult = np.dot(projector[idx], mult)
else:
- mult = self.comp[idx] * cals
+ mult = self._comp[idx] * cals
elif projector is not None:
mult = projector[idx] * cals
else:
@@ -514,6 +564,7 @@ class _BaseRaw(ProjMixin, ContainsMixin, UpdateChannelsMixin,
self._data = self._read_segment(data_buffer=data_buffer)
assert len(self._data) == self.info['nchan']
self.preload = True
+ self._comp = None # no longer needed
self.close()
def _update_times(self):
@@ -534,6 +585,62 @@ class _BaseRaw(ProjMixin, ContainsMixin, UpdateChannelsMixin,
def _raw_lengths(self):
return [l - f + 1 for f, l in zip(self._first_samps, self._last_samps)]
+ @property
+ def annotations(self):
+ """Annotations for marking segments of data."""
+ return self._annotations
+
+ @annotations.setter
+ def annotations(self, annotations):
+ """Setter for annotations. Checks if they are inside the data range.
+
+ Parameters
+ ----------
+ annotations : Instance of mne.Annotations
+ Annotations to set.
+ """
+ if annotations is not None:
+ if not isinstance(annotations, Annotations):
+ raise ValueError('Annotations must be an instance of '
+ 'mne.Annotations. Got %s.' % annotations)
+ meas_date = self.info['meas_date']
+ if meas_date is None:
+ meas_date = 0
+ elif not np.isscalar(meas_date):
+ if len(meas_date) > 1:
+ meas_date = meas_date[0] + meas_date[1] / 1000000.
+ if annotations.orig_time is not None:
+ offset = (annotations.orig_time - meas_date -
+ self.first_samp / self.info['sfreq'])
+ else:
+ offset = 0
+ omit_ind = list()
+ for ind, onset in enumerate(annotations.onset):
+ onset += offset
+ if onset > self.times[-1]:
+ warn('Omitting annotation outside data range.')
+ omit_ind.append(ind)
+ elif onset < self.times[0]:
+ if onset + annotations.duration[ind] < self.times[0]:
+ warn('Omitting annotation outside data range.')
+ omit_ind.append(ind)
+ else:
+ warn('Annotation starting outside the data range. '
+ 'Limiting to the start of data.')
+ duration = annotations.duration[ind] + onset
+ annotations.duration[ind] = duration
+ annotations.onset[ind] = self.times[0] - offset
+ elif onset + annotations.duration[ind] > self.times[-1]:
+ warn('Annotation expanding outside the data range. '
+ 'Limiting to the end of data.')
+ annotations.duration[ind] = self.times[-1] - onset
+ annotations.onset = np.delete(annotations.onset, omit_ind)
+ annotations.duration = np.delete(annotations.duration, omit_ind)
+ annotations.description = np.delete(annotations.description,
+ omit_ind)
+
+ self._annotations = annotations
+
def __del__(self):
# remove file for memmap
if hasattr(self, '_data') and hasattr(self._data, 'filename'):
@@ -557,11 +664,6 @@ class _BaseRaw(ProjMixin, ContainsMixin, UpdateChannelsMixin,
except:
return exception_type, exception_val, trace
- def __hash__(self):
- if not self.preload:
- raise RuntimeError('Cannot hash raw unless preloaded')
- return object_hash(dict(info=self.info, data=self._data))
-
def _parse_get_set_params(self, item):
# make sure item is a tuple
if not isinstance(item, tuple): # only channel selection passed
@@ -612,7 +714,41 @@ class _BaseRaw(ProjMixin, ContainsMixin, UpdateChannelsMixin,
return sel, start, stop
def __getitem__(self, item):
- """getting raw data content with python slicing"""
+ """Get raw data and times
+
+ Parameters
+ ----------
+ item : tuple or array-like
+ See below for use cases.
+
+ Returns
+ -------
+ data : ndarray, shape (n_channels, n_times)
+ The raw data.
+ times : ndarray, shape (n_times,)
+ The times associated with the data.
+
+ Examples
+ --------
+ Generally raw data is accessed as::
+
+ >>> data, times = raw[picks, time_slice] # doctest: +SKIP
+
+ To get all data, you can thus do either of::
+
+ >>> data, times = raw[:] # doctest: +SKIP
+
+ Which will be equivalent to:
+
+ >>> data, times = raw[:, :] # doctest: +SKIP
+
+ To get only the good MEG data from 10-20 seconds, you could do::
+
+ >>> picks = mne.pick_types(raw.info, meg=True, exclude='bads') # doctest: +SKIP
+ >>> t_idx = raw.time_as_index([10., 20.]) # doctest: +SKIP
+ >>> data, times = raw[picks, t_idx[0]:t_idx[1]] # doctest: +SKIP
+
+ """ # noqa
sel, start, stop = self._parse_get_set_params(item)
if self.preload:
data = self._data[sel, start:stop]
@@ -630,19 +766,6 @@ class _BaseRaw(ProjMixin, ContainsMixin, UpdateChannelsMixin,
# set the data
self._data[sel, start:stop] = value
- def anonymize(self):
- """Anonymize data
-
- This function will remove ``raw.info['subject_info']`` if it exists.
-
- Returns
- -------
- raw : instance of Raw
- The raw object. Operates in place.
- """
- self.info._anonymize()
- return self
-
@verbose
def apply_function(self, fun, picks, dtype, n_jobs, *args, **kwargs):
""" Apply a function to a subset of channels.
@@ -712,7 +835,7 @@ class _BaseRaw(ProjMixin, ContainsMixin, UpdateChannelsMixin,
self._data[p, :] = data_picks_new[pp]
@verbose
- def apply_hilbert(self, picks, envelope=False, n_jobs=1, n_fft=None,
+ def apply_hilbert(self, picks, envelope=False, n_jobs=1, n_fft='',
verbose=None):
""" Compute analytic signal or envelope for a subset of channels.
@@ -746,10 +869,11 @@ class _BaseRaw(ProjMixin, ContainsMixin, UpdateChannelsMixin,
Compute the envelope signal of each channel.
n_jobs: int
Number of jobs to run in parallel.
- n_fft : int > self.n_times | None
+ n_fft : int | None | str
Points to use in the FFT for Hilbert transformation. The signal
will be padded with zeros before computing Hilbert, then cut back
- to original length. If None, n == self.n_times.
+ to original length. If None, n == self.n_times. If 'auto',
+ the next highest fast FFT length will be use.
verbose : bool, str, int, or None
If not None, override default verbose level (see mne.verbose).
Defaults to self.verbose.
@@ -773,7 +897,16 @@ class _BaseRaw(ProjMixin, ContainsMixin, UpdateChannelsMixin,
is cut off, but it may result in a slightly different result
(particularly around the edges). Use at your own risk.
"""
- n_fft = self.n_times if n_fft is None else n_fft
+ if n_fft is None:
+ n_fft = len(self.times)
+ elif isinstance(n_fft, string_types):
+ if n_fft == '':
+ n_fft = len(self.times)
+ warn('n_fft is None by default in 0.13 but will change to '
+ '"auto" in 0.14', DeprecationWarning)
+ elif n_fft == 'auto':
+ n_fft = next_fast_len(len(self.times))
+ n_fft = int(n_fft)
if n_fft < self.n_times:
raise ValueError("n_fft must be greater than n_times")
if envelope is True:
@@ -784,9 +917,10 @@ class _BaseRaw(ProjMixin, ContainsMixin, UpdateChannelsMixin,
n_fft, envelope=envelope)
@verbose
- def filter(self, l_freq, h_freq, picks=None, filter_length='10s',
- l_trans_bandwidth=0.5, h_trans_bandwidth=0.5, n_jobs=1,
- method='fft', iir_params=None, verbose=None):
+ def filter(self, l_freq, h_freq, picks=None, filter_length='',
+ l_trans_bandwidth=None, h_trans_bandwidth=None, n_jobs=1,
+ method='fir', iir_params=None, phase='', fir_window='',
+ verbose=None):
"""Filter a subset of channels.
Applies a zero-phase low-pass, high-pass, band-pass, or band-stop
@@ -821,32 +955,62 @@ class _BaseRaw(ProjMixin, ContainsMixin, UpdateChannelsMixin,
picks : array-like of int | None
Indices of channels to filter. If None only the data (MEG/EEG)
channels will be filtered.
- filter_length : str (Default: '10s') | int | None
- Length of the filter to use. If None or "len(x) < filter_length",
- the filter length used is len(x). Otherwise, if int, overlap-add
- filtering with a filter of the specified length in samples) is
- used (faster for long signals). If str, a human-readable time in
- units of "s" or "ms" (e.g., "10s" or "5500ms") will be converted
- to the shortest power-of-two length at least that duration.
- Not used for 'iir' filters.
- l_trans_bandwidth : float
+ filter_length : str | int
+ Length of the FIR filter to use (if applicable):
+
+ * int: specified length in samples.
+ * 'auto' (default in 0.14): the filter length is chosen based
+ on the size of the transition regions (6.6 times the
+ reciprocal of the shortest transition band for
+ fir_window='hamming').
+ * str: (default in 0.13 is "10s") a human-readable time in
+ units of "s" or "ms" (e.g., "10s" or "5500ms") will be
+ converted to that number of samples if ``phase="zero"``, or
+ the shortest power-of-two length at least that duration for
+ ``phase="zero-double"``.
+
+ l_trans_bandwidth : float | str
Width of the transition band at the low cut-off frequency in Hz
- (high pass or cutoff 1 in bandpass). Not used if 'order' is
- specified in iir_params.
- h_trans_bandwidth : float
+ (high pass or cutoff 1 in bandpass). Can be "auto"
+ (default in 0.14) to use a multiple of ``l_freq``::
+
+ min(max(l_freq * 0.25, 2), l_freq)
+
+ Only used for ``method='fir'``.
+ h_trans_bandwidth : float | str
Width of the transition band at the high cut-off frequency in Hz
- (low pass or cutoff 2 in bandpass). Not used if 'order' is
- specified in iir_params.
+ (low pass or cutoff 2 in bandpass). Can be "auto"
+ (default in 0.14) to use a multiple of ``h_freq``::
+
+ min(max(h_freq * 0.25, 2.), info['sfreq'] / 2. - h_freq)
+
+ Only used for ``method='fir'``.
n_jobs : int | str
Number of jobs to run in parallel. Can be 'cuda' if scikits.cuda
- is installed properly, CUDA is initialized, and method='fft'.
+ is installed properly, CUDA is initialized, and method='fir'.
method : str
- 'fft' will use overlap-add FIR filtering, 'iir' will use IIR
+ 'fir' will use overlap-add FIR filtering, 'iir' will use IIR
forward-backward filtering (via filtfilt).
iir_params : dict | None
Dictionary of parameters to use for IIR filtering.
See mne.filter.construct_iir_filter for details. If iir_params
is None and method="iir", 4th order Butterworth will be used.
+ phase : str
+ Phase of the filter, only used if ``method='fir'``.
+ By default, a symmetric linear-phase FIR filter is constructed.
+ If ``phase='zero'`` (default in 0.14), the delay of this filter
+ is compensated for. If ``phase=='zero-double'`` (default in 0.13
+ and before), then this filter is applied twice, once forward, and
+ once backward.
+
+ .. versionadded:: 0.13
+
+ fir_window : str
+ The window to use in FIR design, can be "hamming" (default in
+ 0.14), "hann" (default in 0.13), or "blackman".
+
+ .. versionadded:: 0.13
+
verbose : bool, str, int, or None
If not None, override default verbose level (see mne.verbose).
Defaults to self.verbose.
@@ -861,85 +1025,54 @@ class _BaseRaw(ProjMixin, ContainsMixin, UpdateChannelsMixin,
mne.Epochs.savgol_filter
mne.io.Raw.notch_filter
mne.io.Raw.resample
+ mne.filter.filter_data
+ mne.filter.construct_iir_filter
+
+ Notes
+ -----
+ For more information, see the tutorials :ref:`tut_background_filtering`
+ and :ref:`tut_artifacts_filter`.
"""
- fs = float(self.info['sfreq'])
- if l_freq == 0:
- l_freq = None
- if h_freq is not None and h_freq > (fs / 2.):
- h_freq = None
- if l_freq is not None and not isinstance(l_freq, float):
- l_freq = float(l_freq)
- if h_freq is not None and not isinstance(h_freq, float):
- h_freq = float(h_freq)
_check_preload(self, 'raw.filter')
+ data_picks = _pick_data_or_ica(self.info)
+ update_info = False
if picks is None:
- picks = _pick_data_or_ica(self.info)
+ picks = data_picks
+ update_info = True
# let's be safe.
- if len(picks) < 1:
+ if len(picks) == 0:
raise RuntimeError('Could not find any valid channels for '
'your Raw object. Please contact the '
'MNE-Python developers.')
-
- # update info if filter is applied to all data channels,
- # and it's not a band-stop filter
- if h_freq is not None:
- if (l_freq is None or l_freq < h_freq) and \
- (self.info["lowpass"] is None or
- h_freq < self.info['lowpass']):
- self.info['lowpass'] = h_freq
- if l_freq is not None:
- if (h_freq is None or l_freq < h_freq) and \
- (self.info["highpass"] is None or
- l_freq > self.info['highpass']):
- self.info['highpass'] = l_freq
- else:
- if h_freq is not None or l_freq is not None:
+ elif h_freq is not None or l_freq is not None:
+ if np.in1d(data_picks, picks).all():
+ update_info = True
+ else:
logger.info('Filtering a subset of channels. The highpass and '
'lowpass values in the measurement info will not '
'be updated.')
-
- if l_freq is None and h_freq is not None:
- logger.info('Low-pass filtering at %0.2g Hz' % h_freq)
- low_pass_filter(self._data, fs, h_freq,
- filter_length=filter_length,
- trans_bandwidth=h_trans_bandwidth, method=method,
- iir_params=iir_params, picks=picks, n_jobs=n_jobs,
- copy=False)
- if l_freq is not None and h_freq is None:
- logger.info('High-pass filtering at %0.2g Hz' % l_freq)
- high_pass_filter(self._data, fs, l_freq,
- filter_length=filter_length,
- trans_bandwidth=l_trans_bandwidth, method=method,
- iir_params=iir_params, picks=picks, n_jobs=n_jobs,
- copy=False)
- if l_freq is not None and h_freq is not None:
- if l_freq < h_freq:
- logger.info('Band-pass filtering from %0.2g - %0.2g Hz'
- % (l_freq, h_freq))
- self._data = band_pass_filter(
- self._data, fs, l_freq, h_freq,
- filter_length=filter_length,
- l_trans_bandwidth=l_trans_bandwidth,
- h_trans_bandwidth=h_trans_bandwidth,
- method=method, iir_params=iir_params, picks=picks,
- n_jobs=n_jobs, copy=False)
- else:
- logger.info('Band-stop filtering from %0.2g - %0.2g Hz'
- % (h_freq, l_freq))
- self._data = band_stop_filter(
- self._data, fs, h_freq, l_freq,
- filter_length=filter_length,
- l_trans_bandwidth=h_trans_bandwidth,
- h_trans_bandwidth=l_trans_bandwidth, method=method,
- iir_params=iir_params, picks=picks, n_jobs=n_jobs,
- copy=False)
+ filter_data(self._data, self.info['sfreq'], l_freq, h_freq, picks,
+ filter_length, l_trans_bandwidth, h_trans_bandwidth,
+ n_jobs, method, iir_params, copy=False, phase=phase,
+ fir_window=fir_window)
+ # update info if filter is applied to all data channels,
+ # and it's not a band-stop filter
+ if update_info:
+ if h_freq is not None and (l_freq is None or l_freq < h_freq) and \
+ (self.info["lowpass"] is None or
+ h_freq < self.info['lowpass']):
+ self.info['lowpass'] = float(h_freq)
+ if l_freq is not None and (h_freq is None or l_freq < h_freq) and \
+ (self.info["highpass"] is None or
+ l_freq > self.info['highpass']):
+ self.info['highpass'] = float(l_freq)
return self
@verbose
- def notch_filter(self, freqs, picks=None, filter_length='10s',
+ def notch_filter(self, freqs, picks=None, filter_length='',
notch_widths=None, trans_bandwidth=1.0, n_jobs=1,
method='fft', iir_params=None, mt_bandwidth=None,
- p_value=0.05, verbose=None):
+ p_value=0.05, phase='', fir_window='', verbose=None):
"""Notch filter a subset of channels.
Applies a zero-phase notch filter to the channels selected by
@@ -962,24 +1095,31 @@ class _BaseRaw(ProjMixin, ContainsMixin, UpdateChannelsMixin,
picks : array-like of int | None
Indices of channels to filter. If None only the data (MEG/EEG)
channels will be filtered.
- filter_length : str (Default: '10s') | int | None
- Length of the filter to use. If None or "len(x) < filter_length",
- the filter length used is len(x). Otherwise, if int, overlap-add
- filtering with a filter of the specified length in samples) is
- used (faster for long signals). If str, a human-readable time in
- units of "s" or "ms" (e.g., "10s" or "5500ms") will be converted
- to the shortest power-of-two length at least that duration.
- Not used for 'iir' filters.
+ filter_length : str | int
+ Length of the FIR filter to use (if applicable):
+
+ * int: specified length in samples.
+ * 'auto' (default in 0.14): the filter length is chosen based
+ on the size of the transition regions (6.6 times the
+ reciprocal of the shortest transition band for
+ fir_window='hamming').
+ * str: (default in 0.13 is "10s") a human-readable time in
+ units of "s" or "ms" (e.g., "10s" or "5500ms") will be
+ converted to that number of samples if ``phase="zero"``, or
+ the shortest power-of-two length at least that duration for
+ ``phase="zero-double"``.
+
notch_widths : float | array of float | None
Width of each stop band (centred at each freq in freqs) in Hz.
If None, freqs / 200 is used.
trans_bandwidth : float
Width of the transition band in Hz.
+ Only used for ``method='fir'``.
n_jobs : int | str
Number of jobs to run in parallel. Can be 'cuda' if scikits.cuda
- is installed properly, CUDA is initialized, and method='fft'.
+ is installed properly, CUDA is initialized, and method='fir'.
method : str
- 'fft' will use overlap-add FIR filtering, 'iir' will use IIR
+ 'fir' will use overlap-add FIR filtering, 'iir' will use IIR
forward-backward filtering (via filtfilt). 'spectrum_fit' will
use multi-taper estimation of sinusoidal components.
iir_params : dict | None
@@ -994,6 +1134,22 @@ class _BaseRaw(ProjMixin, ContainsMixin, UpdateChannelsMixin,
sinusoidal components to remove when method='spectrum_fit' and
freqs=None. Note that this will be Bonferroni corrected for the
number of frequencies, so large p-values may be justified.
+ phase : str
+ Phase of the filter, only used if ``method='fir'``.
+ By default, a symmetric linear-phase FIR filter is constructed.
+ If ``phase='zero'`` (default in 0.14), the delay of this filter
+ is compensated for. If ``phase=='zero-double'`` (default in 0.13
+ and before), then this filter is applied twice, once forward, and
+ once backward.
+
+ .. versionadded:: 0.13
+
+ fir_window : str
+ The window to use in FIR design, can be "hamming" (default in
+ 0.14), "hann" (default in 0.13), or "blackman".
+
+ .. versionadded:: 0.13
+
verbose : bool, str, int, or None
If not None, override default verbose level (see mne.verbose).
Defaults to self.verbose.
@@ -1024,11 +1180,12 @@ class _BaseRaw(ProjMixin, ContainsMixin, UpdateChannelsMixin,
self._data, fs, freqs, filter_length=filter_length,
notch_widths=notch_widths, trans_bandwidth=trans_bandwidth,
method=method, iir_params=iir_params, mt_bandwidth=mt_bandwidth,
- p_value=p_value, picks=picks, n_jobs=n_jobs, copy=False)
+ p_value=p_value, picks=picks, n_jobs=n_jobs, copy=False,
+ phase=phase, fir_window=fir_window)
return self
@verbose
- def resample(self, sfreq, npad=None, window='boxcar', stim_picks=None,
+ def resample(self, sfreq, npad='auto', window='boxcar', stim_picks=None,
n_jobs=1, events=None, copy=None, verbose=None):
"""Resample all channels.
@@ -1095,11 +1252,6 @@ class _BaseRaw(ProjMixin, ContainsMixin, UpdateChannelsMixin,
For some data, it may be more accurate to use ``npad=0`` to reduce
artifacts. This is dataset dependent -- check your data!
""" # noqa
- if npad is None:
- npad = 100
- warn('npad is currently taken to be 100, but will be changed to '
- '"auto" in 0.13. Please set the value explicitly.',
- DeprecationWarning)
_check_preload(self, 'raw.resample')
inst = _check_copy_dep(self, copy)
@@ -1193,7 +1345,7 @@ class _BaseRaw(ProjMixin, ContainsMixin, UpdateChannelsMixin,
tmax : float | None
New end time in seconds of the data (cannot exceed data duration).
copy : bool
- This parameter has been deprecated and will be removed in 0.13.
+ This parameter has been deprecated and will be removed in 0.14.
Use inst.copy() instead.
Whether to return a new instance or modify in place.
@@ -1202,7 +1354,7 @@ class _BaseRaw(ProjMixin, ContainsMixin, UpdateChannelsMixin,
raw : instance of Raw
The cropped raw object.
"""
- raw = _check_copy_dep(self, copy, default=True)
+ raw = _check_copy_dep(self, copy)
max_time = (raw.n_times - 1) / raw.info['sfreq']
if tmax is None:
tmax = max_time
@@ -1236,10 +1388,14 @@ class _BaseRaw(ProjMixin, ContainsMixin, UpdateChannelsMixin,
# slice and copy to avoid the reference to large array
raw._data = raw._data[:, smin:smax + 1].copy()
raw._update_times()
+ if raw.annotations is not None:
+ annotations = raw.annotations
+ annotations.onset -= tmin
+ raw.annotations = annotations
return raw
@verbose
- def save(self, fname, picks=None, tmin=0, tmax=None, buffer_size_sec=10,
+ def save(self, fname, picks=None, tmin=0, tmax=None, buffer_size_sec=None,
drop_small_buffer=False, proj=False, fmt='single',
overwrite=False, split_size='2GB', verbose=None):
"""Save raw data to file
@@ -1260,8 +1416,8 @@ class _BaseRaw(ProjMixin, ContainsMixin, UpdateChannelsMixin,
Time in seconds of last sample to save. If None last sample
is used.
buffer_size_sec : float | None
- Size of data chunks in seconds. If None, the buffer size of
- the original file is used.
+ Size of data chunks in seconds. If None (default), the buffer
+ size of the original file is used.
drop_small_buffer : bool
Drop or not the last buffer. It is required by maxfilter (SSS)
that only accepts raw files with buffers of the same size.
@@ -1349,12 +1505,6 @@ class _BaseRaw(ProjMixin, ContainsMixin, UpdateChannelsMixin,
info = self.info
projector = None
- # set the correct compensation grade and make inverse compensator
- inv_comp = None
- if self.comp is not None:
- inv_comp = linalg.inv(self.comp)
- set_current_comp(info, self._orig_comp_grade)
-
#
# Set up the reading parameters
#
@@ -1372,278 +1522,44 @@ class _BaseRaw(ProjMixin, ContainsMixin, UpdateChannelsMixin,
# write the raw file
_write_raw(fname, self, info, picks, fmt, data_type, reset_range,
- start, stop, buffer_size, projector, inv_comp,
- drop_small_buffer, split_size, 0, None)
+ start, stop, buffer_size, projector, drop_small_buffer,
+ split_size, 0, None)
+ @copy_function_doc_to_method_doc(plot_raw)
def plot(self, events=None, duration=10.0, start=0.0, n_channels=20,
bgcolor='w', color=None, bad_color=(0.8, 0.8, 0.8),
event_color='cyan', scalings=None, remove_dc=True, order='type',
show_options=False, title=None, show=True, block=False,
highpass=None, lowpass=None, filtorder=4, clipping=None):
- """Plot raw data
-
- Parameters
- ----------
- events : array | None
- Events to show with vertical bars.
- duration : float
- Time window (sec) to plot in a given time.
- start : float
- Initial time to show (can be changed dynamically once plotted).
- n_channels : int
- Number of channels to plot at once. Defaults to 20.
- bgcolor : color object
- Color of the background.
- color : dict | color object | None
- Color for the data traces. If None, defaults to::
-
- dict(mag='darkblue', grad='b', eeg='k', eog='k', ecg='r',
- emg='k', ref_meg='steelblue', misc='k', stim='k',
- resp='k', chpi='k')
-
- bad_color : color object
- Color to make bad channels.
- event_color : color object
- Color to use for events.
- scalings : dict | None
- Scaling factors for the traces. If any fields in scalings are
- 'auto', the scaling factor is set to match the 99.5th percentile of
- a subset of the corresponding data. If scalings == 'auto', all
- scalings fields are set to 'auto'. If any fields are 'auto' and
- data is not preloaded, a subset of times up to 100mb will be
- loaded. If None, defaults to::
-
- dict(mag=1e-12, grad=4e-11, eeg=20e-6, eog=150e-6, ecg=5e-4,
- emg=1e-3, ref_meg=1e-12, misc=1e-3, stim=1,
- resp=1, chpi=1e-4)
-
- remove_dc : bool
- If True remove DC component when plotting data.
- order : 'type' | 'original' | array
- Order in which to plot data. 'type' groups by channel type,
- 'original' plots in the order of ch_names, array gives the
- indices to use in plotting.
- show_options : bool
- If True, a dialog for options related to projection is shown.
- title : str | None
- The title of the window. If None, and either the filename of the
- raw object or '<unknown>' will be displayed as title.
- show : bool
- Show figures if True
- block : bool
- Whether to halt program execution until the figure is closed.
- Useful for setting bad channels on the fly (click on line).
- May not work on all systems / platforms.
- highpass : float | None
- Highpass to apply when displaying data.
- lowpass : float | None
- Lowpass to apply when displaying data.
- filtorder : int
- Filtering order. Note that for efficiency and simplicity,
- filtering during plotting uses forward-backward IIR filtering,
- so the effective filter order will be twice ``filtorder``.
- Filtering the lines for display may also produce some edge
- artifacts (at the left and right edges) of the signals
- during display. Filtering requires scipy >= 0.10.
- clipping : str | None
- If None, channels are allowed to exceed their designated bounds in
- the plot. If "clamp", then values are clamped to the appropriate
- range for display, creating step-like artifacts. If "transparent",
- then excessive values are not shown, creating gaps in the traces.
-
- Returns
- -------
- fig : Instance of matplotlib.figure.Figure
- Raw traces.
-
- Notes
- -----
- The arrow keys (up/down/left/right) can typically be used to navigate
- between channels and time ranges, but this depends on the backend
- matplotlib is configured to use (e.g., mpl.use('TkAgg') should work).
- The scaling can be adjusted with - and + (or =) keys. The viewport
- dimensions can be adjusted with page up/page down and home/end keys.
- Full screen mode can be to toggled with f11 key. To mark or un-mark a
- channel as bad, click on the rather flat segments of a channel's time
- series. The changes will be reflected immediately in the raw object's
- ``raw.info['bads']`` entry.
- """
return plot_raw(self, events, duration, start, n_channels, bgcolor,
color, bad_color, event_color, scalings, remove_dc,
order, show_options, title, show, block, highpass,
lowpass, filtorder, clipping)
@verbose
+ @copy_function_doc_to_method_doc(plot_raw_psd)
def plot_psd(self, tmin=0.0, tmax=60.0, fmin=0, fmax=np.inf,
proj=False, n_fft=2048, picks=None, ax=None,
color='black', area_mode='std', area_alpha=0.33,
n_overlap=0, dB=True, show=True, n_jobs=1, verbose=None):
- """Plot the power spectral density across channels
-
- Parameters
- ----------
- tmin : float
- Start time for calculations.
- tmax : float
- End time for calculations.
- fmin : float
- Start frequency to consider.
- fmax : float
- End frequency to consider.
- proj : bool
- Apply projection.
- n_fft : int
- Number of points to use in Welch FFT calculations.
- picks : array-like of int | None
- List of channels to use. Cannot be None if `ax` is supplied. If
- both `picks` and `ax` are None, separate subplots will be created
- for each standard channel type (`mag`, `grad`, and `eeg`).
- ax : instance of matplotlib Axes | None
- Axes to plot into. If None, axes will be created.
- color : str | tuple
- A matplotlib-compatible color to use.
- area_mode : str | None
- How to plot area. If 'std', the mean +/- 1 STD (across channels)
- will be plotted. If 'range', the min and max (across channels)
- will be plotted. Bad channels will be excluded from these
- calculations. If None, no area will be plotted.
- area_alpha : float
- Alpha for the area.
- n_overlap : int
- The number of points of overlap between blocks. The default value
- is 0 (no overlap).
- dB : bool
- If True, transform data to decibels.
- show : bool
- Call pyplot.show() at the end.
- n_jobs : int
- Number of jobs to run in parallel.
- verbose : bool, str, int, or None
- If not None, override default verbose level (see mne.verbose).
-
- Returns
- -------
- fig : instance of matplotlib figure
- Figure with frequency spectra of the data channels.
- """
return plot_raw_psd(self, tmin=tmin, tmax=tmax, fmin=fmin, fmax=fmax,
proj=proj, n_fft=n_fft, picks=picks, ax=ax,
color=color, area_mode=area_mode,
area_alpha=area_alpha, n_overlap=n_overlap,
dB=dB, show=show, n_jobs=n_jobs)
+ @copy_function_doc_to_method_doc(plot_raw_psd_topo)
def plot_psd_topo(self, tmin=0., tmax=None, fmin=0, fmax=100, proj=False,
n_fft=2048, n_overlap=0, layout=None, color='w',
fig_facecolor='k', axis_facecolor='k', dB=True,
- show=True, n_jobs=1, verbose=None):
- """Function for plotting channel wise frequency spectra as topography.
-
- Parameters
- ----------
- tmin : float
- Start time for calculations. Defaults to zero.
- tmax : float | None
- End time for calculations. If None (default), the end of data is
- used.
- fmin : float
- Start frequency to consider. Defaults to zero.
- fmax : float
- End frequency to consider. Defaults to 100.
- proj : bool
- Apply projection. Defaults to False.
- n_fft : int
- Number of points to use in Welch FFT calculations. Defaults to
- 2048.
- n_overlap : int
- The number of points of overlap between blocks. Defaults to 0
- (no overlap).
- layout : instance of Layout | None
- Layout instance specifying sensor positions (does not need to
- be specified for Neuromag data). If None (default), the correct
- layout is inferred from the data.
- color : str | tuple
- A matplotlib-compatible color to use for the curves. Defaults to
- white.
- fig_facecolor : str | tuple
- A matplotlib-compatible color to use for the figure background.
- Defaults to black.
- axis_facecolor : str | tuple
- A matplotlib-compatible color to use for the axis background.
- Defaults to black.
- dB : bool
- If True, transform data to decibels. Defaults to True.
- show : bool
- Show figure if True. Defaults to True.
- n_jobs : int
- Number of jobs to run in parallel. Defaults to 1.
- verbose : bool, str, int, or None
- If not None, override default verbose level (see mne.verbose).
-
- Returns
- -------
- fig : instance of matplotlib figure
- Figure distributing one image per channel across sensor topography.
- """
+ show=True, block=False, n_jobs=1, verbose=None):
return plot_raw_psd_topo(self, tmin=tmin, tmax=tmax, fmin=fmin,
fmax=fmax, proj=proj, n_fft=n_fft,
n_overlap=n_overlap, layout=layout,
color=color, fig_facecolor=fig_facecolor,
axis_facecolor=axis_facecolor, dB=dB,
- show=show, n_jobs=n_jobs, verbose=verbose)
-
- def time_as_index(self, times, use_first_samp=None, use_rounding=False):
- """Convert time to indices
-
- Parameters
- ----------
- times : list-like | float | int
- List of numbers or a number representing points in time.
- use_first_samp : boolean
- This is deprecated and will be removed in 0.13.
- If True, time is treated as relative to the session onset, else
- as relative to the recording onset. Default is False.
- use_rounding : boolean
- If True, use rounding (instead of truncation) when converting
- times to indices. This can help avoid non-unique indices.
-
- Returns
- -------
- index : ndarray
- Indices corresponding to the times supplied.
- """
- # Note: this entire class can be removed in 0.13 (proper method
- # will be inherited from TimeMixin)
- if use_first_samp is None:
- use_first_samp = False
- else:
- warn('use_first_samp is deprecated, add raw.first_samp manually '
- 'if first sample offset is required', DeprecationWarning)
- index = super(_BaseRaw, self).time_as_index(times, use_rounding)
- if use_first_samp:
- index -= self.first_samp
- return index
-
- @deprecated('index_as_time is deprecated and will be removed in 0.13, '
- 'use raw.times[idx] (or raw.times[idx + raw.first_samp] '
- 'instead')
- def index_as_time(self, index, use_first_samp=False):
- """Convert indices to time
-
- Parameters
- ----------
- index : list-like | int
- List of ints or int representing points in time.
- use_first_samp : boolean
- If True, the time returned is relative to the session onset, else
- relative to the recording onset.
-
- Returns
- -------
- times : ndarray
- Times corresponding to the index supplied.
- """
- return _index_as_time(index, self.info['sfreq'], self.first_samp,
- use_first_samp)
+ show=show, block=block, n_jobs=n_jobs,
+ verbose=verbose)
def estimate_rank(self, tstart=0.0, tstop=30.0, tol=1e-4,
return_singular=False, picks=None, scalings='norm'):
@@ -1739,6 +1655,21 @@ class _BaseRaw(ProjMixin, ContainsMixin, UpdateChannelsMixin,
return self.last_samp - self.first_samp + 1
def __len__(self):
+ """The number of time points
+
+ Returns
+ -------
+ len : int
+ The number of time points.
+
+ Examples
+ --------
+ This can be used as::
+
+ >>> len(raw) # doctest: +SKIP
+ 1000
+
+ """
return self.n_times
def load_bad_channels(self, bad_file=None, force=False):
@@ -1848,18 +1779,19 @@ class _BaseRaw(ProjMixin, ContainsMixin, UpdateChannelsMixin,
self.preload = True
# now combine information from each raw file to construct new self
+ annotations = self.annotations
for r in raws:
self._first_samps = np.r_[self._first_samps, r._first_samps]
self._last_samps = np.r_[self._last_samps, r._last_samps]
self._raw_extras += r._raw_extras
self._filenames += r._filenames
- self.annotations = _combine_annotations((self.annotations,
- r.annotations),
- self._last_samps,
- self._first_samps,
- self.info['sfreq'])
+ annotations = _combine_annotations((annotations, r.annotations),
+ self._last_samps,
+ self._first_samps,
+ self.info['sfreq'])
self._update_times()
+ self.annotations = annotations
if not (len(self._first_samps) == len(self._last_samps) ==
len(self._raw_extras) == len(self._filenames)):
@@ -1881,8 +1813,11 @@ class _BaseRaw(ProjMixin, ContainsMixin, UpdateChannelsMixin,
def __repr__(self):
name = self._filenames[0]
name = 'None' if name is None else op.basename(name)
- s = ('%s, n_channels x n_times : %s x %s (%0.1f sec)'
- % (name, len(self.ch_names), self.n_times, self.times[-1]))
+ size_str = str(sizeof_fmt(self._size)) # str in case it fails -> None
+ size_str += ', data%s loaded' % ('' if self.preload else ' not')
+ s = ('%s, n_channels x n_times : %s x %s (%0.1f sec), ~%s'
+ % (name, len(self.ch_names), self.n_times, self.times[-1],
+ size_str))
return "<%s | %s>" % (self.__class__.__name__, s)
def add_events(self, events, stim_channel=None):
@@ -1926,10 +1861,7 @@ class _BaseRaw(ProjMixin, ContainsMixin, UpdateChannelsMixin,
def _get_buffer_size(self, buffer_size_sec=None):
"""Helper to get the buffer size"""
if buffer_size_sec is None:
- if 'buffer_size_sec' in self.info:
- buffer_size_sec = self.info['buffer_size_sec']
- else:
- buffer_size_sec = 10.0
+ buffer_size_sec = self.info.get('buffer_size_sec', 1.)
return int(np.ceil(buffer_size_sec * self.info['sfreq']))
@@ -1992,7 +1924,7 @@ class _RawShell():
###############################################################################
# Writing
def _write_raw(fname, raw, info, picks, fmt, data_type, reset_range, start,
- stop, buffer_size, projector, inv_comp, drop_small_buffer,
+ stop, buffer_size, projector, drop_small_buffer,
split_size, part_idx, prev_fname):
"""Write raw file with splitting
"""
@@ -2050,7 +1982,7 @@ def _write_raw(fname, raw, info, picks, fmt, data_type, reset_range, start,
'[done]')
break
logger.info('Writing ...')
- _write_raw_buffer(fid, data, cals, fmt, inv_comp)
+ _write_raw_buffer(fid, data, cals, fmt)
pos = fid.tell()
this_buff_size_bytes = pos - pos_prev
@@ -2073,7 +2005,7 @@ def _write_raw(fname, raw, info, picks, fmt, data_type, reset_range, start,
next_fname, next_idx = _write_raw(
fname, raw, info, picks, fmt,
data_type, reset_range, first + buffer_size, stop, buffer_size,
- projector, inv_comp, drop_small_buffer, split_size,
+ projector, drop_small_buffer, split_size,
part_idx + 1, use_fname)
start_block(fid, FIFF.FIFFB_REF)
@@ -2178,7 +2110,7 @@ def _start_writing_raw(name, info, sel=None, data_type=FIFF.FIFFT_FLOAT,
return fid, cals
-def _write_raw_buffer(fid, buf, cals, fmt, inv_comp):
+def _write_raw_buffer(fid, buf, cals, fmt):
"""Write raw buffer
Parameters
@@ -2193,9 +2125,6 @@ def _write_raw_buffer(fid, buf, cals, fmt, inv_comp):
'short', 'int', 'single', or 'double' for 16/32 bit int or 32/64 bit
float for each item. This will be doubled for complex datatypes. Note
that short and int formats cannot be used for complex data.
- inv_comp : array | None
- The CTF compensation matrix used to revert compensation
- change when reading.
"""
if buf.shape[0] != len(cals):
raise ValueError('buffer and calibration sizes do not match')
@@ -2221,11 +2150,7 @@ def _write_raw_buffer(fid, buf, cals, fmt, inv_comp):
raise ValueError('only "single" and "double" supported for '
'writing complex data')
- if inv_comp is not None:
- buf = np.dot(inv_comp / np.ravel(cals)[:, None], buf)
- else:
- buf = buf / np.ravel(cals)[:, None]
-
+ buf = buf / np.ravel(cals)[:, None]
write_function(fid, FIFF.FIFF_DATA_BUFFER, buf)
@@ -2236,9 +2161,9 @@ def _my_hilbert(x, n_fft=None, envelope=False):
----------
x : array, shape (n_times)
The signal to convert
- n_fft : int, length > x.shape[-1] | None
- How much to pad the signal before Hilbert transform.
- Note that signal will then be cut back to original length.
+ n_fft : int
+ Size of the FFT to perform, must be at least ``len(x)``.
+ The signal will be cut back to original length.
envelope : bool
Whether to compute amplitude of the hilbert transform in order
to return the signal envelope.
@@ -2249,7 +2174,6 @@ def _my_hilbert(x, n_fft=None, envelope=False):
The hilbert transform of the signal, or the envelope.
"""
from scipy.signal import hilbert
- n_fft = x.shape[-1] if n_fft is None else n_fft
n_x = x.shape[-1]
out = hilbert(x, N=n_fft)[:n_x]
if envelope is True:
@@ -2347,3 +2271,21 @@ def _check_update_montage(info, montage, path=None, update_ch_names=False):
"definitions: %s. If those channels lack positions "
"because they are EOG channels use the eog parameter."
% str(missing_positions))
+
+
+def _check_maxshield(allow_maxshield):
+ """Warn or error about MaxShield."""
+ msg = ('This file contains raw Internal Active '
+ 'Shielding data. It may be distorted. Elekta '
+ 'recommends it be run through MaxFilter to '
+ 'produce reliable results. Consider closing '
+ 'the file and running MaxFilter on the data.')
+ if allow_maxshield:
+ if not (isinstance(allow_maxshield, string_types) and
+ allow_maxshield == 'yes'):
+ warn(msg)
+ allow_maxshield = 'yes'
+ else:
+ msg += (' Use allow_maxshield=True if you are sure you'
+ ' want to load the data despite this warning.')
+ raise ValueError(msg)
diff --git a/mne/io/brainvision/brainvision.py b/mne/io/brainvision/brainvision.py
index 943686f..55bd100 100644
--- a/mne/io/brainvision/brainvision.py
+++ b/mne/io/brainvision/brainvision.py
@@ -4,6 +4,8 @@
# Authors: Teon Brooks <teon.brooks at gmail.com>
# Christian Brodbeck <christianbrodbeck at nyu.edu>
# Eric Larson <larson.eric.d at gmail.com>
+# Jona Sassenhagen <jona.sassenhagen at gmail.com>
+# Phillip Alday <phillip.alday at unisa.edu.au>
#
# License: BSD (3-clause)
@@ -13,7 +15,7 @@ import time
import numpy as np
-from ...utils import verbose, logger, warn, deprecated
+from ...utils import verbose, logger, warn
from ..constants import FIFF
from ..meas_info import _empty_info
from ..base import _BaseRaw, _check_update_montage
@@ -38,14 +40,14 @@ class RawBrainVision(_BaseRaw):
Names of channels or list of indices that should be designated
EOG channels. Values should correspond to the vhdr file.
Default is ``('HEOGL', 'HEOGR', 'VEOGb')``.
- misc : list or tuple
+ misc : list or tuple of str | 'auto'
Names of channels or list of indices that should be designated
MISC channels. Values should correspond to the electrodes
- in the vhdr file. Default is ``()``.
+ in the vhdr file. If 'auto', units in vhdr file are used for inferring
+ misc channels. Default is ``'auto'``.
scale : float
- The scaling factor for EEG data. Units are in volts. Default scale
- factor is 1. For microvolts, the scale factor would be 1e-6. This is
- used when the header file does not specify the scale factor.
+ The scaling factor for EEG data. Unless specified otherwise by
+ header file, units are in microvolts. Default scale factor is 1.
preload : bool
If True, all data are loaded at initialization.
If False, data are not read until save.
@@ -56,7 +58,7 @@ class RawBrainVision(_BaseRaw):
typically another value or None will be necessary.
event_id : dict | None
The id of special events to consider in addition to those that
- follow the normal Brainvision trigger format ('SXXX').
+ follow the normal Brainvision trigger format ('S###').
If dict, the keys will be mapped to trigger values on the stimulus
channel. Example: {'SyncStatus': 1; 'Pulse Artifact': 3}. If None
or an empty dict (default), only stimulus events are added to the
@@ -70,7 +72,7 @@ class RawBrainVision(_BaseRaw):
"""
@verbose
def __init__(self, vhdr_fname, montage=None,
- eog=('HEOGL', 'HEOGR', 'VEOGb'), misc=(),
+ eog=('HEOGL', 'HEOGR', 'VEOGb'), misc='auto',
scale=1., preload=False, response_trig_shift=0,
event_id=None, verbose=None):
# Channel info and events
@@ -100,20 +102,6 @@ class RawBrainVision(_BaseRaw):
dtype=dtype, n_channels=n_data_ch,
trigger_ch=self._event_ch)
- @deprecated('get_brainvision_events is deprecated and will be removed '
- 'in 0.13, use mne.find_events(raw, "STI014") to get properly '
- 'formatted events instead')
- def get_brainvision_events(self):
- """Retrieve the events associated with the Brain Vision Raw object
-
- Returns
- -------
- events : array, shape (n_events, 3)
- Events, each row consisting of an (onset, duration, trigger)
- sequence.
- """
- return self._get_brainvision_events()
-
def _get_brainvision_events(self):
"""Retrieve the events associated with the Brain Vision Raw object
@@ -125,19 +113,6 @@ class RawBrainVision(_BaseRaw):
"""
return self._events.copy()
- @deprecated('set_brainvision_events is deprecated and will be removed '
- 'in 0.13')
- def set_brainvision_events(self, events):
- """Set the events and update the synthesized stim channel
-
- Parameters
- ----------
- events : array, shape (n_events, 3)
- Events, each row consisting of an (onset, duration, trigger)
- sequence.
- """
- return self._set_brainvision_events(events)
-
def _set_brainvision_events(self, events):
"""Set the events and update the synthesized stim channel
@@ -172,7 +147,7 @@ def _read_vmrk_events(fname, event_id=None, response_trig_shift=0):
vmrk file to be read.
event_id : dict | None
The id of special events to consider in addition to those that
- follow the normal Brainvision trigger format ('SXXX').
+ follow the normal Brainvision trigger format ('S###').
If dict, the keys will be mapped to trigger values on the stimulus
channel. Example: {'SyncStatus': 1; 'Pulse Artifact': 3}. If None
or an empty dict (default), only stimulus events are added to the
@@ -190,14 +165,39 @@ def _read_vmrk_events(fname, event_id=None, response_trig_shift=0):
event_id = dict()
# read vmrk file
with open(fname, 'rb') as fid:
- txt = fid.read().decode('utf-8')
+ txt = fid.read()
- header = txt.split('\n')[0].strip()
+ # we don't actually need to know the coding for the header line.
+ # the characters in it all belong to ASCII and are thus the
+ # same in Latin-1 and UTF-8
+ header = txt.decode('ascii', 'ignore').split('\n')[0].strip()
_check_mrk_version(header)
if (response_trig_shift is not None and
not isinstance(response_trig_shift, int)):
raise TypeError("response_trig_shift must be an integer or None")
+ # although the markers themselves are guaranteed to be ASCII (they
+ # consist of numbers and a few reserved words), we should still
+ # decode the file properly here because other (currently unused)
+ # blocks, such as that the filename are specifying are not
+ # guaranteed to be ASCII.
+
+ codepage = 'utf-8'
+ try:
+ # if there is an explicit codepage set, use it
+ # we pretend like it's ascii when searching for the codepage
+ cp_setting = re.search('Codepage=(.+)',
+ txt.decode('ascii', 'ignore'),
+ re.IGNORECASE & re.MULTILINE)
+ if cp_setting:
+ codepage = cp_setting.group(1).strip()
+ txt = txt.decode(codepage)
+ except UnicodeDecodeError:
+ # if UTF-8 (new standard) or explicit codepage setting fails,
+ # fallback to Latin-1, which is Windows default and implicit
+ # standard in older recordings
+ txt = txt.decode('latin-1')
+
# extract Marker Infos block
m = re.search("\[Marker Infos\]", txt)
if not m:
@@ -239,7 +239,7 @@ def _read_vmrk_events(fname, event_id=None, response_trig_shift=0):
examples += ", ..."
warn("Currently, {0} trigger(s) will be dropped, such as [{1}]. "
"Consider using ``event_id`` to parse triggers that "
- "do not follow the 'SXXX' pattern.".format(
+ "do not follow the 'S###' pattern.".format(
len(dropped), examples))
events = np.array(events).reshape(-1, 3)
@@ -268,7 +268,15 @@ _orientation_dict = dict(MULTIPLEXED='F', VECTORIZED='C')
_fmt_dict = dict(INT_16='short', INT_32='int', IEEE_FLOAT_32='single')
_fmt_byte_dict = dict(short=2, int=4, single=4)
_fmt_dtype_dict = dict(short='<i2', int='<i4', single='<f4')
-_unit_dict = {'V': 1., u'µV': 1e-6, 'uV': 1e-6}
+_unit_dict = {'V': 1., # V stands for Volt
+ u'µV': 1e-6,
+ 'uV': 1e-6,
+ 'C': 1, # C stands for celsius
+ u'µS': 1e-6, # S stands for Siemens
+ u'uS': 1e-6,
+ u'ARU': 1, # ARU is the unity for the breathing data
+ 'S': 1,
+ 'N': 1} # Newton
def _get_vhdr_info(vhdr_fname, eog, misc, scale, montage):
@@ -281,13 +289,14 @@ def _get_vhdr_info(vhdr_fname, eog, misc, scale, montage):
eog : list of str
Names of channels that should be designated EOG channels. Names should
correspond to the vhdr file.
- misc : list of str
- Names of channels that should be designated MISC channels. Names
- should correspond to the electrodes in the vhdr file.
+ misc : list or tuple of str | 'auto'
+ Names of channels or list of indices that should be designated
+ MISC channels. Values should correspond to the electrodes
+ in the vhdr file. If 'auto', units in vhdr file are used for inferring
+ misc channels. Default is ``'auto'``.
scale : float
- The scaling factor for EEG data. Units are in volts. Default scale
- factor is 1.. For microvolts, the scale factor would be 1e-6. This is
- used when the header file does not specify the scale factor.
+ The scaling factor for EEG data. Unless specified otherwise by
+ header file, units are in microvolts. Default scale factor is 1.
montage : str | True | None | instance of Montage
Path or instance of montage containing electrode positions.
If None, sensor locations are (0,0,0). See the documentation of
@@ -312,9 +321,29 @@ def _get_vhdr_info(vhdr_fname, eog, misc, scale, montage):
"not the '%s' file." % ext)
with open(vhdr_fname, 'rb') as f:
# extract the first section to resemble a cfg
- header = f.readline().decode('utf-8').strip()
+ header = f.readline()
+ codepage = 'utf-8'
+ # we don't actually need to know the coding for the header line.
+ # the characters in it all belong to ASCII and are thus the
+ # same in Latin-1 and UTF-8
+ header = header.decode('ascii', 'ignore').strip()
_check_hdr_version(header)
- settings = f.read().decode('utf-8')
+
+ settings = f.read()
+ try:
+ # if there is an explicit codepage set, use it
+ # we pretend like it's ascii when searching for the codepage
+ cp_setting = re.search('Codepage=(.+)',
+ settings.decode('ascii', 'ignore'),
+ re.IGNORECASE & re.MULTILINE)
+ if cp_setting:
+ codepage = cp_setting.group(1).strip()
+ settings = settings.decode(codepage)
+ except UnicodeDecodeError:
+ # if UTF-8 (new standard) or explicit codepage setting fails,
+ # fallback to Latin-1, which is Windows default and implicit
+ # standard in older recordings
+ settings = settings.decode('latin-1')
if settings.find('[Comment]') != -1:
params, settings = settings.split('[Comment]')
@@ -351,11 +380,15 @@ def _get_vhdr_info(vhdr_fname, eog, misc, scale, montage):
ranges = np.empty(nchan)
cals.fill(np.nan)
ch_dict = dict()
+ misc_chs = dict()
for chan, props in cfg.items('Channel Infos'):
n = int(re.findall(r'ch(\d+)', chan)[0]) - 1
props = props.split(',')
+ # default to microvolts because that's what the older brainvision
+ # standard explicitly assumed; the unit is only allowed to be
+ # something else if explicitly stated (cf. EEGLAB export below)
if len(props) < 4:
- props += ('V',)
+ props += (u'µV',)
name, _, resolution, unit = props[:4]
ch_dict[chan] = name
ch_names[n] = name
@@ -366,7 +399,11 @@ def _get_vhdr_info(vhdr_fname, eog, misc, scale, montage):
resolution = 1. # for files with units specified, but not res
unit = unit.replace(u'\xc2', u'') # Remove unwanted control characters
cals[n] = float(resolution)
- ranges[n] = _unit_dict.get(unit, unit) * scale
+ ranges[n] = _unit_dict.get(unit, 1) * scale
+ if unit not in ('V', u'µV', 'uV'):
+ misc_chs[name] = (FIFF.FIFF_UNIT_CEL if unit == 'C'
+ else FIFF.FIFF_UNIT_NONE)
+ misc = list(misc_chs.keys()) if misc == 'auto' else misc
# create montage
if montage is True:
@@ -394,55 +431,161 @@ def _get_vhdr_info(vhdr_fname, eog, misc, scale, montage):
# set to zero.
settings = settings.splitlines()
idx = None
+
if 'Channels' in settings:
idx = settings.index('Channels')
settings = settings[idx + 1:]
+ hp_col, lp_col = 4, 5
for idx, setting in enumerate(settings):
if re.match('#\s+Name', setting):
break
else:
idx = None
+ # If software filters are active, then they override the hardware setup
+ # But we still want to be able to double check the channel names
+ # for alignment purposes, we keep track of the hardware setting idx
+ idx_amp = idx
+
+ if 'S o f t w a r e F i l t e r s' in settings:
+ idx = settings.index('S o f t w a r e F i l t e r s')
+ for idx, setting in enumerate(settings[idx + 1:], idx + 1):
+ if re.match('#\s+Low Cutoff', setting):
+ hp_col, lp_col = 1, 2
+ warn('Online software filter detected. Using software '
+ 'filter settings and ignoring hardware values')
+ break
+ else:
+ idx = idx_amp
+
if idx:
lowpass = []
highpass = []
+
+ # extract filter units and convert s to Hz if necessary
+ # this cannot be done as post-processing as the inverse t-f
+ # relationship means that the min/max comparisons don't make sense
+ # unless we know the units
+ header = re.split('\s\s+', settings[idx])
+ hp_s = '[s]' in header[hp_col]
+ lp_s = '[s]' in header[lp_col]
+
for i, ch in enumerate(ch_names[:-1], 1):
- line = settings[idx + i].split()
- assert ch in line
- highpass.append(line[5])
- lowpass.append(line[6])
+ line = re.split('\s\s+', settings[idx + i])
+ # double check alignment with channel by using the hw settings
+ # the actual divider is multiple spaces -- for newer BV
+ # files, the unit is specified for every channel separated
+ # by a single space, while for older files, the unit is
+ # specified in the column headers
+ if idx == idx_amp:
+ line_amp = line
+ else:
+ line_amp = re.split('\s\s+', settings[idx_amp + i])
+ assert ch in line_amp
+ highpass.append(line[hp_col])
+ lowpass.append(line[lp_col])
if len(highpass) == 0:
pass
- elif all(highpass):
- if highpass[0] == 'NaN':
+ elif len(set(highpass)) == 1:
+ if highpass[0] in ('NaN', 'Off'):
pass # Placeholder for future use. Highpass set in _empty_info
elif highpass[0] == 'DC':
info['highpass'] = 0.
else:
info['highpass'] = float(highpass[0])
+ if hp_s:
+ info['highpass'] = 1. / info['highpass']
else:
- info['highpass'] = np.min(np.array(highpass, dtype=np.float))
- warn('Channels contain different highpass filters. Highest filter '
- 'setting will be stored.')
+ heterogeneous_hp_filter = True
+ if hp_s:
+ # We convert channels with disabled filters to having
+ # highpass relaxed / no filters
+ highpass = [float(filt) if filt not in ('NaN', 'Off', 'DC')
+ else np.Inf for filt in highpass]
+ info['highpass'] = np.max(np.array(highpass, dtype=np.float))
+ # Coveniently enough 1 / np.Inf = 0.0, so this works for
+ # DC / no highpass filter
+ info['highpass'] = 1. / info['highpass']
+
+ # not exactly the cleanest use of FP, but this makes us
+ # more conservative in *not* warning.
+ if info['highpass'] == 0.0 and len(set(highpass)) == 1:
+ # not actually heterogeneous in effect
+ # ... just heterogeneously disabled
+ heterogeneous_hp_filter = False
+ else:
+ highpass = [float(filt) if filt not in ('NaN', 'Off', 'DC')
+ else 0.0 for filt in highpass]
+ info['highpass'] = np.min(np.array(highpass, dtype=np.float))
+ if info['highpass'] == 0.0 and len(set(highpass)) == 1:
+ # not actually heterogeneous in effect
+ # ... just heterogeneously disabled
+ heterogeneous_hp_filter = False
+
+ if heterogeneous_hp_filter:
+ warn('Channels contain different highpass filters. '
+ 'Lowest (weakest) filter setting (%0.2f Hz) '
+ 'will be stored.' % info['highpass'])
+
if len(lowpass) == 0:
pass
- elif all(lowpass):
- if lowpass[0] == 'NaN':
+ elif len(set(lowpass)) == 1:
+ if lowpass[0] in ('NaN', 'Off'):
pass # Placeholder for future use. Lowpass set in _empty_info
else:
info['lowpass'] = float(lowpass[0])
+ if lp_s:
+ info['lowpass'] = 1. / info['lowpass']
else:
- info['lowpass'] = np.min(np.array(lowpass, dtype=np.float))
- warn('Channels contain different lowpass filters. Lowest filter '
- 'setting will be stored.')
-
- # Post process highpass and lowpass to take into account units
- header = settings[idx].split(' ')
- header = [h for h in header if len(h)]
- if '[s]' in header[4] and (info['highpass'] > 0):
- info['highpass'] = 1. / info['highpass']
- if '[s]' in header[5]:
- info['lowpass'] = 1. / info['lowpass']
+ heterogeneous_lp_filter = True
+ if lp_s:
+ # We convert channels with disabled filters to having
+ # infinitely relaxed / no filters
+ lowpass = [float(filt) if filt not in ('NaN', 'Off')
+ else 0.0 for filt in lowpass]
+ info['lowpass'] = np.min(np.array(lowpass, dtype=np.float))
+ try:
+ info['lowpass'] = 1. / info['lowpass']
+ except ZeroDivisionError:
+ if len(set(lowpass)) == 1:
+ # No lowpass actually set for the weakest setting
+ # so we set lowpass to the Nyquist frequency
+ info['lowpass'] = info['sfreq'] / 2.
+ # not actually heterogeneous in effect
+ # ... just heterogeneously disabled
+ heterogeneous_lp_filter = False
+ else:
+ # no lowpass filter is the weakest filter,
+ # but it wasn't the only filter
+ pass
+ else:
+ # We convert channels with disabled filters to having
+ # infinitely relaxed / no filters
+ lowpass = [float(filt) if filt not in ('NaN', 'Off')
+ else np.Inf for filt in lowpass]
+ info['lowpass'] = np.max(np.array(lowpass, dtype=np.float))
+
+ if np.isinf(info['lowpass']):
+ # No lowpass actually set for the weakest setting
+ # so we set lowpass to the Nyquist frequency
+ info['lowpass'] = info['sfreq'] / 2.
+ if len(set(lowpass)) == 1:
+ # not actually heterogeneous in effect
+ # ... just heterogeneously disabled
+ heterogeneous_lp_filter = False
+
+ if heterogeneous_lp_filter:
+ # this isn't clean FP, but then again, we only want to provide
+ # the Nyquist hint when the lowpass filter was actually
+ # calculated from dividing the sampling frequency by 2, so the
+ # exact/direct comparison (instead of tolerance) makes sense
+ if info['lowpass'] == info['sfreq'] / 2.0:
+ nyquist = ', Nyquist limit'
+ else:
+ nyquist = ""
+ warn('Channels contain different lowpass filters. '
+ 'Highest (weakest) filter setting (%0.2f Hz%s) '
+ 'will be stored.' % (info['lowpass'], nyquist))
# locate EEG and marker files
path = os.path.dirname(vhdr_fname)
@@ -461,7 +604,10 @@ def _get_vhdr_info(vhdr_fname, eog, misc, scale, montage):
elif ch_name in misc or idx in misc or idx - nchan in misc:
kind = FIFF.FIFFV_MISC_CH
coil_type = FIFF.FIFFV_COIL_NONE
- unit = FIFF.FIFF_UNIT_V
+ if ch_name in misc_chs:
+ unit = misc_chs[ch_name]
+ else:
+ unit = FIFF.FIFF_UNIT_NONE
elif ch_name == 'STI 014':
kind = FIFF.FIFFV_STIM_CH
coil_type = FIFF.FIFFV_COIL_NONE
@@ -484,7 +630,7 @@ def _get_vhdr_info(vhdr_fname, eog, misc, scale, montage):
def read_raw_brainvision(vhdr_fname, montage=None,
- eog=('HEOGL', 'HEOGR', 'VEOGb'), misc=(),
+ eog=('HEOGL', 'HEOGR', 'VEOGb'), misc='auto',
scale=1., preload=False, response_trig_shift=0,
event_id=None, verbose=None):
"""Reader for Brain Vision EEG file
@@ -501,14 +647,14 @@ def read_raw_brainvision(vhdr_fname, montage=None,
Names of channels or list of indices that should be designated
EOG channels. Values should correspond to the vhdr file
Default is ``('HEOGL', 'HEOGR', 'VEOGb')``.
- misc : list or tuple of str
+ misc : list or tuple of str | 'auto'
Names of channels or list of indices that should be designated
MISC channels. Values should correspond to the electrodes
- in the vhdr file. Default is ``()``.
+ in the vhdr file. If 'auto', units in vhdr file are used for inferring
+ misc channels. Default is ``'auto'``.
scale : float
- The scaling factor for EEG data. Units are in volts. Default scale
- factor is 1. For microvolts, the scale factor would be 1e-6. This is
- used when the header file does not specify the scale factor.
+ The scaling factor for EEG data. Unless specified otherwise by
+ header file, units are in microvolts. Default scale factor is 1.
preload : bool
If True, all data are loaded at initialization.
If False, data are not read until save.
@@ -519,7 +665,7 @@ def read_raw_brainvision(vhdr_fname, montage=None,
typically another value or None will be necessary.
event_id : dict | None
The id of special events to consider in addition to those that
- follow the normal Brainvision trigger format ('SXXX').
+ follow the normal Brainvision trigger format ('S###').
If dict, the keys will be mapped to trigger values on the stimulus
channel. Example: {'SyncStatus': 1; 'Pulse Artifact': 3}. If None
or an empty dict (default), only stimulus events are added to the
@@ -536,8 +682,7 @@ def read_raw_brainvision(vhdr_fname, montage=None,
--------
mne.io.Raw : Documentation of attribute and methods.
"""
- raw = RawBrainVision(vhdr_fname=vhdr_fname, montage=montage, eog=eog,
- misc=misc, scale=scale,
- preload=preload, verbose=verbose, event_id=event_id,
- response_trig_shift=response_trig_shift)
- return raw
+ return RawBrainVision(vhdr_fname=vhdr_fname, montage=montage, eog=eog,
+ misc=misc, scale=scale, preload=preload,
+ response_trig_shift=response_trig_shift,
+ event_id=event_id, verbose=verbose)
diff --git a/mne/io/brainvision/tests/data/test.vhdr b/mne/io/brainvision/tests/data/test.vhdr
index 697ce92..2a52d72 100755
--- a/mne/io/brainvision/tests/data/test.vhdr
+++ b/mne/io/brainvision/tests/data/test.vhdr
@@ -46,12 +46,12 @@ Ch23=CP1,,0.5,µV
Ch24=CP2,,0.5,µV
Ch25=FC5,,0.5,µV
Ch26=FC6,,0.5,µV
-Ch27=CP5,,0.5,µV
-Ch28=CP6,,0.5,µV
-Ch29=HL,,0.5,µV
-Ch30=HR,,0.5,µV
-Ch31=Vb,,0.5,µV
-Ch32=ReRef,,0.5,µV
+Ch27=CP5,,0.5,BS
+Ch28=CP6,,0.5,µS
+Ch29=HL,,0.5,ARU
+Ch30=HR,,0.5,uS
+Ch31=Vb,,0.5,S
+Ch32=ReRef,,0.5,C
[Comment]
@@ -90,12 +90,12 @@ Channels
24 CP2 24 0.5 µV DC 250 Off 0
25 FC5 25 0.5 µV DC 250 Off 0
26 FC6 26 0.5 µV DC 250 Off 0
-27 CP5 27 0.5 µV DC 250 Off 0
-28 CP6 28 0.5 µV DC 250 Off 0
-29 HL 29 0.5 µV DC 250 Off 0
-30 HR 30 0.5 µV DC 250 Off 0
-31 Vb 31 0.5 µV DC 250 Off 0
-32 ReRef 32 0.5 µV DC 250 Off 0
+27 CP5 27 0.5 BS DC 250 Off 0
+28 CP6 28 0.5 µS DC 250 Off 0
+29 HL 29 0.5 ARU DC 250 Off 0
+30 HR 30 0.5 uS DC 250 Off 0
+31 Vb 31 0.5 S DC 250 Off 0
+32 ReRef 32 0.5 C DC 250 Off 0
S o f t w a r e F i l t e r s
==============================
diff --git a/mne/io/brainvision/tests/data/test_highpass_hz.vhdr b/mne/io/brainvision/tests/data/test_highpass_hz.vhdr
new file mode 100644
index 0000000..6315b44
--- /dev/null
+++ b/mne/io/brainvision/tests/data/test_highpass_hz.vhdr
@@ -0,0 +1,103 @@
+Brain Vision Data Exchange Header File Version 1.0
+; Data created by the Vision Recorder
+
+[Common Infos]
+Codepage=UTF-8
+DataFile=test.eeg
+MarkerFile=test.vmrk
+DataFormat=BINARY
+; Data orientation: MULTIPLEXED=ch1,pt1, ch2,pt1 ...
+DataOrientation=MULTIPLEXED
+NumberOfChannels=32
+; Sampling interval in microseconds
+SamplingInterval=1000
+
+[Binary Infos]
+BinaryFormat=INT_16
+
+[Channel Infos]
+; Each entry: Ch<Channel number>=<Name>,<Reference channel name>,
+; <Resolution in "Unit">,<Unit>, Future extensions..
+; Fields are delimited by commas, some fields might be omitted (empty).
+; Commas in channel names are coded as "\1".
+Ch1=FP1,,0.5,µV
+Ch2=FP2,,0.5,µV
+Ch3=F3,,0.5,µV
+Ch4=F4,,0.5,µV
+Ch5=C3,,0.5,µV
+Ch6=C4,,0.5,µV
+Ch7=P3,,0.5,µV
+Ch8=P4,,0.5,µV
+Ch9=O1,,0.5,µV
+Ch10=O2,,0.5,µV
+Ch11=F7,,0.5,µV
+Ch12=F8,,0.5,µV
+Ch13=P7,,0.5,µV
+Ch14=P8,,0.5,µV
+Ch15=Fz,,0.5,µV
+Ch16=FCz,,0.5,µV
+Ch17=Cz,,0.5,µV
+Ch18=CPz,,0.5,µV
+Ch19=Pz,,0.5,µV
+Ch20=POz,,0.5,µV
+Ch21=FC1,,0.5,µV
+Ch22=FC2,,0.5,µV
+Ch23=CP1,,0.5,µV
+Ch24=CP2,,0.5,µV
+Ch25=FC5,,0.5,µV
+Ch26=FC6,,0.5,µV
+Ch27=CP5,,0.5,µV
+Ch28=CP6,,0.5,µV
+Ch29=HL,,0.5,µV
+Ch30=HR,,0.5,µV
+Ch31=Vb,,0.5,µV
+Ch32=ReRef,,0.5,µV
+
+[Comment]
+
+A m p l i f i e r S e t u p
+============================
+Number of channels: 32
+Sampling Rate [Hz]: 1000
+Sampling Interval [µS]: 1000
+
+Channels
+--------
+# Name Phys. Chn. Resolution / Unit Low Cutoff [Hz] High Cutoff [Hz] Notch [Hz] Series Res. [kOhm] Gradient Offset
+1 FP1 1 0.5 µV 10 250 Off 0
+2 FP2 2 0.5 µV 10 250 Off 0
+3 F3 3 0.5 µV 10 250 Off 0
+4 F4 4 0.5 µV 10 250 Off 0
+5 C3 5 0.5 µV 10 250 Off 0
+6 C4 6 0.5 µV 10 250 Off 0
+7 P3 7 0.5 µV 10 250 Off 0
+8 P4 8 0.5 µV 10 250 Off 0
+9 O1 9 0.5 µV 10 250 Off 0
+10 O2 10 0.5 µV 10 250 Off 0
+11 F7 11 0.5 µV 10 250 Off 0
+12 F8 12 0.5 µV 10 250 Off 0
+13 P7 13 0.5 µV 10 250 Off 0
+14 P8 14 0.5 µV 10 250 Off 0
+15 Fz 15 0.5 µV 10 250 Off 0
+16 FCz 16 0.5 µV 10 250 Off 0
+17 Cz 17 0.5 µV 10 250 Off 0
+18 CPz 18 0.5 µV 10 250 Off 0
+19 Pz 19 0.5 µV 10 250 Off 0
+20 POz 20 0.5 µV 10 250 Off 0
+21 FC1 21 0.5 µV 10 250 Off 0
+22 FC2 22 0.5 µV 10 250 Off 0
+23 CP1 23 0.5 µV 10 250 Off 0
+24 CP2 24 0.5 µV 10 250 Off 0
+25 FC5 25 0.5 µV 10 250 Off 0
+26 FC6 26 0.5 µV 10 250 Off 0
+27 CP5 27 0.5 µV 10 250 Off 0
+28 CP6 28 0.5 µV 10 250 Off 0
+29 HL 29 0.5 µV 10 250 Off 0
+30 HR 30 0.5 µV 10 250 Off 0
+31 Vb 31 0.5 µV 10 250 Off 0
+32 ReRef 32 0.5 µV 10 250 Off 0
+
+S o f t w a r e F i l t e r s
+==============================
+Disabled
+
diff --git a/mne/io/brainvision/tests/data/test_lowpass_s.vhdr b/mne/io/brainvision/tests/data/test_lowpass_s.vhdr
new file mode 100755
index 0000000..483e290
--- /dev/null
+++ b/mne/io/brainvision/tests/data/test_lowpass_s.vhdr
@@ -0,0 +1,103 @@
+Brain Vision Data Exchange Header File Version 1.0
+; Data created by the Vision Recorder
+
+[Common Infos]
+Codepage=UTF-8
+DataFile=test.eeg
+MarkerFile=test.vmrk
+DataFormat=BINARY
+; Data orientation: MULTIPLEXED=ch1,pt1, ch2,pt1 ...
+DataOrientation=MULTIPLEXED
+NumberOfChannels=32
+; Sampling interval in microseconds
+SamplingInterval=1000
+
+[Binary Infos]
+BinaryFormat=INT_16
+
+[Channel Infos]
+; Each entry: Ch<Channel number>=<Name>,<Reference channel name>,
+; <Resolution in "Unit">,<Unit>, Future extensions..
+; Fields are delimited by commas, some fields might be omitted (empty).
+; Commas in channel names are coded as "\1".
+Ch1=FP1,,0.5,µV
+Ch2=FP2,,0.5,µV
+Ch3=F3,,0.5,µV
+Ch4=F4,,0.5,µV
+Ch5=C3,,0.5,µV
+Ch6=C4,,0.5,µV
+Ch7=P3,,0.5,µV
+Ch8=P4,,0.5,µV
+Ch9=O1,,0.5,µV
+Ch10=O2,,0.5,µV
+Ch11=F7,,0.5,µV
+Ch12=F8,,0.5,µV
+Ch13=P7,,0.5,µV
+Ch14=P8,,0.5,µV
+Ch15=Fz,,0.5,µV
+Ch16=FCz,,0.5,µV
+Ch17=Cz,,0.5,µV
+Ch18=CPz,,0.5,µV
+Ch19=Pz,,0.5,µV
+Ch20=POz,,0.5,µV
+Ch21=FC1,,0.5,µV
+Ch22=FC2,,0.5,µV
+Ch23=CP1,,0.5,µV
+Ch24=CP2,,0.5,µV
+Ch25=FC5,,0.5,µV
+Ch26=FC6,,0.5,µV
+Ch27=CP5,,0.5,µV
+Ch28=CP6,,0.5,µV
+Ch29=HL,,0.5,µV
+Ch30=HR,,0.5,µV
+Ch31=Vb,,0.5,µV
+Ch32=ReRef,,0.5,µV
+
+[Comment]
+
+A m p l i f i e r S e t u p
+============================
+Number of channels: 32
+Sampling Rate [Hz]: 1000
+Sampling Interval [µS]: 1000
+
+Channels
+--------
+# Name Phys. Chn. Resolution / Unit Low Cutoff [s] High Cutoff [s] Notch [Hz] Series Res. [kOhm] Gradient Offset
+1 FP1 1 0.5 µV 10 0.004 Off 0
+2 FP2 2 0.5 µV 10 0.004 Off 0
+3 F3 3 0.5 µV 10 0.004 Off 0
+4 F4 4 0.5 µV 10 0.004 Off 0
+5 C3 5 0.5 µV 10 0.004 Off 0
+6 C4 6 0.5 µV 10 0.004 Off 0
+7 P3 7 0.5 µV 10 0.004 Off 0
+8 P4 8 0.5 µV 10 0.004 Off 0
+9 O1 9 0.5 µV 10 0.004 Off 0
+10 O2 10 0.5 µV 10 0.004 Off 0
+11 F7 11 0.5 µV 10 0.004 Off 0
+12 F8 12 0.5 µV 10 0.004 Off 0
+13 P7 13 0.5 µV 10 0.004 Off 0
+14 P8 14 0.5 µV 10 0.004 Off 0
+15 Fz 15 0.5 µV 10 0.004 Off 0
+16 FCz 16 0.5 µV 10 0.004 Off 0
+17 Cz 17 0.5 µV 10 0.004 Off 0
+18 CPz 18 0.5 µV 10 0.004 Off 0
+19 Pz 19 0.5 µV 10 0.004 Off 0
+20 POz 20 0.5 µV 10 0.004 Off 0
+21 FC1 21 0.5 µV 10 0.004 Off 0
+22 FC2 22 0.5 µV 10 0.004 Off 0
+23 CP1 23 0.5 µV 10 0.004 Off 0
+24 CP2 24 0.5 µV 10 0.004 Off 0
+25 FC5 25 0.5 µV 10 0.004 Off 0
+26 FC6 26 0.5 µV 10 0.004 Off 0
+27 CP5 27 0.5 µV 10 0.004 Off 0
+28 CP6 28 0.5 µV 10 0.004 Off 0
+29 HL 29 0.5 µV 10 0.004 Off 0
+30 HR 30 0.5 µV 10 0.004 Off 0
+31 Vb 31 0.5 µV 10 0.004 Off 0
+32 ReRef 32 0.5 µV 10 0.004 Off 0
+
+S o f t w a r e F i l t e r s
+==============================
+Disabled
+
diff --git a/mne/io/brainvision/tests/data/test.vhdr b/mne/io/brainvision/tests/data/test_mixed_highpass.vhdr
similarity index 51%
copy from mne/io/brainvision/tests/data/test.vhdr
copy to mne/io/brainvision/tests/data/test_mixed_highpass.vhdr
index 697ce92..b9c175c 100755
--- a/mne/io/brainvision/tests/data/test.vhdr
+++ b/mne/io/brainvision/tests/data/test_mixed_highpass.vhdr
@@ -64,79 +64,40 @@ Sampling Interval [µS]: 1000
Channels
--------
# Name Phys. Chn. Resolution / Unit Low Cutoff [s] High Cutoff [Hz] Notch [Hz] Series Res. [kOhm] Gradient Offset
-1 FP1 1 0.5 µV DC 250 Off 0
-2 FP2 2 0.5 µV DC 250 Off 0
-3 F3 3 0.5 µV DC 250 Off 0
-4 F4 4 0.5 µV DC 250 Off 0
-5 C3 5 0.5 µV DC 250 Off 0
-6 C4 6 0.5 µV DC 250 Off 0
-7 P3 7 0.5 µV DC 250 Off 0
-8 P4 8 0.5 µV DC 250 Off 0
-9 O1 9 0.5 µV DC 250 Off 0
-10 O2 10 0.5 µV DC 250 Off 0
-11 F7 11 0.5 µV DC 250 Off 0
-12 F8 12 0.5 µV DC 250 Off 0
-13 P7 13 0.5 µV DC 250 Off 0
-14 P8 14 0.5 µV DC 250 Off 0
-15 Fz 15 0.5 µV DC 250 Off 0
-16 FCz 16 0.5 µV DC 250 Off 0
-17 Cz 17 0.5 µV DC 250 Off 0
-18 CPz 18 0.5 µV DC 250 Off 0
-19 Pz 19 0.5 µV DC 250 Off 0
-20 POz 20 0.5 µV DC 250 Off 0
-21 FC1 21 0.5 µV DC 250 Off 0
-22 FC2 22 0.5 µV DC 250 Off 0
-23 CP1 23 0.5 µV DC 250 Off 0
-24 CP2 24 0.5 µV DC 250 Off 0
-25 FC5 25 0.5 µV DC 250 Off 0
-26 FC6 26 0.5 µV DC 250 Off 0
-27 CP5 27 0.5 µV DC 250 Off 0
-28 CP6 28 0.5 µV DC 250 Off 0
-29 HL 29 0.5 µV DC 250 Off 0
-30 HR 30 0.5 µV DC 250 Off 0
-31 Vb 31 0.5 µV DC 250 Off 0
-32 ReRef 32 0.5 µV DC 250 Off 0
+1 FP1 1 0.5 µV 10 250 Off 0
+2 FP2 2 0.5 µV 10 250 Off 0
+3 F3 3 0.5 µV 10 250 Off 0
+4 F4 4 0.5 µV 10 250 Off 0
+5 C3 5 0.5 µV 10 250 Off 0
+6 C4 6 0.5 µV 10 250 Off 0
+7 P3 7 0.5 µV 10 250 Off 0
+8 P4 8 0.5 µV 10 250 Off 0
+9 O1 9 0.5 µV 10 250 Off 0
+10 O2 10 0.5 µV 10 250 Off 0
+11 F7 11 0.5 µV 10 250 Off 0
+12 F8 12 0.5 µV 10 250 Off 0
+13 P7 13 0.5 µV 10 250 Off 0
+14 P8 14 0.5 µV 10 250 Off 0
+15 Fz 15 0.5 µV 10 250 Off 0
+16 FCz 16 0.5 µV 10 250 Off 0
+17 Cz 17 0.5 µV 10 250 Off 0
+18 CPz 18 0.5 µV 10 250 Off 0
+19 Pz 19 0.5 µV 10 250 Off 0
+20 POz 20 0.5 µV 10 250 Off 0
+21 FC1 21 0.5 µV 10 250 Off 0
+22 FC2 22 0.5 µV 10 250 Off 0
+23 CP1 23 0.5 µV 10 250 Off 0
+24 CP2 24 0.5 µV 10 250 Off 0
+25 FC5 25 0.5 µV 10 250 Off 0
+26 FC6 26 0.5 µV 10 250 Off 0
+27 CP5 27 0.5 µV 10 250 Off 0
+28 CP6 28 0.5 µV 10 250 Off 0
+29 HL 29 0.5 µV 10 250 Off 0
+30 HR 30 0.5 µV 10 250 Off 0
+31 Vb 31 0.5 µV 10 250 Off 0
+32 ReRef 32 0.5 µV 5 250 Off 0
S o f t w a r e F i l t e r s
==============================
Disabled
-
-Data Electrodes Selected Impedance Measurement Range: 0 - 100 kOhm
-Ground Electrode Selected Impedance Measurement Range: 0 - 10 kOhm
-Reference Electrode Selected Impedance Measurement Range: 0 - 10 kOhm
-Impedance [kOhm] at 16:12:27 :
-FP1: ???
-FP2: ???
-F3: ???
-F4: ???
-C3: ???
-C4: ???
-P3: ???
-P4: ???
-O1: ???
-O2: ???
-F7: ???
-F8: ???
-P7: ???
-P8: ???
-Fz: ???
-FCz: ???
-Cz: ???
-CPz: ???
-Pz: ???
-POz: ???
-FC1: ???
-FC2: ???
-CP1: ???
-CP2: ???
-FC5: ???
-FC6: ???
-CP5: ???
-CP6: ???
-HL: ???
-HR: ???
-Vb: ???
-ReRef: ???
-Ref: 0
-Gnd: 4
diff --git a/mne/io/brainvision/tests/data/test_mixed_highpass_hz.vhdr b/mne/io/brainvision/tests/data/test_mixed_highpass_hz.vhdr
new file mode 100644
index 0000000..e178585
--- /dev/null
+++ b/mne/io/brainvision/tests/data/test_mixed_highpass_hz.vhdr
@@ -0,0 +1,103 @@
+Brain Vision Data Exchange Header File Version 1.0
+; Data created by the Vision Recorder
+
+[Common Infos]
+Codepage=UTF-8
+DataFile=test.eeg
+MarkerFile=test.vmrk
+DataFormat=BINARY
+; Data orientation: MULTIPLEXED=ch1,pt1, ch2,pt1 ...
+DataOrientation=MULTIPLEXED
+NumberOfChannels=32
+; Sampling interval in microseconds
+SamplingInterval=1000
+
+[Binary Infos]
+BinaryFormat=INT_16
+
+[Channel Infos]
+; Each entry: Ch<Channel number>=<Name>,<Reference channel name>,
+; <Resolution in "Unit">,<Unit>, Future extensions..
+; Fields are delimited by commas, some fields might be omitted (empty).
+; Commas in channel names are coded as "\1".
+Ch1=FP1,,0.5,µV
+Ch2=FP2,,0.5,µV
+Ch3=F3,,0.5,µV
+Ch4=F4,,0.5,µV
+Ch5=C3,,0.5,µV
+Ch6=C4,,0.5,µV
+Ch7=P3,,0.5,µV
+Ch8=P4,,0.5,µV
+Ch9=O1,,0.5,µV
+Ch10=O2,,0.5,µV
+Ch11=F7,,0.5,µV
+Ch12=F8,,0.5,µV
+Ch13=P7,,0.5,µV
+Ch14=P8,,0.5,µV
+Ch15=Fz,,0.5,µV
+Ch16=FCz,,0.5,µV
+Ch17=Cz,,0.5,µV
+Ch18=CPz,,0.5,µV
+Ch19=Pz,,0.5,µV
+Ch20=POz,,0.5,µV
+Ch21=FC1,,0.5,µV
+Ch22=FC2,,0.5,µV
+Ch23=CP1,,0.5,µV
+Ch24=CP2,,0.5,µV
+Ch25=FC5,,0.5,µV
+Ch26=FC6,,0.5,µV
+Ch27=CP5,,0.5,µV
+Ch28=CP6,,0.5,µV
+Ch29=HL,,0.5,µV
+Ch30=HR,,0.5,µV
+Ch31=Vb,,0.5,µV
+Ch32=ReRef,,0.5,µV
+
+[Comment]
+
+A m p l i f i e r S e t u p
+============================
+Number of channels: 32
+Sampling Rate [Hz]: 1000
+Sampling Interval [µS]: 1000
+
+Channels
+--------
+# Name Phys. Chn. Resolution / Unit Low Cutoff [Hz] High Cutoff [Hz] Notch [Hz] Series Res. [kOhm] Gradient Offset
+1 FP1 1 0.5 µV 10 250 Off 0
+2 FP2 2 0.5 µV 10 250 Off 0
+3 F3 3 0.5 µV 10 250 Off 0
+4 F4 4 0.5 µV 10 250 Off 0
+5 C3 5 0.5 µV 10 250 Off 0
+6 C4 6 0.5 µV 10 250 Off 0
+7 P3 7 0.5 µV 10 250 Off 0
+8 P4 8 0.5 µV 10 250 Off 0
+9 O1 9 0.5 µV 10 250 Off 0
+10 O2 10 0.5 µV 10 250 Off 0
+11 F7 11 0.5 µV 10 250 Off 0
+12 F8 12 0.5 µV 10 250 Off 0
+13 P7 13 0.5 µV 10 250 Off 0
+14 P8 14 0.5 µV 10 250 Off 0
+15 Fz 15 0.5 µV 10 250 Off 0
+16 FCz 16 0.5 µV 10 250 Off 0
+17 Cz 17 0.5 µV 10 250 Off 0
+18 CPz 18 0.5 µV 10 250 Off 0
+19 Pz 19 0.5 µV 10 250 Off 0
+20 POz 20 0.5 µV 10 250 Off 0
+21 FC1 21 0.5 µV 10 250 Off 0
+22 FC2 22 0.5 µV 10 250 Off 0
+23 CP1 23 0.5 µV 10 250 Off 0
+24 CP2 24 0.5 µV 10 250 Off 0
+25 FC5 25 0.5 µV 10 250 Off 0
+26 FC6 26 0.5 µV 10 250 Off 0
+27 CP5 27 0.5 µV 10 250 Off 0
+28 CP6 28 0.5 µV 10 250 Off 0
+29 HL 29 0.5 µV 10 250 Off 0
+30 HR 30 0.5 µV 10 250 Off 0
+31 Vb 31 0.5 µV 10 250 Off 0
+32 ReRef 32 0.5 µV 5 250 Off 0
+
+S o f t w a r e F i l t e r s
+==============================
+Disabled
+
diff --git a/mne/io/brainvision/tests/data/test.vhdr b/mne/io/brainvision/tests/data/test_mixed_lowpass.vhdr
old mode 100755
new mode 100644
similarity index 50%
copy from mne/io/brainvision/tests/data/test.vhdr
copy to mne/io/brainvision/tests/data/test_mixed_lowpass.vhdr
index 697ce92..4827608
--- a/mne/io/brainvision/tests/data/test.vhdr
+++ b/mne/io/brainvision/tests/data/test_mixed_lowpass.vhdr
@@ -64,79 +64,40 @@ Sampling Interval [µS]: 1000
Channels
--------
# Name Phys. Chn. Resolution / Unit Low Cutoff [s] High Cutoff [Hz] Notch [Hz] Series Res. [kOhm] Gradient Offset
-1 FP1 1 0.5 µV DC 250 Off 0
-2 FP2 2 0.5 µV DC 250 Off 0
-3 F3 3 0.5 µV DC 250 Off 0
-4 F4 4 0.5 µV DC 250 Off 0
-5 C3 5 0.5 µV DC 250 Off 0
-6 C4 6 0.5 µV DC 250 Off 0
-7 P3 7 0.5 µV DC 250 Off 0
-8 P4 8 0.5 µV DC 250 Off 0
-9 O1 9 0.5 µV DC 250 Off 0
-10 O2 10 0.5 µV DC 250 Off 0
-11 F7 11 0.5 µV DC 250 Off 0
-12 F8 12 0.5 µV DC 250 Off 0
-13 P7 13 0.5 µV DC 250 Off 0
-14 P8 14 0.5 µV DC 250 Off 0
-15 Fz 15 0.5 µV DC 250 Off 0
-16 FCz 16 0.5 µV DC 250 Off 0
-17 Cz 17 0.5 µV DC 250 Off 0
-18 CPz 18 0.5 µV DC 250 Off 0
-19 Pz 19 0.5 µV DC 250 Off 0
-20 POz 20 0.5 µV DC 250 Off 0
-21 FC1 21 0.5 µV DC 250 Off 0
-22 FC2 22 0.5 µV DC 250 Off 0
-23 CP1 23 0.5 µV DC 250 Off 0
-24 CP2 24 0.5 µV DC 250 Off 0
-25 FC5 25 0.5 µV DC 250 Off 0
-26 FC6 26 0.5 µV DC 250 Off 0
-27 CP5 27 0.5 µV DC 250 Off 0
-28 CP6 28 0.5 µV DC 250 Off 0
-29 HL 29 0.5 µV DC 250 Off 0
-30 HR 30 0.5 µV DC 250 Off 0
-31 Vb 31 0.5 µV DC 250 Off 0
-32 ReRef 32 0.5 µV DC 250 Off 0
+1 FP1 1 0.5 µV 10 250 Off 0
+2 FP2 2 0.5 µV 10 250 Off 0
+3 F3 3 0.5 µV 10 250 Off 0
+4 F4 4 0.5 µV 10 250 Off 0
+5 C3 5 0.5 µV 10 250 Off 0
+6 C4 6 0.5 µV 10 250 Off 0
+7 P3 7 0.5 µV 10 250 Off 0
+8 P4 8 0.5 µV 10 250 Off 0
+9 O1 9 0.5 µV 10 250 Off 0
+10 O2 10 0.5 µV 10 250 Off 0
+11 F7 11 0.5 µV 10 250 Off 0
+12 F8 12 0.5 µV 10 250 Off 0
+13 P7 13 0.5 µV 10 250 Off 0
+14 P8 14 0.5 µV 10 250 Off 0
+15 Fz 15 0.5 µV 10 250 Off 0
+16 FCz 16 0.5 µV 10 250 Off 0
+17 Cz 17 0.5 µV 10 250 Off 0
+18 CPz 18 0.5 µV 10 250 Off 0
+19 Pz 19 0.5 µV 10 250 Off 0
+20 POz 20 0.5 µV 10 250 Off 0
+21 FC1 21 0.5 µV 10 250 Off 0
+22 FC2 22 0.5 µV 10 250 Off 0
+23 CP1 23 0.5 µV 10 250 Off 0
+24 CP2 24 0.5 µV 10 250 Off 0
+25 FC5 25 0.5 µV 10 250 Off 0
+26 FC6 26 0.5 µV 10 250 Off 0
+27 CP5 27 0.5 µV 10 250 Off 0
+28 CP6 28 0.5 µV 10 250 Off 0
+29 HL 29 0.5 µV 10 250 Off 0
+30 HR 30 0.5 µV 10 250 Off 0
+31 Vb 31 0.5 µV 10 250 Off 0
+32 ReRef 32 0.5 µV 10 125 Off 0
S o f t w a r e F i l t e r s
==============================
Disabled
-
-Data Electrodes Selected Impedance Measurement Range: 0 - 100 kOhm
-Ground Electrode Selected Impedance Measurement Range: 0 - 10 kOhm
-Reference Electrode Selected Impedance Measurement Range: 0 - 10 kOhm
-Impedance [kOhm] at 16:12:27 :
-FP1: ???
-FP2: ???
-F3: ???
-F4: ???
-C3: ???
-C4: ???
-P3: ???
-P4: ???
-O1: ???
-O2: ???
-F7: ???
-F8: ???
-P7: ???
-P8: ???
-Fz: ???
-FCz: ???
-Cz: ???
-CPz: ???
-Pz: ???
-POz: ???
-FC1: ???
-FC2: ???
-CP1: ???
-CP2: ???
-FC5: ???
-FC6: ???
-CP5: ???
-CP6: ???
-HL: ???
-HR: ???
-Vb: ???
-ReRef: ???
-Ref: 0
-Gnd: 4
diff --git a/mne/io/brainvision/tests/data/test_mixed_lowpass_s.vhdr b/mne/io/brainvision/tests/data/test_mixed_lowpass_s.vhdr
new file mode 100644
index 0000000..41ddcad
--- /dev/null
+++ b/mne/io/brainvision/tests/data/test_mixed_lowpass_s.vhdr
@@ -0,0 +1,103 @@
+Brain Vision Data Exchange Header File Version 1.0
+; Data created by the Vision Recorder
+
+[Common Infos]
+Codepage=UTF-8
+DataFile=test.eeg
+MarkerFile=test.vmrk
+DataFormat=BINARY
+; Data orientation: MULTIPLEXED=ch1,pt1, ch2,pt1 ...
+DataOrientation=MULTIPLEXED
+NumberOfChannels=32
+; Sampling interval in microseconds
+SamplingInterval=1000
+
+[Binary Infos]
+BinaryFormat=INT_16
+
+[Channel Infos]
+; Each entry: Ch<Channel number>=<Name>,<Reference channel name>,
+; <Resolution in "Unit">,<Unit>, Future extensions..
+; Fields are delimited by commas, some fields might be omitted (empty).
+; Commas in channel names are coded as "\1".
+Ch1=FP1,,0.5,µV
+Ch2=FP2,,0.5,µV
+Ch3=F3,,0.5,µV
+Ch4=F4,,0.5,µV
+Ch5=C3,,0.5,µV
+Ch6=C4,,0.5,µV
+Ch7=P3,,0.5,µV
+Ch8=P4,,0.5,µV
+Ch9=O1,,0.5,µV
+Ch10=O2,,0.5,µV
+Ch11=F7,,0.5,µV
+Ch12=F8,,0.5,µV
+Ch13=P7,,0.5,µV
+Ch14=P8,,0.5,µV
+Ch15=Fz,,0.5,µV
+Ch16=FCz,,0.5,µV
+Ch17=Cz,,0.5,µV
+Ch18=CPz,,0.5,µV
+Ch19=Pz,,0.5,µV
+Ch20=POz,,0.5,µV
+Ch21=FC1,,0.5,µV
+Ch22=FC2,,0.5,µV
+Ch23=CP1,,0.5,µV
+Ch24=CP2,,0.5,µV
+Ch25=FC5,,0.5,µV
+Ch26=FC6,,0.5,µV
+Ch27=CP5,,0.5,µV
+Ch28=CP6,,0.5,µV
+Ch29=HL,,0.5,µV
+Ch30=HR,,0.5,µV
+Ch31=Vb,,0.5,µV
+Ch32=ReRef,,0.5,µV
+
+[Comment]
+
+A m p l i f i e r S e t u p
+============================
+Number of channels: 32
+Sampling Rate [Hz]: 1000
+Sampling Interval [µS]: 1000
+
+Channels
+--------
+# Name Phys. Chn. Resolution / Unit Low Cutoff [s] High Cutoff [s] Notch [Hz] Series Res. [kOhm] Gradient Offset
+1 FP1 1 0.5 µV 10 0.004 Off 0
+2 FP2 2 0.5 µV 10 0.004 Off 0
+3 F3 3 0.5 µV 10 0.004 Off 0
+4 F4 4 0.5 µV 10 0.004 Off 0
+5 C3 5 0.5 µV 10 0.004 Off 0
+6 C4 6 0.5 µV 10 0.004 Off 0
+7 P3 7 0.5 µV 10 0.004 Off 0
+8 P4 8 0.5 µV 10 0.004 Off 0
+9 O1 9 0.5 µV 10 0.004 Off 0
+10 O2 10 0.5 µV 10 0.004 Off 0
+11 F7 11 0.5 µV 10 0.004 Off 0
+12 F8 12 0.5 µV 10 0.004 Off 0
+13 P7 13 0.5 µV 10 0.004 Off 0
+14 P8 14 0.5 µV 10 0.004 Off 0
+15 Fz 15 0.5 µV 10 0.004 Off 0
+16 FCz 16 0.5 µV 10 0.004 Off 0
+17 Cz 17 0.5 µV 10 0.004 Off 0
+18 CPz 18 0.5 µV 10 0.004 Off 0
+19 Pz 19 0.5 µV 10 0.004 Off 0
+20 POz 20 0.5 µV 10 0.004 Off 0
+21 FC1 21 0.5 µV 10 0.004 Off 0
+22 FC2 22 0.5 µV 10 0.004 Off 0
+23 CP1 23 0.5 µV 10 0.004 Off 0
+24 CP2 24 0.5 µV 10 0.004 Off 0
+25 FC5 25 0.5 µV 10 0.004 Off 0
+26 FC6 26 0.5 µV 10 0.004 Off 0
+27 CP5 27 0.5 µV 10 0.004 Off 0
+28 CP6 28 0.5 µV 10 0.004 Off 0
+29 HL 29 0.5 µV 10 0.004 Off 0
+30 HR 30 0.5 µV 10 0.004 Off 0
+31 Vb 31 0.5 µV 10 0.004 Off 0
+32 ReRef 32 0.5 µV 10 0.008 Off 0
+
+S o f t w a r e F i l t e r s
+==============================
+Disabled
+
diff --git a/mne/io/brainvision/tests/data/test_old_layout_latin1_software_filter.eeg b/mne/io/brainvision/tests/data/test_old_layout_latin1_software_filter.eeg
new file mode 100644
index 0000000..b1b3710
Binary files /dev/null and b/mne/io/brainvision/tests/data/test_old_layout_latin1_software_filter.eeg differ
diff --git a/mne/io/brainvision/tests/data/test_old_layout_latin1_software_filter.vhdr b/mne/io/brainvision/tests/data/test_old_layout_latin1_software_filter.vhdr
new file mode 100644
index 0000000..cdb837f
--- /dev/null
+++ b/mne/io/brainvision/tests/data/test_old_layout_latin1_software_filter.vhdr
@@ -0,0 +1,156 @@
+Brain Vision Data Exchange Header File Version 1.0
+; Data created by the Vision Recorder and modified by hand
+
+[Common Infos]
+DataFile=test_old_layout_latin1_software_filter.eeg
+MarkerFile=test_old_layout_latin1_software_filter.vmrk
+DataFormat=BINARY
+; Data orientation: VECTORIZED=ch1,pt1, ch1,pt2...,MULTIPLEXED=ch1,pt1, ch2,pt1 ...
+DataOrientation=VECTORIZED
+NumberOfChannels=29
+; Sampling interval in microseconds
+SamplingInterval=4000
+
+[Binary Infos]
+BinaryFormat=IEEE_FLOAT_32
+
+[Channel Infos]
+; Each entry: Ch<Channel number>=<Name>,<Reference channel name>,
+; <Resolution in microvolts>,<Future extensions..
+; Fields are delimited by commas, some fields might be omitted (empty).
+; Commas in channel names are coded as "\1".
+Ch1=F7,,0.1
+Ch2=F3,,0.1
+Ch3=Fz,,0.1
+Ch4=F4,,0.1
+Ch5=F8,,0.1
+Ch6=FT7,,0.1
+Ch7=FC5,,0.1
+Ch8=FCz,,0.1
+Ch9=FC6,,0.1
+Ch10=FT8,,0.1
+Ch11=Cz,,0.1
+Ch12=C3,,0.1
+Ch13=CP5,,0.1
+Ch14=CPz,,0.1
+Ch15=CP6,,0.1
+Ch16=C4,,0.1
+Ch17=P7,,0.1
+Ch18=P3,,0.1
+Ch19=Pz,,0.1
+Ch20=P4,,0.1
+Ch21=P8,,0.1
+Ch22=POz,,0.1
+Ch23=O1,,0.1
+Ch24=O2,,0.1
+Ch25=A2,,0.1
+Ch26=VEOGo,,0.1
+Ch27=VEOGu,,0.1
+Ch28=HEOGli,,0.1
+Ch29=HEOGre,,0.1
+
+[Comment]
+
+A m p l i f i e r S e t u p
+============================
+Number of channels: 29
+Sampling Rate [Hz]: 250
+Sampling Interval [�S]: 4000
+
+Channels
+--------
+# Name Phys. Chn. Resolution [�V] Low Cutoff [s] High Cutoff [Hz] Notch [Hz]
+1 F7 1 0.1 10 1000 Off
+2 F3 2 0.1 10 1000 Off
+3 Fz 3 0.1 10 1000 Off
+4 F4 4 0.1 10 1000 Off
+5 F8 5 0.1 10 1000 Off
+6 FT7 6 0.1 10 1000 Off
+7 FC5 7 0.1 10 1000 Off
+8 FCz 8 0.1 10 1000 Off
+9 FC6 9 0.1 10 1000 Off
+10 FT8 10 0.1 10 1000 Off
+11 Cz 11 0.1 10 1000 Off
+12 C3 12 0.1 10 1000 Off
+13 CP5 13 0.1 10 1000 Off
+14 CPz 14 0.1 10 1000 Off
+15 CP6 15 0.1 10 1000 Off
+16 C4 16 0.1 10 1000 Off
+17 P7 17 0.1 10 1000 Off
+18 P3 18 0.1 10 1000 Off
+19 Pz 19 0.1 10 1000 Off
+20 P4 20 0.1 10 1000 Off
+21 P8 21 0.1 10 1000 Off
+22 POz 22 0.1 10 1000 Off
+23 O1 23 0.1 10 1000 Off
+24 O2 24 0.1 10 1000 Off
+25 A2 25 0.1 10 1000 Off
+26 VEOGo 26 0.1 10 1000 Off
+27 VEOGu 27 0.1 10 1000 Off
+28 HEOGli 28 0.1 10 1000 Off
+29 HEOGre 29 0.1 10 1000 Off
+
+S o f t w a r e F i l t e r s
+==============================
+# Low Cutoff [s] High Cutoff [Hz] Notch [Hz]
+1 0.9 50 50
+2 0.9 50 50
+3 0.9 50 50
+4 0.9 50 50
+5 0.9 50 50
+6 0.9 50 50
+7 0.9 50 50
+8 0.9 50 50
+9 0.9 50 50
+10 0.9 50 50
+11 0.9 50 50
+12 0.9 50 50
+13 0.9 50 50
+14 0.9 50 50
+15 0.9 50 50
+16 0.9 50 50
+17 0.9 50 50
+18 0.9 50 50
+19 0.9 50 50
+20 0.9 50 50
+21 0.9 50 50
+22 0.9 50 50
+23 0.9 50 50
+24 0.9 50 50
+25 0.9 50 50
+26 0.9 50 50
+27 0.9 50 50
+28 0.9 50 50
+29 0.9 50 50
+
+
+Impedance [kOhm] at 12:18:40 :
+F7: 0
+F3: 0
+Fz: 2
+F4: 1
+F8: 0
+FT7: 1
+FC5: 0
+FCz: 0
+FC6: 1
+FT8: 1
+Cz: 0
+C3: 0
+CP5: 2
+CPz: 0
+CP6: 0
+C4: 0
+P7: 1
+P3: 1
+Pz: 2
+P4: 0
+P8: 1
+POz: 3
+O1: 1
+O2: 3
+A2: 1
+VEOGo: 0
+VEOGu: 1
+HEOGli: 3
+HEOGre: 5
diff --git a/mne/io/brainvision/tests/data/test_old_layout_latin1_software_filter.vmrk b/mne/io/brainvision/tests/data/test_old_layout_latin1_software_filter.vmrk
new file mode 100644
index 0000000..c430411
--- /dev/null
+++ b/mne/io/brainvision/tests/data/test_old_layout_latin1_software_filter.vmrk
@@ -0,0 +1,14 @@
+Brain Vision Data Exchange Marker File, Version 1.0
+; Data created from the EEGLAB software and modified by hand
+; The channel numbers are related to the channels in the exported file.
+
+[Common Infos]
+DataFile=test_old_layout_latin1_software_filter.eeg
+
+[Marker Infos]
+; Each entry: Mk<Marker number>=<Type>,<Description>,<Position in data points>,
+; <Size in data points>, <Channel number (0 = marker is related to all channels)>,
+; <Date (YYYYMMDDhhmmssuuuuuu)>
+; Fields are delimited by commas, some fields might be omitted (empty).
+; Commas in type or description text are coded as "".
+Mk1=New Segment,,1,1,0,20070716122240937454
diff --git a/mne/io/brainvision/tests/data/test.vhdr b/mne/io/brainvision/tests/data/test_partially_disabled_hw_filter.vhdr
similarity index 74%
copy from mne/io/brainvision/tests/data/test.vhdr
copy to mne/io/brainvision/tests/data/test_partially_disabled_hw_filter.vhdr
index 697ce92..2c8d9bb 100755
--- a/mne/io/brainvision/tests/data/test.vhdr
+++ b/mne/io/brainvision/tests/data/test_partially_disabled_hw_filter.vhdr
@@ -46,12 +46,12 @@ Ch23=CP1,,0.5,µV
Ch24=CP2,,0.5,µV
Ch25=FC5,,0.5,µV
Ch26=FC6,,0.5,µV
-Ch27=CP5,,0.5,µV
-Ch28=CP6,,0.5,µV
-Ch29=HL,,0.5,µV
-Ch30=HR,,0.5,µV
-Ch31=Vb,,0.5,µV
-Ch32=ReRef,,0.5,µV
+Ch27=CP5,,0.5,BS
+Ch28=CP6,,0.5,µS
+Ch29=HL,,0.5,ARU
+Ch30=HR,,0.5,uS
+Ch31=Vb,,0.5,S
+Ch32=ReRef,,0.5,C
[Comment]
@@ -90,53 +90,14 @@ Channels
24 CP2 24 0.5 µV DC 250 Off 0
25 FC5 25 0.5 µV DC 250 Off 0
26 FC6 26 0.5 µV DC 250 Off 0
-27 CP5 27 0.5 µV DC 250 Off 0
-28 CP6 28 0.5 µV DC 250 Off 0
-29 HL 29 0.5 µV DC 250 Off 0
-30 HR 30 0.5 µV DC 250 Off 0
-31 Vb 31 0.5 µV DC 250 Off 0
-32 ReRef 32 0.5 µV DC 250 Off 0
+27 CP5 27 0.5 BS DC 250 Off 0
+28 CP6 28 0.5 µS DC 250 Off 0
+29 HL 29 0.5 ARU DC 250 Off 0
+30 HR 30 0.5 uS DC Off Off 0
+31 Vb 31 0.5 S Off 250 Off 0
+32 ReRef 32 0.5 C 10 250 Off 0
S o f t w a r e F i l t e r s
==============================
Disabled
-
-Data Electrodes Selected Impedance Measurement Range: 0 - 100 kOhm
-Ground Electrode Selected Impedance Measurement Range: 0 - 10 kOhm
-Reference Electrode Selected Impedance Measurement Range: 0 - 10 kOhm
-Impedance [kOhm] at 16:12:27 :
-FP1: ???
-FP2: ???
-F3: ???
-F4: ???
-C3: ???
-C4: ???
-P3: ???
-P4: ???
-O1: ???
-O2: ???
-F7: ???
-F8: ???
-P7: ???
-P8: ???
-Fz: ???
-FCz: ???
-Cz: ???
-CPz: ???
-Pz: ???
-POz: ???
-FC1: ???
-FC2: ???
-CP1: ???
-CP2: ???
-FC5: ???
-FC6: ???
-CP5: ???
-CP6: ???
-HL: ???
-HR: ???
-Vb: ???
-ReRef: ???
-Ref: 0
-Gnd: 4
diff --git a/mne/io/brainvision/tests/test_brainvision.py b/mne/io/brainvision/tests/test_brainvision.py
index 8882991..bbb4dc9 100644
--- a/mne/io/brainvision/tests/test_brainvision.py
+++ b/mne/io/brainvision/tests/test_brainvision.py
@@ -17,16 +17,37 @@ from numpy.testing import (assert_array_almost_equal, assert_array_equal,
from mne.utils import _TempDir, run_tests_if_main
from mne import pick_types, find_events
from mne.io.constants import FIFF
-from mne.io import Raw, read_raw_brainvision
+from mne.io import read_raw_fif, read_raw_brainvision
from mne.io.tests.test_raw import _test_raw_reader
FILE = inspect.getfile(inspect.currentframe())
data_dir = op.join(op.dirname(op.abspath(FILE)), 'data')
vhdr_path = op.join(data_dir, 'test.vhdr')
vmrk_path = op.join(data_dir, 'test.vmrk')
+
+vhdr_partially_disabled_hw_filter_path = op.join(data_dir,
+ 'test_partially_disabled'
+ '_hw_filter.vhdr')
+
+vhdr_old_path = op.join(data_dir,
+ 'test_old_layout_latin1_software_filter.vhdr')
+vmrk_old_path = op.join(data_dir,
+ 'test_old_layout_latin1_software_filter.vmrk')
+
vhdr_v2_path = op.join(data_dir, 'testv2.vhdr')
vmrk_v2_path = op.join(data_dir, 'testv2.vmrk')
+
vhdr_highpass_path = op.join(data_dir, 'test_highpass.vhdr')
+vhdr_mixed_highpass_path = op.join(data_dir, 'test_mixed_highpass.vhdr')
+vhdr_highpass_hz_path = op.join(data_dir, 'test_highpass_hz.vhdr')
+vhdr_mixed_highpass_hz_path = op.join(data_dir, 'test_mixed_highpass_hz.vhdr')
+
+# Not a typo: we can reuse the highpass file for the lowpass (Hz) test
+vhdr_lowpass_path = op.join(data_dir, 'test_highpass.vhdr')
+vhdr_mixed_lowpass_path = op.join(data_dir, 'test_mixed_lowpass.vhdr')
+vhdr_lowpass_s_path = op.join(data_dir, 'test_lowpass_s.vhdr')
+vhdr_mixed_lowpass_s_path = op.join(data_dir, 'test_mixed_lowpass_s.vhdr')
+
montage = op.join(data_dir, 'test.hpts')
eeg_bin = op.join(data_dir, 'test_bin_raw.fif')
eog = ['HL', 'HR', 'Vb']
@@ -34,9 +55,9 @@ eog = ['HL', 'HR', 'Vb']
warnings.simplefilter('always')
-def test_brainvision_data_filters():
- """Test reading raw Brain Vision files
- """
+def test_brainvision_data_highpass_filters():
+ """Test reading raw Brain Vision files with amplifier filter settings."""
+ # Homogeneous highpass in seconds (default measurement unit)
with warnings.catch_warnings(record=True) as w: # event parsing
raw = _test_raw_reader(
read_raw_brainvision, vhdr_fname=vhdr_highpass_path,
@@ -46,18 +67,170 @@ def test_brainvision_data_filters():
assert_equal(raw.info['highpass'], 0.1)
assert_equal(raw.info['lowpass'], 250.)
+ # Heterogeneous highpass in seconds (default measurement unit)
+ with warnings.catch_warnings(record=True) as w: # event parsing
+ raw = _test_raw_reader(
+ read_raw_brainvision, vhdr_fname=vhdr_mixed_highpass_path,
+ montage=montage, eog=eog)
+
+ trigger_warning = ['parse triggers that' in str(ww.message)
+ for ww in w]
+ lowpass_warning = ['different lowpass filters' in str(ww.message)
+ for ww in w]
+ highpass_warning = ['different highpass filters' in str(ww.message)
+ for ww in w]
-def test_brainvision_data():
- """Test reading raw Brain Vision files
+ expected_warnings = zip(trigger_warning, lowpass_warning, highpass_warning)
+
+ assert_true(all(any([trg, lp, hp]) for trg, lp, hp in expected_warnings))
+
+ assert_equal(raw.info['highpass'], 0.1)
+ assert_equal(raw.info['lowpass'], 250.)
+
+ # Homogeneous highpass in Hertz
+ with warnings.catch_warnings(record=True) as w: # event parsing
+ raw = _test_raw_reader(
+ read_raw_brainvision, vhdr_fname=vhdr_highpass_hz_path,
+ montage=montage, eog=eog)
+ assert_true(all('parse triggers that' in str(ww.message) for ww in w))
+
+ assert_equal(raw.info['highpass'], 10.)
+ assert_equal(raw.info['lowpass'], 250.)
+
+ # Heterogeneous highpass in Hertz
+ with warnings.catch_warnings(record=True) as w: # event parsing
+ raw = _test_raw_reader(
+ read_raw_brainvision, vhdr_fname=vhdr_mixed_highpass_hz_path,
+ montage=montage, eog=eog)
+
+ trigger_warning = ['parse triggers that' in str(ww.message)
+ for ww in w]
+ lowpass_warning = ['different lowpass filters' in str(ww.message)
+ for ww in w]
+ highpass_warning = ['different highpass filters' in str(ww.message)
+ for ww in w]
+
+ expected_warnings = zip(trigger_warning, lowpass_warning, highpass_warning)
+
+ assert_true(all(any([trg, lp, hp]) for trg, lp, hp in expected_warnings))
+
+ assert_equal(raw.info['highpass'], 5.)
+ assert_equal(raw.info['lowpass'], 250.)
+
+
+def test_brainvision_data_lowpass_filters():
+ """Test reading raw Brain Vision files with amplifier LP filter settings"""
+
+ # Homogeneous lowpass in Hertz (default measurement unit)
+ with warnings.catch_warnings(record=True) as w: # event parsing
+ raw = _test_raw_reader(
+ read_raw_brainvision, vhdr_fname=vhdr_lowpass_path,
+ montage=montage, eog=eog)
+ assert_true(all('parse triggers that' in str(ww.message) for ww in w))
+
+ assert_equal(raw.info['highpass'], 0.1)
+ assert_equal(raw.info['lowpass'], 250.)
+
+ # Heterogeneous lowpass in Hertz (default measurement unit)
+ with warnings.catch_warnings(record=True) as w: # event parsing
+ raw = _test_raw_reader(
+ read_raw_brainvision, vhdr_fname=vhdr_mixed_lowpass_path,
+ montage=montage, eog=eog)
+
+ trigger_warning = ['parse triggers that' in str(ww.message)
+ for ww in w]
+ lowpass_warning = ['different lowpass filters' in str(ww.message)
+ for ww in w]
+ highpass_warning = ['different highpass filters' in str(ww.message)
+ for ww in w]
+
+ expected_warnings = zip(trigger_warning, lowpass_warning, highpass_warning)
+
+ assert_true(all(any([trg, lp, hp]) for trg, lp, hp in expected_warnings))
+
+ assert_equal(raw.info['highpass'], 0.1)
+ assert_equal(raw.info['lowpass'], 250.)
+
+ # Homogeneous lowpass in seconds
+ with warnings.catch_warnings(record=True) as w: # event parsing
+ raw = _test_raw_reader(
+ read_raw_brainvision, vhdr_fname=vhdr_lowpass_s_path,
+ montage=montage, eog=eog)
+ assert_true(all('parse triggers that' in str(ww.message) for ww in w))
+
+ assert_equal(raw.info['highpass'], 0.1)
+ assert_equal(raw.info['lowpass'], 250.)
+
+ # Heterogeneous lowpass in seconds
+ with warnings.catch_warnings(record=True) as w: # event parsing
+ raw = _test_raw_reader(
+ read_raw_brainvision, vhdr_fname=vhdr_mixed_lowpass_s_path,
+ montage=montage, eog=eog)
+
+ trigger_warning = ['parse triggers that' in str(ww.message)
+ for ww in w]
+ lowpass_warning = ['different lowpass filters' in str(ww.message)
+ for ww in w]
+ highpass_warning = ['different highpass filters' in str(ww.message)
+ for ww in w]
+
+ expected_warnings = zip(trigger_warning, lowpass_warning, highpass_warning)
+
+ assert_true(all(any([trg, lp, hp]) for trg, lp, hp in expected_warnings))
+
+ assert_equal(raw.info['highpass'], 0.1)
+ assert_equal(raw.info['lowpass'], 250.)
+
+
+def test_brainvision_data_partially_disabled_hw_filters():
+ """Test reading raw Brain Vision files with heterogeneous amplifier
+ filter settings including non-numeric values
"""
+ with warnings.catch_warnings(record=True) as w: # event parsing
+ raw = _test_raw_reader(
+ read_raw_brainvision,
+ vhdr_fname=vhdr_partially_disabled_hw_filter_path,
+ montage=montage, eog=eog)
+
+ trigger_warning = ['parse triggers that' in str(ww.message)
+ for ww in w]
+ lowpass_warning = ['different lowpass filters' in str(ww.message)
+ for ww in w]
+ highpass_warning = ['different highpass filters' in str(ww.message)
+ for ww in w]
+
+ expected_warnings = zip(trigger_warning, lowpass_warning, highpass_warning)
+
+ assert_true(all(any([trg, lp, hp]) for trg, lp, hp in expected_warnings))
+
+ assert_equal(raw.info['highpass'], 0.)
+ assert_equal(raw.info['lowpass'], 500.)
+
+
+def test_brainvision_data_software_filters_latin1_global_units():
+ """Test reading raw Brain Vision files."""
+ with warnings.catch_warnings(record=True) as w: # event parsing
+ raw = _test_raw_reader(
+ read_raw_brainvision, vhdr_fname=vhdr_old_path,
+ eog=("VEOGo", "VEOGu", "HEOGli", "HEOGre"), misc=("A2",))
+ assert_true(all('software filter detected' in str(ww.message) for ww in w))
+
+ assert_equal(raw.info['highpass'], 1. / 0.9)
+ assert_equal(raw.info['lowpass'], 50.)
+
+
+def test_brainvision_data():
+ """Test reading raw Brain Vision files."""
assert_raises(IOError, read_raw_brainvision, vmrk_path)
assert_raises(ValueError, read_raw_brainvision, vhdr_path, montage,
preload=True, scale="foo")
+
with warnings.catch_warnings(record=True) as w: # event parsing
raw_py = _test_raw_reader(
read_raw_brainvision, vhdr_fname=vhdr_path, montage=montage,
- eog=eog)
+ eog=eog, misc='auto')
assert_true(all('parse triggers that' in str(ww.message) for ww in w))
+
assert_true('RawBrainVision' in repr(raw_py))
assert_equal(raw_py.info['highpass'], 0.)
@@ -67,7 +240,7 @@ def test_brainvision_data():
data_py, times_py = raw_py[picks]
# compare with a file that was generated using MNE-C
- raw_bin = Raw(eeg_bin, preload=True)
+ raw_bin = read_raw_fif(eeg_bin, preload=True, add_eeg_ref=False)
picks = pick_types(raw_py.info, meg=False, eeg=True, exclude='bads')
data_bin, times_bin = raw_bin[picks]
@@ -80,8 +253,15 @@ def test_brainvision_data():
assert_equal(ch['kind'], FIFF.FIFFV_EOG_CH)
elif ch['ch_name'] == 'STI 014':
assert_equal(ch['kind'], FIFF.FIFFV_STIM_CH)
+ elif ch['ch_name'] in ('CP5', 'CP6'):
+ assert_equal(ch['kind'], FIFF.FIFFV_MISC_CH)
+ assert_equal(ch['unit'], FIFF.FIFF_UNIT_NONE)
+ elif ch['ch_name'] == 'ReRef':
+ assert_equal(ch['kind'], FIFF.FIFFV_MISC_CH)
+ assert_equal(ch['unit'], FIFF.FIFF_UNIT_CEL)
elif ch['ch_name'] in raw_py.info['ch_names']:
assert_equal(ch['kind'], FIFF.FIFFV_EEG_CH)
+ assert_equal(ch['unit'], FIFF.FIFF_UNIT_V)
else:
raise RuntimeError("Unknown Channel: %s" % ch['ch_name'])
@@ -91,7 +271,7 @@ def test_brainvision_data():
def test_events():
- """Test reading and modifying events"""
+ """Test reading and modifying events."""
tempdir = _TempDir()
# check that events are read and stim channel is synthesized correcly
diff --git a/mne/io/bti/bti.py b/mne/io/bti/bti.py
index 2c25bf9..9e41ff5 100644
--- a/mne/io/bti/bti.py
+++ b/mne/io/bti/bti.py
@@ -12,7 +12,7 @@ from itertools import count
import numpy as np
-from ...utils import logger, verbose, sum_squared, warn
+from ...utils import logger, verbose, sum_squared
from ...transforms import (combine_transforms, invert_transform, apply_trans,
Transform)
from ..constants import FIFF
@@ -1303,10 +1303,11 @@ def _get_bti_info(pdf_fname, config_fname, head_shape_fname, rotation_x,
colcals=np.ones(mat.shape[1], dtype='>f4'),
save_calibrated=0)]
else:
- warn('Currently direct inclusion of 4D weight tables is not supported.'
- ' For critical use cases please take into account the MNE command'
- ' "mne_create_comp_data" to include weights as printed out by '
- 'the 4D "print_table" routine.')
+ logger.info(
+ 'Currently direct inclusion of 4D weight tables is not supported.'
+ ' For critical use cases please take into account the MNE command'
+ ' "mne_create_comp_data" to include weights as printed out by '
+ 'the 4D "print_table" routine.')
# check that the info is complete
info._update_redundant()
diff --git a/mne/io/bti/tests/test_bti.py b/mne/io/bti/tests/test_bti.py
index 50a9f6d..49c7719 100644
--- a/mne/io/bti/tests/test_bti.py
+++ b/mne/io/bti/tests/test_bti.py
@@ -5,7 +5,7 @@ from __future__ import print_function
import os
import os.path as op
-from functools import reduce
+from functools import reduce, partial
import warnings
import numpy as np
@@ -13,7 +13,7 @@ from numpy.testing import (assert_array_almost_equal, assert_array_equal,
assert_allclose)
from nose.tools import assert_true, assert_raises, assert_equal
-from mne.io import Raw, read_raw_bti
+from mne.io import read_raw_fif, read_raw_bti
from mne.io.bti.bti import (_read_config, _process_bti_headshape,
_read_bti_header, _get_bti_dev_t,
_correct_trans, _get_bti_info)
@@ -25,7 +25,6 @@ from mne import pick_types
from mne.utils import run_tests_if_main
from mne.transforms import Transform, combine_transforms, invert_transform
from mne.externals import six
-from mne.fixes import partial
warnings.simplefilter('always')
@@ -44,7 +43,7 @@ NCH = 248
def test_read_config():
- """ Test read bti config file """
+ """Test read bti config file."""
# for config in config_fname, config_solaris_fname:
for config in config_fnames:
cfg = _read_config(config)
@@ -53,12 +52,10 @@ def test_read_config():
def test_crop_append():
- """ Test crop and append raw """
- with warnings.catch_warnings(record=True): # preload warning
- warnings.simplefilter('always')
- raw = _test_raw_reader(
- read_raw_bti, pdf_fname=pdf_fnames[0],
- config_fname=config_fnames[0], head_shape_fname=hs_fnames[0])
+ """Test crop and append raw."""
+ raw = _test_raw_reader(
+ read_raw_bti, pdf_fname=pdf_fnames[0],
+ config_fname=config_fnames[0], head_shape_fname=hs_fnames[0])
y, t = raw[:]
t0, t1 = 0.25 * t[-1], 0.75 * t[-1]
mask = (t0 <= t) * (t <= t1)
@@ -69,12 +66,11 @@ def test_crop_append():
def test_transforms():
- """ Test transformations """
+ """Test transformations."""
bti_trans = (0.0, 0.02, 0.11)
bti_dev_t = Transform('ctf_meg', 'meg', _get_bti_dev_t(0.0, bti_trans))
for pdf, config, hs, in zip(pdf_fnames, config_fnames, hs_fnames):
- with warnings.catch_warnings(record=True): # weight tables
- raw = read_raw_bti(pdf, config, hs, preload=False)
+ raw = read_raw_bti(pdf, config, hs, preload=False)
dev_ctf_t = raw.info['dev_ctf_t']
dev_head_t_old = raw.info['dev_head_t']
ctf_head_t = raw.info['ctf_head_t']
@@ -92,7 +88,7 @@ def test_transforms():
def test_raw():
- """ Test bti conversion to Raw object """
+ """Test bti conversion to Raw object."""
for pdf, config, hs, exported in zip(pdf_fnames, config_fnames, hs_fnames,
exported_fnames):
# rx = 2 if 'linux' in pdf else 0
@@ -101,15 +97,13 @@ def test_raw():
preload=False)
if op.exists(tmp_raw_fname):
os.remove(tmp_raw_fname)
- ex = Raw(exported, preload=True)
- with warnings.catch_warnings(record=True): # weight tables
- ra = read_raw_bti(pdf, config, hs, preload=False)
+ ex = read_raw_fif(exported, preload=True, add_eeg_ref=False)
+ ra = read_raw_bti(pdf, config, hs, preload=False)
assert_true('RawBTi' in repr(ra))
assert_equal(ex.ch_names[:NCH], ra.ch_names[:NCH])
assert_array_almost_equal(ex.info['dev_head_t']['trans'],
ra.info['dev_head_t']['trans'], 7)
- with warnings.catch_warnings(record=True): # headshape
- assert_dig_allclose(ex.info, ra.info)
+ assert_dig_allclose(ex.info, ra.info)
coil1, coil2 = [np.concatenate([d['loc'].flatten()
for d in r_.info['chs'][:NCH]])
for r_ in (ra, ex)]
@@ -138,7 +132,7 @@ def test_raw():
ra.info[key][ent])
ra.save(tmp_raw_fname)
- re = Raw(tmp_raw_fname)
+ re = read_raw_fif(tmp_raw_fname, add_eeg_ref=False)
print(re)
for key in ('dev_head_t', 'dev_ctf_t', 'ctf_head_t'):
assert_true(isinstance(re.info[key], dict))
@@ -150,19 +144,18 @@ def test_raw():
def test_info_no_rename_no_reorder_no_pdf():
- """ Test private renaming, reordering and partial construction option """
+ """Test private renaming, reordering and partial construction option."""
for pdf, config, hs in zip(pdf_fnames, config_fnames, hs_fnames):
- with warnings.catch_warnings(record=True): # weight tables
- info, bti_info = _get_bti_info(
- pdf_fname=pdf, config_fname=config, head_shape_fname=hs,
- rotation_x=0.0, translation=(0.0, 0.02, 0.11), convert=False,
- ecg_ch='E31', eog_ch=('E63', 'E64'),
- rename_channels=False, sort_by_ch_name=False)
- info2, bti_info = _get_bti_info(
- pdf_fname=None, config_fname=config, head_shape_fname=hs,
- rotation_x=0.0, translation=(0.0, 0.02, 0.11), convert=False,
- ecg_ch='E31', eog_ch=('E63', 'E64'),
- rename_channels=False, sort_by_ch_name=False)
+ info, bti_info = _get_bti_info(
+ pdf_fname=pdf, config_fname=config, head_shape_fname=hs,
+ rotation_x=0.0, translation=(0.0, 0.02, 0.11), convert=False,
+ ecg_ch='E31', eog_ch=('E63', 'E64'),
+ rename_channels=False, sort_by_ch_name=False)
+ info2, bti_info = _get_bti_info(
+ pdf_fname=None, config_fname=config, head_shape_fname=hs,
+ rotation_x=0.0, translation=(0.0, 0.02, 0.11), convert=False,
+ ecg_ch='E31', eog_ch=('E63', 'E64'),
+ rename_channels=False, sort_by_ch_name=False)
assert_equal(info['ch_names'],
[ch['ch_name'] for ch in info['chs']])
@@ -196,15 +189,14 @@ def test_info_no_rename_no_reorder_no_pdf():
np.array([ch['loc'] for ch in info2['chs']]))
# just check reading data | corner case
- with warnings.catch_warnings(record=True): # weight tables
- raw1 = read_raw_bti(
- pdf_fname=pdf, config_fname=config, head_shape_fname=None,
- sort_by_ch_name=False, preload=True)
- # just check reading data | corner case
- raw2 = read_raw_bti(
- pdf_fname=pdf, config_fname=config, head_shape_fname=None,
- rename_channels=False,
- sort_by_ch_name=True, preload=True)
+ raw1 = read_raw_bti(
+ pdf_fname=pdf, config_fname=config, head_shape_fname=None,
+ sort_by_ch_name=False, preload=True)
+ # just check reading data | corner case
+ raw2 = read_raw_bti(
+ pdf_fname=pdf, config_fname=config, head_shape_fname=None,
+ rename_channels=False,
+ sort_by_ch_name=True, preload=True)
sort_idx = [raw1.bti_ch_labels.index(ch) for ch in raw2.bti_ch_labels]
raw1._data = raw1._data[sort_idx]
@@ -213,8 +205,7 @@ def test_info_no_rename_no_reorder_no_pdf():
def test_no_conversion():
- """ Test bti no-conversion option """
-
+ """Test bti no-conversion option."""
get_info = partial(
_get_bti_info,
rotation_x=0.0, translation=(0.0, 0.02, 0.11), convert=False,
@@ -222,12 +213,10 @@ def test_no_conversion():
rename_channels=False, sort_by_ch_name=False)
for pdf, config, hs in zip(pdf_fnames, config_fnames, hs_fnames):
- with warnings.catch_warnings(record=True): # weight tables
- raw_info, _ = get_info(pdf, config, hs, convert=False)
- with warnings.catch_warnings(record=True): # weight tables
- raw_info_con = read_raw_bti(
- pdf_fname=pdf, config_fname=config, head_shape_fname=hs,
- convert=True, preload=False).info
+ raw_info, _ = get_info(pdf, config, hs, convert=False)
+ raw_info_con = read_raw_bti(
+ pdf_fname=pdf, config_fname=config, head_shape_fname=hs,
+ convert=True, preload=False).info
pick_info(raw_info_con,
pick_types(raw_info_con, meg=True, ref_meg=True),
@@ -273,10 +262,9 @@ def test_no_conversion():
def test_bytes_io():
- """ Test bti bytes-io API """
+ """Test bti bytes-io API."""
for pdf, config, hs in zip(pdf_fnames, config_fnames, hs_fnames):
- with warnings.catch_warnings(record=True): # weight tables
- raw = read_raw_bti(pdf, config, hs, convert=True, preload=False)
+ raw = read_raw_bti(pdf, config, hs, convert=True, preload=False)
with open(pdf, 'rb') as fid:
pdf = six.BytesIO(fid.read())
@@ -284,14 +272,14 @@ def test_bytes_io():
config = six.BytesIO(fid.read())
with open(hs, 'rb') as fid:
hs = six.BytesIO(fid.read())
- with warnings.catch_warnings(record=True): # weight tables
- raw2 = read_raw_bti(pdf, config, hs, convert=True, preload=False)
+
+ raw2 = read_raw_bti(pdf, config, hs, convert=True, preload=False)
repr(raw2)
assert_array_equal(raw[:][0], raw2[:][0])
def test_setup_headshape():
- """ Test reading bti headshape """
+ """Test reading bti headshape."""
for hs in hs_fnames:
dig, t = _process_bti_headshape(hs)
expected = set(['kind', 'ident', 'r'])
diff --git a/mne/io/cnt/cnt.py b/mne/io/cnt/cnt.py
index db472f9..80413b8 100644
--- a/mne/io/cnt/cnt.py
+++ b/mne/io/cnt/cnt.py
@@ -20,7 +20,8 @@ from ..utils import read_str
def read_raw_cnt(input_fname, montage, eog=(), misc=(), ecg=(), emg=(),
- preload=False, verbose=None):
+ data_format='auto', date_format='mm/dd/yy', preload=False,
+ verbose=None):
"""Read CNT data as raw object.
.. Note::
@@ -49,18 +50,25 @@ def read_raw_cnt(input_fname, montage, eog=(), misc=(), ecg=(), emg=(),
EOG channels. If 'header', VEOG and HEOG channels assigned in the file
header are used. If 'auto', channel names containing 'EOG' are used.
Defaults to empty tuple.
- misc : list or tuple
+ misc : list | tuple
Names of channels or list of indices that should be designated
MISC channels. Defaults to empty tuple.
- ecg : list or tuple | 'auto'
+ ecg : list | tuple | 'auto'
Names of channels or list of indices that should be designated
ECG channels. If 'auto', the channel names containing 'ECG' are used.
Defaults to empty tuple.
- emg : list or tuple
+ emg : list | tuple
Names of channels or list of indices that should be designated
EMG channels. If 'auto', the channel names containing 'EMG' are used.
Defaults to empty tuple.
- preload : bool or str (default False)
+ data_format : 'auto' | 'int16' | 'int32'
+ Defines the data format the data is read in. If 'auto', it is
+ determined from the file header using ``numsamples`` field.
+ Defaults to 'auto'.
+ date_format : str
+ Format of date in the header. Currently supports 'mm/dd/yy' (default)
+ and 'dd/mm/yy'.
+ preload : bool | str (default False)
Preload data into memory for data manipulation and faster indexing.
If True, the data will be preloaded into memory (fast, requires
large amount of memory). If preload is a string, preload is the
@@ -83,10 +91,11 @@ def read_raw_cnt(input_fname, montage, eog=(), misc=(), ecg=(), emg=(),
.. versionadded:: 0.12
"""
return RawCNT(input_fname, montage=montage, eog=eog, misc=misc, ecg=ecg,
- emg=emg, preload=preload, verbose=verbose)
+ emg=emg, data_format=data_format, date_format=date_format,
+ preload=preload, verbose=verbose)
-def _get_cnt_info(input_fname, eog, ecg, emg, misc):
+def _get_cnt_info(input_fname, eog, ecg, emg, misc, data_format, date_format):
"""Helper for reading the cnt header."""
data_offset = 900 # Size of the 'SETUP' header.
cnt_info = dict()
@@ -129,6 +138,14 @@ def _get_cnt_info(input_fname, eog, ecg, emg, misc):
date[2] = '20' + date[2]
time = time.split(':')
if len(time) == 3:
+ if date_format == 'mm/dd/yy':
+ pass
+ elif date_format == 'dd/mm/yy':
+ date[0], date[1] = date[1], date[0]
+ else:
+ raise ValueError("Only date formats 'mm/dd/yy' and "
+ "'dd/mm/yy' supported. "
+ "Got '%s'." % date_format)
# Assuming mm/dd/yy
date = datetime.datetime(int(date[2]), int(date[0]),
int(date[1]), int(time[0]),
@@ -151,6 +168,11 @@ def _get_cnt_info(input_fname, eog, ecg, emg, misc):
fid.seek(438)
lowpass_toggle = np.fromfile(fid, 'i1', count=1)[0]
highpass_toggle = np.fromfile(fid, 'i1', count=1)[0]
+
+ # Header has a field for number of samples, but it does not seem to be
+ # too reliable. That's why we have option for setting n_bytes manually.
+ fid.seek(864)
+ n_samples = np.fromfile(fid, dtype='<i4', count=1)[0]
fid.seek(869)
lowcutoff = np.fromfile(fid, dtype='f4', count=1)[0]
fid.seek(2, 1)
@@ -160,14 +182,29 @@ def _get_cnt_info(input_fname, eog, ecg, emg, misc):
event_offset = np.fromfile(fid, dtype='<i4', count=1)[0]
cnt_info['continuous_seconds'] = np.fromfile(fid, dtype='<f4',
count=1)[0]
+
+ data_size = event_offset - (900 + 75 * n_channels)
+ if data_format == 'auto':
+ if (n_samples == 0 or
+ data_size // (n_samples * n_channels) not in [2, 4]):
+ warn('Could not define the number of bytes automatically. '
+ 'Defaulting to 2.')
+ n_bytes = 2
+ n_samples = data_size // (n_bytes * n_channels)
+ else:
+ n_bytes = data_size // (n_samples * n_channels)
+ else:
+ if data_format not in ['int16', 'int32']:
+ raise ValueError("data_format should be 'auto', 'int16' or "
+ "'int32'. Got %s." % data_format)
+ n_bytes = 2 if data_format == 'int16' else 4
+ n_samples = data_size // (n_bytes * n_channels)
# Channel offset refers to the size of blocks per channel in the file.
cnt_info['channel_offset'] = np.fromfile(fid, dtype='<i4', count=1)[0]
if cnt_info['channel_offset'] > 1:
- cnt_info['channel_offset'] //= 2 # Data read as 2 byte ints.
+ cnt_info['channel_offset'] //= n_bytes
else:
cnt_info['channel_offset'] = 1
- n_samples = (event_offset - (900 + 75 * n_channels)) // (2 *
- n_channels)
ch_names, cals, baselines, chs, pos = (list(), list(), list(), list(),
list())
bads = list()
@@ -208,7 +245,8 @@ def _get_cnt_info(input_fname, eog, ecg, emg, misc):
event_id = np.fromfile(fid, dtype='u2', count=1)[0]
fid.seek(event_offset + 9 + i * event_bytes + 4)
offset = np.fromfile(fid, dtype='<i4', count=1)[0]
- event_time = (offset - 900 - 75 * n_channels) // (n_channels * 2)
+ event_time = (offset - 900 - 75 * n_channels) // (n_channels *
+ n_bytes)
stim_channel[event_time - 1] = event_id
info = _empty_info(sfreq)
@@ -245,7 +283,7 @@ def _get_cnt_info(input_fname, eog, ecg, emg, misc):
chs.append(chan_info)
baselines.append(0) # For stim channel
cnt_info.update(baselines=np.array(baselines), n_samples=n_samples,
- stim_channel=stim_channel)
+ stim_channel=stim_channel, n_bytes=n_bytes)
info.update(filename=input_fname, meas_date=np.array([meas_date, 0]),
description=str(session_label), buffer_size_sec=10., bads=bads,
subject_info=subject_info, chs=chs)
@@ -281,18 +319,25 @@ class RawCNT(_BaseRaw):
Names of channels or list of indices that should be designated
EOG channels. If 'auto', the channel names beginning with
``EOG`` are used. Defaults to empty tuple.
- misc : list or tuple
+ misc : list | tuple
Names of channels or list of indices that should be designated
MISC channels. Defaults to empty tuple.
- ecg : list or tuple
+ ecg : list | tuple
Names of channels or list of indices that should be designated
ECG channels. If 'auto', the channel names beginning with
``ECG`` are used. Defaults to empty tuple.
- emg : list or tuple
+ emg : list | tuple
Names of channels or list of indices that should be designated
EMG channels. If 'auto', the channel names beginning with
``EMG`` are used. Defaults to empty tuple.
- preload : bool or str (default False)
+ data_format : 'auto' | 'int16' | 'int32'
+ Defines the data format the data is read in. If 'auto', it is
+ determined from the file header using ``numsamples`` field.
+ Defaults to 'auto'.
+ date_format : str
+ Format of date in the header. Currently supports 'mm/dd/yy' (default)
+ and 'dd/mm/yy'.
+ preload : bool | str (default False)
Preload data into memory for data manipulation and faster indexing.
If True, the data will be preloaded into memory (fast, requires
large amount of memory). If preload is a string, preload is the
@@ -306,9 +351,11 @@ class RawCNT(_BaseRaw):
mne.io.Raw : Documentation of attribute and methods.
"""
def __init__(self, input_fname, montage, eog=(), misc=(), ecg=(), emg=(),
- preload=False, verbose=None):
+ data_format='auto', date_format='mm/dd/yy', preload=False,
+ verbose=None):
input_fname = path.abspath(input_fname)
- info, cnt_info = _get_cnt_info(input_fname, eog, ecg, emg, misc)
+ info, cnt_info = _get_cnt_info(input_fname, eog, ecg, emg, misc,
+ data_format, date_format)
last_samps = [cnt_info['n_samples'] - 1]
_check_update_montage(info, montage)
super(RawCNT, self).__init__(
@@ -323,7 +370,8 @@ class RawCNT(_BaseRaw):
channel_offset = self._raw_extras[0]['channel_offset']
baselines = self._raw_extras[0]['baselines']
stim_ch = self._raw_extras[0]['stim_channel']
- n_bytes = 2
+ n_bytes = self._raw_extras[0]['n_bytes']
+ dtype = '<i4' if n_bytes == 4 else '<i2'
sel = np.arange(n_channels + 1)[idx]
chunk_size = channel_offset * n_channels # Size of chunks in file.
# The data is divided into blocks of samples / channel.
@@ -354,7 +402,7 @@ class RawCNT(_BaseRaw):
extra_samps += chunk_size
count = n_samps // channel_offset * chunk_size + extra_samps
n_chunks = count // chunk_size
- samps = np.fromfile(fid, dtype='<i2', count=count)
+ samps = np.fromfile(fid, dtype=dtype, count=count)
samps = samps.reshape((n_chunks, n_channels, channel_offset),
order='C')
# Intermediate shaping to chunk sizes.
diff --git a/mne/io/cnt/tests/test_cnt.py b/mne/io/cnt/tests/test_cnt.py
index 7175a11..faf356b 100644
--- a/mne/io/cnt/tests/test_cnt.py
+++ b/mne/io/cnt/tests/test_cnt.py
@@ -8,7 +8,7 @@ import warnings
from nose.tools import assert_equal, assert_true
-import mne
+from mne import pick_types
from mne.utils import run_tests_if_main
from mne.datasets import testing
from mne.io.tests.test_raw import _test_raw_reader
@@ -26,8 +26,9 @@ def test_data():
with warnings.catch_warnings(record=True) as w:
raw = _test_raw_reader(read_raw_cnt, montage=None, input_fname=fname,
eog='auto', misc=['NA1', 'LEFT_EAR'])
- assert_true(all('meas date' in str(ww.message) for ww in w))
- eog_chs = mne.pick_types(raw.info, eog=True, exclude=[])
+ assert_true(all('meas date' in str(ww.message) or
+ 'number of bytes' in str(ww.message) for ww in w))
+ eog_chs = pick_types(raw.info, eog=True, exclude=[])
assert_equal(len(eog_chs), 2) # test eog='auto'
assert_equal(raw.info['bads'], ['LEFT_EAR', 'VEOGR']) # test bads
diff --git a/mne/io/compensator.py b/mne/io/compensator.py
index 8746c6e..70060d5 100644
--- a/mne/io/compensator.py
+++ b/mne/io/compensator.py
@@ -1,4 +1,5 @@
import numpy as np
+from scipy import linalg
from .constants import FIFF
@@ -29,11 +30,11 @@ def set_current_comp(info, comp):
chan['coil_type'] = int(rem + (comp << 16))
-def _make_compensator(info, kind):
+def _make_compensator(info, grade):
"""Auxiliary function for make_compensator
"""
for k in range(len(info['comps'])):
- if info['comps'][k]['kind'] == kind:
+ if info['comps'][k]['kind'] == grade:
this_data = info['comps'][k]['data']
# Create the preselector
@@ -61,8 +62,8 @@ def _make_compensator(info, kind):
this_comp = np.dot(postsel, np.dot(this_data['data'], presel))
return this_comp
- raise ValueError('Desired compensation matrix (kind = %d) not'
- ' found' % kind)
+ raise ValueError('Desired compensation matrix (grade = %d) not'
+ ' found' % grade)
def make_compensator(info, from_, to, exclude_comp_chs=False):
@@ -91,20 +92,26 @@ def make_compensator(info, from_, to, exclude_comp_chs=False):
if from_ == to:
return None
- if from_ == 0:
- C1 = np.zeros((info['nchan'], info['nchan']))
- else:
- C1 = _make_compensator(info, from_)
-
- if to == 0:
- C2 = np.zeros((info['nchan'], info['nchan']))
- else:
- C2 = _make_compensator(info, to)
-
# s_orig = s_from + C1*s_from = (I + C1)*s_from
# s_to = s_orig - C2*s_orig = (I - C2)*s_orig
# s_to = (I - C2)*(I + C1)*s_from = (I + C1 - C2 - C2*C1)*s_from
- comp = np.eye(info['nchan']) + C1 - C2 - np.dot(C2, C1)
+ if from_ != 0:
+ C1 = _make_compensator(info, from_)
+ comp_from_0 = linalg.inv(np.eye(info['nchan']) - C1)
+ if to != 0:
+ C2 = _make_compensator(info, to)
+ comp_0_to = np.eye(info['nchan']) - C2
+ if from_ != 0:
+ if to != 0:
+ # This is mathematically equivalent, but has higher numerical
+ # error than using the inverse to always go to zero and back
+ # comp = np.eye(info['nchan']) + C1 - C2 - np.dot(C2, C1)
+ comp = np.dot(comp_0_to, comp_from_0)
+ else:
+ comp = comp_from_0
+ else:
+ # from == 0, to != 0 guaranteed here
+ comp = comp_0_to
if exclude_comp_chs:
pick = [k for k, c in enumerate(info['chs'])
diff --git a/mne/io/constants.py b/mne/io/constants.py
index 64d1813..e72eb52 100644
--- a/mne/io/constants.py
+++ b/mne/io/constants.py
@@ -182,6 +182,7 @@ FIFF.FIFFV_IAS_CH = 910 # Internal Active Shielding data (maybe on Triux
FIFF.FIFFV_EXCI_CH = 920 # flux excitation channel used to be a stimulus channel
FIFF.FIFFV_DIPOLE_WAVE = 1000 # Dipole time curve (xplotter/xfit)
FIFF.FIFFV_GOODNESS_FIT = 1001 # Goodness of fit (xplotter/xfit)
+FIFF.FIFFV_FNIRS_CH = 1100 # Functional near-infrared spectroscopy
#
# Quaternion channels for head position monitoring
@@ -500,7 +501,7 @@ FIFF.FIFF_MNE_EXTERNAL_LITTLE_ENDIAN = 3553 # Reference to an external binar
# 3560... Miscellaneous
#
FIFF.FIFF_MNE_PROJ_ITEM_ACTIVE = 3560 # Is this projection item active?
-FIFF.FIFF_MNE_EVENT_LIST = 3561 # An event list (for STI 014)
+FIFF.FIFF_MNE_EVENT_LIST = 3561 # An event list (for STI101 / STI 014)
FIFF.FIFF_MNE_HEMI = 3562 # Hemisphere association for general purposes
FIFF.FIFF_MNE_DATA_SKIP_NOP = 3563 # A data skip turned off in the raw data
FIFF.FIFF_MNE_ORIG_CH_INFO = 3564 # Channel information before any changes
@@ -536,6 +537,10 @@ FIFF.FIFF_MNE_ICA_MATRIX = 3607 # ICA unmixing matrix
FIFF.FIFF_MNE_ICA_BADS = 3608 # ICA bad sources
FIFF.FIFF_MNE_ICA_MISC_PARAMS = 3609 # ICA misc params
#
+# Miscellaneous
+#
+FIFF.FIFF_MNE_KIT_SYSTEM_ID = 3612 # Unique ID assigned to KIT systems
+#
# Maxfilter tags
#
FIFF.FIFF_SSS_FRAME = 263
@@ -634,30 +639,7 @@ FIFF.FIFFV_MNE_COORD_FS_TAL = 2006 # FreeSurfer Talairach coordinat
#
FIFF.FIFFV_MNE_COORD_4D_HEAD = FIFF.FIFFV_MNE_COORD_CTF_HEAD
FIFF.FIFFV_MNE_COORD_KIT_HEAD = FIFF.FIFFV_MNE_COORD_CTF_HEAD
-#
-# KIT system coil types
-#
-FIFF.FIFFV_COIL_KIT_GRAD = 6001
-FIFF.FIFFV_COIL_KIT_REF_MAG = 6002
-#
-# CTF coil and channel types
-#
-FIFF.FIFFV_COIL_CTF_GRAD = 5001
-FIFF.FIFFV_COIL_CTF_REF_MAG = 5002
-FIFF.FIFFV_COIL_CTF_REF_GRAD = 5003
-FIFF.FIFFV_COIL_CTF_OFFDIAG_REF_GRAD = 5004
-#
-# Magnes reference sensors
-#
-FIFF.FIFFV_COIL_MAGNES_REF_MAG = 4003
-FIFF.FIFFV_COIL_MAGNES_REF_GRAD = 4004
-FIFF.FIFFV_COIL_MAGNES_OFFDIAG_REF_GRAD = 4005
-#
-# BabySQUID sensors
-#
-FIFF.FIFFV_COIL_BABY_GRAD = 7001
-FIFF.FIFFV_COIL_BABY_MAG = 7002
-FIFF.FIFFV_COIL_BABY_REF_MAG = 7003
+
#
# FWD Types
#
@@ -719,40 +701,40 @@ FIFF.FIFF_UNIT_NONE = -1
#
# SI base units
#
-FIFF.FIFF_UNIT_M = 1
-FIFF.FIFF_UNIT_KG = 2
-FIFF.FIFF_UNIT_SEC = 3
-FIFF.FIFF_UNIT_A = 4
-FIFF.FIFF_UNIT_K = 5
-FIFF.FIFF_UNIT_MOL = 6
+FIFF.FIFF_UNIT_M = 1 # meter
+FIFF.FIFF_UNIT_KG = 2 # kilogram
+FIFF.FIFF_UNIT_SEC = 3 # second
+FIFF.FIFF_UNIT_A = 4 # ampere
+FIFF.FIFF_UNIT_K = 5 # Kelvin
+FIFF.FIFF_UNIT_MOL = 6 # mole
#
# SI Supplementary units
#
-FIFF.FIFF_UNIT_RAD = 7
-FIFF.FIFF_UNIT_SR = 8
+FIFF.FIFF_UNIT_RAD = 7 # radian
+FIFF.FIFF_UNIT_SR = 8 # steradian
#
# SI base candela
#
-FIFF.FIFF_UNIT_CD = 9
+FIFF.FIFF_UNIT_CD = 9 # candela
#
# SI derived units
#
-FIFF.FIFF_UNIT_HZ = 101
-FIFF.FIFF_UNIT_N = 102
-FIFF.FIFF_UNIT_PA = 103
-FIFF.FIFF_UNIT_J = 104
-FIFF.FIFF_UNIT_W = 105
-FIFF.FIFF_UNIT_C = 106
-FIFF.FIFF_UNIT_V = 107
-FIFF.FIFF_UNIT_F = 108
-FIFF.FIFF_UNIT_OHM = 109
-FIFF.FIFF_UNIT_MHO = 110
-FIFF.FIFF_UNIT_WB = 111
-FIFF.FIFF_UNIT_T = 112
-FIFF.FIFF_UNIT_H = 113
-FIFF.FIFF_UNIT_CEL = 114
-FIFF.FIFF_UNIT_LM = 115
-FIFF.FIFF_UNIT_LX = 116
+FIFF.FIFF_UNIT_HZ = 101 # hertz
+FIFF.FIFF_UNIT_N = 102 # Newton
+FIFF.FIFF_UNIT_PA = 103 # pascal
+FIFF.FIFF_UNIT_J = 104 # joule
+FIFF.FIFF_UNIT_W = 105 # watt
+FIFF.FIFF_UNIT_C = 106 # coulomb
+FIFF.FIFF_UNIT_V = 107 # volt
+FIFF.FIFF_UNIT_F = 108 # farad
+FIFF.FIFF_UNIT_OHM = 109 # ohm
+FIFF.FIFF_UNIT_MHO = 110 # one per ohm
+FIFF.FIFF_UNIT_WB = 111 # weber
+FIFF.FIFF_UNIT_T = 112 # tesla
+FIFF.FIFF_UNIT_H = 113 # Henry
+FIFF.FIFF_UNIT_CEL = 114 # celcius
+FIFF.FIFF_UNIT_LM = 115 # lumen
+FIFF.FIFF_UNIT_LX = 116 # lux
#
# Others we need
#
@@ -793,6 +775,9 @@ FIFF.FIFFV_COIL_EEG_BIPOLAR = 5 # Bipolar EEG lead
FIFF.FIFFV_COIL_DIPOLE = 200 # Time-varying dipole definition
# The coil info contains dipole location (r0) and
# direction (ex)
+FIFF.FIFFV_COIL_FNIRS_HBO = 300 # fNIRS oxyhemoglobin
+FIFF.FIFFV_COIL_FNIRS_HBR = 301 # fNIRS deoxyhemoglobin
+
FIFF.FIFFV_COIL_MCG_42 = 1000 # For testing the MCG software
FIFF.FIFFV_COIL_POINT_MAGNETOMETER = 2000 # Simple point magnetometer
@@ -814,6 +799,30 @@ FIFF.FIFFV_COIL_MAGNES_GRAD = 4002 # Magnes WH gradiometer
FIFF.FIFFV_COIL_MAGNES_R_MAG = 4003 # Magnes WH reference magnetometer
FIFF.FIFFV_COIL_MAGNES_R_GRAD_DIA = 4004 # Magnes WH reference diagonal gradioometer
FIFF.FIFFV_COIL_MAGNES_R_GRAD_OFF = 4005 # Magnes WH reference off-diagonal gradiometer
+#
+# Magnes reference sensors
+#
+FIFF.FIFFV_COIL_MAGNES_REF_MAG = 4003
+FIFF.FIFFV_COIL_MAGNES_REF_GRAD = 4004
+FIFF.FIFFV_COIL_MAGNES_OFFDIAG_REF_GRAD = 4005
+#
+# CTF coil and channel types
+#
+FIFF.FIFFV_COIL_CTF_GRAD = 5001
+FIFF.FIFFV_COIL_CTF_REF_MAG = 5002
+FIFF.FIFFV_COIL_CTF_REF_GRAD = 5003
+FIFF.FIFFV_COIL_CTF_OFFDIAG_REF_GRAD = 5004
+#
+# KIT system coil types
+#
+FIFF.FIFFV_COIL_KIT_GRAD = 6001
+FIFF.FIFFV_COIL_KIT_REF_MAG = 6002
+#
+# BabySQUID sensors
+#
+FIFF.FIFFV_COIL_BABY_GRAD = 7001
+FIFF.FIFFV_COIL_BABY_MAG = 7002
+FIFF.FIFFV_COIL_BABY_REF_MAG = 7003
# MNE RealTime
FIFF.FIFF_MNE_RT_COMMAND = 3700 # realtime command
diff --git a/mne/io/ctf/info.py b/mne/io/ctf/info.py
index 36def60..da32be7 100644
--- a/mne/io/ctf/info.py
+++ b/mne/io/ctf/info.py
@@ -120,10 +120,26 @@ def _at_origin(x):
return (np.sum(x * x) < 1e-8)
+def _check_comp_ch(cch, kind, desired=None):
+ if 'reference' in kind.lower():
+ if cch['grad_order_no'] != 0:
+ raise RuntimeError('%s channel with non-zero compensation grade %s'
+ % (kind, cch['grad_order_no']))
+ else:
+ if desired is None:
+ desired = cch['grad_order_no']
+ if cch['grad_order_no'] != desired:
+ raise RuntimeError('%s channel with inconsistent compensation '
+ 'grade %s, should be %s'
+ % (kind, cch['grad_order_no'], desired))
+ return desired
+
+
def _convert_channel_info(res4, t, use_eeg_pos):
"""Convert CTF channel information to fif format"""
nmeg = neeg = nstim = nmisc = nref = 0
chs = list()
+ this_comp = None
for k, cch in enumerate(res4['chs']):
cal = float(1. / (cch['proper_gain'] * cch['qgain']))
ch = dict(scanno=k + 1, range=1., cal=cal, loc=np.zeros(12),
@@ -189,18 +205,22 @@ def _convert_channel_info(res4, t, use_eeg_pos):
# Set the coil type
if cch['sensor_type_index'] == CTF.CTFV_REF_MAG_CH:
ch['kind'] = FIFF.FIFFV_REF_MEG_CH
+ _check_comp_ch(cch, 'Reference magnetometer')
ch['coil_type'] = FIFF.FIFFV_COIL_CTF_REF_MAG
nref += 1
ch['logno'] = nref
elif cch['sensor_type_index'] == CTF.CTFV_REF_GRAD_CH:
ch['kind'] = FIFF.FIFFV_REF_MEG_CH
if off_diag:
+ _check_comp_ch(cch, 'Reference off-diagonal gradiometer')
ch['coil_type'] = FIFF.FIFFV_COIL_CTF_OFFDIAG_REF_GRAD
else:
+ _check_comp_ch(cch, 'Reference gradiometer')
ch['coil_type'] = FIFF.FIFFV_COIL_CTF_REF_GRAD
nref += 1
ch['logno'] = nref
else:
+ this_comp = _check_comp_ch(cch, 'Gradiometer', this_comp)
ch['kind'] = FIFF.FIFFV_MEG_CH
ch['coil_type'] = FIFF.FIFFV_COIL_CTF_GRAD
nmeg += 1
diff --git a/mne/io/ctf/tests/test_ctf.py b/mne/io/ctf/tests/test_ctf.py
index 6edd6f3..7460dfe 100644
--- a/mne/io/ctf/tests/test_ctf.py
+++ b/mne/io/ctf/tests/test_ctf.py
@@ -191,6 +191,7 @@ def test_read_ctf():
assert_true(all('MISC channel' in str(ww.message) for ww in w))
assert_allclose(raw[:][0], raw_c[:][0])
raw.plot(show=False) # Test plotting with ref_meg channels.
+ assert_raises(ValueError, raw.plot, order='selection')
assert_raises(TypeError, read_raw_ctf, 1)
assert_raises(ValueError, read_raw_ctf, ctf_fname_continuous + 'foo.ds')
# test ignoring of system clock
diff --git a/mne/io/edf/edf.py b/mne/io/edf/edf.py
index d1109cd..67215dc 100644
--- a/mne/io/edf/edf.py
+++ b/mne/io/edf/edf.py
@@ -204,8 +204,8 @@ class RawEDF(_BaseRaw):
# make sure events without duration get one sample
n_stop = n_stop if n_stop > n_start else n_start + 1
if any(stim[n_start:n_stop]):
- raise NotImplementedError('EDF+ with overlapping '
- 'events not supported.')
+ warn('EDF+ with overlapping events'
+ ' are not fully supported')
stim[n_start:n_stop] = evid
data[stim_channel_idx, :] = stim[start:stop]
else:
@@ -255,10 +255,13 @@ def _parse_tal_channel(tal_channel_data):
tals = bytearray()
for s in tal_channel_data:
i = int(s)
- tals.extend([i % 256, i // 256])
+ tals.extend(np.uint8([i % 256, i // 256]))
regex_tal = '([+-]\d+\.?\d*)(\x15(\d+\.?\d*))?(\x14.*?)\x14\x00'
- tal_list = re.findall(regex_tal, tals.decode('ascii'))
+ # use of latin-1 because characters are only encoded for the first 256
+ # code points and utf-8 can triggers an "invalid continuation byte" error
+ tal_list = re.findall(regex_tal, tals.decode('latin-1'))
+
events = []
for ev in tal_list:
onset = float(ev[0])
@@ -285,10 +288,8 @@ def _get_edf_info(fname, stim_channel, annot, annotmap, eog, misc, preload):
with open(fname, 'rb') as fid:
assert(fid.tell() == 0)
- fid.seek(8)
+ fid.seek(168) # Seek 8 + 80 bytes for Subject id + 80 bytes for rec id
- fid.read(80).strip().decode() # subject id
- fid.read(80).strip().decode() # recording id
day, month, year = [int(x) for x in re.findall('(\d+)',
fid.read(8).decode())]
hour, minute, sec = [int(x) for x in re.findall('(\d+)',
@@ -457,7 +458,7 @@ def _get_edf_info(fname, stim_channel, annot, annotmap, eog, misc, preload):
# Some keys to be consistent with FIF measurement info
info['description'] = None
- info['buffer_size_sec'] = 10.
+ info['buffer_size_sec'] = 1.
edf_info['nsamples'] = int(n_records * max_samp)
# These are the conditions under which a stim channel will be interpolated
diff --git a/mne/io/edf/tests/test_edf.py b/mne/io/edf/tests/test_edf.py
index 5029476..621a502 100644
--- a/mne/io/edf/tests/test_edf.py
+++ b/mne/io/edf/tests/test_edf.py
@@ -21,10 +21,10 @@ import numpy as np
from mne import pick_types
from mne.datasets import testing
from mne.externals.six import iterbytes
-from mne.utils import _TempDir, run_tests_if_main, requires_pandas
-from mne.io import read_raw_edf, Raw
+from mne.utils import run_tests_if_main, requires_pandas
+from mne.io import read_raw_edf
from mne.io.tests.test_raw import _test_raw_reader
-import mne.io.edf.edf as edfmodule
+from mne.io.edf.edf import _parse_tal_channel
from mne.event import find_events
warnings.simplefilter('always')
@@ -43,6 +43,8 @@ edf_txt_stim_channel_path = op.join(data_dir, 'test_edf_stim_channel.txt')
data_path = testing.data_path(download=False)
edf_stim_resamp_path = op.join(data_path, 'EDF', 'test_edf_stim_resamp.edf')
+edf_overlap_annot_path = op.join(data_path, 'EDF',
+ 'test_edf_overlapping_annotations.edf')
eog = ['REOG', 'LEOG', 'IEOG']
@@ -50,7 +52,7 @@ misc = ['EXG1', 'EXG5', 'EXG8', 'M1', 'M2']
def test_bdf_data():
- """Test reading raw bdf files"""
+ """Test reading raw bdf files."""
raw_py = _test_raw_reader(read_raw_edf, input_fname=bdf_path,
montage=montage_path, eog=eog, misc=misc)
assert_true('RawEDF' in repr(raw_py))
@@ -70,16 +72,21 @@ def test_bdf_data():
assert_true((raw_py.info['chs'][63]['loc']).any())
+ at testing.requires_testing_data
+def test_edf_overlapping_annotations():
+ """Test EDF with overlapping annotations."""
+ n_warning = 2
+ with warnings.catch_warnings(record=True) as w:
+ read_raw_edf(edf_overlap_annot_path, preload=True, verbose=True)
+ assert_equal(sum('overlapping' in str(ww.message) for ww in w),
+ n_warning)
+
+
def test_edf_data():
- """Test edf files"""
+ """Test edf files."""
_test_raw_reader(read_raw_edf, input_fname=edf_path, stim_channel=None)
raw_py = read_raw_edf(edf_path, preload=True)
# Test saving and loading when annotations were parsed.
- tempdir = _TempDir()
- raw_file = op.join(tempdir, 'test-raw.fif')
- raw_py.save(raw_file, overwrite=True, buffer_size_sec=1)
- Raw(raw_file, preload=True)
-
edf_events = find_events(raw_py, output='step', shortest_event=0,
stim_channel='STI 014')
@@ -107,7 +114,7 @@ def test_edf_data():
@testing.requires_testing_data
def test_stim_channel():
- """Test reading raw edf files with stim channel"""
+ """Test reading raw edf files with stim channel."""
raw_py = read_raw_edf(edf_path, misc=range(-4, 0), stim_channel=139,
preload=True)
@@ -152,8 +159,7 @@ def test_stim_channel():
def test_parse_annotation():
- """Test parsing the tal channel"""
-
+ """Test parsing the tal channel."""
# test the parser
annot = (b'+180\x14Lights off\x14Close door\x14\x00\x00\x00\x00\x00'
b'+180\x14Lights off\x14\x00\x00\x00\x00\x00\x00\x00\x00'
@@ -164,18 +170,14 @@ def test_parse_annotation():
annot = [a for a in iterbytes(annot)]
annot[1::2] = [a * 256 for a in annot[1::2]]
tal_channel = map(sum, zip(annot[0::2], annot[1::2]))
- events = edfmodule._parse_tal_channel(tal_channel)
- assert_equal(events, [[180.0, 0, 'Lights off'],
- [180.0, 0, 'Close door'],
- [180.0, 0, 'Lights off'],
- [180.0, 0, 'Close door'],
- [3.14, 4.2, 'nothing'],
- [1800.2, 25.5, 'Apnea']])
+ assert_equal(_parse_tal_channel(tal_channel),
+ [[180.0, 0, 'Lights off'], [180.0, 0, 'Close door'],
+ [180.0, 0, 'Lights off'], [180.0, 0, 'Close door'],
+ [3.14, 4.2, 'nothing'], [1800.2, 25.5, 'Apnea']])
def test_edf_annotations():
"""Test if events are detected correctly in a typical MNE workflow."""
-
# test an actual file
raw = read_raw_edf(edf_path, preload=True)
edf_events = find_events(raw, output='step', shortest_event=0,
@@ -204,7 +206,7 @@ def test_edf_annotations():
def test_edf_stim_channel():
- """Test stim channel for edf file"""
+ """Test stim channel for edf file."""
raw = read_raw_edf(edf_stim_channel_path, preload=True,
stim_channel=-1)
true_data = np.loadtxt(edf_txt_stim_channel_path).T
@@ -222,7 +224,7 @@ def test_edf_stim_channel():
@requires_pandas
def test_to_data_frame():
- """Test edf Raw Pandas exporter"""
+ """Test edf Raw Pandas exporter."""
for path in [edf_path, bdf_path]:
raw = read_raw_edf(path, stim_channel=None, preload=True)
_, times = raw[0, :10]
diff --git a/mne/io/eeglab/eeglab.py b/mne/io/eeglab/eeglab.py
index f53e6ae..ef45c71 100644
--- a/mne/io/eeglab/eeglab.py
+++ b/mne/io/eeglab/eeglab.py
@@ -64,6 +64,9 @@ def _get_info(eeg, montage, eog=()):
# add the ch_names and info['chs'][idx]['loc']
path = None
+ if not isinstance(eeg.chanlocs, np.ndarray) and eeg.nbchan == 1:
+ eeg.chanlocs = [eeg.chanlocs]
+
if len(eeg.chanlocs) > 0:
ch_names, pos = list(), list()
kind = 'user_defined'
@@ -107,7 +110,7 @@ def _get_info(eeg, montage, eog=()):
def read_raw_eeglab(input_fname, montage=None, eog=(), event_id=None,
event_id_func='strip_to_integer', preload=False,
- verbose=None):
+ verbose=None, uint16_codec=None):
"""Read an EEGLAB .set file
Parameters
@@ -149,8 +152,14 @@ def read_raw_eeglab(input_fname, montage=None, eog=(), event_id=None,
on the hard drive (slower, requires less memory). Note that
preload=False will be effective only if the data is stored in a
separate binary file.
- verbose : bool, str, int, or None
+ verbose : bool | str | int | None
If not None, override default verbose level (see mne.verbose).
+ uint16_codec : str | None
+ If your \*.set file contains non-ascii characters, sometimes reading
+ it may fail and give rise to error message stating that "buffer is
+ too small". ``uint16_codec`` allows to specify what codec (for example:
+ 'latin1' or 'utf-8') should be used when reading character arrays and
+ can therefore help you solve this problem.
Returns
-------
@@ -167,11 +176,11 @@ def read_raw_eeglab(input_fname, montage=None, eog=(), event_id=None,
"""
return RawEEGLAB(input_fname=input_fname, montage=montage, preload=preload,
eog=eog, event_id=event_id, event_id_func=event_id_func,
- verbose=verbose)
+ verbose=verbose, uint16_codec=uint16_codec)
def read_epochs_eeglab(input_fname, events=None, event_id=None, montage=None,
- eog=(), verbose=None):
+ eog=(), verbose=None, uint16_codec=None):
"""Reader function for EEGLAB epochs files
Parameters
@@ -204,8 +213,14 @@ def read_epochs_eeglab(input_fname, events=None, event_id=None, montage=None,
Names or indices of channels that should be designated EOG channels.
If 'auto', the channel names containing ``EOG`` or ``EYE`` are used.
Defaults to empty tuple.
- verbose : bool, str, int, or None
+ verbose : bool | str | int | None
If not None, override default verbose level (see mne.verbose).
+ uint16_codec : str | None
+ If your \*.set file contains non-ascii characters, sometimes reading
+ it may fail and give rise to error message stating that "buffer is
+ too small". ``uint16_codec`` allows to specify what codec (for example:
+ 'latin1' or 'utf-8') should be used when reading character arrays and
+ can therefore help you solve this problem.
Returns
-------
@@ -222,7 +237,8 @@ def read_epochs_eeglab(input_fname, events=None, event_id=None, montage=None,
mne.Epochs : Documentation of attribute and methods.
"""
epochs = EpochsEEGLAB(input_fname=input_fname, events=events, eog=eog,
- event_id=event_id, montage=montage, verbose=verbose)
+ event_id=event_id, montage=montage, verbose=verbose,
+ uint16_codec=uint16_codec)
return epochs
@@ -266,8 +282,14 @@ class RawEEGLAB(_BaseRaw):
amount of memory). If preload is a string, preload is the file name of
a memory-mapped file which is used to store the data on the hard
drive (slower, requires less memory).
- verbose : bool, str, int, or None
+ verbose : bool | str | int | None
If not None, override default verbose level (see mne.verbose).
+ uint16_codec : str | None
+ If your \*.set file contains non-ascii characters, sometimes reading
+ it may fail and give rise to error message stating that "buffer is
+ too small". ``uint16_codec`` allows to specify what codec (for example:
+ 'latin1' or 'utf-8') should be used when reading character arrays and
+ can therefore help you solve this problem.
Returns
-------
@@ -285,14 +307,14 @@ class RawEEGLAB(_BaseRaw):
@verbose
def __init__(self, input_fname, montage, eog=(), event_id=None,
event_id_func='strip_to_integer', preload=False,
- verbose=None):
+ verbose=None, uint16_codec=None):
"""Read EEGLAB .set file.
"""
from scipy import io
basedir = op.dirname(input_fname)
_check_mat_struct(input_fname)
eeg = io.loadmat(input_fname, struct_as_record=False,
- squeeze_me=True)['EEG']
+ squeeze_me=True, uint16_codec=uint16_codec)['EEG']
if eeg.trials != 1:
raise TypeError('The number of trials is %d. It must be 1 for raw'
' files. Please use `mne.io.read_epochs_eeglab` if'
@@ -329,7 +351,10 @@ class RawEEGLAB(_BaseRaw):
'the .set file')
# can't be done in standard way with preload=True because of
# different reading path (.set file)
- n_chan, n_times = eeg.data.shape
+ if eeg.nbchan == 1 and len(eeg.data.shape) == 1:
+ n_chan, n_times = [1, eeg.data.shape[0]]
+ else:
+ n_chan, n_times = eeg.data.shape
data = np.empty((n_chan + 1, n_times), dtype=np.double)
data[:-1] = eeg.data
data *= CAL
@@ -417,8 +442,14 @@ class EpochsEEGLAB(_BaseEpochs):
Names or indices of channels that should be designated EOG channels.
If 'auto', the channel names containing ``EOG`` or ``EYE`` are used.
Defaults to empty tuple.
- verbose : bool, str, int, or None
+ verbose : bool | str | int | None
If not None, override default verbose level (see mne.verbose).
+ uint16_codec : str | None
+ If your \*.set file contains non-ascii characters, sometimes reading
+ it may fail and give rise to error message stating that "buffer is
+ too small". ``uint16_codec`` allows to specify what codec (for example:
+ 'latin1' or 'utf-8') should be used when reading character arrays and
+ can therefore help you solve this problem.
Notes
-----
@@ -431,11 +462,12 @@ class EpochsEEGLAB(_BaseEpochs):
@verbose
def __init__(self, input_fname, events=None, event_id=None, tmin=0,
baseline=None, reject=None, flat=None, reject_tmin=None,
- reject_tmax=None, montage=None, eog=(), verbose=None):
+ reject_tmax=None, montage=None, eog=(), verbose=None,
+ uint16_codec=None):
from scipy import io
_check_mat_struct(input_fname)
eeg = io.loadmat(input_fname, struct_as_record=False,
- squeeze_me=True)['EEG']
+ squeeze_me=True, uint16_codec=uint16_codec)['EEG']
if not ((events is None and event_id is None) or
(events is not None and event_id is not None)):
@@ -446,6 +478,7 @@ class EpochsEEGLAB(_BaseEpochs):
# first extract the events and construct an event_id dict
event_name, event_latencies, unique_ev = list(), list(), list()
ev_idx = 0
+ warn_multiple_events = False
for ep in eeg.epoch:
if not isinstance(ep.eventtype, string_types):
event_type = '/'.join(ep.eventtype.tolist())
@@ -453,8 +486,7 @@ class EpochsEEGLAB(_BaseEpochs):
# store latency of only first event
event_latencies.append(eeg.event[ev_idx].latency)
ev_idx += len(ep.eventtype)
- warn('An epoch has multiple events. Only the latency of '
- 'the first event will be retained.')
+ warn_multiple_events = True
else:
event_type = ep.eventtype
event_name.append(ep.eventtype)
@@ -467,6 +499,12 @@ class EpochsEEGLAB(_BaseEpochs):
# invent event dict but use id > 0 so you know its a trigger
event_id = dict((ev, idx + 1) for idx, ev
in enumerate(unique_ev))
+
+ # warn about multiple events in epoch if necessary
+ if warn_multiple_events:
+ warn('At least one epoch has multiple events. Only the latency'
+ ' of the first event will be retained.')
+
# now fill up the event array
events = np.zeros((eeg.trials, 3), dtype=int)
for idx in range(0, eeg.trials):
@@ -501,6 +539,9 @@ class EpochsEEGLAB(_BaseEpochs):
order="F")
else:
data = eeg.data
+
+ if eeg.nbchan == 1 and len(data.shape) == 2:
+ data = data[np.newaxis, :]
data = data.transpose((2, 0, 1)).astype('double')
data *= CAL
assert data.shape == (eeg.trials, eeg.nbchan, eeg.pnts)
@@ -526,17 +567,21 @@ def _read_eeglab_events(eeg, event_id=None, event_id_func='strip_to_integer'):
event_id = dict()
if isinstance(eeg.event, np.ndarray):
- types = [event.type for event in eeg.event]
+ types = [str(event.type) for event in eeg.event]
latencies = [event.latency for event in eeg.event]
else:
# only one event - TypeError: 'mat_struct' object is not iterable
- types = [eeg.event.type]
+ types = [str(eeg.event.type)]
latencies = [eeg.event.latency]
if "boundary" in types and "boundary" not in event_id:
warn("The data contains 'boundary' events, indicating data "
"discontinuities. Be cautious of filtering and epoching around "
"these events.")
+ if len(types) < 1: # if there are 0 events, we can exit here
+ logger.info('No events found, returning empty stim channel ...')
+ return np.zeros((0, 3))
+
not_in_event_id = set(x for x in types if x not in event_id)
not_purely_numeric = set(x for x in not_in_event_id if not x.isdigit())
no_numbers = set([x for x in not_purely_numeric
@@ -566,14 +611,15 @@ def _read_eeglab_events(eeg, event_id=None, event_id_func='strip_to_integer'):
pass # We're already raising warnings above, so we just drop
if len(events) < len(types):
- warn("Some event codes could not be mapped to integers. Use the "
- "`event_id` parameter to map such events to integers manually.")
- if len(events) < 1:
- warn("No events found, consider adding an `event_id`. As is, the "
- "trigger channel will consist entirely of zeros.")
- return np.zeros((0, 3))
- else:
- return np.asarray(events)
+ missings = len(types) - len(events)
+ msg = ("{0}/{1} event codes could not be mapped to integers. Use "
+ "the 'event_id' parameter to map such events manually.")
+ warn(msg.format(missings, len(types)))
+ if len(events) < 1:
+ warn("As is, the trigger channel will consist entirely of zeros.")
+ return np.zeros((0, 3))
+
+ return np.asarray(events)
def _strip_to_integer(trigger):
diff --git a/mne/io/eeglab/tests/test_eeglab.py b/mne/io/eeglab/tests/test_eeglab.py
index 6ca046a..94555ce 100644
--- a/mne/io/eeglab/tests/test_eeglab.py
+++ b/mne/io/eeglab/tests/test_eeglab.py
@@ -7,11 +7,13 @@ import shutil
import warnings
from nose.tools import assert_raises, assert_equal
+import numpy as np
from numpy.testing import assert_array_equal
from mne import write_events, read_epochs_eeglab, Epochs, find_events
from mne.io import read_raw_eeglab
from mne.io.tests.test_raw import _test_raw_reader
+from mne.io.eeglab.eeglab import _read_eeglab_events
from mne.datasets import testing
from mne.utils import _TempDir, run_tests_if_main, requires_version
@@ -52,23 +54,38 @@ def test_io_set():
raw3 = read_raw_eeglab(input_fname=raw_fname, montage=montage,
event_id=event_id)
raw4 = read_raw_eeglab(input_fname=raw_fname, montage=montage)
- Epochs(raw0, find_events(raw0), event_id)
- epochs = Epochs(raw1, find_events(raw1), event_id)
+ Epochs(raw0, find_events(raw0), event_id, add_eeg_ref=False)
+ epochs = Epochs(raw1, find_events(raw1), event_id, add_eeg_ref=False)
assert_equal(len(find_events(raw4)), 0) # no events without event_id
assert_equal(epochs["square"].average().nave, 80) # 80 with
assert_array_equal(raw0[:][0], raw1[:][0], raw2[:][0], raw3[:][0])
assert_array_equal(raw0[:][-1], raw1[:][-1], raw2[:][-1], raw3[:][-1])
assert_equal(len(w), 4)
# 1 for preload=False / str with fname_onefile, 3 for dropped events
- raw0.filter(1, None) # test that preloading works
+ raw0.filter(1, None, l_trans_bandwidth='auto', filter_length='auto',
+ phase='zero') # test that preloading works
+
+ # test that using uin16_codec does not break stuff
+ raw0 = read_raw_eeglab(input_fname=raw_fname, montage=montage,
+ event_id=event_id, preload=False,
+ uint16_codec='ascii')
+
+ # test old EEGLAB version event import
+ eeg = io.loadmat(raw_fname, struct_as_record=False,
+ squeeze_me=True)['EEG']
+ for event in eeg.event: # old version allows integer events
+ event.type = 1
+ assert_equal(_read_eeglab_events(eeg)[-1, -1], 1)
+ eeg.event = eeg.event[0] # single event
+ assert_equal(_read_eeglab_events(eeg)[-1, -1], 1)
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter('always')
epochs = read_epochs_eeglab(epochs_fname)
epochs2 = read_epochs_eeglab(epochs_fname_onefile)
- # 3 warnings for each read_epochs_eeglab because there are 3 epochs
+ # one warning for each read_epochs_eeglab because both files have epochs
# associated with multiple events
- assert_equal(len(w), 6)
+ assert_equal(len(w), 2)
assert_array_equal(epochs.get_data(), epochs2.get_data())
# test different combinations of events and event_ids
@@ -78,6 +95,7 @@ def test_io_set():
event_id = {'S255/S8': 1, 'S8': 2, 'S255/S9': 3}
epochs = read_epochs_eeglab(epochs_fname, epochs.events, event_id)
+ assert_equal(len(epochs.events), 4)
epochs = read_epochs_eeglab(epochs_fname, out_fname, event_id)
assert_raises(ValueError, read_epochs_eeglab, epochs_fname,
None, event_id)
@@ -99,6 +117,21 @@ def test_io_set():
read_raw_eeglab(input_fname=one_event_fname, montage=montage,
event_id=event_id, preload=True)
+ # test reading file with one channel
+ one_chan_fname = op.join(temp_dir, 'test_one_channel.set')
+ io.savemat(one_chan_fname, {'EEG':
+ {'trials': eeg.trials, 'srate': eeg.srate,
+ 'nbchan': 1, 'data': np.random.random((1, 3)),
+ 'epoch': eeg.epoch, 'event': eeg.epoch,
+ 'chanlocs': {'labels': 'E1', 'Y': -6.6069,
+ 'X': 6.3023, 'Z': -2.9423},
+ 'times': eeg.times[:3], 'pnts': 3}})
+ with warnings.catch_warnings(record=True) as w:
+ warnings.simplefilter('always')
+ read_raw_eeglab(input_fname=one_chan_fname, preload=True)
+ # no warning for 'no events found'
+ assert_equal(len(w), 0)
+
# test if .dat file raises an error
eeg = io.loadmat(epochs_fname, struct_as_record=False,
squeeze_me=True)['EEG']
@@ -115,6 +148,6 @@ def test_io_set():
warnings.simplefilter('always')
assert_raises(NotImplementedError, read_epochs_eeglab,
bad_epochs_fname)
- assert_equal(len(w), 3)
+ assert_equal(len(w), 1)
run_tests_if_main()
diff --git a/mne/io/fiff/raw.py b/mne/io/fiff/raw.py
index 64f588e..f917c6f 100644
--- a/mne/io/fiff/raw.py
+++ b/mne/io/fiff/raw.py
@@ -18,12 +18,11 @@ from ..meas_info import read_meas_info
from ..tree import dir_tree_find
from ..tag import read_tag, read_tag_info
from ..proj import make_eeg_average_ref_proj, _needs_eeg_average_ref_proj
-from ..compensator import get_current_comp, set_current_comp, make_compensator
-from ..base import _BaseRaw, _RawShell, _check_raw_compatibility
+from ..base import (_BaseRaw, _RawShell, _check_raw_compatibility,
+ _check_maxshield)
from ..utils import _mult_cal_one
from ...annotations import Annotations, _combine_annotations
-from ...externals.six import string_types
from ...utils import check_fname, logger, verbose, warn
@@ -32,17 +31,17 @@ class Raw(_BaseRaw):
Parameters
----------
- fnames : list, or string
- A list of the raw files to treat as a Raw instance, or a single
- raw file. For files that have automatically been split, only the
- name of the first file has to be specified. Filenames should end
+ fname : str
+ The raw file to load. For files that have automatically been split,
+ the split part will be automatically loaded. Filenames should end
with raw.fif, raw.fif.gz, raw_sss.fif, raw_sss.fif.gz,
raw_tsss.fif or raw_tsss.fif.gz.
allow_maxshield : bool | str (default False)
- allow_maxshield if True, allow loading of data that has been
- processed with Maxshield. Maxshield-processed data should generally
- not be loaded directly, but should be processed using SSS first.
- Can also be "yes" to load without eliciting a warning.
+ If True, allow loading of data that has been recorded with internal
+ active compensation (MaxShield). Data recorded with MaxShield should
+ generally not be loaded directly, but should first be processed using
+ SSS/tSSS to remove the compensation signals that may also affect brain
+ activity. Can also be "yes" to load without eliciting a warning.
preload : bool or str (default False)
Preload data into memory for data manipulation and faster indexing.
If True, the data will be preloaded into memory (fast, requires
@@ -50,19 +49,18 @@ class Raw(_BaseRaw):
file name of a memory-mapped file which is used to store the data
on the hard drive (slower, requires less memory).
proj : bool
- Apply the signal space projection (SSP) operators present in
- the file to the data. Note: Once the projectors have been
- applied, they can no longer be removed. It is usually not
- recommended to apply the projectors at this point as they are
- applied automatically later on (e.g. when computing inverse
- solutions).
+ Deprecated. Use :meth:`raw.apply_proj() <mne.io.Raw.apply_proj>`
+ instead.
compensation : None | int
- If None the compensation in the data is not modified.
- If set to n, e.g. 3, apply gradient compensation of grade n as
- for CTF systems.
+ Deprecated. Use :meth:`mne.io.Raw.apply_gradient_compensation`
+ instead.
add_eeg_ref : bool
- If True, add average EEG reference projector (if it's not already
- present).
+ If True, an EEG average reference will be added (unless one
+ already exists). The default value of True in 0.13 will change to
+ False in 0.14, and the parameter will be removed in 0.15. Use
+ :func:`mne.set_eeg_reference` instead.
+ fnames : list or str
+ Deprecated. Use :func:`mne.concatenate_raws` instead.
verbose : bool, str, int, or None
If not None, override default verbose level (see mne.verbose).
@@ -80,12 +78,27 @@ class Raw(_BaseRaw):
See above.
"""
@verbose
- def __init__(self, fnames, allow_maxshield=False, preload=False,
- proj=False, compensation=None, add_eeg_ref=True,
- verbose=None):
-
+ def __init__(self, fname, allow_maxshield=False, preload=False,
+ proj=None, compensation=None, add_eeg_ref=None,
+ fnames=None, verbose=None):
+ if not proj:
+ proj = False
+ else:
+ warn('The proj parameter has been dprecated and will be removed '
+ 'in 0.14. Use raw.apply_proj() instead.', DeprecationWarning)
+ dep = ('Supplying a list of filenames with "fnames" to the Raw class '
+ 'has been deprecated and will be removed in 0.13. Use multiple '
+ 'calls to read_raw_fif with the "fname" argument followed by '
+ 'concatenate_raws instead.')
+ if fnames is not None:
+ warn(dep, DeprecationWarning)
+ else:
+ fnames = fname
+ del fname
if not isinstance(fnames, list):
fnames = [fnames]
+ else:
+ warn(dep, DeprecationWarning)
fnames = [op.realpath(f) for f in fnames]
split_fnames = []
@@ -93,8 +106,7 @@ class Raw(_BaseRaw):
for ii, fname in enumerate(fnames):
do_check_fname = fname not in split_fnames
raw, next_fname = self._read_raw_file(fname, allow_maxshield,
- preload, compensation,
- do_check_fname)
+ preload, do_check_fname)
raws.append(raw)
if next_fname is not None:
if not op.exists(next_fname):
@@ -119,8 +131,10 @@ class Raw(_BaseRaw):
copy.deepcopy(raws[0].info), False,
[r.first_samp for r in raws], [r.last_samp for r in raws],
[r.filename for r in raws], [r._raw_extras for r in raws],
- copy.deepcopy(raws[0].comp), raws[0]._orig_comp_grade,
raws[0].orig_format, None, verbose=verbose)
+ if 'eeg' in self:
+ from ...epochs import _dep_eeg_ref
+ add_eeg_ref = _dep_eeg_ref(add_eeg_ref, True)
# combine information from each raw file to construct self
if add_eeg_ref and _needs_eeg_average_ref_proj(self.info):
@@ -140,6 +154,12 @@ class Raw(_BaseRaw):
last_samps,
first_samps,
r.info['sfreq'])
+ if compensation is not None:
+ warn('The "compensation" argument has been deprecated '
+ 'in favor of the "raw.apply_gradient_compensation" '
+ 'method and will be removed in 0.14',
+ DeprecationWarning)
+ self.apply_gradient_compensation(compensation)
if preload:
self._preload_data(preload)
else:
@@ -150,7 +170,7 @@ class Raw(_BaseRaw):
self.apply_proj()
@verbose
- def _read_raw_file(self, fname, allow_maxshield, preload, compensation,
+ def _read_raw_file(self, fname, allow_maxshield, preload,
do_check_fname=True, verbose=None):
"""Read in header information from a raw file"""
logger.info('Opening raw data file %s...' % fname)
@@ -197,22 +217,10 @@ class Raw(_BaseRaw):
raw_node = dir_tree_find(meas, FIFF.FIFFB_CONTINUOUS_DATA)
if (len(raw_node) == 0):
raw_node = dir_tree_find(meas, FIFF.FIFFB_SMSH_RAW_DATA)
- msg = ('This file contains raw Internal Active '
- 'Shielding data. It may be distorted. Elekta '
- 'recommends it be run through MaxFilter to '
- 'produce reliable results. Consider closing '
- 'the file and running MaxFilter on the data.')
if (len(raw_node) == 0):
raise ValueError('No raw data in %s' % fname)
- elif allow_maxshield:
- info['maxshield'] = True
- if not (isinstance(allow_maxshield, string_types) and
- allow_maxshield == 'yes'):
- warn(msg)
- else:
- msg += (' Use allow_maxshield=True if you are sure you'
- ' want to load the data despite this warning.')
- raise ValueError(msg)
+ _check_maxshield(allow_maxshield)
+ info['maxshield'] = True
if len(raw_node) == 1:
raw_node = raw_node[0]
@@ -325,22 +333,6 @@ class Raw(_BaseRaw):
raw._cals = cals
raw._raw_extras = raw_extras
- raw.comp = None
- raw._orig_comp_grade = None
-
- # Set up the CTF compensator
- current_comp = get_current_comp(info)
- if current_comp is not None:
- logger.info('Current compensation grade : %d' % current_comp)
-
- if compensation is not None:
- raw.comp = make_compensator(info, current_comp, compensation)
- if raw.comp is not None:
- logger.info('Appropriate compensator added to change to '
- 'grade %d.' % (compensation))
- raw._orig_comp_grade = current_comp
- set_current_comp(info, compensation)
-
logger.info(' Range : %d ... %d = %9.3f ... %9.3f secs' % (
raw.first_samp, raw.last_samp,
float(raw.first_samp) / info['sfreq'],
@@ -460,7 +452,7 @@ class Raw(_BaseRaw):
.. note:: The effect of the difference between the coil sizes on the
current estimates computed by the MNE software is very small.
- Therefore the use of this function is not mandatory.
+ Therefore the use of mne_fix_mag_coil_types is not mandatory.
"""
from ...channels import fix_mag_coil_types
fix_mag_coil_types(self.info)
@@ -473,23 +465,24 @@ def _check_entry(first, nent):
raise IOError('Could not read data, perhaps this is a corrupt file')
-def read_raw_fif(fnames, allow_maxshield=False, preload=False,
- proj=False, compensation=None, add_eeg_ref=True,
- verbose=None):
+def read_raw_fif(fname, allow_maxshield=False, preload=False,
+ proj=False, compensation=None, add_eeg_ref=None,
+ fnames=None, verbose=None):
"""Reader function for Raw FIF data
Parameters
----------
- fnames : list, or string
- A list of the raw files to treat as a Raw instance, or a single
- raw file. For files that have automatically been split, only the
- name of the first file has to be specified. Filenames should end
+ fname : str
+ The raw file to load. For files that have automatically been split,
+ the split part will be automatically loaded. Filenames should end
with raw.fif, raw.fif.gz, raw_sss.fif, raw_sss.fif.gz,
raw_tsss.fif or raw_tsss.fif.gz.
- allow_maxshield : bool, (default False)
- allow_maxshield if True, allow loading of data that has been
- processed with Maxshield. Maxshield-processed data should generally
- not be loaded directly, but should be processed using SSS first.
+ allow_maxshield : bool | str (default False)
+ If True, allow loading of data that has been recorded with internal
+ active compensation (MaxShield). Data recorded with MaxShield should
+ generally not be loaded directly, but should first be processed using
+ SSS/tSSS to remove the compensation signals that may also affect brain
+ activity. Can also be "yes" to load without eliciting a warning.
preload : bool or str (default False)
Preload data into memory for data manipulation and faster indexing.
If True, the data will be preloaded into memory (fast, requires
@@ -497,19 +490,18 @@ def read_raw_fif(fnames, allow_maxshield=False, preload=False,
file name of a memory-mapped file which is used to store the data
on the hard drive (slower, requires less memory).
proj : bool
- Apply the signal space projection (SSP) operators present in
- the file to the data. Note: Once the projectors have been
- applied, they can no longer be removed. It is usually not
- recommended to apply the projectors at this point as they are
- applied automatically later on (e.g. when computing inverse
- solutions).
+ Deprecated. Use :meth:`raw.apply_proj() <mne.io.Raw.apply_proj>`
+ instead.
compensation : None | int
- If None the compensation in the data is not modified.
- If set to n, e.g. 3, apply gradient compensation of grade n as
- for CTF systems.
+ Deprecated. Use :meth:`mne.io.Raw.apply_gradient_compensation`
+ instead.
add_eeg_ref : bool
- If True, add average EEG reference projector (if it's not already
- present).
+ If True, an EEG average reference will be added (unless one
+ already exists). The default value of True in 0.13 will change to
+ False in 0.14, and the parameter will be removed in 0.15. Use
+ :func:`mne.set_eeg_reference` instead.
+ fnames : list or str
+ Deprecated. Use :func:`mne.concatenate_raws` instead.
verbose : bool, str, int, or None
If not None, override default verbose level (see mne.verbose).
@@ -522,6 +514,6 @@ def read_raw_fif(fnames, allow_maxshield=False, preload=False,
-----
.. versionadded:: 0.9.0
"""
- return Raw(fnames=fnames, allow_maxshield=allow_maxshield,
+ return Raw(fname=fname, allow_maxshield=allow_maxshield,
preload=preload, proj=proj, compensation=compensation,
- add_eeg_ref=add_eeg_ref, verbose=verbose)
+ add_eeg_ref=add_eeg_ref, fnames=fnames, verbose=verbose)
diff --git a/mne/io/fiff/tests/test_raw_fiff.py b/mne/io/fiff/tests/test_raw_fiff.py
index 4076c56..6129266 100644
--- a/mne/io/fiff/tests/test_raw_fiff.py
+++ b/mne/io/fiff/tests/test_raw_fiff.py
@@ -3,12 +3,13 @@
#
# License: BSD (3-clause)
+from copy import deepcopy
+from functools import partial
+import glob
+import itertools as itt
import os
import os.path as op
-import glob
-from copy import deepcopy
import warnings
-import itertools as itt
import numpy as np
from numpy.testing import (assert_array_almost_equal, assert_array_equal,
@@ -17,7 +18,7 @@ from nose.tools import assert_true, assert_raises, assert_not_equal
from mne.datasets import testing
from mne.io.constants import FIFF
-from mne.io import Raw, RawArray, concatenate_raws, read_raw_fif
+from mne.io import RawArray, concatenate_raws, read_raw_fif
from mne.io.tests.test_raw import _test_concat, _test_raw_reader
from mne import (concatenate_events, find_events, equalize_channels,
compute_proj_raw, pick_types, pick_channels, create_info)
@@ -46,15 +47,13 @@ bad_file_works = op.join(base_dir, 'test_bads.txt')
bad_file_wrong = op.join(base_dir, 'test_wrong_bads.txt')
hp_fname = op.join(base_dir, 'test_chpi_raw_hp.txt')
hp_fif_fname = op.join(base_dir, 'test_chpi_raw_sss.fif')
-rng = np.random.RandomState(0)
def test_fix_types():
- """Test fixing of channel types
- """
+ """Test fixing of channel types."""
for fname, change in ((hp_fif_fname, True), (test_fif_fname, False),
(ctf_fname, False)):
- raw = Raw(fname)
+ raw = read_raw_fif(fname, add_eeg_ref=False)
mag_picks = pick_types(raw.info, meg='mag')
other_picks = np.setdiff1d(np.arange(len(raw.ch_names)), mag_picks)
# we don't actually have any files suffering from this problem, so
@@ -75,27 +74,28 @@ def test_fix_types():
def test_concat():
- """Test RawFIF concatenation
- """
+ """Test RawFIF concatenation."""
# we trim the file to save lots of memory and some time
tempdir = _TempDir()
- raw = read_raw_fif(test_fif_fname)
+ raw = read_raw_fif(test_fif_fname, add_eeg_ref=False)
raw.crop(0, 2., copy=False)
test_name = op.join(tempdir, 'test_raw.fif')
raw.save(test_name)
# now run the standard test
- _test_concat(read_raw_fif, test_name)
+ _test_concat(partial(read_raw_fif, add_eeg_ref=False), test_name)
@testing.requires_testing_data
def test_hash_raw():
- """Test hashing raw objects
- """
- raw = read_raw_fif(fif_fname)
+ """Test hashing raw objects."""
+ raw = read_raw_fif(fif_fname, add_eeg_ref=False)
assert_raises(RuntimeError, raw.__hash__)
- raw = Raw(fif_fname).crop(0, 0.5, copy=False)
+ raw = read_raw_fif(fif_fname, add_eeg_ref=False).crop(0, 0.5, copy=False)
+ raw_size = raw._size
raw.load_data()
- raw_2 = Raw(fif_fname).crop(0, 0.5, copy=False)
+ raw_load_size = raw._size
+ assert_true(raw_size < raw_load_size)
+ raw_2 = read_raw_fif(fif_fname, add_eeg_ref=False).crop(0, 0.5, copy=False)
raw_2.load_data()
assert_equal(hash(raw), hash(raw_2))
# do NOT use assert_equal here, failing output is terrible
@@ -107,21 +107,19 @@ def test_hash_raw():
@testing.requires_testing_data
def test_maxshield():
- """Test maxshield warning
- """
+ """Test maxshield warning."""
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter('always')
- Raw(ms_fname, allow_maxshield=True)
+ read_raw_fif(ms_fname, allow_maxshield=True, add_eeg_ref=False)
assert_equal(len(w), 1)
assert_true('test_raw_fiff.py' in w[0].filename)
@testing.requires_testing_data
def test_subject_info():
- """Test reading subject information
- """
+ """Test reading subject information."""
tempdir = _TempDir()
- raw = Raw(fif_fname).crop(0, 1, copy=False)
+ raw = read_raw_fif(fif_fname, add_eeg_ref=False).crop(0, 1, copy=False)
assert_true(raw.info['subject_info'] is None)
# fake some subject data
keys = ['id', 'his_id', 'last_name', 'first_name', 'birthday', 'sex',
@@ -133,30 +131,17 @@ def test_subject_info():
raw.info['subject_info'] = subject_info
out_fname = op.join(tempdir, 'test_subj_info_raw.fif')
raw.save(out_fname, overwrite=True)
- raw_read = Raw(out_fname)
+ raw_read = read_raw_fif(out_fname, add_eeg_ref=False)
for key in keys:
assert_equal(subject_info[key], raw_read.info['subject_info'][key])
assert_equal(raw.info['meas_date'], raw_read.info['meas_date'])
- raw.anonymize()
- raw.save(out_fname, overwrite=True)
- raw_read = Raw(out_fname)
- for this_raw in (raw, raw_read):
- assert_true(this_raw.info.get('subject_info') is None)
- assert_equal(this_raw.info['meas_date'], [0, 0])
- assert_equal(raw.info['file_id']['secs'], 0)
- assert_equal(raw.info['meas_id']['secs'], 0)
- # When we write out with raw.save, these get overwritten with the
- # new save time
- assert_true(raw_read.info['file_id']['secs'] > 0)
- assert_true(raw_read.info['meas_id']['secs'] > 0)
@testing.requires_testing_data
def test_copy_append():
- """Test raw copying and appending combinations
- """
- raw = Raw(fif_fname, preload=True).copy()
- raw_full = Raw(fif_fname)
+ """Test raw copying and appending combinations."""
+ raw = read_raw_fif(fif_fname, preload=True, add_eeg_ref=False).copy()
+ raw_full = read_raw_fif(fif_fname, add_eeg_ref=False)
raw_full.append(raw)
data = raw_full[:, :][0]
assert_equal(data.shape[1], 2 * raw._data.shape[1])
@@ -165,20 +150,19 @@ def test_copy_append():
@slow_test
@testing.requires_testing_data
def test_rank_estimation():
- """Test raw rank estimation
- """
+ """Test raw rank estimation."""
iter_tests = itt.product(
[fif_fname, hp_fif_fname], # sss
['norm', dict(mag=1e11, grad=1e9, eeg=1e5)]
)
for fname, scalings in iter_tests:
- raw = Raw(fname)
+ raw = read_raw_fif(fname, add_eeg_ref=False)
(_, picks_meg), (_, picks_eeg) = _picks_by_type(raw.info,
meg_combined=True)
n_meg = len(picks_meg)
n_eeg = len(picks_eeg)
- raw = Raw(fname, preload=True)
+ raw = read_raw_fif(fname, preload=True, add_eeg_ref=False)
if 'proc_history' not in raw.info:
expected_rank = n_meg + n_eeg
else:
@@ -190,7 +174,7 @@ def test_rank_estimation():
scalings=scalings),
n_eeg)
- raw = Raw(fname, preload=False)
+ raw = read_raw_fif(fname, preload=False, add_eeg_ref=False)
if 'sss' in fname:
tstart, tstop = 0., 30.
raw.add_proj(compute_proj_raw(raw))
@@ -208,14 +192,14 @@ def test_rank_estimation():
@testing.requires_testing_data
def test_output_formats():
- """Test saving and loading raw data using multiple formats
- """
+ """Test saving and loading raw data using multiple formats."""
tempdir = _TempDir()
formats = ['short', 'int', 'single', 'double']
tols = [1e-4, 1e-7, 1e-7, 1e-15]
# let's fake a raw file with different formats
- raw = Raw(test_fif_fname).crop(0, 1, copy=False)
+ raw = read_raw_fif(test_fif_fname,
+ add_eeg_ref=False).crop(0, 1, copy=False)
temp_file = op.join(tempdir, 'raw.fif')
for ii, (fmt, tol) in enumerate(zip(formats, tols)):
@@ -223,13 +207,14 @@ def test_output_formats():
if ii > 0:
assert_raises(IOError, raw.save, temp_file, fmt=fmt)
raw.save(temp_file, fmt=fmt, overwrite=True)
- raw2 = Raw(temp_file)
+ raw2 = read_raw_fif(temp_file, add_eeg_ref=False)
raw2_data = raw2[:, :][0]
assert_allclose(raw2_data, raw[:, :][0], rtol=tol, atol=1e-25)
assert_equal(raw2.orig_format, fmt)
def _compare_combo(raw, new, times, n_times):
+ """Compare data."""
for ti in times: # let's do a subset of points for speed
orig = raw[:, ti % n_times][0]
# these are almost_equals because of possible dtype differences
@@ -239,11 +224,10 @@ def _compare_combo(raw, new, times, n_times):
@slow_test
@testing.requires_testing_data
def test_multiple_files():
- """Test loading multiple files simultaneously
- """
+ """Test loading multiple files simultaneously."""
# split file
tempdir = _TempDir()
- raw = Raw(fif_fname).crop(0, 10, copy=False)
+ raw = read_raw_fif(fif_fname, add_eeg_ref=False).crop(0, 10, copy=False)
raw.load_data()
raw.load_data() # test no operation
split_size = 3. # in seconds
@@ -260,7 +244,7 @@ def test_multiple_files():
for ri in range(len(tmins) - 1, -1, -1):
fname = op.join(tempdir, 'test_raw_split-%d_raw.fif' % ri)
raw.save(fname, tmin=tmins[ri], tmax=tmaxs[ri])
- raws[ri] = Raw(fname)
+ raws[ri] = read_raw_fif(fname, add_eeg_ref=False)
assert_equal(len(raws[ri].times),
int(round((tmaxs[ri] - tmins[ri]) *
raw.info['sfreq'])) + 1) # + 1 b/c inclusive
@@ -276,7 +260,7 @@ def test_multiple_files():
assert_equal(raw.first_samp, all_raw_1.first_samp)
assert_equal(raw.last_samp, all_raw_1.last_samp)
assert_allclose(raw[:, :][0], all_raw_1[:, :][0])
- raws[0] = Raw(fname)
+ raws[0] = read_raw_fif(fname, add_eeg_ref=False)
all_raw_2 = concatenate_raws(raws, preload=True)
assert_allclose(raw[:, :][0], all_raw_2[:, :][0])
@@ -287,64 +271,81 @@ def test_multiple_files():
assert_array_equal(events1, events3)
# test various methods of combining files
- raw = Raw(fif_fname, preload=True)
+ raw = read_raw_fif(fif_fname, preload=True, add_eeg_ref=False)
n_times = raw.n_times
# make sure that all our data match
times = list(range(0, 2 * n_times, 999))
# add potentially problematic points
times.extend([n_times - 1, n_times, 2 * n_times - 1])
- raw_combo0 = Raw([fif_fname, fif_fname], preload=True)
+ raw_combo0 = concatenate_raws([read_raw_fif(f, add_eeg_ref=False)
+ for f in [fif_fname, fif_fname]],
+ preload=True)
_compare_combo(raw, raw_combo0, times, n_times)
- raw_combo = Raw([fif_fname, fif_fname], preload=False)
+ raw_combo = concatenate_raws([read_raw_fif(f, add_eeg_ref=False)
+ for f in [fif_fname, fif_fname]],
+ preload=False)
_compare_combo(raw, raw_combo, times, n_times)
- raw_combo = Raw([fif_fname, fif_fname], preload='memmap8.dat')
+ raw_combo = concatenate_raws([read_raw_fif(f, add_eeg_ref=False)
+ for f in [fif_fname, fif_fname]],
+ preload='memmap8.dat')
_compare_combo(raw, raw_combo, times, n_times)
- assert_raises(ValueError, Raw, [fif_fname, ctf_fname])
- assert_raises(ValueError, Raw, [fif_fname, fif_bad_marked_fname])
+ with warnings.catch_warnings(record=True): # deprecated
+ assert_raises(ValueError, read_raw_fif, [fif_fname, ctf_fname])
+ assert_raises(ValueError, read_raw_fif,
+ [fif_fname, fif_bad_marked_fname])
assert_equal(raw[:, :][0].shape[1] * 2, raw_combo0[:, :][0].shape[1])
assert_equal(raw_combo0[:, :][0].shape[1], raw_combo0.n_times)
# with all data preloaded, result should be preloaded
- raw_combo = Raw(fif_fname, preload=True)
- raw_combo.append(Raw(fif_fname, preload=True))
+ raw_combo = read_raw_fif(fif_fname, preload=True, add_eeg_ref=False)
+ raw_combo.append(read_raw_fif(fif_fname, preload=True, add_eeg_ref=False))
assert_true(raw_combo.preload is True)
assert_equal(raw_combo.n_times, raw_combo._data.shape[1])
_compare_combo(raw, raw_combo, times, n_times)
# with any data not preloaded, don't set result as preloaded
- raw_combo = concatenate_raws([Raw(fif_fname, preload=True),
- Raw(fif_fname, preload=False)])
+ raw_combo = concatenate_raws([read_raw_fif(fif_fname, preload=True,
+ add_eeg_ref=False),
+ read_raw_fif(fif_fname, preload=False,
+ add_eeg_ref=False)])
assert_true(raw_combo.preload is False)
assert_array_equal(find_events(raw_combo, stim_channel='STI 014'),
find_events(raw_combo0, stim_channel='STI 014'))
_compare_combo(raw, raw_combo, times, n_times)
# user should be able to force data to be preloaded upon concat
- raw_combo = concatenate_raws([Raw(fif_fname, preload=False),
- Raw(fif_fname, preload=True)],
+ raw_combo = concatenate_raws([read_raw_fif(fif_fname, preload=False,
+ add_eeg_ref=False),
+ read_raw_fif(fif_fname, preload=True,
+ add_eeg_ref=False)],
preload=True)
assert_true(raw_combo.preload is True)
_compare_combo(raw, raw_combo, times, n_times)
- raw_combo = concatenate_raws([Raw(fif_fname, preload=False),
- Raw(fif_fname, preload=True)],
+ raw_combo = concatenate_raws([read_raw_fif(fif_fname, preload=False,
+ add_eeg_ref=False),
+ read_raw_fif(fif_fname, preload=True,
+ add_eeg_ref=False)],
preload='memmap3.dat')
_compare_combo(raw, raw_combo, times, n_times)
- raw_combo = concatenate_raws([Raw(fif_fname, preload=True),
- Raw(fif_fname, preload=True)],
- preload='memmap4.dat')
+ raw_combo = concatenate_raws([
+ read_raw_fif(fif_fname, preload=True, add_eeg_ref=False),
+ read_raw_fif(fif_fname, preload=True, add_eeg_ref=False)],
+ preload='memmap4.dat')
_compare_combo(raw, raw_combo, times, n_times)
- raw_combo = concatenate_raws([Raw(fif_fname, preload=False),
- Raw(fif_fname, preload=False)],
- preload='memmap5.dat')
+ raw_combo = concatenate_raws([
+ read_raw_fif(fif_fname, preload=False, add_eeg_ref=False),
+ read_raw_fif(fif_fname, preload=False, add_eeg_ref=False)],
+ preload='memmap5.dat')
_compare_combo(raw, raw_combo, times, n_times)
# verify that combining raws with different projectors throws an exception
raw.add_proj([], remove_existing=True)
- assert_raises(ValueError, raw.append, Raw(fif_fname, preload=True))
+ assert_raises(ValueError, raw.append,
+ read_raw_fif(fif_fname, preload=True, add_eeg_ref=False))
# now test event treatment for concatenated raw files
events = [find_events(raw, stim_channel='STI 014'),
@@ -362,10 +363,9 @@ def test_multiple_files():
@testing.requires_testing_data
def test_split_files():
- """Test writing and reading of split raw files
- """
+ """Test writing and reading of split raw files."""
tempdir = _TempDir()
- raw_1 = Raw(fif_fname, preload=True)
+ raw_1 = read_raw_fif(fif_fname, preload=True, add_eeg_ref=False)
# Test a very close corner case
raw_crop = raw_1.copy().crop(0, 1., copy=False)
@@ -373,7 +373,7 @@ def test_split_files():
split_fname = op.join(tempdir, 'split_raw.fif')
raw_1.save(split_fname, buffer_size_sec=1.0, split_size='10MB')
- raw_2 = Raw(split_fname)
+ raw_2 = read_raw_fif(split_fname, add_eeg_ref=False)
assert_allclose(raw_2.info['buffer_size_sec'], 1., atol=1e-2) # samp rate
data_1, times_1 = raw_1[:, :]
data_2, times_2 = raw_2[:, :]
@@ -385,7 +385,7 @@ def test_split_files():
fnames.extend(sorted(glob.glob(op.join(tempdir, 'split_raw-*.fif'))))
with warnings.catch_warnings(record=True):
warnings.simplefilter('always')
- raw_2 = Raw(fnames)
+ raw_2 = read_raw_fif(fnames, add_eeg_ref=False) # deprecated list
data_2, times_2 = raw_2[:, :]
assert_array_equal(data_1, data_2)
assert_array_equal(times_1, times_2)
@@ -411,7 +411,7 @@ def test_split_files():
# at a time so we hit GH#3210 if we aren't careful
raw_crop.save(split_fname, split_size='4.5MB',
buffer_size_sec=1., overwrite=True)
- raw_read = read_raw_fif(split_fname)
+ raw_read = read_raw_fif(split_fname, add_eeg_ref=False)
assert_allclose(raw_crop[:][0], raw_read[:][0], atol=1e-20)
# Check our buffer arithmetic
@@ -419,13 +419,13 @@ def test_split_files():
# 1 buffer required
raw_crop = raw_1.copy().crop(0, 1, copy=False)
raw_crop.save(split_fname, buffer_size_sec=1., overwrite=True)
- raw_read = read_raw_fif(split_fname)
+ raw_read = read_raw_fif(split_fname, add_eeg_ref=False)
assert_equal(len(raw_read._raw_extras[0]), 1)
assert_equal(raw_read._raw_extras[0][0]['nsamp'], 301)
assert_allclose(raw_crop[:][0], raw_read[:][0])
# 2 buffers required
raw_crop.save(split_fname, buffer_size_sec=0.5, overwrite=True)
- raw_read = read_raw_fif(split_fname)
+ raw_read = read_raw_fif(split_fname, add_eeg_ref=False)
assert_equal(len(raw_read._raw_extras[0]), 2)
assert_equal(raw_read._raw_extras[0][0]['nsamp'], 151)
assert_equal(raw_read._raw_extras[0][1]['nsamp'], 150)
@@ -434,7 +434,7 @@ def test_split_files():
raw_crop.save(split_fname,
buffer_size_sec=1. - 1.01 / raw_crop.info['sfreq'],
overwrite=True)
- raw_read = read_raw_fif(split_fname)
+ raw_read = read_raw_fif(split_fname, add_eeg_ref=False)
assert_equal(len(raw_read._raw_extras[0]), 2)
assert_equal(raw_read._raw_extras[0][0]['nsamp'], 300)
assert_equal(raw_read._raw_extras[0][1]['nsamp'], 1)
@@ -442,7 +442,7 @@ def test_split_files():
raw_crop.save(split_fname,
buffer_size_sec=1. - 2.01 / raw_crop.info['sfreq'],
overwrite=True)
- raw_read = read_raw_fif(split_fname)
+ raw_read = read_raw_fif(split_fname, add_eeg_ref=False)
assert_equal(len(raw_read._raw_extras[0]), 2)
assert_equal(raw_read._raw_extras[0][0]['nsamp'], 299)
assert_equal(raw_read._raw_extras[0][1]['nsamp'], 2)
@@ -450,13 +450,12 @@ def test_split_files():
def test_load_bad_channels():
- """Test reading/writing of bad channels
- """
+ """Test reading/writing of bad channels."""
tempdir = _TempDir()
# Load correctly marked file (manually done in mne_process_raw)
- raw_marked = Raw(fif_bad_marked_fname)
+ raw_marked = read_raw_fif(fif_bad_marked_fname, add_eeg_ref=False)
correct_bads = raw_marked.info['bads']
- raw = Raw(test_fif_fname)
+ raw = read_raw_fif(test_fif_fname, add_eeg_ref=False)
# Make sure it starts clean
assert_array_equal(raw.info['bads'], [])
@@ -464,7 +463,7 @@ def test_load_bad_channels():
raw.load_bad_channels(bad_file_works)
# Write it out, read it in, and check
raw.save(op.join(tempdir, 'foo_raw.fif'))
- raw_new = Raw(op.join(tempdir, 'foo_raw.fif'))
+ raw_new = read_raw_fif(op.join(tempdir, 'foo_raw.fif'), add_eeg_ref=False)
assert_equal(correct_bads, raw_new.info['bads'])
# Reset it
raw.info['bads'] = []
@@ -480,36 +479,37 @@ def test_load_bad_channels():
assert_equal(n_found, 1) # there could be other irrelevant errors
# write it out, read it in, and check
raw.save(op.join(tempdir, 'foo_raw.fif'), overwrite=True)
- raw_new = Raw(op.join(tempdir, 'foo_raw.fif'))
+ raw_new = read_raw_fif(op.join(tempdir, 'foo_raw.fif'),
+ add_eeg_ref=False)
assert_equal(correct_bads, raw_new.info['bads'])
# Check that bad channels are cleared
raw.load_bad_channels(None)
raw.save(op.join(tempdir, 'foo_raw.fif'), overwrite=True)
- raw_new = Raw(op.join(tempdir, 'foo_raw.fif'))
+ raw_new = read_raw_fif(op.join(tempdir, 'foo_raw.fif'), add_eeg_ref=False)
assert_equal([], raw_new.info['bads'])
@slow_test
@testing.requires_testing_data
def test_io_raw():
- """Test IO for raw data (Neuromag + CTF + gz)
- """
+ """Test IO for raw data (Neuromag + CTF + gz)."""
+ rng = np.random.RandomState(0)
tempdir = _TempDir()
# test unicode io
for chars in [b'\xc3\xa4\xc3\xb6\xc3\xa9', b'a']:
- with Raw(fif_fname) as r:
+ with read_raw_fif(fif_fname, add_eeg_ref=False) as r:
assert_true('Raw' in repr(r))
assert_true(op.basename(fif_fname) in repr(r))
desc1 = r.info['description'] = chars.decode('utf-8')
temp_file = op.join(tempdir, 'raw.fif')
r.save(temp_file, overwrite=True)
- with Raw(temp_file) as r2:
+ with read_raw_fif(temp_file, add_eeg_ref=False) as r2:
desc2 = r2.info['description']
assert_equal(desc1, desc2)
# Let's construct a simple test for IO first
- raw = Raw(fif_fname).crop(0, 3.5, copy=False)
+ raw = read_raw_fif(fif_fname, add_eeg_ref=False).crop(0, 3.5, copy=False)
raw.load_data()
# put in some data that we know the values of
data = rng.randn(raw._data.shape[0], raw._data.shape[1])
@@ -518,7 +518,7 @@ def test_io_raw():
fname = op.join(tempdir, 'test_copy_raw.fif')
raw.save(fname, buffer_size_sec=1.0)
# read it in, make sure the whole thing matches
- raw = Raw(fname)
+ raw = read_raw_fif(fname, add_eeg_ref=False)
assert_allclose(data, raw[:, :][0], rtol=1e-6, atol=1e-20)
# let's read portions across the 1-sec tag boundary, too
inds = raw.time_as_index([1.75, 2.25])
@@ -530,7 +530,7 @@ def test_io_raw():
fnames_out = ['raw.fif', 'raw.fif.gz', 'raw.fif']
for fname_in, fname_out in zip(fnames_in, fnames_out):
fname_out = op.join(tempdir, fname_out)
- raw = Raw(fname_in)
+ raw = read_raw_fif(fname_in, add_eeg_ref=False)
nchan = raw.info['nchan']
ch_names = raw.info['ch_names']
@@ -552,7 +552,7 @@ def test_io_raw():
# Writing with drop_small_buffer True
raw.save(fname_out, picks, tmin=0, tmax=4, buffer_size_sec=3,
drop_small_buffer=True, overwrite=True)
- raw2 = Raw(fname_out)
+ raw2 = read_raw_fif(fname_out, add_eeg_ref=False)
sel = pick_channels(raw2.ch_names, meg_ch_names)
data2, times2 = raw2[sel, :]
@@ -564,7 +564,7 @@ def test_io_raw():
if fname_in == fif_fname or fname_in == fif_fname + '.gz':
assert_equal(len(raw.info['dig']), 146)
- raw2 = Raw(fname_out)
+ raw2 = read_raw_fif(fname_out, add_eeg_ref=False)
sel = pick_channels(raw2.ch_names, meg_ch_names)
data2, times2 = raw2[sel, :]
@@ -602,19 +602,19 @@ def test_io_raw():
warnings.simplefilter("always")
raw_badname = op.join(tempdir, 'test-bad-name.fif.gz')
raw.save(raw_badname)
- Raw(raw_badname)
+ read_raw_fif(raw_badname, add_eeg_ref=False)
assert_naming(w, 'test_raw_fiff.py', 2)
@testing.requires_testing_data
def test_io_complex():
- """Test IO with complex data types
- """
+ """Test IO with complex data types."""
rng = np.random.RandomState(0)
tempdir = _TempDir()
dtypes = [np.complex64, np.complex128]
- raw = _test_raw_reader(Raw, fnames=fif_fname)
+ raw = _test_raw_reader(partial(read_raw_fif, add_eeg_ref=False),
+ fname=fif_fname)
picks = np.arange(5)
start, stop = raw.time_as_index([0, 5])
@@ -635,12 +635,13 @@ def test_io_complex():
# warning gets thrown on every instance b/c simplifilter('always')
assert_equal(len(w), 1)
- raw2 = Raw(op.join(tempdir, 'raw.fif'))
+ raw2 = read_raw_fif(op.join(tempdir, 'raw.fif'), add_eeg_ref=False)
raw2_data, _ = raw2[picks, :]
n_samp = raw2_data.shape[1]
assert_allclose(raw2_data[:, :n_samp], raw_cp._data[picks, :n_samp])
# with preloading
- raw2 = Raw(op.join(tempdir, 'raw.fif'), preload=True)
+ raw2 = read_raw_fif(op.join(tempdir, 'raw.fif'), preload=True,
+ add_eeg_ref=False)
raw2_data, _ = raw2[picks, :]
n_samp = raw2_data.shape[1]
assert_allclose(raw2_data[:, :n_samp], raw_cp._data[picks, :n_samp])
@@ -648,10 +649,9 @@ def test_io_complex():
@testing.requires_testing_data
def test_getitem():
- """Test getitem/indexing of Raw
- """
+ """Test getitem/indexing of Raw."""
for preload in [False, True, 'memmap.dat']:
- raw = Raw(fif_fname, preload=preload)
+ raw = read_raw_fif(fif_fname, preload=preload, add_eeg_ref=False)
data, times = raw[0, :]
data1, times1 = raw[0]
assert_array_equal(data, data1)
@@ -671,11 +671,12 @@ def test_getitem():
@testing.requires_testing_data
def test_proj():
- """Test SSP proj operations
- """
+ """Test SSP proj operations."""
tempdir = _TempDir()
for proj in [True, False]:
- raw = Raw(fif_fname, preload=False, proj=proj)
+ raw = read_raw_fif(fif_fname, preload=False, add_eeg_ref=False)
+ if proj:
+ raw.apply_proj()
assert_true(all(p['active'] == proj for p in raw.info['projs']))
data, times = raw[0:2, :]
@@ -701,23 +702,27 @@ def test_proj():
# test apply_proj() with and without preload
for preload in [True, False]:
- raw = Raw(fif_fname, preload=preload, proj=False)
+ raw = read_raw_fif(fif_fname, preload=preload, proj=False,
+ add_eeg_ref=False)
data, times = raw[:, 0:2]
raw.apply_proj()
data_proj_1 = np.dot(raw._projector, data)
# load the file again without proj
- raw = Raw(fif_fname, preload=preload, proj=False)
+ raw = read_raw_fif(fif_fname, preload=preload, proj=False,
+ add_eeg_ref=False)
# write the file with proj. activated, make sure proj has been applied
raw.save(op.join(tempdir, 'raw.fif'), proj=True, overwrite=True)
- raw2 = Raw(op.join(tempdir, 'raw.fif'), proj=False)
+ raw2 = read_raw_fif(op.join(tempdir, 'raw.fif'), proj=False,
+ add_eeg_ref=False)
data_proj_2, _ = raw2[:, 0:2]
assert_allclose(data_proj_1, data_proj_2)
assert_true(all(p['active'] for p in raw2.info['projs']))
# read orig file with proj. active
- raw2 = Raw(fif_fname, preload=preload, proj=True)
+ raw2 = read_raw_fif(fif_fname, preload=preload, add_eeg_ref=False)
+ raw2.apply_proj()
data_proj_2, _ = raw2[:, 0:2]
assert_allclose(data_proj_1, data_proj_2)
assert_true(all(p['active'] for p in raw2.info['projs']))
@@ -730,23 +735,25 @@ def test_proj():
tempdir = _TempDir()
out_fname = op.join(tempdir, 'test_raw.fif')
- raw = read_raw_fif(test_fif_fname, preload=True).crop(0, 0.002, copy=False)
+ raw = read_raw_fif(test_fif_fname, preload=True,
+ add_eeg_ref=False).crop(0, 0.002, copy=False)
raw.pick_types(meg=False, eeg=True)
raw.info['projs'] = [raw.info['projs'][-1]]
raw._data.fill(0)
raw._data[-1] = 1.
raw.save(out_fname)
- raw = read_raw_fif(out_fname, proj=True, preload=False)
+ raw = read_raw_fif(out_fname, preload=False, add_eeg_ref=False)
+ raw.apply_proj()
assert_allclose(raw[:, :][0][:1], raw[0, :][0])
@testing.requires_testing_data
def test_preload_modify():
- """Test preloading and modifying data
- """
+ """Test preloading and modifying data."""
tempdir = _TempDir()
+ rng = np.random.RandomState(0)
for preload in [False, True, 'memmap.dat']:
- raw = Raw(fif_fname, preload=preload)
+ raw = read_raw_fif(fif_fname, preload=preload, add_eeg_ref=False)
nsamp = raw.last_samp - raw.first_samp + 1
picks = pick_types(raw.info, meg='grad', exclude='bads')
@@ -764,7 +771,7 @@ def test_preload_modify():
tmp_fname = op.join(tempdir, 'raw.fif')
raw.save(tmp_fname, overwrite=True)
- raw_new = Raw(tmp_fname)
+ raw_new = read_raw_fif(tmp_fname, add_eeg_ref=False)
data_new, _ = raw_new[picks, :nsamp / 2]
assert_allclose(data, data_new)
@@ -773,21 +780,22 @@ def test_preload_modify():
@slow_test
@testing.requires_testing_data
def test_filter():
- """Test filtering (FIR and IIR) and Raw.apply_function interface
- """
- raw = Raw(fif_fname).crop(0, 7, copy=False)
+ """Test filtering (FIR and IIR) and Raw.apply_function interface."""
+ raw = read_raw_fif(fif_fname, add_eeg_ref=False).crop(0, 7, copy=False)
raw.load_data()
- sig_dec = 11
sig_dec_notch = 12
sig_dec_notch_fit = 12
picks_meg = pick_types(raw.info, meg=True, exclude='bads')
picks = picks_meg[:4]
- filter_params = dict(picks=picks, n_jobs=2)
- raw_lp = raw.copy().filter(0., 4.0 - 0.25, **filter_params)
- raw_hp = raw.copy().filter(8.0 + 0.25, None, **filter_params)
- raw_bp = raw.copy().filter(4.0 + 0.25, 8.0 - 0.25, **filter_params)
- raw_bs = raw.copy().filter(8.0 + 0.25, 4.0 - 0.25, **filter_params)
+ trans = 2.0
+ filter_params = dict(picks=picks, filter_length='auto',
+ h_trans_bandwidth=trans, l_trans_bandwidth=trans,
+ phase='zero', fir_window='hamming')
+ raw_lp = raw.copy().filter(None, 8.0, **filter_params)
+ raw_hp = raw.copy().filter(16.0, None, **filter_params)
+ raw_bp = raw.copy().filter(8.0 + trans, 16.0 - trans, **filter_params)
+ raw_bs = raw.copy().filter(16.0, 8.0, **filter_params)
data, _ = raw[picks, :]
@@ -796,11 +804,14 @@ def test_filter():
bp_data, _ = raw_bp[picks, :]
bs_data, _ = raw_bs[picks, :]
- assert_array_almost_equal(data, lp_data + bp_data + hp_data, sig_dec)
- assert_array_almost_equal(data, bp_data + bs_data, sig_dec)
+ tols = dict(atol=1e-20, rtol=1e-5)
+ assert_allclose(bs_data, lp_data + hp_data, **tols)
+ assert_allclose(data, lp_data + bp_data + hp_data, **tols)
+ assert_allclose(data, bp_data + bs_data, **tols)
- filter_params_iir = dict(picks=picks, n_jobs=2, method='iir')
- raw_lp_iir = raw.copy().filter(0., 4.0, **filter_params_iir)
+ filter_params_iir = dict(picks=picks, n_jobs=2, method='iir',
+ iir_params=dict(output='ba'))
+ raw_lp_iir = raw.copy().filter(None, 4.0, **filter_params_iir)
raw_hp_iir = raw.copy().filter(8.0, None, **filter_params_iir)
raw_bp_iir = raw.copy().filter(4.0, 8.0, **filter_params_iir)
del filter_params_iir
@@ -808,8 +819,7 @@ def test_filter():
hp_data_iir, _ = raw_hp_iir[picks, :]
bp_data_iir, _ = raw_bp_iir[picks, :]
summation = lp_data_iir + hp_data_iir + bp_data_iir
- assert_array_almost_equal(data[:, 100:-100], summation[:, 100:-100],
- sig_dec)
+ assert_array_almost_equal(data[:, 100:-100], summation[:, 100:-100], 11)
# make sure we didn't touch other channels
data, _ = raw[picks_meg[4:], :]
@@ -820,7 +830,7 @@ def test_filter():
# ... and that inplace changes are inplace
raw_copy = raw.copy()
- raw_copy.filter(None, 20., picks=picks, n_jobs=2)
+ raw_copy.filter(None, 20., n_jobs=2, **filter_params)
assert_true(raw._data[0, 0] != raw_copy._data[0, 0])
assert_equal(raw.copy().filter(None, 20., **filter_params)._data,
raw_copy._data)
@@ -828,10 +838,11 @@ def test_filter():
# do a very simple check on line filtering
with warnings.catch_warnings(record=True):
warnings.simplefilter('always')
- raw_bs = raw.copy().filter(60.0 + 0.5, 60.0 - 0.5, **filter_params)
+ raw_bs = raw.copy().filter(60.0 + trans, 60.0 - trans, **filter_params)
data_bs, _ = raw_bs[picks, :]
raw_notch = raw.copy().notch_filter(
- 60.0, picks=picks, n_jobs=2, method='fft')
+ 60.0, picks=picks, n_jobs=2, method='fir', filter_length='auto',
+ trans_bandwidth=2 * trans)
data_notch, _ = raw_notch[picks, :]
assert_array_almost_equal(data_bs, data_notch, sig_dec_notch)
@@ -842,26 +853,58 @@ def test_filter():
data, _ = raw[picks, :]
assert_array_almost_equal(data, data_notch, sig_dec_notch_fit)
+ # filter should set the "lowpass" and "highpass" parameters
+ raw = RawArray(np.random.randn(3, 1000),
+ create_info(3, 1000., ['eeg'] * 2 + ['stim']))
+ raw.info['lowpass'] = raw.info['highpass'] = None
+ for kind in ('none', 'lowpass', 'highpass', 'bandpass', 'bandstop'):
+ print(kind)
+ h_freq = l_freq = None
+ if kind in ('lowpass', 'bandpass'):
+ h_freq = 70
+ if kind in ('highpass', 'bandpass'):
+ l_freq = 30
+ if kind == 'bandstop':
+ l_freq, h_freq = 70, 30
+ assert_true(raw.info['lowpass'] is None)
+ assert_true(raw.info['highpass'] is None)
+ kwargs = dict(l_trans_bandwidth=20, h_trans_bandwidth=20,
+ filter_length='auto', phase='zero', fir_window='hann')
+ raw_filt = raw.copy().filter(l_freq, h_freq, picks=np.arange(1),
+ **kwargs)
+ assert_true(raw.info['lowpass'] is None)
+ assert_true(raw.info['highpass'] is None)
+ raw_filt = raw.copy().filter(l_freq, h_freq, **kwargs)
+ wanted_h = h_freq if kind != 'bandstop' else None
+ wanted_l = l_freq if kind != 'bandstop' else None
+ assert_equal(raw_filt.info['lowpass'], wanted_h)
+ assert_equal(raw_filt.info['highpass'], wanted_l)
+ # Using all data channels should still set the params (GH#3259)
+ raw_filt = raw.copy().filter(l_freq, h_freq, picks=np.arange(2),
+ **kwargs)
+ assert_equal(raw_filt.info['lowpass'], wanted_h)
+ assert_equal(raw_filt.info['highpass'], wanted_l)
+
def test_filter_picks():
- """Test filtering default channel picks"""
- ch_types = ['mag', 'grad', 'eeg', 'seeg', 'misc', 'stim', 'ecog']
+ """Test filtering default channel picks."""
+ ch_types = ['mag', 'grad', 'eeg', 'seeg', 'misc', 'stim', 'ecog', 'hbo',
+ 'hbr']
info = create_info(ch_names=ch_types, ch_types=ch_types, sfreq=256)
raw = RawArray(data=np.zeros((len(ch_types), 1000)), info=info)
- # -- Deal with meg mag grad exception
+ # -- Deal with meg mag grad and fnirs exceptions
ch_types = ('misc', 'stim', 'meg', 'eeg', 'seeg', 'ecog')
# -- Filter data channels
- for ch_type in ('mag', 'grad', 'eeg', 'seeg', 'ecog'):
+ for ch_type in ('mag', 'grad', 'eeg', 'seeg', 'ecog', 'hbo', 'hbr'):
picks = dict((ch, ch == ch_type) for ch in ch_types)
picks['meg'] = ch_type if ch_type in ('mag', 'grad') else False
+ picks['fnirs'] = ch_type if ch_type in ('hbo', 'hbr') else False
raw_ = raw.copy().pick_types(**picks)
- # Avoid RuntimeWarning due to Attenuation
- with warnings.catch_warnings(record=True) as w:
- warnings.simplefilter('always')
- raw_.filter(10, 30)
- assert_true(any(['Attenuation' in str(ww.message) for ww in w]))
+ raw_.filter(10, 30, l_trans_bandwidth='auto',
+ h_trans_bandwidth='auto', filter_length='auto',
+ phase='zero', fir_window='hamming')
# -- Error if no data channel
for ch_type in ('misc', 'stim'):
@@ -872,10 +915,10 @@ def test_filter_picks():
@testing.requires_testing_data
def test_crop():
- """Test cropping raw files
- """
+ """Test cropping raw files."""
# split a concatenated file to test a difficult case
- raw = Raw([fif_fname, fif_fname], preload=False)
+ raw = concatenate_raws([read_raw_fif(f, add_eeg_ref=False)
+ for f in [fif_fname, fif_fname]])
split_size = 10. # in seconds
sfreq = raw.info['sfreq']
nsamp = (raw.last_samp - raw.first_samp + 1)
@@ -923,10 +966,9 @@ def test_crop():
@testing.requires_testing_data
def test_resample():
- """Test resample (with I/O and multiple files)
- """
+ """Test resample (with I/O and multiple files)."""
tempdir = _TempDir()
- raw = Raw(fif_fname).crop(0, 3, copy=False)
+ raw = read_raw_fif(fif_fname, add_eeg_ref=False).crop(0, 3, copy=False)
raw.load_data()
raw_resamp = raw.copy()
sfreq = raw.info['sfreq']
@@ -934,7 +976,8 @@ def test_resample():
raw_resamp.resample(sfreq * 2, n_jobs=2, npad='auto')
assert_equal(raw_resamp.n_times, len(raw_resamp.times))
raw_resamp.save(op.join(tempdir, 'raw_resamp-raw.fif'))
- raw_resamp = Raw(op.join(tempdir, 'raw_resamp-raw.fif'), preload=True)
+ raw_resamp = read_raw_fif(op.join(tempdir, 'raw_resamp-raw.fif'),
+ preload=True, add_eeg_ref=False)
assert_equal(sfreq, raw_resamp.info['sfreq'] / 2)
assert_equal(raw.n_times, raw_resamp.n_times / 2)
assert_equal(raw_resamp._data.shape[1], raw_resamp.n_times)
@@ -1037,24 +1080,26 @@ def test_resample():
@testing.requires_testing_data
def test_hilbert():
- """Test computation of analytic signal using hilbert
- """
- raw = Raw(fif_fname, preload=True)
+ """Test computation of analytic signal using hilbert."""
+ raw = read_raw_fif(fif_fname, preload=True, add_eeg_ref=False)
picks_meg = pick_types(raw.info, meg=True, exclude='bads')
picks = picks_meg[:4]
raw_filt = raw.copy()
- raw_filt.filter(10, 20)
+ raw_filt.filter(10, 20, picks=picks, l_trans_bandwidth='auto',
+ h_trans_bandwidth='auto', filter_length='auto',
+ phase='zero', fir_window='blackman')
raw_filt_2 = raw_filt.copy()
raw2 = raw.copy()
raw3 = raw.copy()
- raw.apply_hilbert(picks)
- raw2.apply_hilbert(picks, envelope=True, n_jobs=2)
+ raw.apply_hilbert(picks, n_fft='auto')
+ raw2.apply_hilbert(picks, n_fft='auto', envelope=True)
# Test custom n_fft
- raw_filt.apply_hilbert(picks)
- raw_filt_2.apply_hilbert(picks, n_fft=raw_filt_2.n_times + 1000)
+ raw_filt.apply_hilbert(picks, n_fft='auto')
+ n_fft = 2 ** int(np.ceil(np.log2(raw_filt_2.n_times + 1000)))
+ raw_filt_2.apply_hilbert(picks, n_fft=n_fft)
assert_equal(raw_filt._data.shape, raw_filt_2._data.shape)
assert_allclose(raw_filt._data[:, 50:-50], raw_filt_2._data[:, 50:-50],
atol=1e-13, rtol=1e-2)
@@ -1067,9 +1112,8 @@ def test_hilbert():
@testing.requires_testing_data
def test_raw_copy():
- """Test Raw copy
- """
- raw = Raw(fif_fname, preload=True)
+ """Test Raw copy."""
+ raw = read_raw_fif(fif_fname, preload=True, add_eeg_ref=False)
data, _ = raw[:, :]
copied = raw.copy()
copied_data, _ = copied[:, :]
@@ -1077,7 +1121,7 @@ def test_raw_copy():
assert_equal(sorted(raw.__dict__.keys()),
sorted(copied.__dict__.keys()))
- raw = Raw(fif_fname, preload=False)
+ raw = read_raw_fif(fif_fname, preload=False, add_eeg_ref=False)
data, _ = raw[:, :]
copied = raw.copy()
copied_data, _ = copied[:, :]
@@ -1088,8 +1132,8 @@ def test_raw_copy():
@requires_pandas
def test_to_data_frame():
- """Test raw Pandas exporter"""
- raw = Raw(test_fif_fname, preload=True)
+ """Test raw Pandas exporter."""
+ raw = read_raw_fif(test_fif_fname, preload=True, add_eeg_ref=False)
_, times = raw[0, :10]
df = raw.to_data_frame()
assert_true((df.columns == raw.ch_names).all())
@@ -1101,10 +1145,11 @@ def test_to_data_frame():
def test_add_channels():
- """Test raw splitting / re-appending channel types
- """
- raw = Raw(test_fif_fname).crop(0, 1, copy=False).load_data()
- raw_nopre = Raw(test_fif_fname, preload=False)
+ """Test raw splitting / re-appending channel types."""
+ rng = np.random.RandomState(0)
+ raw = read_raw_fif(test_fif_fname,
+ add_eeg_ref=False).crop(0, 1, copy=False).load_data()
+ raw_nopre = read_raw_fif(test_fif_fname, preload=False, add_eeg_ref=False)
raw_eeg_meg = raw.copy().pick_types(meg=True, eeg=True)
raw_eeg = raw.copy().pick_types(meg=False, eeg=True)
raw_meg = raw.copy().pick_types(meg=True, eeg=False)
@@ -1124,7 +1169,7 @@ def test_add_channels():
# Testing force updates
raw_arr_info = create_info(['1', '2'], raw_meg.info['sfreq'], 'eeg')
orig_head_t = raw_arr_info['dev_head_t']
- raw_arr = np.random.randn(2, raw_eeg.n_times)
+ raw_arr = rng.randn(2, raw_eeg.n_times)
raw_arr = RawArray(raw_arr, raw_arr_info)
# This should error because of conflicts in Info
assert_raises(ValueError, raw_meg.copy().add_channels, [raw_arr])
@@ -1145,93 +1190,193 @@ def test_add_channels():
@testing.requires_testing_data
-def test_raw_time_as_index():
- """ Test time as index conversion"""
- raw = Raw(fif_fname, preload=True)
- with warnings.catch_warnings(record=True): # deprecation
- first_samp = raw.time_as_index([0], True)[0]
- assert_equal(raw.first_samp, -first_samp)
-
-
- at testing.requires_testing_data
def test_save():
- """ Test saving raw"""
+ """Test saving raw."""
tempdir = _TempDir()
- raw = Raw(fif_fname, preload=False)
+ raw = read_raw_fif(fif_fname, preload=False, add_eeg_ref=False)
# can't write over file being read
assert_raises(ValueError, raw.save, fif_fname)
- raw = Raw(fif_fname, preload=True)
+ raw = read_raw_fif(fif_fname, preload=True, add_eeg_ref=False)
# can't overwrite file without overwrite=True
assert_raises(IOError, raw.save, fif_fname)
# test abspath support and annotations
- annot = Annotations([10], [10], ['test'], raw.info['meas_date'])
+ sfreq = raw.info['sfreq']
+ annot = Annotations([10], [5], ['test'],
+ raw.info['meas_date'] + raw.first_samp / sfreq)
raw.annotations = annot
new_fname = op.join(op.abspath(op.curdir), 'break-raw.fif')
raw.save(op.join(tempdir, new_fname), overwrite=True)
- new_raw = Raw(op.join(tempdir, new_fname), preload=False)
+ new_raw = read_raw_fif(op.join(tempdir, new_fname), preload=False,
+ add_eeg_ref=False)
assert_raises(ValueError, new_raw.save, new_fname)
assert_array_equal(annot.onset, new_raw.annotations.onset)
assert_array_equal(annot.duration, new_raw.annotations.duration)
assert_array_equal(annot.description, new_raw.annotations.description)
assert_equal(annot.orig_time, new_raw.annotations.orig_time)
+
+ # test that annotations are in sync after cropping and concatenating
+ annot = Annotations([5., 11., 15.], [2., 1., 3.], ['test', 'test', 'test'])
+ raw.annotations = annot
+ with warnings.catch_warnings(record=True) as w:
+ r1 = raw.copy().crop(2.5, 7.5)
+ r2 = raw.copy().crop(12.5, 17.5)
+ r3 = raw.copy().crop(10., 12.)
+ assert_true(all('data range' in str(ww.message) for ww in w))
+ raw = concatenate_raws([r1, r2, r3]) # segments reordered
+ onsets = raw.annotations.onset
+ durations = raw.annotations.duration
+ # 2*5s clips combined with annotations at 2.5s + 2s clip, annotation at 1s
+ assert_array_almost_equal([2.5, 7.5, 11.], onsets, decimal=2)
+ assert_array_almost_equal([2., 2.5, 1.], durations, decimal=2)
+
+ # test annotation clipping
+ annot = Annotations([0., raw.times[-1]], [2., 2.], 'test',
+ raw.info['meas_date'] + raw.first_samp / sfreq - 1.)
+ with warnings.catch_warnings(record=True) as w: # outside range
+ raw.annotations = annot
+ assert_true(all('data range' in str(ww.message) for ww in w))
+ assert_array_almost_equal(raw.annotations.duration, [1., 1.], decimal=3)
+
# make sure we can overwrite the file we loaded when preload=True
- new_raw = Raw(op.join(tempdir, new_fname), preload=True)
+ new_raw = read_raw_fif(op.join(tempdir, new_fname), preload=True,
+ add_eeg_ref=False)
new_raw.save(op.join(tempdir, new_fname), overwrite=True)
os.remove(new_fname)
@testing.requires_testing_data
def test_with_statement():
- """ Test with statement """
+ """Test with statement."""
for preload in [True, False]:
- with Raw(fif_fname, preload=preload) as raw_:
+ with read_raw_fif(fif_fname, preload=preload,
+ add_eeg_ref=False) as raw_:
print(raw_)
def test_compensation_raw():
- """Test Raw compensation
- """
+ """Test Raw compensation."""
tempdir = _TempDir()
- raw1 = Raw(ctf_comp_fname, compensation=None)
- assert_true(raw1.comp is None)
- data1, times1 = raw1[:, :]
- raw2 = Raw(ctf_comp_fname, compensation=3)
- data2, times2 = raw2[:, :]
- assert_true(raw2.comp is None) # unchanged (data come with grade 3)
- assert_array_equal(times1, times2)
- assert_array_equal(data1, data2)
- raw3 = Raw(ctf_comp_fname, compensation=1)
- data3, times3 = raw3[:, :]
- assert_true(raw3.comp is not None)
- assert_array_equal(times1, times3)
- # make sure it's different with a different compensation:
- assert_true(np.mean(np.abs(data1 - data3)) > 1e-12)
- assert_raises(ValueError, Raw, ctf_comp_fname, compensation=33)
+ raw_3 = read_raw_fif(ctf_comp_fname, add_eeg_ref=False)
+ assert_equal(raw_3.compensation_grade, 3)
+ data_3, times = raw_3[:, :]
+
+ # data come with grade 3
+ for ii in range(2):
+ raw_3_new = raw_3.copy()
+ if ii == 0:
+ raw_3_new.load_data()
+ raw_3_new.apply_gradient_compensation(3)
+ assert_equal(raw_3_new.compensation_grade, 3)
+ data_new, times_new = raw_3_new[:, :]
+ assert_array_equal(times, times_new)
+ assert_array_equal(data_3, data_new)
+ # deprecated way
+ preload = True if ii == 0 else False
+ raw_3_new = read_raw_fif(ctf_comp_fname, compensation=3,
+ preload=preload, verbose='error',
+ add_eeg_ref=False)
+ assert_equal(raw_3_new.compensation_grade, 3)
+ data_new, times_new = raw_3_new[:, :]
+ assert_array_equal(times, times_new)
+ assert_array_equal(data_3, data_new)
+
+ # change to grade 0
+ raw_0 = raw_3.copy().apply_gradient_compensation(0)
+ assert_equal(raw_0.compensation_grade, 0)
+ data_0, times_new = raw_0[:, :]
+ assert_array_equal(times, times_new)
+ assert_true(np.mean(np.abs(data_0 - data_3)) > 1e-12)
+ # change to grade 1
+ raw_1 = raw_0.copy().apply_gradient_compensation(1)
+ assert_equal(raw_1.compensation_grade, 1)
+ data_1, times_new = raw_1[:, :]
+ assert_array_equal(times, times_new)
+ assert_true(np.mean(np.abs(data_1 - data_3)) > 1e-12)
+ assert_raises(ValueError, read_raw_fif, ctf_comp_fname, compensation=33,
+ verbose='error', add_eeg_ref=False)
+ raw_bad = raw_0.copy()
+ raw_bad.add_proj(compute_proj_raw(raw_0, duration=0.5, verbose='error'))
+ raw_bad.apply_proj()
+ assert_raises(RuntimeError, raw_bad.apply_gradient_compensation, 1)
+ # with preload
+ tols = dict(rtol=1e-12, atol=1e-25)
+ raw_1_new = raw_3.copy().load_data().apply_gradient_compensation(1)
+ assert_equal(raw_1_new.compensation_grade, 1)
+ data_1_new, times_new = raw_1_new[:, :]
+ assert_array_equal(times, times_new)
+ assert_true(np.mean(np.abs(data_1_new - data_3)) > 1e-12)
+ assert_allclose(data_1, data_1_new, **tols)
+ # deprecated way
+ for preload in (True, False):
+ raw_1_new = read_raw_fif(ctf_comp_fname, compensation=1,
+ verbose='error', preload=preload,
+ add_eeg_ref=False)
+ assert_equal(raw_1_new.compensation_grade, 1)
+ data_1_new, times_new = raw_1_new[:, :]
+ assert_array_equal(times, times_new)
+ assert_true(np.mean(np.abs(data_1_new - data_3)) > 1e-12)
+ assert_allclose(data_1, data_1_new, **tols)
+ # change back
+ raw_3_new = raw_1.copy().apply_gradient_compensation(3)
+ data_3_new, times_new = raw_3_new[:, :]
+ assert_allclose(data_3, data_3_new, **tols)
+ raw_3_new = raw_1.copy().load_data().apply_gradient_compensation(3)
+ data_3_new, times_new = raw_3_new[:, :]
+ assert_allclose(data_3, data_3_new, **tols)
+
+ for load in (False, True):
+ for raw in (raw_0, raw_1):
+ raw_3_new = raw.copy()
+ if load:
+ raw_3_new.load_data()
+ raw_3_new.apply_gradient_compensation(3)
+ assert_equal(raw_3_new.compensation_grade, 3)
+ data_3_new, times_new = raw_3_new[:, :]
+ assert_array_equal(times, times_new)
+ assert_true(np.mean(np.abs(data_3_new - data_1)) > 1e-12)
+ assert_allclose(data_3, data_3_new, **tols)
# Try IO with compensation
temp_file = op.join(tempdir, 'raw.fif')
-
- raw1.save(temp_file, overwrite=True)
- raw4 = Raw(temp_file)
- data4, times4 = raw4[:, :]
- assert_array_equal(times1, times4)
- assert_array_equal(data1, data4)
+ raw_3.save(temp_file, overwrite=True)
+ for preload in (True, False):
+ raw_read = read_raw_fif(temp_file, preload=preload, add_eeg_ref=False)
+ assert_equal(raw_read.compensation_grade, 3)
+ data_read, times_new = raw_read[:, :]
+ assert_array_equal(times, times_new)
+ assert_allclose(data_3, data_read, **tols)
+ raw_read.apply_gradient_compensation(1)
+ data_read, times_new = raw_read[:, :]
+ assert_array_equal(times, times_new)
+ assert_allclose(data_1, data_read, **tols)
# Now save the file that has modified compensation
- # and make sure we can the same data as input ie. compensation
- # is undone
- raw3.save(temp_file, overwrite=True)
- raw5 = Raw(temp_file)
- data5, times5 = raw5[:, :]
- assert_array_equal(times1, times5)
- assert_allclose(data1, data5, rtol=1e-12, atol=1e-22)
+ # and make sure the compensation is the same as it was,
+ # but that we can undo it
+
+ # These channels have norm 1e-11/1e-12, so atol=1e-18 isn't awesome,
+ # but it's due to the single precision of the info['comps'] leading
+ # to inexact inversions with saving/loading (casting back to single)
+ # in between (e.g., 1->3->1 will degrade like this)
+ looser_tols = dict(rtol=1e-6, atol=1e-18)
+ raw_1.save(temp_file, overwrite=True)
+ for preload in (True, False):
+ raw_read = read_raw_fif(temp_file, preload=preload, verbose=True,
+ add_eeg_ref=False)
+ assert_equal(raw_read.compensation_grade, 1)
+ data_read, times_new = raw_read[:, :]
+ assert_array_equal(times, times_new)
+ assert_allclose(data_1, data_read, **looser_tols)
+ raw_read.apply_gradient_compensation(3, verbose=True)
+ data_read, times_new = raw_read[:, :]
+ assert_array_equal(times, times_new)
+ assert_allclose(data_3, data_read, **looser_tols)
@requires_mne
def test_compensation_raw_mne():
- """Test Raw compensation by comparing with MNE
- """
+ """Test Raw compensation by comparing with MNE-C."""
tempdir = _TempDir()
def compensate_mne(fname, grad):
@@ -1239,19 +1384,27 @@ def test_compensation_raw_mne():
cmd = ['mne_process_raw', '--raw', fname, '--save', tmp_fname,
'--grad', str(grad), '--projoff', '--filteroff']
run_subprocess(cmd)
- return Raw(tmp_fname, preload=True)
+ return read_raw_fif(tmp_fname, preload=True, add_eeg_ref=False)
for grad in [0, 2, 3]:
- raw_py = Raw(ctf_comp_fname, preload=True, compensation=grad)
+ with warnings.catch_warnings(record=True): # deprecated param
+ raw_py = read_raw_fif(ctf_comp_fname, preload=True,
+ compensation=grad, add_eeg_ref=False)
raw_c = compensate_mne(ctf_comp_fname, grad)
assert_allclose(raw_py._data, raw_c._data, rtol=1e-6, atol=1e-17)
+ assert_equal(raw_py.info['nchan'], raw_c.info['nchan'])
+ for ch_py, ch_c in zip(raw_py.info['chs'], raw_c.info['chs']):
+ for key in ('ch_name', 'coil_type', 'scanno', 'logno', 'unit',
+ 'coord_frame', 'kind'):
+ assert_equal(ch_py[key], ch_c[key])
+ for key in ('loc', 'unit_mul', 'range', 'cal'):
+ assert_allclose(ch_py[key], ch_c[key])
@testing.requires_testing_data
def test_drop_channels_mixin():
- """Test channels-dropping functionality
- """
- raw = Raw(fif_fname, preload=True)
+ """Test channels-dropping functionality."""
+ raw = read_raw_fif(fif_fname, preload=True, add_eeg_ref=False)
drop_ch = raw.ch_names[:3]
ch_names = raw.ch_names[3:]
@@ -1269,11 +1422,10 @@ def test_drop_channels_mixin():
@testing.requires_testing_data
def test_pick_channels_mixin():
- """Test channel-picking functionality
- """
+ """Test channel-picking functionality."""
# preload is True
- raw = Raw(fif_fname, preload=True)
+ raw = read_raw_fif(fif_fname, preload=True, add_eeg_ref=False)
ch_names = raw.ch_names[:3]
ch_names_orig = raw.ch_names
@@ -1288,16 +1440,15 @@ def test_pick_channels_mixin():
assert_equal(len(ch_names), raw._data.shape[0])
assert_raises(ValueError, raw.pick_channels, ch_names[0])
- raw = Raw(fif_fname, preload=False)
+ raw = read_raw_fif(fif_fname, preload=False, add_eeg_ref=False)
assert_raises(RuntimeError, raw.pick_channels, ch_names)
assert_raises(RuntimeError, raw.drop_channels, ch_names)
@testing.requires_testing_data
def test_equalize_channels():
- """Test equalization of channels
- """
- raw1 = Raw(fif_fname, preload=True)
+ """Test equalization of channels."""
+ raw1 = read_raw_fif(fif_fname, preload=True, add_eeg_ref=False)
raw2 = raw1.copy()
ch_names = raw1.ch_names[2:]
diff --git a/mne/io/kit/constants.py b/mne/io/kit/constants.py
index 012f99a..168d28e 100644
--- a/mne/io/kit/constants.py
+++ b/mne/io/kit/constants.py
@@ -71,11 +71,12 @@ KIT_NY.HPFS = [0, 1, 3]
KIT_NY.LPFS = [10, 20, 50, 100, 200, 500, 1000, 2000]
-# Maryland-system channel information
-# Virtually the same as the NY-system except new ADC circa July 2014
+# University of Maryland - system channel information
+# Virtually the same as the NY-system except new ADC in July 2014
# 16-bit A-to-D converter, one bit for signed integer. range +/- 32768
-KIT_MD = Bunch(**KIT_NY)
-KIT_MD.DYNAMIC_RANGE = 2 ** 15
+KIT_UMD = KIT_NY
+KIT_UMD_2014 = Bunch(**KIT_UMD)
+KIT_UMD_2014.DYNAMIC_RANGE = 2 ** 15
# AD-system channel information
@@ -107,17 +108,41 @@ KIT_AD.LPFS = [10, 20, 50, 100, 200, 500, 1000, 10000]
# KIT recording system is encoded in the SQD file as integer:
-KIT_CONSTANTS = {32: KIT_NY, # NYU-NY, July 7, 2008 -
- 33: KIT_NY, # NYU-NY, January 24, 2009 -
- 34: KIT_NY, # NYU-NY, January 22, 2010 -
- # 440 NYU-AD, initial launch May 20, 2011 -
- 441: KIT_AD, # NYU-AD more channels July 11, 2012 -
- 442: KIT_AD, # NYU-AD move to NYUAD campus Nov 20, 2014 -
- 51: KIT_NY, # UMD
- 52: KIT_MD, # UMD update to 16 bit ADC, July 4, 2014 -
- 53: KIT_MD} # UMD December 4, 2014 -
-
-SYSNAMES = {33: 'NYU 160ch System since Jan24 2009',
- 34: 'NYU 160ch System since Jan24 2009',
- 441: "New York University Abu Dhabi",
- 442: "New York University Abu Dhabi"}
+KIT.SYSTEM_NYU_2008 = 32 # NYU-NY, July 7, 2008 -
+KIT.SYSTEM_NYU_2009 = 33 # NYU-NY, January 24, 2009 -
+KIT.SYSTEM_NYU_2010 = 34 # NYU-NY, January 22, 2010 -
+KIT.SYSTEM_NYUAD_2011 = 440 # NYU-AD initial launch May 20, 2011 -
+KIT.SYSTEM_NYUAD_2012 = 441 # NYU-AD more channels July 11, 2012 -
+KIT.SYSTEM_NYUAD_2014 = 442 # NYU-AD move to NYUAD campus Nov 20, 2014 -
+KIT.SYSTEM_UMD_2004 = 51 # UMD Marie Mount Hall, October 1, 2004 -
+KIT.SYSTEM_UMD_2014_07 = 52 # UMD update to 16 bit ADC, July 4, 2014 -
+KIT.SYSTEM_UMD_2014_12 = 53 # UMD December 4, 2014 -
+
+KIT_CONSTANTS = {KIT.SYSTEM_NYU_2008: KIT_NY,
+ KIT.SYSTEM_NYU_2009: KIT_NY,
+ KIT.SYSTEM_NYU_2010: KIT_NY,
+ KIT.SYSTEM_NYUAD_2011: KIT_AD,
+ KIT.SYSTEM_NYUAD_2012: KIT_AD,
+ KIT.SYSTEM_NYUAD_2014: KIT_AD,
+ KIT.SYSTEM_UMD_2004: KIT_UMD,
+ KIT.SYSTEM_UMD_2014_07: KIT_UMD_2014,
+ KIT.SYSTEM_UMD_2014_12: KIT_UMD_2014}
+
+KIT_LAYOUT = {KIT.SYSTEM_NYU_2008: 'KIT-157',
+ KIT.SYSTEM_NYU_2009: 'KIT-157',
+ KIT.SYSTEM_NYU_2010: 'KIT-157',
+ KIT.SYSTEM_NYUAD_2011: 'KIT-AD',
+ KIT.SYSTEM_NYUAD_2012: 'KIT-AD',
+ KIT.SYSTEM_NYUAD_2014: 'KIT-AD',
+ KIT.SYSTEM_UMD_2004: None,
+ KIT.SYSTEM_UMD_2014_07: None,
+ KIT.SYSTEM_UMD_2014_12: 'KIT-UMD-3'}
+
+# Names stored along with ID in SQD files
+SYSNAMES = {KIT.SYSTEM_NYU_2009: 'NYU 160ch System since Jan24 2009',
+ KIT.SYSTEM_NYU_2010: 'NYU 160ch System since Jan24 2009',
+ KIT.SYSTEM_NYUAD_2012: "New York University Abu Dhabi",
+ KIT.SYSTEM_NYUAD_2014: "New York University Abu Dhabi",
+ KIT.SYSTEM_UMD_2004: "University of Maryland",
+ KIT.SYSTEM_UMD_2014_07: "University of Maryland",
+ KIT.SYSTEM_UMD_2014_12: "University of Maryland"}
diff --git a/mne/io/kit/coreg.py b/mne/io/kit/coreg.py
index 48b56cd..9634b0f 100644
--- a/mne/io/kit/coreg.py
+++ b/mne/io/kit/coreg.py
@@ -43,7 +43,7 @@ def read_mrk(fname):
pts.append(np.fromfile(fid, dtype='d', count=3))
mrk_points = np.array(pts)
elif ext == '.txt':
- mrk_points = _read_dig_points(fname)
+ mrk_points = _read_dig_points(fname, unit='m')
elif ext == '.pickled':
with open(fname, 'rb') as fid:
food = pickle.load(fid)
@@ -53,9 +53,8 @@ def read_mrk(fname):
err = ("%r does not contain marker points." % fname)
raise ValueError(err)
else:
- err = ('KIT marker file must be *.sqd, *.txt or *.pickled, '
- 'not *%s.' % ext)
- raise ValueError(err)
+ raise ValueError('KIT marker file must be *.sqd, *.mrk, *.txt or '
+ '*.pickled, *%s is not supported.' % ext)
# check output
mrk_points = np.asarray(mrk_points)
diff --git a/mne/io/kit/kit.py b/mne/io/kit/kit.py
index 62edec9..8e83e39 100644
--- a/mne/io/kit/kit.py
+++ b/mne/io/kit/kit.py
@@ -18,7 +18,7 @@ from scipy import linalg
from ..pick import pick_types
from ...coreg import fit_matched_points, _decimate_points
from ...utils import verbose, logger, warn
-from ...transforms import (apply_trans, als_ras_trans, als_ras_trans_mm,
+from ...transforms import (apply_trans, als_ras_trans,
get_ras_to_neuromag_trans, Transform)
from ..base import _BaseRaw
from ..utils import _mult_cal_one
@@ -179,8 +179,7 @@ class RawKIT(_BaseRaw):
if stim is not None:
if isinstance(stim, str):
- picks = pick_types(info, meg=False, ref_meg=False,
- misc=True, exclude=[])[:8]
+ picks = _default_stim_chs(info)
if stim == '<':
stim = picks[::-1]
elif stim == '>':
@@ -240,6 +239,7 @@ class RawKIT(_BaseRaw):
data_offset = unpack('i', fid.read(KIT.INT))[0]
pointer = start * nchan * KIT.SHORT
fid.seek(data_offset + pointer)
+ stim = self._raw_extras[fi]['stim']
for blk_start in np.arange(0, data_left, blk_size) // nchan:
blk_size = min(blk_size, data_left - blk_start * nchan)
block = np.fromfile(fid, dtype='h', count=blk_size)
@@ -248,30 +248,44 @@ class RawKIT(_BaseRaw):
data_view = data[:, blk_start:blk_stop]
block *= conv_factor[:, np.newaxis]
- # Create a synthetic channel
- if self._raw_extras[fi]['stim'] is not None:
- trig_chs = block[self._raw_extras[fi]['stim'], :]
- if self._raw_extras[fi]['slope'] == '+':
- trig_chs = trig_chs > self._raw_extras[0]['stimthresh']
- elif self._raw_extras[fi]['slope'] == '-':
- trig_chs = trig_chs < self._raw_extras[0]['stimthresh']
- else:
- raise ValueError("slope needs to be '+' or '-'")
- # trigger value
- if self._raw_extras[0]['stim_code'] == 'binary':
- ntrigchan = len(self._raw_extras[0]['stim'])
- trig_vals = np.array(2 ** np.arange(ntrigchan),
- ndmin=2).T
- else:
- trig_vals = np.reshape(self._raw_extras[0]['stim'],
- (-1, 1))
- trig_chs = trig_chs * trig_vals
- stim_ch = np.array(trig_chs.sum(axis=0), ndmin=2)
+ # Create a synthetic stim channel
+ if stim is not None:
+ params = self._raw_extras[fi]
+ stim_ch = _make_stim_channel(block[stim, :],
+ params['slope'],
+ params['stimthresh'],
+ params['stim_code'], stim)
block = np.vstack((block, stim_ch))
+
_mult_cal_one(data_view, block, idx, None, mult)
# cals are all unity, so can be ignored
+def _default_stim_chs(info):
+ """Default stim channels for SQD files"""
+ return pick_types(info, meg=False, ref_meg=False, misc=True,
+ exclude=[])[:8]
+
+
+def _make_stim_channel(trigger_chs, slope, threshold, stim_code,
+ trigger_values):
+ """Create synthetic stim channel from multiple trigger channels"""
+ if slope == '+':
+ trig_chs_bin = trigger_chs > threshold
+ elif slope == '-':
+ trig_chs_bin = trigger_chs < threshold
+ else:
+ raise ValueError("slope needs to be '+' or '-'")
+ # trigger value
+ if stim_code == 'binary':
+ trigger_values = 2 ** np.arange(len(trigger_chs))
+ elif stim_code != 'channel':
+ raise ValueError("stim_code must be 'binary' or 'channel', got %s" %
+ repr(stim_code))
+ trig_chs = trig_chs_bin * trigger_values[:, np.newaxis]
+ return np.array(trig_chs.sum(axis=0), ndmin=2)
+
+
class EpochsKIT(_BaseEpochs):
"""Epochs Array object from KIT SQD file
@@ -504,8 +518,8 @@ def _set_dig_kit(mrk, elp, hsp):
if isinstance(mrk, string_types):
mrk = read_mrk(mrk)
- hsp = apply_trans(als_ras_trans_mm, hsp)
- elp = apply_trans(als_ras_trans_mm, elp)
+ hsp = apply_trans(als_ras_trans, hsp)
+ elp = apply_trans(als_ras_trans, elp)
mrk = apply_trans(als_ras_trans, mrk)
nasion, lpa, rpa = elp[:3]
@@ -558,6 +572,7 @@ def get_kit_info(rawfile):
"contact the MNE-Python developers."
% (sysname, sysid))
KIT_SYS = KIT_CONSTANTS[sysid]
+ logger.info("KIT-System ID %i: %s" % (sysid, sysname))
if sysid in SYSNAMES:
if sysname != SYSNAMES[sysid]:
warn("KIT file %s has system-name %r, expected %r"
@@ -659,7 +674,7 @@ def get_kit_info(rawfile):
info = _empty_info(float(sqd['sfreq']))
info.update(meas_date=int(time.time()), lowpass=sqd['lowpass'],
highpass=sqd['highpass'], filename=rawfile,
- buffer_size_sec=1.)
+ buffer_size_sec=1., kit_system_id=sysid)
# Creates a list of dicts of meg channels for raw.info
logger.info('Setting channel info structure...')
diff --git a/mne/io/kit/tests/data/test.elp b/mne/io/kit/tests/data/test.elp
new file mode 100644
index 0000000..9e76e07
--- /dev/null
+++ b/mne/io/kit/tests/data/test.elp
@@ -0,0 +1,37 @@
+3 2
+//Probe file
+//Minor revision number
+1
+//ProbeName
+%N Name
+//Probe type, number of sensors
+0 5
+//Position of fiducials X+, Y+, Y- on the subject
+%F 0.11056 -5.421e-19 0
+%F -0.00021075 0.080793 -7.5894e-19
+%F 0.00021075 -0.080793 -2.8731e-18
+//Sensor type
+%S 4000
+//Sensor name and data for sensor # 1
+%N 0-RED
+0.0050132 0.077834 0.00010455
+//Sensor type
+%S 4000
+//Sensor name and data for sensor # 2
+%N 1-YELLOW
+0.010353 -0.076396 -0.0045289
+//Sensor type
+%S 4000
+//Sensor name and data for sensor # 3
+%N 2-BLUE
+0.11786 0.0049369 0.025876
+//Sensor type
+%S 4000
+//Sensor name and data for sensor # 4
+%N 3-WHITE
+0.1004 0.04654 0.024553
+//Sensor type
+%S 4000
+//Sensor name and data for sensor # 5
+%N 4-BLACK
+0.10746 -0.034116 0.031846
diff --git a/mne/io/kit/tests/data/test.hsp b/mne/io/kit/tests/data/test.hsp
new file mode 100644
index 0000000..a67ee50
--- /dev/null
+++ b/mne/io/kit/tests/data/test.hsp
@@ -0,0 +1,514 @@
+3 200
+//Shape file
+//Minor revision number
+2
+//Subject Name
+%N Name
+////Shape code, number of digitized points
+0 500
+//Position of fiducials X+, Y+, Y- on the subject
+%F 0.11056 -5.421e-19 0
+%F -0.00021075 0.080793 -7.5894e-19
+%F 0.00021075 -0.080793 -2.8731e-18
+//No of rows, no of columns; position of digitized points
+500 3
+-0.009834 -0.095567 0.031855
+-0.008069 -0.095958 0.032424
+-0.006919 -0.096273 0.031884
+-0.007449 -0.095409 0.033315
+-0.006068 -0.095745 0.032652
+-0.006486 -0.094639 0.034006
+-0.005279 -0.095372 0.033638
+-0.004021 -0.095359 0.032883
+-0.003826 -0.094649 0.034051
+-0.002588 -0.095163 0.033564
+-0.020912 -0.078901 0.062986
+-0.020881 -0.078405 0.063868
+-0.019912 -0.078852 0.063517
+-0.022529 -0.077117 0.065358
+-0.021557 -0.077512 0.064988
+-0.020648 -0.077893 0.064646
+-0.019743 -0.078500 0.064407
+-0.018814 -0.078601 0.063927
+-0.022660 -0.076428 0.066386
+-0.021752 -0.076993 0.066128
+-0.020854 -0.077467 0.065832
+-0.019939 -0.077899 0.065515
+-0.018818 -0.078569 0.065188
+-0.017899 -0.078722 0.064721
+-0.022362 -0.076047 0.067450
+-0.021125 -0.076871 0.067125
+-0.019944 -0.077598 0.066789
+-0.018579 -0.078204 0.066292
+-0.017217 -0.078790 0.065782
+-0.016124 -0.078197 0.064872
+-0.014800 -0.077737 0.063904
+-0.013284 -0.077979 0.063140
+-0.011681 -0.078419 0.062414
+-0.010418 -0.078829 0.061879
+-0.008465 -0.079392 0.061014
+-0.004792 -0.080425 0.059360
+-0.001462 -0.081407 0.057884
+0.008183 -0.084293 0.053657
+-0.028415 -0.072888 0.071012
+-0.026642 -0.073448 0.070271
+-0.025230 -0.073988 0.069716
+-0.024019 -0.074352 0.069195
+-0.022778 -0.075072 0.068818
+-0.021608 -0.076160 0.068658
+-0.020223 -0.076445 0.067990
+-0.018898 -0.077725 0.067833
+-0.017534 -0.078322 0.067337
+-0.016193 -0.079095 0.066933
+-0.014952 -0.078086 0.065744
+-0.013742 -0.077512 0.064772
+-0.012372 -0.077984 0.064203
+-0.010990 -0.078363 0.063587
+-0.009611 -0.078731 0.062958
+-0.008229 -0.079110 0.062342
+-0.006843 -0.079467 0.061705
+-0.005461 -0.079793 0.061056
+-0.004093 -0.080183 0.060441
+-0.002707 -0.080550 0.059804
+-0.001079 -0.080896 0.059031
+0.000495 -0.081585 0.058438
+0.002538 -0.082366 0.057624
+0.005302 -0.083072 0.056359
+0.007411 -0.083445 0.055317
+0.008359 -0.083534 0.054815
+-0.033495 -0.070560 0.074667
+-0.031961 -0.071110 0.074042
+-0.030579 -0.071756 0.073555
+-0.029200 -0.072310 0.073024
+-0.027816 -0.072771 0.072453
+-0.026445 -0.073253 0.071883
+-0.025058 -0.073735 0.071326
+-0.023689 -0.074249 0.070790
+-0.022314 -0.074691 0.070214
+-0.020977 -0.075835 0.069978
+-0.019587 -0.076089 0.069304
+-0.018256 -0.077162 0.069043
+-0.016912 -0.077976 0.068651
+-0.015567 -0.078781 0.068260
+-0.014187 -0.079015 0.067580
+-0.012684 -0.077664 0.066084
+-0.011296 -0.077897 0.065396
+-0.009918 -0.078317 0.064800
+-0.008538 -0.078675 0.064171
+-0.007162 -0.079013 0.063530
+-0.005788 -0.079391 0.062908
+-0.004406 -0.079780 0.062291
+-0.003020 -0.080137 0.061655
+-0.001642 -0.080423 0.060994
+-0.000259 -0.080749 0.060345
+0.001105 -0.081232 0.059783
+0.002447 -0.081892 0.059313
+0.003795 -0.082479 0.058805
+0.005229 -0.082441 0.057962
+0.006677 -0.082412 0.057118
+0.007871 -0.082406 0.056425
+-0.035214 -0.069369 0.077141
+-0.033374 -0.070139 0.076447
+-0.031531 -0.070948 0.075765
+-0.029694 -0.071677 0.075059
+-0.027855 -0.072375 0.074333
+-0.026022 -0.073001 0.073583
+-0.024195 -0.073710 0.072872
+-0.022365 -0.074439 0.072174
+-0.021012 -0.075263 0.071788
+-0.019939 -0.076863 0.071911
+-0.018250 -0.075866 0.070493
+-0.016714 -0.077240 0.070254
+-0.015160 -0.078323 0.069864
+-0.013358 -0.079188 0.069234
+-0.011884 -0.078210 0.067935
+-0.010434 -0.077718 0.066879
+-0.008595 -0.078219 0.066056
+-0.006767 -0.078752 0.065247
+-0.004928 -0.079182 0.064392
+-0.003095 -0.079693 0.063562
+-0.001267 -0.080226 0.062752
+0.000561 -0.080759 0.061943
+0.002441 -0.080918 0.060941
+0.004272 -0.081409 0.060112
+0.006195 -0.081337 0.058975
+0.007852 -0.081505 0.058100
+-0.037299 -0.067791 0.079847
+-0.035422 -0.068731 0.079205
+-0.033584 -0.069552 0.078530
+-0.031747 -0.070415 0.077889
+-0.029916 -0.071164 0.077190
+-0.028082 -0.071882 0.076471
+-0.026241 -0.072580 0.075759
+-0.024419 -0.073248 0.075022
+-0.022589 -0.073977 0.074324
+-0.020997 -0.074554 0.073692
+-0.019955 -0.076784 0.074128
+-0.019076 -0.077600 0.074003
+-0.017564 -0.075704 0.072283
+-0.015964 -0.076042 0.071527
+-0.014471 -0.077995 0.071570
+-0.012664 -0.078706 0.070868
+-0.011072 -0.079158 0.070171
+-0.009569 -0.077992 0.068773
+-0.007934 -0.077946 0.067811
+-0.006095 -0.078376 0.066956
+-0.004274 -0.078971 0.066173
+-0.002438 -0.079431 0.065338
+-0.000598 -0.079891 0.064482
+0.001247 -0.080238 0.063587
+0.003126 -0.080396 0.062585
+0.004963 -0.080825 0.061716
+0.006865 -0.080930 0.060686
+0.008768 -0.081056 0.059655
+0.010195 -0.081142 0.058885
+-0.039305 -0.066140 0.082599
+-0.037449 -0.067144 0.081999
+-0.035605 -0.068088 0.081381
+-0.033761 -0.069032 0.080764
+-0.031927 -0.069750 0.080044
+-0.030093 -0.070582 0.079391
+-0.028265 -0.071435 0.078744
+-0.026437 -0.072154 0.078032
+-0.024611 -0.072771 0.077290
+-0.022787 -0.073449 0.076567
+-0.020966 -0.074220 0.075896
+-0.019313 -0.077845 0.076603
+-0.017883 -0.077007 0.075423
+-0.016385 -0.075432 0.073867
+-0.014559 -0.075935 0.073059
+-0.012835 -0.077826 0.072948
+-0.011031 -0.078444 0.072201
+-0.009207 -0.078864 0.071348
+-0.007214 -0.077531 0.069604
+-0.005380 -0.077981 0.068755
+-0.003532 -0.078307 0.067840
+-0.001731 -0.079028 0.067145
+0.000098 -0.079561 0.066336
+0.001940 -0.079897 0.065428
+0.003848 -0.079796 0.064294
+0.005664 -0.080474 0.063564
+0.007552 -0.080518 0.062503
+0.009457 -0.080654 0.061485
+0.011354 -0.080790 0.060461
+0.013021 -0.080917 0.059573
+-0.042940 -0.063924 0.085965
+-0.040913 -0.064605 0.085122
+-0.039021 -0.065688 0.084542
+-0.037170 -0.066713 0.083949
+-0.035324 -0.067677 0.083344
+-0.033490 -0.068653 0.082755
+-0.031660 -0.069536 0.082106
+-0.029826 -0.070306 0.081420
+-0.027996 -0.071035 0.080722
+-0.026180 -0.071838 0.080057
+-0.024353 -0.072475 0.079314
+-0.022531 -0.073051 0.078560
+-0.020713 -0.073864 0.077909
+-0.019516 -0.077314 0.078790
+-0.018428 -0.078347 0.078633
+-0.017436 -0.077239 0.077587
+-0.016147 -0.075190 0.075947
+-0.014304 -0.075445 0.075022
+-0.012489 -0.076032 0.074261
+-0.010782 -0.078027 0.074203
+-0.008978 -0.078583 0.073423
+-0.007157 -0.078973 0.072558
+-0.005274 -0.078894 0.071453
+-0.003382 -0.078702 0.070289
+-0.001507 -0.078686 0.069210
+0.000339 -0.078991 0.068282
+0.002186 -0.079306 0.067368
+0.004053 -0.079538 0.066413
+0.005901 -0.079854 0.065499
+0.007763 -0.080106 0.064551
+0.009667 -0.080170 0.063501
+0.011569 -0.080286 0.062470
+0.013491 -0.080317 0.061399
+0.015391 -0.080536 0.060420
+-0.041891 -0.063482 0.087420
+-0.040168 -0.064909 0.087082
+-0.038324 -0.065370 0.086254
+-0.036477 -0.066416 0.085681
+-0.034633 -0.067432 0.085095
+-0.032800 -0.068397 0.084492
+-0.030970 -0.069271 0.083858
+-0.029141 -0.070062 0.083179
+-0.027317 -0.070802 0.082488
+-0.025510 -0.071584 0.081818
+-0.023681 -0.072222 0.081089
+-0.021861 -0.072850 0.080354
+-0.020261 -0.073374 0.079695
+-0.019411 -0.074654 0.079784
+-0.018658 -0.077732 0.080741
+-0.017286 -0.077866 0.080039
+-0.015492 -0.074979 0.077742
+-0.013648 -0.075173 0.076799
+-0.011827 -0.075739 0.076031
+-0.010246 -0.076419 0.075445
+-0.008760 -0.078042 0.075335
+-0.006961 -0.078619 0.074562
+-0.005101 -0.078646 0.073518
+-0.003221 -0.078537 0.072401
+-0.001332 -0.078314 0.071224
+0.000535 -0.078484 0.070237
+0.002373 -0.078780 0.069317
+0.004218 -0.079137 0.068423
+0.006093 -0.079306 0.067442
+0.007952 -0.079590 0.066514
+0.009836 -0.079718 0.065505
+0.011746 -0.079771 0.064448
+0.013635 -0.079950 0.063465
+0.015552 -0.080013 0.062400
+0.017486 -0.080023 0.061315
+-0.046056 -0.061047 0.090813
+-0.044260 -0.061779 0.090125
+-0.042380 -0.063295 0.089739
+-0.040522 -0.063631 0.088844
+-0.038673 -0.064430 0.088155
+-0.036819 -0.065682 0.087671
+-0.034977 -0.066739 0.087105
+-0.033151 -0.067643 0.086477
+-0.031317 -0.068609 0.085888
+-0.029497 -0.069452 0.085235
+-0.027670 -0.070151 0.084524
+-0.025855 -0.070995 0.083893
+-0.024036 -0.071694 0.083176
+-0.022220 -0.072415 0.082494
+-0.020403 -0.073012 0.081747
+-0.018833 -0.073837 0.081240
+-0.017595 -0.077108 0.082005
+-0.015687 -0.074819 0.079925
+-0.014025 -0.074813 0.078994
+-0.012206 -0.075287 0.078196
+-0.010379 -0.075790 0.077388
+-0.008565 -0.076232 0.076564
+-0.006839 -0.077824 0.076303
+-0.005058 -0.078557 0.075617
+-0.003215 -0.078750 0.074659
+-0.001312 -0.078464 0.073448
+0.000529 -0.078709 0.072510
+0.002436 -0.078464 0.071311
+0.004293 -0.078666 0.070352
+0.006139 -0.078951 0.069439
+0.008012 -0.079110 0.068445
+0.009883 -0.079312 0.067484
+0.011776 -0.079459 0.066481
+0.013677 -0.079555 0.065451
+0.015576 -0.079629 0.064408
+0.017511 -0.079650 0.063336
+0.019450 -0.079650 0.062243
+0.020892 -0.079746 0.061472
+-0.057444 -0.057606 0.097741
+-0.054896 -0.058094 0.096515
+-0.050229 -0.059102 0.094331
+-0.046964 -0.060061 0.092920
+-0.045310 -0.060939 0.092374
+-0.043674 -0.061830 0.091850
+-0.041806 -0.062884 0.091268
+-0.039947 -0.063847 0.090648
+-0.038096 -0.064585 0.089940
+-0.036255 -0.065508 0.089310
+-0.034420 -0.066422 0.088688
+-0.032585 -0.067480 0.088130
+-0.030765 -0.068385 0.087509
+-0.028939 -0.069207 0.086864
+-0.027124 -0.070061 0.086218
+-0.025304 -0.070833 0.085547
+-0.023489 -0.071471 0.084820
+-0.021676 -0.072162 0.084125
+-0.019861 -0.072862 0.083430
+-0.018279 -0.073481 0.082825
+-0.016784 -0.075473 0.082866
+-0.014870 -0.074377 0.081332
+-0.013040 -0.074829 0.080519
+-0.011225 -0.075272 0.079694
+-0.009408 -0.075745 0.078882
+-0.007584 -0.076146 0.078044
+-0.006021 -0.076796 0.077433
+-0.004523 -0.078140 0.077178
+-0.002728 -0.078738 0.076426
+-0.001120 -0.078881 0.075579
+0.000330 -0.078513 0.074602
+0.001973 -0.078457 0.073647
+0.003810 -0.078795 0.072761
+0.005725 -0.078549 0.071569
+0.007594 -0.078668 0.070563
+0.009463 -0.078849 0.069589
+0.011346 -0.078956 0.068582
+0.013221 -0.079198 0.067633
+0.015127 -0.079263 0.066597
+0.017040 -0.079275 0.065520
+0.018992 -0.079201 0.064394
+0.020922 -0.079243 0.063329
+0.022626 -0.079244 0.062380
+-0.056904 -0.057222 0.099187
+-0.045705 -0.060240 0.094225
+-0.044305 -0.061223 0.093862
+-0.042661 -0.062052 0.093313
+-0.040812 -0.062707 0.092559
+-0.038960 -0.063733 0.091980
+-0.037119 -0.064574 0.091318
+-0.035274 -0.065528 0.090714
+-0.033441 -0.066493 0.090111
+-0.031615 -0.067469 0.089515
+-0.029798 -0.068344 0.088882
+-0.027979 -0.069301 0.088295
+-0.026169 -0.070104 0.087638
+-0.024352 -0.070835 0.086955
+-0.022532 -0.071514 0.086253
+-0.020727 -0.072267 0.085585
+-0.018921 -0.072958 0.084884
+-0.017361 -0.073887 0.084437
+-0.016261 -0.075710 0.084625
+-0.014676 -0.074509 0.083220
+-0.013028 -0.074752 0.082421
+-0.011220 -0.075195 0.081604
+-0.009387 -0.075554 0.080739
+-0.007562 -0.075893 0.079868
+-0.005730 -0.076180 0.078957
+-0.004188 -0.077017 0.078452
+-0.002702 -0.078321 0.078179
+-0.000897 -0.078805 0.077382
+0.000692 -0.079165 0.076640
+0.002171 -0.078548 0.075530
+0.003850 -0.078232 0.074449
+0.005631 -0.078903 0.073730
+0.007270 -0.079003 0.072861
+0.008724 -0.078738 0.071928
+0.010395 -0.078628 0.070938
+0.012278 -0.078736 0.069931
+0.014162 -0.078863 0.068923
+0.016065 -0.078979 0.067906
+0.017742 -0.078983 0.066974
+0.019442 -0.078943 0.066013
+0.021377 -0.078964 0.064941
+0.023084 -0.078933 0.063972
+0.024325 -0.078832 0.063237
+-0.056325 -0.056690 0.100536
+-0.045146 -0.059434 0.095483
+-0.043725 -0.060826 0.095290
+-0.042315 -0.061676 0.094883
+-0.040920 -0.062670 0.094528
+-0.039308 -0.063121 0.093822
+-0.037697 -0.063675 0.093182
+-0.036079 -0.064588 0.092681
+-0.034470 -0.065347 0.092110
+-0.032870 -0.066283 0.091630
+-0.031044 -0.067239 0.091035
+-0.029236 -0.068155 0.090430
+-0.027417 -0.069112 0.089843
+-0.025606 -0.069854 0.089167
+-0.023796 -0.070657 0.088510
+-0.021996 -0.071420 0.087849
+-0.020191 -0.072163 0.087181
+-0.018388 -0.072947 0.086533
+-0.017075 -0.074186 0.086358
+-0.015809 -0.075911 0.086406
+-0.014105 -0.074536 0.084870
+-0.012502 -0.074608 0.084013
+-0.010691 -0.075082 0.083208
+-0.008873 -0.075484 0.082378
+-0.007043 -0.075802 0.081501
+-0.005215 -0.076048 0.080577
+-0.003379 -0.076304 0.079660
+-0.001636 -0.077400 0.079152
+0.000128 -0.078227 0.078519
+0.001895 -0.079022 0.077867
+0.003535 -0.078936 0.076900
+0.005334 -0.077787 0.075393
+0.007125 -0.078397 0.074662
+0.008924 -0.078964 0.073890
+0.010594 -0.078834 0.072901
+0.012042 -0.078652 0.072007
+0.013747 -0.078457 0.070975
+0.015641 -0.078543 0.069953
+0.017539 -0.078660 0.068944
+0.019492 -0.078516 0.067800
+0.021173 -0.078570 0.066880
+0.022659 -0.078415 0.065988
+0.024373 -0.078375 0.065026
+0.026086 -0.078344 0.064049
+-0.055255 -0.056613 0.101893
+-0.052835 -0.056826 0.100629
+-0.048453 -0.057192 0.098341
+-0.045532 -0.058090 0.097109
+-0.044108 -0.059154 0.096782
+-0.042680 -0.060537 0.096597
+-0.041279 -0.061654 0.096298
+-0.039656 -0.062823 0.095905
+-0.037811 -0.063284 0.095090
+-0.035972 -0.063972 0.094365
+-0.034135 -0.064978 0.093788
+-0.032306 -0.065965 0.093205
+-0.030484 -0.066890 0.092598
+-0.028668 -0.067755 0.091966
+-0.026863 -0.068774 0.091412
+-0.025051 -0.069598 0.090768
+-0.023237 -0.070350 0.090106
+-0.021434 -0.071206 0.089490
+-0.019630 -0.071939 0.088823
+-0.017830 -0.072692 0.088162
+-0.016071 -0.074004 0.087762
+-0.014305 -0.074976 0.087208
+-0.012991 -0.075885 0.086872
+-0.011521 -0.074675 0.085542
+-0.009703 -0.075005 0.084680
+-0.007882 -0.075365 0.083830
+-0.006071 -0.075777 0.082993
+-0.004250 -0.076138 0.082143
+-0.002422 -0.076446 0.081252
+-0.000636 -0.077086 0.080527
+0.001203 -0.077300 0.079590
+0.002917 -0.078450 0.079118
+0.004694 -0.079123 0.078421
+0.006285 -0.079431 0.077660
+0.007940 -0.077700 0.075989
+0.009556 -0.077914 0.075181
+0.011354 -0.078471 0.074409
+0.013199 -0.078798 0.073530
+0.015158 -0.078447 0.072282
+0.017120 -0.078179 0.071079
+0.019044 -0.078150 0.070003
+0.020993 -0.078098 0.068898
+0.022932 -0.078016 0.067787
+0.024913 -0.077879 0.066626
+0.026879 -0.077795 0.065499
+-0.054070 -0.055909 0.102929
+-0.051648 -0.056205 0.101710
+-0.047710 -0.056714 0.099743
+-0.045006 -0.057878 0.098746
+-0.043341 -0.059188 0.098399
+-0.041710 -0.059894 0.097799
+-0.040303 -0.061402 0.097674
+-0.038678 -0.062591 0.097293
+-0.036837 -0.063639 0.096728
+-0.034997 -0.063772 0.095773
+-0.033166 -0.064789 0.095203
+-0.031349 -0.065725 0.094603
+-0.029528 -0.066703 0.094029
+-0.027712 -0.067557 0.093397
+-0.025905 -0.068463 0.092792
+-0.024093 -0.069411 0.092213
+-0.022292 -0.070256 0.091583
+-0.020485 -0.071010 0.090929
+-0.018694 -0.071949 0.090360
+-0.017151 -0.073116 0.090017
+-0.015844 -0.074303 0.089817
+-0.014293 -0.074935 0.089236
+-0.012479 -0.075224 0.088362
+-0.010685 -0.075649 0.087561
+-0.008795 -0.075131 0.086299
+-0.006952 -0.075161 0.085293
+-0.005145 -0.075666 0.084508
+-0.003350 -0.076182 0.083725
+-0.001502 -0.076273 0.082743
+0.000318 -0.076633 0.081879
+0.002187 -0.076661 0.080856
+0.003958 -0.077354 0.080165
+0.005674 -0.078452 0.079674
+0.007431 -0.079198 0.079017
+0.009040 -0.079403 0.078217
+0.010716 -0.077609 0.076505
+0.012380 -0.077542 0.075556
+0.014196 -0.078057 0.074778
diff --git a/mne/io/kit/tests/data/test_umd-raw.sqd b/mne/io/kit/tests/data/test_umd-raw.sqd
new file mode 100644
index 0000000..2d85ee8
Binary files /dev/null and b/mne/io/kit/tests/data/test_umd-raw.sqd differ
diff --git a/mne/io/kit/tests/test_coreg.py b/mne/io/kit/tests/test_coreg.py
index f117d99..7521777 100644
--- a/mne/io/kit/tests/test_coreg.py
+++ b/mne/io/kit/tests/test_coreg.py
@@ -4,7 +4,9 @@
import inspect
import os
+from ....externals.six.moves import cPickle as pickle
+from nose.tools import assert_raises
from numpy.testing import assert_array_equal
from mne.io.kit import read_mrk
@@ -28,3 +30,16 @@ def test_io_mrk():
_write_dig_points(path, pts)
pts_2 = read_mrk(path)
assert_array_equal(pts, pts_2, "read/write mrk to text")
+
+ # pickle
+ fname = os.path.join(tempdir, 'mrk.pickled')
+ with open(fname, 'wb') as fid:
+ pickle.dump(dict(mrk=pts), fid)
+ pts_2 = read_mrk(fname)
+ assert_array_equal(pts_2, pts, "pickle mrk")
+ with open(fname, 'wb') as fid:
+ pickle.dump(dict(), fid)
+ assert_raises(ValueError, read_mrk, fname)
+
+ # unsupported extension
+ assert_raises(ValueError, read_mrk, "file.ext")
diff --git a/mne/io/kit/tests/test_kit.py b/mne/io/kit/tests/test_kit.py
index 2e171a5..7fa7e6e 100644
--- a/mne/io/kit/tests/test_kit.py
+++ b/mne/io/kit/tests/test_kit.py
@@ -9,51 +9,58 @@ import os.path as op
import inspect
import numpy as np
from numpy.testing import assert_array_almost_equal, assert_array_equal
-from nose.tools import assert_raises, assert_true
+from nose.tools import assert_equal, assert_raises, assert_true
+from scipy import linalg
import scipy.io
from mne import pick_types, Epochs, find_events, read_events
+from mne.transforms import apply_trans
from mne.tests.common import assert_dig_allclose
from mne.utils import run_tests_if_main
-from mne.io import Raw, read_raw_kit, read_epochs_kit
+from mne.io import read_raw_fif, read_raw_kit, read_epochs_kit
from mne.io.kit.coreg import read_sns
+from mne.io.kit.constants import KIT, KIT_CONSTANTS, KIT_NY, KIT_UMD_2014
from mne.io.tests.test_raw import _test_raw_reader
FILE = inspect.getfile(inspect.currentframe())
parent_dir = op.dirname(op.abspath(FILE))
data_dir = op.join(parent_dir, 'data')
sqd_path = op.join(data_dir, 'test.sqd')
+sqd_umd_path = op.join(data_dir, 'test_umd-raw.sqd')
epochs_path = op.join(data_dir, 'test-epoch.raw')
events_path = op.join(data_dir, 'test-eve.txt')
mrk_path = op.join(data_dir, 'test_mrk.sqd')
mrk2_path = op.join(data_dir, 'test_mrk_pre.sqd')
mrk3_path = op.join(data_dir, 'test_mrk_post.sqd')
-elp_path = op.join(data_dir, 'test_elp.txt')
-hsp_path = op.join(data_dir, 'test_hsp.txt')
+elp_txt_path = op.join(data_dir, 'test_elp.txt')
+hsp_txt_path = op.join(data_dir, 'test_hsp.txt')
+elp_path = op.join(data_dir, 'test.elp')
+hsp_path = op.join(data_dir, 'test.hsp')
def test_data():
- """Test reading raw kit files
- """
+ """Test reading raw kit files."""
assert_raises(TypeError, read_raw_kit, epochs_path)
assert_raises(TypeError, read_epochs_kit, sqd_path)
- assert_raises(ValueError, read_raw_kit, sqd_path, mrk_path, elp_path)
+ assert_raises(ValueError, read_raw_kit, sqd_path, mrk_path, elp_txt_path)
assert_raises(ValueError, read_raw_kit, sqd_path, None, None, None,
list(range(200, 190, -1)))
assert_raises(ValueError, read_raw_kit, sqd_path, None, None, None,
list(range(167, 159, -1)), '*', 1, True)
# check functionality
- raw_mrk = read_raw_kit(sqd_path, [mrk2_path, mrk3_path], elp_path,
- hsp_path)
- raw_py = _test_raw_reader(read_raw_kit,
- input_fname=sqd_path, mrk=mrk_path, elp=elp_path,
- hsp=hsp_path, stim=list(range(167, 159, -1)),
- slope='+', stimthresh=1)
+ raw_mrk = read_raw_kit(sqd_path, [mrk2_path, mrk3_path], elp_txt_path,
+ hsp_txt_path)
+ raw_py = _test_raw_reader(read_raw_kit, input_fname=sqd_path, mrk=mrk_path,
+ elp=elp_txt_path, hsp=hsp_txt_path,
+ stim=list(range(167, 159, -1)), slope='+',
+ stimthresh=1)
assert_true('RawKIT' in repr(raw_py))
+ assert_equal(raw_mrk.info['kit_system_id'], KIT.SYSTEM_NYU_2010)
+ assert_true(KIT_CONSTANTS[raw_mrk.info['kit_system_id']] is KIT_NY)
# Test stim channel
- raw_stim = read_raw_kit(sqd_path, mrk_path, elp_path, hsp_path, stim='<',
- preload=False)
+ raw_stim = read_raw_kit(sqd_path, mrk_path, elp_txt_path, hsp_txt_path,
+ stim='<', preload=False)
for raw in [raw_py, raw_stim, raw_mrk]:
stim_pick = pick_types(raw.info, meg=False, ref_meg=False,
stim=True, exclude='bads')
@@ -64,7 +71,7 @@ def test_data():
# Binary file only stores the sensor channels
py_picks = pick_types(raw_py.info, exclude='bads')
raw_bin = op.join(data_dir, 'test_bin_raw.fif')
- raw_bin = Raw(raw_bin, preload=True)
+ raw_bin = read_raw_fif(raw_bin, preload=True, add_eeg_ref=False)
bin_picks = pick_types(raw_bin.info, stim=True, exclude='bads')
data_bin, _ = raw_bin[bin_picks]
data_py, _ = raw_py[py_picks]
@@ -81,11 +88,19 @@ def test_data():
data_py, _ = raw_py[py_picks]
assert_array_almost_equal(data_py, data_bin)
+ # KIT-UMD data
+ _test_raw_reader(read_raw_kit, input_fname=sqd_umd_path)
+ raw = read_raw_kit(sqd_umd_path)
+ assert_equal(raw.info['kit_system_id'], KIT.SYSTEM_UMD_2014_12)
+ assert_true(KIT_CONSTANTS[raw.info['kit_system_id']] is KIT_UMD_2014)
+
def test_epochs():
+ """Test reading epoched SQD file."""
raw = read_raw_kit(sqd_path, stim=None)
events = read_events(events_path)
- raw_epochs = Epochs(raw, events, None, tmin=0, tmax=.099, baseline=None)
+ raw_epochs = Epochs(raw, events, None, tmin=0, tmax=.099, baseline=None,
+ add_eeg_ref=False)
data1 = raw_epochs.get_data()
epochs = read_epochs_kit(epochs_path, events_path)
data11 = epochs.get_data()
@@ -93,6 +108,7 @@ def test_epochs():
def test_raw_events():
+ """Test creating stim channel from raw SQD file."""
def evts(a, b, c, d, e, f=None):
out = [[269, a, b], [281, b, c], [1552, c, d], [1564, d, e]]
if f is not None:
@@ -117,10 +133,11 @@ def test_raw_events():
def test_ch_loc():
- """Test raw kit loc
- """
- raw_py = read_raw_kit(sqd_path, mrk_path, elp_path, hsp_path, stim='<')
- raw_bin = Raw(op.join(data_dir, 'test_bin_raw.fif'))
+ """Test raw kit loc."""
+ raw_py = read_raw_kit(sqd_path, mrk_path, elp_txt_path, hsp_txt_path,
+ stim='<')
+ raw_bin = read_raw_fif(op.join(data_dir, 'test_bin_raw.fif'),
+ add_eeg_ref=False)
ch_py = raw_py._raw_extras[0]['sensor_locs'][:, :5]
# ch locs stored as m, not mm
@@ -137,10 +154,32 @@ def test_ch_loc():
# test when more than one marker file provided
mrks = [mrk_path, mrk2_path, mrk3_path]
- read_raw_kit(sqd_path, mrks, elp_path, hsp_path, preload=False)
+ read_raw_kit(sqd_path, mrks, elp_txt_path, hsp_txt_path, preload=False)
# this dataset does not have the equivalent set of points :(
raw_bin.info['dig'] = raw_bin.info['dig'][:8]
raw_py.info['dig'] = raw_py.info['dig'][:8]
assert_dig_allclose(raw_py.info, raw_bin.info)
+
+def test_hsp_elp():
+ """Test KIT usage of *.elp and *.hsp files against *.txt files."""
+ raw_txt = read_raw_kit(sqd_path, mrk_path, elp_txt_path, hsp_txt_path)
+ raw_elp = read_raw_kit(sqd_path, mrk_path, elp_path, hsp_path)
+
+ # head points
+ pts_txt = np.array([dig_point['r'] for dig_point in raw_txt.info['dig']])
+ pts_elp = np.array([dig_point['r'] for dig_point in raw_elp.info['dig']])
+ assert_array_almost_equal(pts_elp, pts_txt, decimal=5)
+
+ # transforms
+ trans_txt = raw_txt.info['dev_head_t']['trans']
+ trans_elp = raw_elp.info['dev_head_t']['trans']
+ assert_array_almost_equal(trans_elp, trans_txt, decimal=5)
+
+ # head points in device space
+ pts_txt_in_dev = apply_trans(linalg.inv(trans_txt), pts_txt)
+ pts_elp_in_dev = apply_trans(linalg.inv(trans_elp), pts_elp)
+ assert_array_almost_equal(pts_elp_in_dev, pts_txt_in_dev, decimal=5)
+
+
run_tests_if_main()
diff --git a/mne/io/meas_info.py b/mne/io/meas_info.py
index 4ecd461..f749640 100644
--- a/mne/io/meas_info.py
+++ b/mne/io/meas_info.py
@@ -4,9 +4,11 @@
#
# License: BSD (3-clause)
+from collections import Counter
from copy import deepcopy
from datetime import datetime as dt
import os.path as op
+import re
import numpy as np
from scipy import linalg
@@ -24,7 +26,6 @@ from .write import (start_file, end_file, start_block, end_block,
write_julian, write_float_matrix)
from .proc_history import _read_proc_history, _write_proc_history
from ..utils import logger, verbose, warn
-from ..fixes import Counter
from .. import __version__
from ..externals.six import b, BytesIO, string_types, text_type
@@ -42,6 +43,8 @@ _kind_dict = dict(
seeg=(FIFF.FIFFV_SEEG_CH, FIFF.FIFFV_COIL_EEG, FIFF.FIFF_UNIT_V),
bio=(FIFF.FIFFV_BIO_CH, FIFF.FIFFV_COIL_NONE, FIFF.FIFF_UNIT_V),
ecog=(FIFF.FIFFV_ECOG_CH, FIFF.FIFFV_COIL_EEG, FIFF.FIFF_UNIT_V),
+ hbo=(FIFF.FIFFV_FNIRS_CH, FIFF.FIFFV_COIL_FNIRS_HBO, FIFF.FIFF_UNIT_MOL),
+ hbr=(FIFF.FIFFV_FNIRS_CH, FIFF.FIFFV_COIL_FNIRS_HBR, FIFF.FIFF_UNIT_MOL)
)
@@ -202,6 +205,9 @@ class Info(dict):
elif k == 'meas_date' and np.iterable(v):
# first entry in meas_date is meaningful
entr = dt.fromtimestamp(v[0]).strftime('%Y-%m-%d %H:%M:%S')
+ elif k == 'kit_system_id' and v is not None:
+ from .kit.constants import SYSNAMES as KIT_SYSNAMES
+ entr = '%i (%s)' % (v, KIT_SYSNAMES.get(v, 'unknown'))
else:
this_len = (len(v) if hasattr(v, '__len__') else
('%s' % v if v is not None else None))
@@ -226,14 +232,6 @@ class Info(dict):
st %= non_empty
return st
- def _anonymize(self):
- if self.get('subject_info') is not None:
- del self['subject_info']
- self['meas_date'] = [0, 0]
- for key_1 in ('file_id', 'meas_id'):
- for key_2 in ('secs', 'msecs', 'usecs'):
- self[key_1][key_2] = 0
-
def _check_consistency(self):
"""Do some self-consistency checks and datatype tweaks"""
missing = [bad for bad in self['bads'] if bad not in self['ch_names']]
@@ -363,10 +361,11 @@ def _read_dig_fif(fid, meas_info):
return dig
-def _read_dig_points(fname, comments='%'):
+def _read_dig_points(fname, comments='%', unit='auto'):
"""Read digitizer data from a text file.
- This function can read space-delimited text files of digitizer data.
+ If fname ends in .hsp or .esp, the function assumes digitizer files in [m],
+ otherwise it assumes space-delimited text files in [mm].
Parameters
----------
@@ -375,17 +374,45 @@ def _read_dig_points(fname, comments='%'):
comments : str
The character used to indicate the start of a comment;
Default: '%'.
+ unit : 'auto' | 'm' | 'cm' | 'mm'
+ Unit of the digitizer files (hsp and elp). If not 'm', coordinates will
+ be rescaled to 'm'. Default is 'auto', which assumes 'm' for *.hsp and
+ *.elp files and 'mm' for *.txt files, corresponding to the known
+ Polhemus export formats.
Returns
-------
dig_points : np.ndarray, shape (n_points, 3)
- Array of dig points.
+ Array of dig points in [m].
"""
- dig_points = np.loadtxt(fname, comments=comments, ndmin=2)
+ if unit not in ('auto', 'm', 'mm', 'cm'):
+ raise ValueError('unit must be one of "auto", "m", "mm", or "cm"')
+
+ _, ext = op.splitext(fname)
+ if ext == '.elp' or ext == '.hsp':
+ with open(fname) as fid:
+ file_str = fid.read()
+ value_pattern = "\-?\d+\.?\d*e?\-?\d*"
+ coord_pattern = "({0})\s+({0})\s+({0})\s*$".format(value_pattern)
+ if ext == '.hsp':
+ coord_pattern = '^' + coord_pattern
+ points_str = [m.groups() for m in re.finditer(coord_pattern, file_str,
+ re.MULTILINE)]
+ dig_points = np.array(points_str, dtype=float)
+ else:
+ dig_points = np.loadtxt(fname, comments=comments, ndmin=2)
+ if unit == 'auto':
+ unit = 'mm'
+
if dig_points.shape[-1] != 3:
err = 'Data must be (n, 3) instead of %s' % (dig_points.shape,)
raise ValueError(err)
+ if unit == 'mm':
+ dig_points /= 1000.
+ elif unit == 'cm':
+ dig_points /= 100.
+
return dig_points
@@ -614,6 +641,7 @@ def read_meas_info(fid, tree, clean_bads=False, verbose=None):
line_freq = None
custom_ref_applied = False
xplotter_layout = None
+ kit_system_id = None
for k in range(meas_info['nent']):
kind = meas_info['directory'][k].kind
pos = meas_info['directory'][k].pos
@@ -669,6 +697,9 @@ def read_meas_info(fid, tree, clean_bads=False, verbose=None):
elif kind == FIFF.FIFF_XPLOTTER_LAYOUT:
tag = read_tag(fid, pos)
xplotter_layout = str(tag.data)
+ elif kind == FIFF.FIFF_MNE_KIT_SYSTEM_ID:
+ tag = read_tag(fid, pos)
+ kit_system_id = int(tag.data)
# Check that we have everything we need
if nchan is None:
@@ -949,6 +980,7 @@ def read_meas_info(fid, tree, clean_bads=False, verbose=None):
info['acq_stim'] = acq_stim
info['custom_ref_applied'] = custom_ref_applied
info['xplotter_layout'] = xplotter_layout
+ info['kit_system_id'] = kit_system_id
info._check_consistency()
return info, meas
@@ -1163,6 +1195,10 @@ def write_meas_info(fid, info, data_type=None, reset_range=True):
# CTF compensation info
write_ctf_comp(fid, info['comps'])
+ # KIT system ID
+ if info.get('kit_system_id') is not None:
+ write_int(fid, FIFF.FIFF_MNE_KIT_SYSTEM_ID, info['kit_system_id'])
+
end_block(fid, FIFF.FIFFB_MEAS_INFO)
# Processing history
@@ -1204,7 +1240,8 @@ def _is_equal_dict(dicts):
is_equal.append((k0 == k) and _is_equal_dict(v))
else:
is_equal.append(all(np.all(k == k0) and
- np.all(v == v0) for k, v in d))
+ (np.array_equal(v, v0) if isinstance(v, np.ndarray)
+ else np.all(v == v0)) for k, v in d))
return all(is_equal)
@@ -1350,6 +1387,17 @@ def _merge_info(infos, force_update_to_first=False, verbose=None):
msg = ("Measurement infos provide mutually inconsistent %s" %
trans_name)
raise ValueError(msg)
+
+ # KIT system-IDs
+ kit_sys_ids = [i['kit_system_id'] for i in infos if i['kit_system_id']]
+ if len(kit_sys_ids) == 0:
+ info['kit_system_id'] = None
+ elif len(set(kit_sys_ids)) == 1:
+ info['kit_system_id'] = kit_sys_ids[0]
+ else:
+ raise ValueError("Trying to merge channels from different KIT systems")
+
+ # other fields
other_fields = ['acq_pars', 'acq_stim', 'bads', 'buffer_size_sec',
'comps', 'custom_ref_applied', 'description', 'dig',
'experimenter', 'file_id', 'filename', 'highpass',
@@ -1357,9 +1405,9 @@ def _merge_info(infos, force_update_to_first=False, verbose=None):
'line_freq', 'lowpass', 'meas_date', 'meas_id',
'proj_id', 'proj_name', 'projs', 'sfreq',
'subject_info', 'sfreq', 'xplotter_layout']
-
for k in other_fields:
info[k] = _merge_dict_values(infos, k)
+
info._check_consistency()
return info
@@ -1371,13 +1419,13 @@ def create_info(ch_names, sfreq, ch_types=None, montage=None):
----------
ch_names : list of str | int
Channel names. If an int, a list of channel names will be created
- from range(ch_names)
+ from :func:`range(ch_names) <range>`.
sfreq : float
Sample rate of the data.
ch_types : list of str | str
Channel types. If None, data are assumed to be misc.
Currently supported fields are 'ecg', 'bio', 'stim', 'eog', 'misc',
- 'seeg', 'ecog', 'mag', 'eeg', 'ref_meg' or 'grad'.
+ 'seeg', 'ecog', 'mag', 'eeg', 'ref_meg', 'grad', 'hbr' or 'hbo'.
If str, then all channels are assumed to be of the same type.
montage : None | str | Montage | DigMontage | list
A montage containing channel positions. If str or Montage is
@@ -1387,6 +1435,11 @@ def create_info(ch_names, sfreq, ch_types=None, montage=None):
can be specifed and applied to the info. See also the documentation of
:func:`mne.channels.read_montage` for more information.
+ Returns
+ -------
+ info : instance of Info
+ The measurement info.
+
Notes
-----
The info dictionary will be sparsely populated to enable functionality
@@ -1410,7 +1463,8 @@ def create_info(ch_names, sfreq, ch_types=None, montage=None):
if isinstance(ch_types, string_types):
ch_types = [ch_types] * nchan
if len(ch_types) != nchan:
- raise ValueError('ch_types and ch_names must be the same length')
+ raise ValueError('ch_types and ch_names must be the same length '
+ '(%s != %s)' % (len(ch_types), nchan))
info = _empty_info(sfreq)
info['meas_date'] = np.array([0, 0], np.int32)
loc = np.concatenate((np.zeros(3), np.eye(3).ravel())).astype(np.float32)
@@ -1453,9 +1507,9 @@ RAW_INFO_FIELDS = (
'comps', 'ctf_head_t', 'custom_ref_applied', 'description', 'dev_ctf_t',
'dev_head_t', 'dig', 'experimenter', 'events',
'file_id', 'filename', 'highpass', 'hpi_meas', 'hpi_results',
- 'hpi_subsystem', 'line_freq', 'lowpass', 'meas_date', 'meas_id', 'nchan',
- 'proj_id', 'proj_name', 'projs', 'sfreq', 'subject_info',
- 'xplotter_layout',
+ 'hpi_subsystem', 'kit_system_id', 'line_freq', 'lowpass', 'meas_date',
+ 'meas_id', 'nchan', 'proj_id', 'proj_name', 'projs', 'sfreq',
+ 'subject_info', 'xplotter_layout',
)
@@ -1465,8 +1519,8 @@ def _empty_info(sfreq):
_none_keys = (
'acq_pars', 'acq_stim', 'buffer_size_sec', 'ctf_head_t', 'description',
'dev_ctf_t', 'dig', 'experimenter',
- 'file_id', 'filename', 'highpass', 'hpi_subsystem', 'line_freq',
- 'lowpass', 'meas_date', 'meas_id', 'proj_id', 'proj_name',
+ 'file_id', 'filename', 'highpass', 'hpi_subsystem', 'kit_system_id',
+ 'line_freq', 'lowpass', 'meas_date', 'meas_id', 'proj_id', 'proj_name',
'subject_info', 'xplotter_layout',
)
_list_keys = ('bads', 'chs', 'comps', 'events', 'hpi_meas', 'hpi_results',
@@ -1513,3 +1567,34 @@ def _force_update_info(info_base, info_target):
continue
for i_targ in info_target:
i_targ[key] = val
+
+
+def anonymize_info(info):
+ """Anonymize measurement information in place.
+
+ Reset 'subject_info', 'meas_date', 'file_id', and 'meas_id' keys if they
+ exist in ``info``.
+
+ Parameters
+ ----------
+ info : dict, instance of Info
+ Measurement information for the dataset.
+
+ Returns
+ -------
+ info : instance of Info
+ Measurement information for the dataset.
+
+ Notes
+ -----
+ Operates in place.
+ """
+ if not isinstance(info, Info):
+ raise ValueError('self must be an Info instance.')
+ if info.get('subject_info') is not None:
+ del info['subject_info']
+ info['meas_date'] = [0, 0]
+ for key_1 in ('file_id', 'meas_id'):
+ for key_2 in ('secs', 'msecs', 'usecs'):
+ info[key_1][key_2] = 0
+ return info
diff --git a/mne/io/nicolet/nicolet.py b/mne/io/nicolet/nicolet.py
index 05956bb..6a1e0f4 100644
--- a/mne/io/nicolet/nicolet.py
+++ b/mne/io/nicolet/nicolet.py
@@ -105,7 +105,7 @@ def _get_nicolet_info(fname, ch_type, eog, ecg, emg, misc):
info = _empty_info(header_info['sample_freq'])
info.update({'filename': fname,
'meas_date': calendar.timegm(date.utctimetuple()),
- 'description': None, 'buffer_size_sec': 10.})
+ 'description': None, 'buffer_size_sec': 1.})
if ch_type == 'eeg':
ch_coil = FIFF.FIFFV_COIL_EEG
diff --git a/mne/io/open.py b/mne/io/open.py
index c5f0dfd..7e81ef8 100644
--- a/mne/io/open.py
+++ b/mne/io/open.py
@@ -4,17 +4,17 @@
#
# License: BSD (3-clause)
-from ..externals.six import string_types
-import numpy as np
import os.path as op
from io import BytesIO
+from gzip import GzipFile
+
+import numpy as np
from .tag import read_tag_info, read_tag, read_big, Tag
from .tree import make_dir_tree, dir_tree_find
from .constants import FIFF
from ..utils import logger, verbose
-from ..externals import six
-from ..fixes import gzip_open
+from ..externals.six import string_types, iteritems
def _fiff_get_fid(fname):
@@ -22,7 +22,7 @@ def _fiff_get_fid(fname):
if isinstance(fname, string_types):
if op.splitext(fname)[1].lower() == '.gz':
logger.debug('Using gzip')
- fid = gzip_open(fname, "rb") # Open in binary mode
+ fid = GzipFile(fname, "rb") # Open in binary mode
else:
logger.debug('Using normal I/O')
fid = open(fname, "rb") # Open in binary mode
@@ -179,8 +179,11 @@ def show_fiff(fname, indent=' ', read_limit=np.inf, max_str=30,
if output not in [list, str]:
raise ValueError('output must be list or str')
f, tree, directory = fiff_open(fname)
+ # This gets set to 0 (unknown) by fiff_open, but FIFFB_ROOT probably
+ # makes more sense for display
+ tree['block'] = FIFF.FIFFB_ROOT
with f as fid:
- out = _show_tree(fid, tree['children'][0], indent=indent, level=0,
+ out = _show_tree(fid, tree, indent=indent, level=0,
read_limit=read_limit, max_str=max_str)
if output == str:
out = '\n'.join(out)
@@ -189,7 +192,8 @@ def show_fiff(fname, indent=' ', read_limit=np.inf, max_str=30,
def _find_type(value, fmts=['FIFF_'], exclude=['FIFF_UNIT']):
"""Helper to find matching values"""
- vals = [k for k, v in six.iteritems(FIFF)
+ value = int(value)
+ vals = [k for k, v in iteritems(FIFF)
if v == value and any(fmt in k for fmt in fmts) and
not any(exc in k for exc in exclude)]
if len(vals) == 0:
@@ -203,7 +207,7 @@ def _show_tree(fid, tree, indent, level, read_limit, max_str):
this_idt = indent * level
next_idt = indent * (level + 1)
# print block-level information
- out = [this_idt + str(tree['block'][0]) + ' = ' +
+ out = [this_idt + str(int(tree['block'])) + ' = ' +
'/'.join(_find_type(tree['block'], fmts=['FIFFB_']))]
if tree['directory'] is not None:
kinds = [ent.kind for ent in tree['directory']] + [-1]
diff --git a/mne/io/pick.py b/mne/io/pick.py
index 6b8d6d8..da4b47d 100644
--- a/mne/io/pick.py
+++ b/mne/io/pick.py
@@ -28,7 +28,7 @@ def channel_type(info, idx):
-------
type : 'grad' | 'mag' | 'eeg' | 'stim' | 'eog' | 'emg' | 'ecg'
'ref_meg' | 'resp' | 'exci' | 'ias' | 'syst' | 'misc'
- 'seeg' | 'bio' | 'chpi' | 'dipole' | 'gof' | 'ecog'
+ 'seeg' | 'bio' | 'chpi' | 'dipole' | 'gof' | 'ecog' | 'hbo' | 'hbr'
Type of channel
"""
kind = info['chs'][idx]['kind']
@@ -74,6 +74,11 @@ def channel_type(info, idx):
return 'gof'
elif kind == FIFF.FIFFV_ECOG_CH:
return 'ecog'
+ elif kind == FIFF.FIFFV_FNIRS_CH:
+ if info['chs'][idx]['coil_type'] == FIFF.FIFFV_COIL_FNIRS_HBO:
+ return 'hbo'
+ elif info['chs'][idx]['coil_type'] == FIFF.FIFFV_COIL_FNIRS_HBR:
+ return 'hbr'
raise Exception('Unknown channel type')
@@ -167,6 +172,17 @@ def _triage_meg_pick(ch, meg):
return False
+def _triage_fnirs_pick(ch, fnirs):
+ """Helper to triage an fNIRS pick type."""
+ if fnirs is True:
+ return True
+ elif ch['coil_type'] == FIFF.FIFFV_COIL_FNIRS_HBO and fnirs == 'hbo':
+ return True
+ elif ch['coil_type'] == FIFF.FIFFV_COIL_FNIRS_HBR and fnirs == 'hbr':
+ return True
+ return False
+
+
def _check_meg_type(meg, allow_auto=False):
"""Helper to ensure a valid meg type"""
if isinstance(meg, string_types):
@@ -180,15 +196,15 @@ def _check_meg_type(meg, allow_auto=False):
def pick_types(info, meg=True, eeg=False, stim=False, eog=False, ecg=False,
emg=False, ref_meg='auto', misc=False, resp=False, chpi=False,
exci=False, ias=False, syst=False, seeg=False, dipole=False,
- gof=False, bio=False, ecog=False, include=[], exclude='bads',
- selection=None):
+ gof=False, bio=False, ecog=False, fnirs=False, include=[],
+ exclude='bads', selection=None):
"""Pick channels by type and names
Parameters
----------
info : dict
The measurement info.
- meg : bool or string
+ meg : bool | str
If True include all MEG channels. If False include None
If string it can be 'mag', 'grad', 'planar1' or 'planar2' to select
only magnetometers, all gradiometers, or a specific type of
@@ -230,11 +246,16 @@ def pick_types(info, meg=True, eeg=False, stim=False, eog=False, ecg=False,
Bio channels.
ecog : bool
Electrocorticography channels.
+ fnirs : bool | str
+ Functional near-infrared spectroscopy channels. If True include all
+ fNIRS channels. If False (default) include none. If string it can be
+ 'hbo' (to include channels measuring oxyhemoglobin) or 'hbr' (to
+ include channels measuring deoxyhemoglobin).
include : list of string
List of additional channels to include. If empty do not include any.
exclude : list of string | str
List of channels to exclude. If 'bads' (default), exclude channels
- in info['bads'].
+ in ``info['bads']``.
selection : list of string
Restrict sensor channels (MEG, EEG) to this list of channel names.
@@ -272,7 +293,8 @@ def pick_types(info, meg=True, eeg=False, stim=False, eog=False, ecg=False,
ias, syst, seeg, dipole, gof, bio, ecog):
if not isinstance(param, bool):
w = ('Parameters for all channel types (with the exception '
- 'of "meg" and "ref_meg") must be of type bool, not {}.')
+ 'of "meg", "ref_meg" and "fnirs") must be of type bool, '
+ 'not {0}.')
raise ValueError(w.format(type(param)))
for k in range(nchan):
@@ -317,6 +339,8 @@ def pick_types(info, meg=True, eeg=False, stim=False, eog=False, ecg=False,
pick[k] = True
elif kind == FIFF.FIFFV_ECOG_CH and ecog:
pick[k] = True
+ elif kind == FIFF.FIFFV_FNIRS_CH:
+ pick[k] = _triage_fnirs_pick(info['chs'][k], fnirs)
# restrict channels to selection if provided
if selection is not None:
@@ -557,8 +581,9 @@ def pick_types_forward(orig, meg=True, eeg=False, ref_meg=True, seeg=False,
def channel_indices_by_type(info):
"""Get indices of channels by type
"""
- idx = dict((key, list()) for key in _PICK_TYPES_KEYS if key != 'meg')
- idx.update(mag=list(), grad=list())
+ idx = dict((key, list()) for key in _PICK_TYPES_KEYS if
+ key not in ('meg', 'fnirs'))
+ idx.update(mag=list(), grad=list(), hbo=list(), hbr=list())
for k, ch in enumerate(info['chs']):
for key in idx.keys():
if channel_type(info, k) == key:
@@ -677,9 +702,9 @@ def _check_excludes_includes(chs, info=None, allow_bads=False):
_PICK_TYPES_DATA_DICT = dict(
meg=True, eeg=True, stim=False, eog=False, ecg=False, emg=False,
misc=False, resp=False, chpi=False, exci=False, ias=False, syst=False,
- seeg=True, dipole=False, gof=False, bio=False, ecog=True)
+ seeg=True, dipole=False, gof=False, bio=False, ecog=True, fnirs=True)
_PICK_TYPES_KEYS = tuple(list(_PICK_TYPES_DATA_DICT.keys()) + ['ref_meg'])
-_DATA_CH_TYPES_SPLIT = ['mag', 'grad', 'eeg', 'seeg', 'ecog']
+_DATA_CH_TYPES_SPLIT = ['mag', 'grad', 'eeg', 'seeg', 'ecog', 'hbo', 'hbr']
def _pick_data_channels(info, exclude='bads', with_ref_meg=True):
diff --git a/mne/io/proj.py b/mne/io/proj.py
index 634b753..edcf77f 100644
--- a/mne/io/proj.py
+++ b/mne/io/proj.py
@@ -18,7 +18,7 @@ from .constants import FIFF
from .pick import pick_types
from .write import (write_int, write_float, write_string, write_name_list,
write_float_matrix, end_block, start_block)
-from ..utils import logger, verbose, warn
+from ..utils import logger, verbose, warn, deprecated
from ..externals.six import string_types
@@ -105,13 +105,19 @@ class ProjMixin(object):
check_active=False, sort=False)
return self
+ @deprecated('This function is deprecated and will be removed in 0.14. '
+ 'Use set_eeg_reference() instead.')
def add_eeg_average_proj(self):
- """Add an average EEG reference projector if one does not exist
- """
+ """Add an average EEG reference projector if one does not exist."""
if _needs_eeg_average_ref_proj(self.info):
# Don't set as active, since we haven't applied it
eeg_proj = make_eeg_average_ref_proj(self.info, activate=False)
self.add_proj(eeg_proj)
+ elif self.info.get('custom_ref_applied', False):
+ raise RuntimeError('Cannot add an average EEG reference '
+ 'projection since a custom reference has been '
+ 'applied to the data earlier.')
+
return self
def apply_proj(self):
@@ -696,12 +702,17 @@ def make_eeg_average_ref_proj(info, activate=True, verbose=None):
return eeg_proj
-def _has_eeg_average_ref_proj(projs):
- """Determine if a list of projectors has an average EEG ref"""
+def _has_eeg_average_ref_proj(projs, check_active=False):
+ """Determine if a list of projectors has an average EEG ref
+
+ Optionally, set check_active=True to additionally check if the CAR
+ has already been applied.
+ """
for proj in projs:
if (proj['desc'] == 'Average EEG reference' or
proj['kind'] == FIFF.FIFFV_MNE_PROJ_ITEM_EEG_AVREF):
- return True
+ if not check_active or proj['active']:
+ return True
return False
diff --git a/mne/io/reference.py b/mne/io/reference.py
index f62b2ea..4d32af0 100644
--- a/mne/io/reference.py
+++ b/mne/io/reference.py
@@ -157,6 +157,10 @@ def add_reference_channels(inst, ref_channels, copy=True):
if ch in inst.info['ch_names']:
raise ValueError("Channel %s already specified in inst." % ch)
+ # Once CAR is applied (active), don't allow adding channels
+ if _has_eeg_average_ref_proj(inst.info['projs'], check_active=True):
+ raise RuntimeError('Average reference already applied to data.')
+
if copy:
inst = inst.copy()
@@ -181,6 +185,31 @@ def add_reference_channels(inst, ref_channels, copy=True):
raise TypeError("inst should be Raw, Epochs, or Evoked instead of %s."
% type(inst))
nchan = len(inst.info['ch_names'])
+
+ # only do this if we actually have digitisation points
+ if inst.info.get('dig', None) is not None:
+ # "zeroth" EEG electrode dig points is reference
+ ref_dig_loc = [dl for dl in inst.info['dig'] if (
+ dl['kind'] == FIFF.FIFFV_POINT_EEG and
+ dl['ident'] == 0)]
+ if len(ref_channels) > 1 or len(ref_dig_loc) != len(ref_channels):
+ ref_dig_array = np.zeros(12)
+ warn('The locations of multiple reference channels are ignored '
+ '(set to zero).')
+ else: # n_ref_channels == 1 and a single ref digitization exists
+ ref_dig_array = np.concatenate((ref_dig_loc[0]['r'],
+ ref_dig_loc[0]['r'], np.zeros(6)))
+ # Replace the (possibly new) Ref location for each channel
+ for idx in pick_types(inst.info, meg=False, eeg=True, exclude=[]):
+ inst.info['chs'][idx]['loc'][3:6] = ref_dig_loc[0]['r']
+ else:
+ # we should actually be able to do this from the montage, but
+ # it looks like the montage isn't stored, so we can't extract
+ # this information. The user will just have to call set_montage()
+ # by setting this to zero, we fall back to the old behavior
+ # when missing digitisation
+ ref_dig_array = np.zeros(12)
+
for ch in ref_channels:
chan_info = {'ch_name': ch,
'coil_type': FIFF.FIFFV_COIL_EEG,
@@ -192,12 +221,13 @@ def add_reference_channels(inst, ref_channels, copy=True):
'unit_mul': 0.,
'unit': FIFF.FIFF_UNIT_V,
'coord_frame': FIFF.FIFFV_COORD_HEAD,
- 'loc': np.zeros(12)}
+ 'loc': ref_dig_array}
inst.info['chs'].append(chan_info)
inst.info._update_redundant()
if isinstance(inst, _BaseRaw):
inst._cals = np.hstack((inst._cals, [1] * len(ref_channels)))
inst.info._check_consistency()
+ set_eeg_reference(inst, ref_channels=ref_channels, copy=False)
return inst
@@ -212,11 +242,12 @@ def set_eeg_reference(inst, ref_channels=None, copy=True):
inst : instance of Raw | Epochs | Evoked
Instance of Raw or Epochs with EEG channels and reference channel(s).
ref_channels : list of str | None
- The names of the channels to use to construct the reference. If None is
- specified here, an average reference will be applied in the form of an
- SSP projector. If an empty list is specified, the data is assumed to
- already have a proper reference and MNE will not attempt any
- re-referencing of the data. Defaults to an average reference (None).
+ The names of the channels to use to construct the reference. If
+ None (default), an average reference will be added as an SSP
+ projector but not immediately applied to the data. If an empty list
+ is specified, the data is assumed to already have a proper reference
+ and MNE will not attempt any re-referencing of the data. Defaults
+ to an average reference (None).
copy : bool
Specifies whether the data will be copied (True) or modified in place
(False). Defaults to True.
@@ -224,9 +255,12 @@ def set_eeg_reference(inst, ref_channels=None, copy=True):
Returns
-------
inst : instance of Raw | Epochs | Evoked
- Data with EEG channels re-referenced.
+ Data with EEG channels re-referenced. For ``ref_channels=None``,
+ an average projector will be added instead of directly subtarcting
+ data.
ref_data : array
- Array of reference data subtracted from EEG channels.
+ Array of reference data subtracted from EEG channels. This will
+ be None for an average reference.
Notes
-----
@@ -253,7 +287,6 @@ def set_eeg_reference(inst, ref_channels=None, copy=True):
'has been left untouched.')
return inst, None
else:
- inst.info['custom_ref_applied'] = False
inst.add_proj(make_eeg_average_ref_proj(inst.info, activate=False))
return inst, None
else:
@@ -375,7 +408,7 @@ def set_bipolar_reference(inst, anode, cathode, ch_name=None, ch_info=None,
inst.info['chs'][an_idx] = info
inst.info['chs'][an_idx]['ch_name'] = name
logger.info('Bipolar channel added as "%s".' % name)
- inst.info._update_redundant()
+ inst.info._update_redundant()
# Drop cathode channels
inst.drop_channels(cathode)
diff --git a/mne/io/tag.py b/mne/io/tag.py
index 152bb28..689e90d 100644
--- a/mne/io/tag.py
+++ b/mne/io/tag.py
@@ -4,13 +4,13 @@
# License: BSD (3-clause)
import gzip
+from functools import partial
import os
import struct
import numpy as np
from .constants import FIFF
-from ..fixes import partial
from ..externals.six import text_type
from ..externals.jdcal import jd2jcal
diff --git a/mne/io/tests/data/test-ave-2.log b/mne/io/tests/data/test-ave-2.log
index 3d3f21c..fde2e35 100644
--- a/mne/io/tests/data/test-ave-2.log
+++ b/mne/io/tests/data/test-ave-2.log
@@ -11,7 +11,8 @@ Reading mne/fiff/tests/data/test-ave.fif ...
Created an SSP operator (subspace dimension = 4)
4 projection items activated
SSP projectors applied...
-No baseline correction applied...
+No baseline correction applied
+No baseline correction applied
Reading mne/fiff/tests/data/test-ave.fif ...
Read a total of 4 projection items:
PCA-v1 (1 x 102) idle
@@ -25,4 +26,5 @@ Reading mne/fiff/tests/data/test-ave.fif ...
Created an SSP operator (subspace dimension = 4)
4 projection items activated
SSP projectors applied...
-No baseline correction applied...
+No baseline correction applied
+No baseline correction applied
diff --git a/mne/io/tests/data/test-ave.log b/mne/io/tests/data/test-ave.log
index d663417..3b124b5 100644
--- a/mne/io/tests/data/test-ave.log
+++ b/mne/io/tests/data/test-ave.log
@@ -11,4 +11,5 @@ Reading mne/fiff/tests/data/test-ave.fif ...
Created an SSP operator (subspace dimension = 4)
4 projection items activated
SSP projectors applied...
-No baseline correction applied...
+No baseline correction applied
+No baseline correction applied
diff --git a/mne/io/tests/test_compensator.py b/mne/io/tests/test_compensator.py
index bc15630..80d8dbf 100644
--- a/mne/io/tests/test_compensator.py
+++ b/mne/io/tests/test_compensator.py
@@ -5,54 +5,76 @@
import os.path as op
from nose.tools import assert_true
import numpy as np
-from numpy.testing import assert_allclose
+from numpy.testing import assert_allclose, assert_equal
from mne import Epochs, read_evokeds, pick_types
from mne.io.compensator import make_compensator, get_current_comp
-from mne.io import Raw
-from mne.utils import _TempDir, requires_mne, run_subprocess
+from mne.io import read_raw_fif
+from mne.utils import _TempDir, requires_mne, run_subprocess, run_tests_if_main
base_dir = op.join(op.dirname(__file__), 'data')
ctf_comp_fname = op.join(base_dir, 'test_ctf_comp_raw.fif')
def test_compensation():
- """Test compensation
- """
+ """Test compensation."""
tempdir = _TempDir()
- raw = Raw(ctf_comp_fname, compensation=None)
+ raw = read_raw_fif(ctf_comp_fname, compensation=None, add_eeg_ref=False)
+ assert_equal(get_current_comp(raw.info), 3)
comp1 = make_compensator(raw.info, 3, 1, exclude_comp_chs=False)
assert_true(comp1.shape == (340, 340))
comp2 = make_compensator(raw.info, 3, 1, exclude_comp_chs=True)
assert_true(comp2.shape == (311, 340))
+ # round-trip
+ desired = np.eye(340)
+ for from_ in range(3):
+ for to in range(3):
+ if from_ == to:
+ continue
+ comp1 = make_compensator(raw.info, from_, to)
+ comp2 = make_compensator(raw.info, to, from_)
+ # To get 1e-12 here (instead of 1e-6) we must use the linalg.inv
+ # method mentioned in compensator.py
+ assert_allclose(np.dot(comp1, comp2), desired, atol=1e-12)
+ assert_allclose(np.dot(comp2, comp1), desired, atol=1e-12)
+
# make sure that changing the comp doesn't modify the original data
- raw2 = Raw(ctf_comp_fname, compensation=2)
- assert_true(get_current_comp(raw2.info) == 2)
+ raw2 = read_raw_fif(ctf_comp_fname, add_eeg_ref=False)
+ raw2.apply_gradient_compensation(2)
+ assert_equal(get_current_comp(raw2.info), 2)
fname = op.join(tempdir, 'ctf-raw.fif')
raw2.save(fname)
- raw2 = Raw(fname, compensation=None)
+ raw2 = read_raw_fif(fname, add_eeg_ref=False)
+ assert_equal(raw2.compensation_grade, 2)
+ raw2.apply_gradient_compensation(3)
+ assert_equal(raw2.compensation_grade, 3)
data, _ = raw[:, :]
data2, _ = raw2[:, :]
- assert_allclose(data, data2, rtol=1e-9, atol=1e-20)
+ # channels have norm ~1e-12
+ assert_allclose(data, data2, rtol=1e-9, atol=1e-18)
for ch1, ch2 in zip(raw.info['chs'], raw2.info['chs']):
assert_true(ch1['coil_type'] == ch2['coil_type'])
@requires_mne
def test_compensation_mne():
- """Test comensation by comparing with MNE
- """
+ """Test comensation by comparing with MNE."""
tempdir = _TempDir()
def make_evoked(fname, comp):
- raw = Raw(fname, compensation=comp)
+ """Make evoked data."""
+ raw = read_raw_fif(fname, add_eeg_ref=False)
+ if comp is not None:
+ raw.apply_gradient_compensation(comp)
picks = pick_types(raw.info, meg=True, ref_meg=True)
events = np.array([[0, 0, 1]], dtype=np.int)
- evoked = Epochs(raw, events, 1, 0, 20e-3, picks=picks).average()
+ evoked = Epochs(raw, events, 1, 0, 20e-3, picks=picks,
+ add_eeg_ref=False).average()
return evoked
def compensate_mne(fname, comp):
+ """Compensate using MNE-C."""
tmp_fname = '%s-%d-ave.fif' % (fname[:-4], comp)
cmd = ['mne_compensate_data', '--in', fname,
'--out', tmp_fname, '--grad', str(comp)]
@@ -70,3 +92,9 @@ def test_compensation_mne():
picks_c = pick_types(evoked_c.info, meg=True, ref_meg=True)
assert_allclose(evoked_py.data[picks_py], evoked_c.data[picks_c],
rtol=1e-3, atol=1e-17)
+ chs_py = [evoked_py.info['chs'][ii] for ii in picks_py]
+ chs_c = [evoked_c.info['chs'][ii] for ii in picks_c]
+ for ch_py, ch_c in zip(chs_py, chs_c):
+ assert_equal(ch_py['coil_type'], ch_c['coil_type'])
+
+run_tests_if_main()
diff --git a/mne/io/tests/test_meas_info.py b/mne/io/tests/test_meas_info.py
index b794f0b..1656319 100644
--- a/mne/io/tests/test_meas_info.py
+++ b/mne/io/tests/test_meas_info.py
@@ -8,7 +8,8 @@ from numpy.testing import assert_array_equal, assert_allclose
from mne import Epochs, read_events
from mne.io import (read_fiducials, write_fiducials, _coil_trans_to_loc,
- _loc_to_coil_trans, Raw, read_info, write_info)
+ _loc_to_coil_trans, read_raw_fif, read_info, write_info,
+ anonymize_info)
from mne.io.constants import FIFF
from mne.io.meas_info import (Info, create_info, _write_dig_points,
_read_dig_points, _make_dig_points, _merge_info,
@@ -21,14 +22,13 @@ fiducials_fname = op.join(base_dir, 'fsaverage-fiducials.fif')
raw_fname = op.join(base_dir, 'test_raw.fif')
chpi_fname = op.join(base_dir, 'test_chpi_raw_sss.fif')
event_name = op.join(base_dir, 'test-eve.fif')
-evoked_nf_name = op.join(base_dir, 'test-nf-ave.fif')
kit_data_dir = op.join(op.dirname(__file__), '..', 'kit', 'tests', 'data')
hsp_fname = op.join(kit_data_dir, 'test_hsp.txt')
elp_fname = op.join(kit_data_dir, 'test_elp.txt')
def test_coil_trans():
- """Test loc<->coil_trans functions"""
+ """Test loc<->coil_trans functions."""
rng = np.random.RandomState(0)
x = rng.randn(4, 4)
x[3] = [0, 0, 0, 1]
@@ -38,8 +38,7 @@ def test_coil_trans():
def test_make_info():
- """Test some create_info properties
- """
+ """Test some create_info properties."""
n_ch = 1
info = create_info(n_ch, 1000., 'eeg')
assert_equal(sorted(info.keys()), sorted(RAW_INFO_FIELDS))
@@ -88,7 +87,7 @@ def test_make_info():
def test_fiducials_io():
- """Test fiducials i/o"""
+ """Test fiducials i/o."""
tempdir = _TempDir()
pts, coord_frame = read_fiducials(fiducials_fname)
assert_equal(pts[0]['coord_frame'], FIFF.FIFFV_COORD_MRI)
@@ -110,13 +109,13 @@ def test_fiducials_io():
def test_info():
- """Test info object"""
- raw = Raw(raw_fname)
+ """Test info object."""
+ raw = read_raw_fif(raw_fname, add_eeg_ref=False)
event_id, tmin, tmax = 1, -0.2, 0.5
events = read_events(event_name)
event_id = int(events[0, 2])
epochs = Epochs(raw, events[:1], event_id, tmin, tmax, picks=None,
- baseline=(None, 0))
+ baseline=(None, 0), add_eeg_ref=False)
evoked = epochs.average()
@@ -164,8 +163,7 @@ def test_info():
def test_read_write_info():
- """Test IO of info
- """
+ """Test IO of info."""
tempdir = _TempDir()
info = read_info(raw_fname)
temp_file = op.join(tempdir, 'info.fif')
@@ -191,7 +189,7 @@ def test_read_write_info():
def test_io_dig_points():
- """Test Writing for dig files"""
+ """Test Writing for dig files."""
tempdir = _TempDir()
points = _read_dig_points(hsp_fname)
@@ -200,7 +198,7 @@ def test_io_dig_points():
assert_raises(ValueError, _write_dig_points, dest, points[:, :2])
assert_raises(ValueError, _write_dig_points, dest_bad, points)
_write_dig_points(dest, points)
- points1 = _read_dig_points(dest)
+ points1 = _read_dig_points(dest, unit='m')
err = "Dig points diverged after writing and reading."
assert_array_equal(points, points1, err)
@@ -210,14 +208,14 @@ def test_io_dig_points():
def test_make_dig_points():
- """Test application of Polhemus HSP to info"""
+ """Test application of Polhemus HSP to info."""
dig_points = _read_dig_points(hsp_fname)
info = create_info(ch_names=['Test Ch'], sfreq=1000., ch_types=None)
assert_false(info['dig'])
info['dig'] = _make_dig_points(dig_points=dig_points)
assert_true(info['dig'])
- assert_array_equal(info['dig'][0]['r'], [-106.93, 99.80, 68.81])
+ assert_allclose(info['dig'][0]['r'], [-.10693, .09980, .06881])
dig_points = _read_dig_points(elp_fname)
nasion, lpa, rpa = dig_points[:3]
@@ -228,7 +226,7 @@ def test_make_dig_points():
assert_true(info['dig'])
idx = [d['ident'] for d in info['dig']].index(FIFF.FIFFV_POINT_NASION)
assert_array_equal(info['dig'][idx]['r'],
- np.array([1.3930, 13.1613, -4.6967]))
+ np.array([.0013930, .0131613, -.0046967]))
assert_raises(ValueError, _make_dig_points, nasion[:2])
assert_raises(ValueError, _make_dig_points, None, lpa[:2])
assert_raises(ValueError, _make_dig_points, None, None, rpa[:2])
@@ -239,7 +237,7 @@ def test_make_dig_points():
def test_redundant():
- """Test some of the redundant properties of info"""
+ """Test some of the redundant properties of info."""
# Indexing
info = create_info(ch_names=['a', 'b', 'c'], sfreq=1000., ch_types=None)
assert_equal(info['ch_names'][0], 'a')
@@ -259,7 +257,7 @@ def test_redundant():
def test_merge_info():
- """Test merging of multiple Info objects"""
+ """Test merging of multiple Info objects."""
info_a = create_info(ch_names=['a', 'b', 'c'], sfreq=1000., ch_types=None)
info_b = create_info(ch_names=['d', 'e', 'f'], sfreq=1000., ch_types=None)
info_merged = _merge_info([info_a, info_b])
@@ -279,10 +277,17 @@ def test_merge_info():
# Check that you must supply Info
assert_raises(ValueError, _force_update_info, info_a,
dict([('sfreq', 1000.)]))
+ # KIT System-ID
+ info_a['kit_system_id'] = 50
+ assert_equal(_merge_info((info_a, info_b))['kit_system_id'], 50)
+ info_b['kit_system_id'] = 50
+ assert_equal(_merge_info((info_a, info_b))['kit_system_id'], 50)
+ info_b['kit_system_id'] = 60
+ assert_raises(ValueError, _merge_info, (info_a, info_b))
def test_check_consistency():
- """Test consistency check of Info objects"""
+ """Test consistency check of Info objects."""
info = create_info(ch_names=['a', 'b', 'c'], sfreq=1000.)
# This should pass
@@ -326,4 +331,44 @@ def test_check_consistency():
assert_raises(RuntimeError, info2._check_consistency)
+def test_anonymize():
+ """Checks that sensitive information can be anonymized."""
+ assert_raises(ValueError, anonymize_info, 'foo')
+
+ # Fake some subject data
+ raw = read_raw_fif(raw_fname, add_eeg_ref=False)
+ raw.info['subject_info'] = dict(id=1, his_id='foobar', last_name='bar',
+ first_name='bar', birthday=(1987, 4, 8),
+ sex=0, hand=1)
+
+ orig_file_id = raw.info['file_id']['secs']
+ orig_meas_id = raw.info['meas_id']['secs']
+ # Test instance method
+ events = read_events(event_name)
+ epochs = Epochs(raw, events[:1], 2, 0., 0.1, add_eeg_ref=False)
+ for inst in [raw, epochs]:
+ assert_true('subject_info' in inst.info.keys())
+ assert_true(inst.info['subject_info'] is not None)
+ assert_true(inst.info['file_id']['secs'] != 0)
+ assert_true(inst.info['meas_id']['secs'] != 0)
+ assert_true(np.any(inst.info['meas_date'] != [0, 0]))
+ inst.anonymize()
+ assert_true('subject_info' not in inst.info.keys())
+ assert_equal(inst.info['file_id']['secs'], 0)
+ assert_equal(inst.info['meas_id']['secs'], 0)
+ assert_equal(inst.info['meas_date'], [0, 0])
+
+ # When we write out with raw.save, these get overwritten with the
+ # new save time
+ tempdir = _TempDir()
+ out_fname = op.join(tempdir, 'test_subj_info_raw.fif')
+ raw.save(out_fname, overwrite=True)
+ raw = read_raw_fif(out_fname, add_eeg_ref=False)
+ assert_true(raw.info.get('subject_info') is None)
+ assert_array_equal(raw.info['meas_date'], [0, 0])
+ # XXX mne.io.write.write_id necessarily writes secs
+ assert_true(raw.info['file_id']['secs'] != orig_file_id)
+ assert_true(raw.info['meas_id']['secs'] != orig_meas_id)
+
+
run_tests_if_main()
diff --git a/mne/io/tests/test_pick.py b/mne/io/tests/test_pick.py
index a1b1d40..89bac51 100644
--- a/mne/io/tests/test_pick.py
+++ b/mne/io/tests/test_pick.py
@@ -9,7 +9,8 @@ import numpy as np
from mne import (pick_channels_regexp, pick_types, Epochs,
read_forward_solution, rename_channels,
pick_info, pick_channels, __file__, create_info)
-from mne.io import Raw, RawArray, read_raw_bti, read_raw_kit, read_info
+from mne.io import (read_raw_fif, RawArray, read_raw_bti, read_raw_kit,
+ read_info)
from mne.io.pick import (channel_indices_by_type, channel_type,
pick_types_forward, _picks_by_type)
from mne.io.constants import FIFF
@@ -45,7 +46,8 @@ def test_pick_refs():
infos.append(raw_bti.info)
# CTF
fname_ctf_raw = op.join(io_dir, 'tests', 'data', 'test_ctf_comp_raw.fif')
- raw_ctf = Raw(fname_ctf_raw, compensation=2)
+ raw_ctf = read_raw_fif(fname_ctf_raw, add_eeg_ref=False)
+ raw_ctf.apply_gradient_compensation(2)
infos.append(raw_ctf.info)
for info in infos:
info['bads'] = []
@@ -107,13 +109,14 @@ def test_pick_seeg_ecog():
assert_equal(channel_type(info, i), types[i])
raw = RawArray(np.zeros((len(names), 10)), info)
events = np.array([[1, 0, 0], [2, 0, 0]])
- epochs = Epochs(raw, events, {'event': 0}, -1e-5, 1e-5)
+ epochs = Epochs(raw, events, {'event': 0}, -1e-5, 1e-5, add_eeg_ref=False)
evoked = epochs.average(pick_types(epochs.info, meg=True, seeg=True))
e_seeg = evoked.copy().pick_types(meg=False, seeg=True)
for l, r in zip(e_seeg.ch_names, [names[4], names[5], names[7]]):
assert_equal(l, r)
# Deal with constant debacle
- raw = Raw(op.join(io_dir, 'tests', 'data', 'test_chpi_raw_sss.fif'))
+ raw = read_raw_fif(op.join(io_dir, 'tests', 'data',
+ 'test_chpi_raw_sss.fif'), add_eeg_ref=False)
assert_equal(len(pick_types(raw.info, meg=False, seeg=True, ecog=True)), 0)
@@ -140,6 +143,18 @@ def test_pick_bio():
assert_array_equal(idx['bio'], [4, 5, 6])
+def test_pick_fnirs():
+ """Test picking fNIRS channels."""
+ names = 'A1 A2 Fz O hbo1 hbo2 hbr1'.split()
+ types = 'mag mag eeg eeg hbo hbo hbr'.split()
+ info = create_info(names, 1024., types)
+ idx = channel_indices_by_type(info)
+ assert_array_equal(idx['mag'], [0, 1])
+ assert_array_equal(idx['eeg'], [2, 3])
+ assert_array_equal(idx['hbo'], [4, 5])
+ assert_array_equal(idx['hbr'], [6])
+
+
def _check_fwd_n_chan_consistent(fwd, n_expected):
n_ok = len(fwd['info']['ch_names'])
n_sol = fwd['sol']['data'].shape[0]
@@ -149,8 +164,7 @@ def _check_fwd_n_chan_consistent(fwd, n_expected):
@testing.requires_testing_data
def test_pick_forward_seeg_ecog():
- """Test picking forward with SEEG and ECoG
- """
+ """Test picking forward with SEEG and ECoG."""
fwd = read_forward_solution(fname_meeg)
counts = channel_indices_by_type(fwd['info'])
for key in counts.keys():
@@ -197,7 +211,6 @@ def test_pick_forward_seeg_ecog():
def test_picks_by_channels():
"""Test creating pick_lists"""
-
rng = np.random.RandomState(909)
test_data = rng.random_sample((4, 2000))
@@ -245,7 +258,7 @@ def test_clean_info_bads():
raw_file = op.join(op.dirname(__file__), 'io', 'tests', 'data',
'test_raw.fif')
- raw = Raw(raw_file)
+ raw = read_raw_fif(raw_file, add_eeg_ref=False)
# select eeg channels
picks_eeg = pick_types(raw.info, meg=False, eeg=True)
diff --git a/mne/io/tests/test_proc_history.py b/mne/io/tests/test_proc_history.py
index 876a097..aa4a58c 100644
--- a/mne/io/tests/test_proc_history.py
+++ b/mne/io/tests/test_proc_history.py
@@ -4,7 +4,7 @@
import numpy as np
import os.path as op
-from mne import io
+from mne.io import read_info
from mne.io.constants import FIFF
from mne.io.proc_history import _get_sss_rank
from nose.tools import assert_true, assert_equal
@@ -14,9 +14,9 @@ raw_fname = op.join(base_dir, 'test_chpi_raw_sss.fif')
def test_maxfilter_io():
- """test maxfilter io"""
- raw = io.read_raw_fif(raw_fname)
- mf = raw.info['proc_history'][1]['max_info']
+ """Test maxfilter io."""
+ info = read_info(raw_fname)
+ mf = info['proc_history'][1]['max_info']
assert_true(mf['sss_info']['frame'], FIFF.FIFFV_COORD_HEAD)
# based on manual 2.0, rev. 5.0 page 23
@@ -24,7 +24,7 @@ def test_maxfilter_io():
assert_true(mf['sss_info']['out_order'] <= 5)
assert_true(mf['sss_info']['nchan'] > len(mf['sss_info']['components']))
- assert_equal(raw.ch_names[:mf['sss_info']['nchan']],
+ assert_equal(info['ch_names'][:mf['sss_info']['nchan']],
mf['sss_ctc']['proj_items_chs'])
assert_equal(mf['sss_ctc']['decoupler'].shape,
(mf['sss_info']['nchan'], mf['sss_info']['nchan']))
@@ -39,9 +39,8 @@ def test_maxfilter_io():
def test_maxfilter_get_rank():
- """test maxfilter rank lookup"""
- raw = io.read_raw_fif(raw_fname)
- mf = raw.info['proc_history'][0]['max_info']
+ """Test maxfilter rank lookup."""
+ mf = read_info(raw_fname)['proc_history'][0]['max_info']
rank1 = mf['sss_info']['nfree']
rank2 = _get_sss_rank(mf)
assert_equal(rank1, rank2)
diff --git a/mne/io/tests/test_raw.py b/mne/io/tests/test_raw.py
index c5969b0..b46c148 100644
--- a/mne/io/tests/test_raw.py
+++ b/mne/io/tests/test_raw.py
@@ -8,7 +8,7 @@ from nose.tools import assert_equal, assert_true
from mne import concatenate_raws
from mne.datasets import testing
-from mne.io import Raw
+from mne.io import read_raw_fif
from mne.utils import _TempDir
@@ -64,7 +64,7 @@ def _test_raw_reader(reader, test_preloading=True, **kwargs):
# Test saving and reading
out_fname = op.join(tempdir, 'test_raw.fif')
raw.save(out_fname, tmax=raw.times[-1], overwrite=True, buffer_size_sec=1)
- raw3 = Raw(out_fname)
+ raw3 = read_raw_fif(out_fname, add_eeg_ref=False)
assert_equal(set(raw.info.keys()), set(raw3.info.keys()))
assert_allclose(raw3[0:20][0], full_data[0:20], rtol=1e-6,
atol=1e-20) # atol is very small but > 0
@@ -75,6 +75,8 @@ def _test_raw_reader(reader, test_preloading=True, **kwargs):
assert_true(not math.isnan(raw.info['highpass']))
assert_true(not math.isnan(raw.info['lowpass']))
+ assert_equal(raw3.info['kit_system_id'], raw.info['kit_system_id'])
+
# Make sure concatenation works
first_samp = raw.first_samp
last_samp = raw.last_samp
@@ -86,7 +88,7 @@ def _test_raw_reader(reader, test_preloading=True, **kwargs):
def _test_concat(reader, *args):
- """Test concatenation of raw classes that allow not preloading"""
+ """Test concatenation of raw classes that allow not preloading."""
data = None
for preload in (True, False):
@@ -119,10 +121,10 @@ def _test_concat(reader, *args):
@testing.requires_testing_data
def test_time_index():
- """Test indexing of raw times"""
+ """Test indexing of raw times."""
raw_fname = op.join(op.dirname(__file__), '..', '..', 'io', 'tests',
'data', 'test_raw.fif')
- raw = Raw(raw_fname)
+ raw = read_raw_fif(raw_fname, add_eeg_ref=False)
# Test original (non-rounding) indexing behavior
orig_inds = raw.time_as_index(raw.times)
diff --git a/mne/io/tests/test_reference.py b/mne/io/tests/test_reference.py
index b01fb88..1cddc25 100644
--- a/mne/io/tests/test_reference.py
+++ b/mne/io/tests/test_reference.py
@@ -11,15 +11,16 @@ import numpy as np
from nose.tools import assert_true, assert_equal, assert_raises
from numpy.testing import assert_array_equal, assert_allclose
-from mne import pick_channels, pick_types, Evoked, Epochs, read_events
+from mne import (pick_channels, pick_types, Evoked, Epochs, read_events,
+ set_eeg_reference, set_bipolar_reference,
+ add_reference_channels)
from mne.epochs import _BaseEpochs
+from mne.io import read_raw_fif
from mne.io.constants import FIFF
-from mne.io import (set_eeg_reference, set_bipolar_reference,
- add_reference_channels)
from mne.io.proj import _has_eeg_average_ref_proj
from mne.io.reference import _apply_reference
from mne.datasets import testing
-from mne.io import Raw
+from mne.utils import run_tests_if_main
warnings.simplefilter('always') # enable b/c these tests throw warnings
@@ -30,8 +31,7 @@ ave_fname = op.join(data_dir, 'sample_audvis_trunc-ave.fif')
def _test_reference(raw, reref, ref_data, ref_from):
- """Helper function to test whether a reference has been correctly
- applied."""
+ """Test whether a reference has been correctly applied."""
# Separate EEG channels from other channel types
picks_eeg = pick_types(raw.info, meg=False, eeg=True, exclude='bads')
picks_other = pick_types(raw.info, meg=True, eeg=False, eog=True,
@@ -72,8 +72,8 @@ def _test_reference(raw, reref, ref_data, ref_from):
@testing.requires_testing_data
def test_apply_reference():
- """Test base function for rereferencing"""
- raw = Raw(fif_fname, preload=True)
+ """Test base function for rereferencing."""
+ raw = read_raw_fif(fif_fname, preload=True, add_eeg_ref=False)
# Rereference raw data by creating a copy of original data
reref, ref_data = _apply_reference(
@@ -93,11 +93,11 @@ def test_apply_reference():
assert_true(raw is reref)
# Test re-referencing Epochs object
- raw = Raw(fif_fname, preload=False, add_eeg_ref=False)
+ raw = read_raw_fif(fif_fname, preload=False, add_eeg_ref=False)
events = read_events(eve_fname)
picks_eeg = pick_types(raw.info, meg=False, eeg=True)
epochs = Epochs(raw, events=events, event_id=1, tmin=-0.2, tmax=0.5,
- picks=picks_eeg, preload=True)
+ picks=picks_eeg, preload=True, add_eeg_ref=False)
reref, ref_data = _apply_reference(
epochs.copy(), ref_from=['EEG 001', 'EEG 002'])
assert_true(reref.info['custom_ref_applied'])
@@ -111,14 +111,14 @@ def test_apply_reference():
_test_reference(evoked, reref, ref_data, ['EEG 001', 'EEG 002'])
# Test invalid input
- raw_np = Raw(fif_fname, preload=False)
+ raw_np = read_raw_fif(fif_fname, preload=False, add_eeg_ref=False)
assert_raises(RuntimeError, _apply_reference, raw_np, ['EEG 001'])
@testing.requires_testing_data
def test_set_eeg_reference():
- """Test rereference eeg data"""
- raw = Raw(fif_fname, preload=True)
+ """Test rereference eeg data."""
+ raw = read_raw_fif(fif_fname, preload=True, add_eeg_ref=False)
raw.info['projs'] = []
# Test setting an average reference
@@ -145,8 +145,8 @@ def test_set_eeg_reference():
@testing.requires_testing_data
def test_set_bipolar_reference():
- """Test bipolar referencing"""
- raw = Raw(fif_fname, preload=True)
+ """Test bipolar referencing."""
+ raw = read_raw_fif(fif_fname, preload=True, add_eeg_ref=False)
reref = set_bipolar_reference(raw, 'EEG 001', 'EEG 002', 'bipolar',
{'kind': FIFF.FIFFV_EOG_CH,
'extra': 'some extra value'})
@@ -181,6 +181,21 @@ def test_set_bipolar_reference():
reref = set_bipolar_reference(raw, 'EEG 001', 'EEG 002')
assert_true('EEG 001-EEG 002' in reref.ch_names)
+ # Set multiple references at once
+ reref = set_bipolar_reference(
+ raw,
+ ['EEG 001', 'EEG 003'],
+ ['EEG 002', 'EEG 004'],
+ ['bipolar1', 'bipolar2'],
+ [{'kind': FIFF.FIFFV_EOG_CH, 'extra': 'some extra value'},
+ {'kind': FIFF.FIFFV_EOG_CH, 'extra': 'some extra value'}],
+ )
+ a = raw.copy().pick_channels(['EEG 001', 'EEG 002', 'EEG 003', 'EEG 004'])
+ a = np.array([a._data[0, :] - a._data[1, :],
+ a._data[2, :] - a._data[3, :]])
+ b = reref.copy().pick_channels(['bipolar1', 'bipolar2'])._data
+ assert_allclose(a, b)
+
# Test creating a bipolar reference that doesn't involve EEG channels:
# it should not set the custom_ref_applied flag
reref = set_bipolar_reference(raw, 'MEG 0111', 'MEG 0112',
@@ -203,6 +218,7 @@ def test_set_bipolar_reference():
def _check_channel_names(inst, ref_names):
+ """Check channel names."""
if isinstance(ref_names, str):
ref_names = [ref_names]
@@ -217,7 +233,8 @@ def _check_channel_names(inst, ref_names):
@testing.requires_testing_data
def test_add_reference():
- raw = Raw(fif_fname, preload=True)
+ """Test adding a reference."""
+ raw = read_raw_fif(fif_fname, preload=True, add_eeg_ref=False)
picks_eeg = pick_types(raw.info, meg=False, eeg=True)
# check if channel already exists
assert_raises(ValueError, add_reference_channels,
@@ -234,12 +251,31 @@ def test_add_reference():
assert_equal(raw.info['nchan'], orig_nchan + 1)
_check_channel_names(raw, 'Ref')
+ # for Neuromag fif's, the reference electrode location is placed in
+ # elements [3:6] of each "data" electrode location
+ assert_allclose(raw.info['chs'][-1]['loc'][:3],
+ raw.info['chs'][picks_eeg[0]]['loc'][3:6], 1e-6)
+
ref_idx = raw.ch_names.index('Ref')
ref_data, _ = raw[ref_idx]
assert_array_equal(ref_data, 0)
- raw = Raw(fif_fname, preload=True)
+ # add reference channel to Raw when no digitization points exist
+ raw = read_raw_fif(fif_fname, add_eeg_ref=False).crop(0, 1).load_data()
picks_eeg = pick_types(raw.info, meg=False, eeg=True)
+ del raw.info['dig']
+
+ raw_ref = add_reference_channels(raw, 'Ref', copy=True)
+
+ assert_equal(raw_ref._data.shape[0], raw._data.shape[0] + 1)
+ assert_array_equal(raw._data[picks_eeg, :], raw_ref._data[picks_eeg, :])
+ _check_channel_names(raw_ref, 'Ref')
+
+ orig_nchan = raw.info['nchan']
+ raw = add_reference_channels(raw, 'Ref', copy=False)
+ assert_array_equal(raw._data, raw_ref._data)
+ assert_equal(raw.info['nchan'], orig_nchan + 1)
+ _check_channel_names(raw, 'Ref')
# Test adding an existing channel as reference channel
assert_raises(ValueError, add_reference_channels, raw,
@@ -260,12 +296,22 @@ def test_add_reference():
assert_array_equal(ref_data, 0)
# add reference channel to epochs
- raw = Raw(fif_fname, preload=True)
+ raw = read_raw_fif(fif_fname, preload=True, add_eeg_ref=False)
events = read_events(eve_fname)
picks_eeg = pick_types(raw.info, meg=False, eeg=True)
epochs = Epochs(raw, events=events, event_id=1, tmin=-0.2, tmax=0.5,
- picks=picks_eeg, preload=True)
+ picks=picks_eeg, preload=True, add_eeg_ref=False)
+ # default: proj=True, after which adding a Ref channel is prohibited
+ assert_raises(RuntimeError, add_reference_channels, epochs, 'Ref')
+
+ # create epochs in delayed mode, allowing removal of CAR when re-reffing
+ epochs = Epochs(raw, events=events, event_id=1, tmin=-0.2, tmax=0.5,
+ picks=picks_eeg, preload=True, proj='delayed',
+ add_eeg_ref=False)
epochs_ref = add_reference_channels(epochs, 'Ref', copy=True)
+ # CAR after custom reference is an Error
+ assert_raises(RuntimeError, epochs_ref.set_eeg_reference)
+
assert_equal(epochs_ref._data.shape[1], epochs._data.shape[1] + 1)
_check_channel_names(epochs_ref, 'Ref')
ref_idx = epochs_ref.ch_names.index('Ref')
@@ -276,12 +322,15 @@ def test_add_reference():
epochs_ref.get_data()[:, picks_eeg, :])
# add two reference channels to epochs
- raw = Raw(fif_fname, preload=True)
+ raw = read_raw_fif(fif_fname, preload=True, add_eeg_ref=False)
events = read_events(eve_fname)
picks_eeg = pick_types(raw.info, meg=False, eeg=True)
+ # create epochs in delayed mode, allowing removal of CAR when re-reffing
epochs = Epochs(raw, events=events, event_id=1, tmin=-0.2, tmax=0.5,
- picks=picks_eeg, preload=True)
- epochs_ref = add_reference_channels(epochs, ['M1', 'M2'], copy=True)
+ picks=picks_eeg, preload=True, proj='delayed',
+ add_eeg_ref=False)
+ with warnings.catch_warnings(record=True): # multiple set zero
+ epochs_ref = add_reference_channels(epochs, ['M1', 'M2'], copy=True)
assert_equal(epochs_ref._data.shape[1], epochs._data.shape[1] + 2)
_check_channel_names(epochs_ref, ['M1', 'M2'])
ref_idx = epochs_ref.ch_names.index('M1')
@@ -295,11 +344,13 @@ def test_add_reference():
epochs_ref.get_data()[:, picks_eeg, :])
# add reference channel to evoked
- raw = Raw(fif_fname, preload=True)
+ raw = read_raw_fif(fif_fname, preload=True, add_eeg_ref=False)
events = read_events(eve_fname)
picks_eeg = pick_types(raw.info, meg=False, eeg=True)
+ # create epochs in delayed mode, allowing removal of CAR when re-reffing
epochs = Epochs(raw, events=events, event_id=1, tmin=-0.2, tmax=0.5,
- picks=picks_eeg, preload=True)
+ picks=picks_eeg, preload=True, proj='delayed',
+ add_eeg_ref=False)
evoked = epochs.average()
evoked_ref = add_reference_channels(evoked, 'Ref', copy=True)
assert_equal(evoked_ref.data.shape[0], evoked.data.shape[0] + 1)
@@ -312,13 +363,16 @@ def test_add_reference():
evoked_ref.data[picks_eeg, :])
# add two reference channels to evoked
- raw = Raw(fif_fname, preload=True)
+ raw = read_raw_fif(fif_fname, preload=True, add_eeg_ref=False)
events = read_events(eve_fname)
picks_eeg = pick_types(raw.info, meg=False, eeg=True)
+ # create epochs in delayed mode, allowing removal of CAR when re-reffing
epochs = Epochs(raw, events=events, event_id=1, tmin=-0.2, tmax=0.5,
- picks=picks_eeg, preload=True)
+ picks=picks_eeg, preload=True, proj='delayed',
+ add_eeg_ref=False)
evoked = epochs.average()
- evoked_ref = add_reference_channels(evoked, ['M1', 'M2'], copy=True)
+ with warnings.catch_warnings(record=True): # multiple set zero
+ evoked_ref = add_reference_channels(evoked, ['M1', 'M2'], copy=True)
assert_equal(evoked_ref.data.shape[0], evoked.data.shape[0] + 2)
_check_channel_names(evoked_ref, ['M1', 'M2'])
ref_idx = evoked_ref.ch_names.index('M1')
@@ -330,6 +384,8 @@ def test_add_reference():
evoked_ref.data[picks_eeg, :])
# Test invalid inputs
- raw_np = Raw(fif_fname, preload=False)
+ raw_np = read_raw_fif(fif_fname, preload=False, add_eeg_ref=False)
assert_raises(RuntimeError, add_reference_channels, raw_np, ['Ref'])
assert_raises(ValueError, add_reference_channels, raw, 1)
+
+run_tests_if_main()
diff --git a/mne/io/write.py b/mne/io/write.py
index 56e49f9..997db75 100644
--- a/mne/io/write.py
+++ b/mne/io/write.py
@@ -3,18 +3,19 @@
#
# License: BSD (3-clause)
-from ..externals.six import string_types, b
-import time
-import numpy as np
-from scipy import linalg
+from gzip import GzipFile
import os.path as op
import re
+import time
import uuid
+import numpy as np
+from scipy import linalg
+
from .constants import FIFF
from ..utils import logger
from ..externals.jdcal import jcal2jd
-from ..fixes import gzip_open
+from ..externals.six import string_types, b
def _write(fid, data, kind, data_size, FIFFT_TYPE, dtype):
@@ -247,7 +248,7 @@ def start_file(fname, id_=None):
logger.debug('Writing using gzip')
# defaults to compression level 9, which is barely smaller but much
# slower. 2 offers a good compromise.
- fid = gzip_open(fname, "wb", compresslevel=2)
+ fid = GzipFile(fname, "wb", compresslevel=2)
else:
logger.debug('Writing using normal I/O')
fid = open(fname, "wb")
diff --git a/mne/label.py b/mne/label.py
index a56319a..1b51761 100644
--- a/mne/label.py
+++ b/mne/label.py
@@ -14,10 +14,9 @@ import re
import numpy as np
from scipy import linalg, sparse
-from .fixes import digitize, in1d
from .utils import (get_subjects_dir, _check_subject, logger, verbose, warn,
_check_copy_dep)
-from .source_estimate import (morph_data, SourceEstimate,
+from .source_estimate import (morph_data, SourceEstimate, _center_of_mass,
spatial_src_connectivity)
from .source_space import add_source_space_distances
from .surface import read_surface, fast_cross_3d, mesh_edges, mesh_dist
@@ -361,7 +360,7 @@ class Label(object):
raise TypeError("Need: Label or BiHemiLabel. Got: %r" % other)
if self.hemi == other.hemi:
- keep = in1d(self.vertices, other.vertices, True, invert=True)
+ keep = np.in1d(self.vertices, other.vertices, True, invert=True)
else:
keep = np.arange(len(self.vertices))
@@ -422,7 +421,7 @@ class Label(object):
elif self.hemi == 'rh':
hemi_src = src[1]
- if not np.all(in1d(self.vertices, hemi_src['vertno'])):
+ if not np.all(np.in1d(self.vertices, hemi_src['vertno'])):
msg = "Source space does not contain all of the label's vertices"
raise ValueError(msg)
@@ -436,11 +435,11 @@ class Label(object):
nearest = hemi_src['nearest']
# find new vertices
- include = in1d(nearest, self.vertices, False)
+ include = np.in1d(nearest, self.vertices, False)
vertices = np.nonzero(include)[0]
# values
- nearest_in_label = digitize(nearest[vertices], self.vertices, True)
+ nearest_in_label = np.digitize(nearest[vertices], self.vertices, True)
values = self.values[nearest_in_label]
# pos
pos = hemi_src['rr'][vertices]
@@ -485,7 +484,7 @@ class Label(object):
n_jobs : int
Number of jobs to run in parallel
copy : bool
- This parameter has been deprecated and will be removed in 0.13.
+ This parameter has been deprecated and will be removed in 0.14.
Use inst.copy() instead.
Whether to return a new instance or modify in place.
verbose : bool, str, int, or None
@@ -541,7 +540,7 @@ class Label(object):
n_jobs : int
Number of jobs to run in parallel.
copy : bool
- This parameter has been deprecated and will be removed in 0.13.
+ This parameter has been deprecated and will be removed in 0.14.
Use inst.copy() instead.
Whether to return a new instance or modify in place.
verbose : bool, str, int, or None
@@ -583,7 +582,7 @@ class Label(object):
smooth=smooth, subjects_dir=subjects_dir,
warn=False, n_jobs=n_jobs)
inds = np.nonzero(stc.data)[0]
- label = _check_copy_dep(self, copy, default=True)
+ label = _check_copy_dep(self, copy)
label.values = stc.data[inds, :].ravel()
label.pos = np.zeros((len(inds), 3))
if label.hemi == 'lh':
@@ -658,7 +657,7 @@ class Label(object):
if vertices is None:
vertices = np.arange(10242)
- label_verts = vertices[in1d(vertices, self.vertices)]
+ label_verts = vertices[np.in1d(vertices, self.vertices)]
return label_verts
def get_tris(self, tris, vertices=None):
@@ -679,7 +678,7 @@ class Label(object):
The subset of tris used by the label
"""
vertices_ = self.get_vertices_used(vertices)
- selection = np.all(in1d(tris, vertices_).reshape(tris.shape),
+ selection = np.all(np.in1d(tris, vertices_).reshape(tris.shape),
axis=1)
label_tris = tris[selection]
if len(np.unique(label_tris)) < len(vertices_):
@@ -698,6 +697,62 @@ class Label(object):
return label_tris
+ def center_of_mass(self, subject=None, restrict_vertices=False,
+ subjects_dir=None, surf='sphere'):
+ """Compute the center of mass of the label
+
+ This function computes the spatial center of mass on the surface
+ as in [1]_.
+
+ Parameters
+ ----------
+ subject : string | None
+ The subject the label is defined for.
+ restrict_vertices : bool | array of int | instance of SourceSpaces
+ If True, returned vertex will be one from the label. Otherwise,
+ it could be any vertex from surf. If an array of int, the
+ returned vertex will come from that array. If instance of
+ SourceSpaces (as of 0.13), the returned vertex will be from
+ the given source space. For most accuruate estimates, do not
+ restrict vertices.
+ subjects_dir : str, or None
+ Path to the SUBJECTS_DIR. If None, the path is obtained by using
+ the environment variable SUBJECTS_DIR.
+ surf : str
+ The surface to use for Euclidean distance center of mass
+ finding. The default here is "sphere", which finds the center
+ of mass on the spherical surface to help avoid potential issues
+ with cortical folding.
+
+ Returns
+ -------
+ vertex : int
+ Vertex of the spatial center of mass for the inferred hemisphere,
+ with each vertex weighted by its label value.
+
+ See Also
+ --------
+ SourceEstimate.center_of_mass
+ vertex_to_mni
+
+ Notes
+ -----
+ .. versionadded: 0.13
+
+ References
+ ----------
+ .. [1] Larson and Lee, "The cortical dynamics underlying effective
+ switching of auditory spatial attention", NeuroImage 2012.
+ """
+ if not isinstance(surf, string_types):
+ raise TypeError('surf must be a string, got %s' % (type(surf),))
+ subject = _check_subject(self.subject, subject)
+ if np.any(self.values < 0):
+ raise ValueError('Cannot compute COM with negative values')
+ vertex = _center_of_mass(self.vertices, self.values, self.hemi, surf,
+ subject, subjects_dir, restrict_vertices)
+ return vertex
+
class BiHemiLabel(object):
"""A freesurfer/MNE label with vertices in both hemispheres
@@ -1002,7 +1057,7 @@ def _split_label_contig(label_to_split, subject=None, subjects_dir=None):
for div, name, color in zip(label_divs, names, colors):
# Get indices of dipoles within this division of the label
verts = np.array(sorted(list(div)))
- vert_indices = in1d(verts_arr, verts, assume_unique=True)
+ vert_indices = np.in1d(verts_arr, verts, assume_unique=True)
# Set label attributes
pos = label_to_split.pos[vert_indices]
diff --git a/mne/minimum_norm/tests/test_inverse.py b/mne/minimum_norm/tests/test_inverse.py
index 8247e85..0ba0374 100644
--- a/mne/minimum_norm/tests/test_inverse.py
+++ b/mne/minimum_norm/tests/test_inverse.py
@@ -16,7 +16,7 @@ from mne.source_estimate import read_source_estimate, VolSourceEstimate
from mne import (read_cov, read_forward_solution, read_evokeds, pick_types,
pick_types_forward, make_forward_solution,
convert_forward_solution, Covariance)
-from mne.io import Raw, Info
+from mne.io import read_raw_fif, Info
from mne.minimum_norm.inverse import (apply_inverse, read_inverse_operator,
apply_inverse_raw, apply_inverse_epochs,
make_inverse_operator,
@@ -60,24 +60,28 @@ last_keys = [None] * 10
def read_forward_solution_meg(*args, **kwargs):
+ """Read MEG forward."""
fwd = read_forward_solution(*args, **kwargs)
fwd = pick_types_forward(fwd, meg=True, eeg=False)
return fwd
def read_forward_solution_eeg(*args, **kwargs):
+ """Read EEG forward."""
fwd = read_forward_solution(*args, **kwargs)
fwd = pick_types_forward(fwd, meg=False, eeg=True)
return fwd
def _get_evoked():
+ """Get evoked data."""
evoked = read_evokeds(fname_data, condition=0, baseline=(None, 0))
evoked.crop(0, 0.2)
return evoked
def _compare(a, b):
+ """Compare two python objects."""
global last_keys
skip_types = ['whitener', 'proj', 'reginv', 'noisenorm', 'nchan',
'command_line', 'working_dir', 'mri_file', 'mri_id']
@@ -115,6 +119,7 @@ def _compare(a, b):
def _compare_inverses_approx(inv_1, inv_2, evoked, rtol, atol,
check_depth=True):
+ """Compare inverses."""
# depth prior
if check_depth:
if inv_1['depth_prior'] is not None:
@@ -148,6 +153,7 @@ def _compare_inverses_approx(inv_1, inv_2, evoked, rtol, atol,
def _compare_io(inv_op, out_file_ext='.fif'):
+ """Compare inverse IO."""
tempdir = _TempDir()
if out_file_ext == '.fif':
out_file = op.join(tempdir, 'test-inv.fif')
@@ -165,8 +171,7 @@ def _compare_io(inv_op, out_file_ext='.fif'):
@testing.requires_testing_data
def test_warn_inverse_operator():
- """Test MNE inverse warning without average EEG projection
- """
+ """Test MNE inverse warning without average EEG projection."""
bad_info = copy.deepcopy(_get_evoked().info)
bad_info['projs'] = list()
fwd_op = read_forward_solution(fname_fwd, surf_ori=True)
@@ -380,7 +385,7 @@ def test_make_inverse_operator_diag():
"""Test MNE inverse computation with diagonal noise cov
"""
evoked = _get_evoked()
- noise_cov = read_cov(fname_cov).as_diag(copy=False)
+ noise_cov = read_cov(fname_cov).as_diag()
fwd_op = read_forward_solution(fname_fwd, surf_ori=True)
inv_op = make_inverse_operator(evoked.info, fwd_op, noise_cov,
loose=0.2, depth=0.8)
@@ -464,11 +469,10 @@ def test_io_inverse_operator():
@testing.requires_testing_data
def test_apply_mne_inverse_raw():
- """Test MNE with precomputed inverse operator on Raw
- """
+ """Test MNE with precomputed inverse operator on Raw."""
start = 3
stop = 10
- raw = Raw(fname_raw)
+ raw = read_raw_fif(fname_raw, add_eeg_ref=False)
label_lh = read_label(fname_label % 'Aud-lh')
_, times = raw[0, start:stop]
inverse_operator = read_inverse_operator(fname_full)
@@ -498,9 +502,8 @@ def test_apply_mne_inverse_raw():
@testing.requires_testing_data
def test_apply_mne_inverse_fixed_raw():
- """Test MNE with fixed-orientation inverse operator on Raw
- """
- raw = Raw(fname_raw)
+ """Test MNE with fixed-orientation inverse operator on Raw."""
+ raw = read_raw_fif(fname_raw, add_eeg_ref=False)
start = 3
stop = 10
_, times = raw[0, start:stop]
@@ -538,13 +541,12 @@ def test_apply_mne_inverse_fixed_raw():
@testing.requires_testing_data
def test_apply_mne_inverse_epochs():
- """Test MNE with precomputed inverse operator on Epochs
- """
+ """Test MNE with precomputed inverse operator on Epochs."""
inverse_operator = read_inverse_operator(fname_full)
label_lh = read_label(fname_label % 'Aud-lh')
label_rh = read_label(fname_label % 'Aud-rh')
event_id, tmin, tmax = 1, -0.2, 0.5
- raw = Raw(fname_raw)
+ raw = read_raw_fif(fname_raw, add_eeg_ref=False)
picks = pick_types(raw.info, meg=True, eeg=False, stim=True, ecg=True,
eog=True, include=['STI 014'], exclude='bads')
@@ -553,7 +555,8 @@ def test_apply_mne_inverse_epochs():
events = read_events(fname_event)[:15]
epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks,
- baseline=(None, 0), reject=reject, flat=flat)
+ baseline=(None, 0), reject=reject, flat=flat,
+ add_eeg_ref=False)
stcs = apply_inverse_epochs(epochs, inverse_operator, lambda2, "dSPM",
label=label_lh, pick_ori="normal")
inverse_operator = prepare_inverse_operator(inverse_operator, nave=1,
@@ -603,8 +606,7 @@ def test_apply_mne_inverse_epochs():
@testing.requires_testing_data
def test_make_inverse_operator_bads():
- """Test MNE inverse computation given a mismatch of bad channels
- """
+ """Test MNE inverse computation given a mismatch of bad channels."""
fwd_op = read_forward_solution_meg(fname_fwd, surf_ori=True)
evoked = _get_evoked()
noise_cov = read_cov(fname_cov)
diff --git a/mne/minimum_norm/tests/test_time_frequency.py b/mne/minimum_norm/tests/test_time_frequency.py
index 0aa1ec3..c8c03ef 100644
--- a/mne/minimum_norm/tests/test_time_frequency.py
+++ b/mne/minimum_norm/tests/test_time_frequency.py
@@ -6,7 +6,8 @@ from nose.tools import assert_true
import warnings
from mne.datasets import testing
-from mne import io, find_events, Epochs, pick_types
+from mne import find_events, Epochs, pick_types
+from mne.io import read_raw_fif
from mne.utils import run_tests_if_main
from mne.label import read_label
from mne.minimum_norm.inverse import (read_inverse_operator,
@@ -31,12 +32,11 @@ warnings.simplefilter('always')
@testing.requires_testing_data
def test_tfr_with_inverse_operator():
- """Test time freq with MNE inverse computation"""
-
+ """Test time freq with MNE inverse computation."""
tmin, tmax, event_id = -0.2, 0.5, 1
# Setup for reading the raw data
- raw = io.read_raw_fif(fname_data)
+ raw = read_raw_fif(fname_data, add_eeg_ref=False)
events = find_events(raw, stim_channel='STI 014')
inverse_operator = read_inverse_operator(fname_inv)
inv = prepare_inverse_operator(inverse_operator, nave=1,
@@ -53,7 +53,7 @@ def test_tfr_with_inverse_operator():
events3 = events[:3] # take 3 events to keep the computation time low
epochs = Epochs(raw, events3, event_id, tmin, tmax, picks=picks,
baseline=(None, 0), reject=dict(grad=4000e-13, eog=150e-6),
- preload=True)
+ preload=True, add_eeg_ref=False)
# Compute a source estimate per frequency band
bands = dict(alpha=[10, 10])
@@ -78,7 +78,7 @@ def test_tfr_with_inverse_operator():
# Compute a source estimate per frequency band
epochs = Epochs(raw, events[:10], event_id, tmin, tmax, picks=picks,
baseline=(None, 0), reject=dict(grad=4000e-13, eog=150e-6),
- preload=True)
+ preload=True, add_eeg_ref=False)
frequencies = np.arange(7, 30, 2) # define frequencies of interest
power, phase_lock = source_induced_power(epochs, inv,
@@ -94,8 +94,8 @@ def test_tfr_with_inverse_operator():
@testing.requires_testing_data
def test_source_psd():
- """Test source PSD computation in label"""
- raw = io.read_raw_fif(fname_data)
+ """Test source PSD computation in label."""
+ raw = read_raw_fif(fname_data, add_eeg_ref=False)
inverse_operator = read_inverse_operator(fname_inv)
label = read_label(fname_label)
tmin, tmax = 0, 20 # seconds
@@ -114,9 +114,8 @@ def test_source_psd():
@testing.requires_testing_data
def test_source_psd_epochs():
- """Test multi-taper source PSD computation in label from epochs"""
-
- raw = io.read_raw_fif(fname_data)
+ """Test multi-taper source PSD computation in label from epochs."""
+ raw = read_raw_fif(fname_data, add_eeg_ref=False)
inverse_operator = read_inverse_operator(fname_inv)
label = read_label(fname_label)
@@ -132,7 +131,7 @@ def test_source_psd_epochs():
events = find_events(raw, stim_channel='STI 014')
epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks,
- baseline=(None, 0), reject=reject)
+ baseline=(None, 0), reject=reject, add_eeg_ref=False)
# only look at one epoch
epochs.drop_bad()
diff --git a/mne/minimum_norm/time_frequency.py b/mne/minimum_norm/time_frequency.py
index eb87303..6924b93 100644
--- a/mne/minimum_norm/time_frequency.py
+++ b/mne/minimum_norm/time_frequency.py
@@ -103,7 +103,7 @@ def source_band_induced_power(epochs, inverse_operator, bands, label=None,
power = [power - mean(power_baseline)] / std(power_baseline)).
pca : bool
If True, the true dimension of data is estimated before running
- the time frequency transforms. It reduces the computation times
+ the time-frequency transforms. It reduces the computation times
e.g. with a dataset that was maxfiltered (true dim is 64).
n_jobs : int
Number of jobs to run in parallel.
@@ -341,7 +341,7 @@ def source_induced_power(epochs, inverse_operator, frequencies, label=None,
power = [power - mean(power_baseline)] / std(power_baseline)).
pca : bool
If True, the true dimension of data is estimated before running
- the time frequency transforms. It reduces the computation times
+ the time-frequency transforms. It reduces the computation times
e.g. with a dataset that was maxfiltered (true dim is 64).
n_jobs : int
Number of jobs to run in parallel.
@@ -413,7 +413,7 @@ def compute_source_psd(raw, inverse_operator, lambda2=1. / 9., method="dSPM",
The number of averages used to scale the noise covariance matrix.
pca: bool
If True, the true dimension of data is estimated before running
- the time frequency transforms. It reduces the computation times
+ the time-frequency transforms. It reduces the computation times
e.g. with a dataset that was maxfiltered (true dim is 64).
prepared : bool
If True, do not call `prepare_inverse_operator`.
@@ -630,7 +630,7 @@ def compute_source_psd_epochs(epochs, inverse_operator, lambda2=1. / 9.,
The number of averages used to scale the noise covariance matrix.
pca : bool
If True, the true dimension of data is estimated before running
- the time frequency transforms. It reduces the computation times
+ the time-frequency transforms. It reduces the computation times
e.g. with a dataset that was maxfiltered (true dim is 64).
inv_split : int or None
Split inverse operator into inv_split parts in order to save memory.
diff --git a/mne/preprocessing/__init__.py b/mne/preprocessing/__init__.py
index 4792927..ab173a6 100644
--- a/mne/preprocessing/__init__.py
+++ b/mne/preprocessing/__init__.py
@@ -12,7 +12,7 @@ from .ssp import compute_proj_ecg, compute_proj_eog
from .eog import find_eog_events, create_eog_epochs
from .ecg import find_ecg_events, create_ecg_epochs
from .ica import (ICA, ica_find_eog_events, ica_find_ecg_events,
- get_score_funcs, read_ica, run_ica)
+ get_score_funcs, read_ica, run_ica, corrmap)
from .bads import find_outliers
from .stim import fix_stim_artifact
from .maxwell import maxwell_filter
diff --git a/mne/preprocessing/_fine_cal.py b/mne/preprocessing/_fine_cal.py
new file mode 100644
index 0000000..f9c8b03
--- /dev/null
+++ b/mne/preprocessing/_fine_cal.py
@@ -0,0 +1,84 @@
+# -*- coding: utf-8 -*-
+# Authors: Eric Larson <larson.eric.d at gmail.com>
+
+# License: BSD (3-clause)
+
+import numpy as np
+
+from ..utils import check_fname, _check_fname
+
+
+def read_fine_calibration(fname):
+ """Read fine calibration information from a .dat file
+
+ The fine calibration typically includes improved sensor locations,
+ calibration coefficients, and gradiometer imbalance information.
+
+ Parameters
+ ----------
+ fname : str
+ The filename.
+
+ Returns
+ -------
+ calibration : dict
+ Fine calibration information.
+ """
+ # Read new sensor locations
+ _check_fname(fname, overwrite=True, must_exist=True)
+ check_fname(fname, 'cal', ('.dat',))
+ ch_names = list()
+ locs = list()
+ imb_cals = list()
+ with open(fname, 'r') as fid:
+ for line in fid:
+ if line[0] in '#\n':
+ continue
+ vals = line.strip().split()
+ if len(vals) not in [14, 16]:
+ raise RuntimeError('Error parsing fine calibration file, '
+ 'should have 14 or 16 entries per line '
+ 'but found %s on line:\n%s'
+ % (len(vals), line))
+ # `vals` contains channel number
+ ch_name = vals[0]
+ if len(ch_name) in (3, 4): # heuristic for Neuromag fix
+ try:
+ ch_name = int(ch_name)
+ except ValueError: # something other than e.g. 113 or 2642
+ pass
+ else:
+ ch_name = 'MEG' + '%04d' % ch_name
+ ch_names.append(ch_name)
+ # (x, y, z), x-norm 3-vec, y-norm 3-vec, z-norm 3-vec
+ locs.append(np.array([float(x) for x in vals[1:13]]))
+ # and 1 or 3 imbalance terms
+ imb_cals.append([float(x) for x in vals[13:]])
+ locs = np.array(locs)
+ return dict(ch_names=ch_names, locs=locs, imb_cals=imb_cals)
+
+
+def write_fine_calibration(fname, calibration):
+ """Write fine calibration information to a .dat file
+
+ Parameters
+ ----------
+ fname : str
+ The filename to write out.
+ calibration : dict
+ Fine calibration information.
+ """
+ _check_fname(fname, overwrite=True)
+ check_fname(fname, 'cal', ('.dat',))
+
+ with open(fname, 'wb') as cal_file:
+ for ci, chan in enumerate(calibration['ch_names']):
+ # Write string containing 1) channel, 2) loc info, 3) calib info
+ # with field widths (e.g., %.6f) chosen to match how Elekta writes
+ # them out
+ cal_line = np.concatenate([calibration['locs'][ci],
+ calibration['imb_cals'][ci]]).round(6)
+ cal_str = str(chan) + ' ' + ' '.join(map(lambda x: "%.6f" % x,
+ cal_line))
+
+ cal_file.write((cal_str + '\n').encode('ASCII'))
diff --git a/mne/preprocessing/ecg.py b/mne/preprocessing/ecg.py
index 822af12..d79c69e 100644
--- a/mne/preprocessing/ecg.py
+++ b/mne/preprocessing/ecg.py
@@ -53,7 +53,9 @@ def qrs_detector(sfreq, ecg, thresh_value=0.6, levels=2.5, n_thresh=3,
win_size = int(round((60.0 * sfreq) / 120.0))
filtecg = band_pass_filter(ecg, sfreq, l_freq, h_freq,
- filter_length=filter_length)
+ filter_length=filter_length,
+ l_trans_bandwidth=0.5, h_trans_bandwidth=0.5,
+ phase='zero-double', fir_window='hann')
ecg_abs = np.abs(filtecg)
init = int(sfreq)
@@ -224,26 +226,26 @@ def _get_ecg_channel_index(ch_name, inst):
@verbose
-def create_ecg_epochs(raw, ch_name=None, event_id=999, picks=None,
- tmin=-0.5, tmax=0.5, l_freq=8, h_freq=16, reject=None,
- flat=None, baseline=None, preload=True,
- keep_ecg=False, verbose=None):
+def create_ecg_epochs(raw, ch_name=None, event_id=999, picks=None, tmin=-0.5,
+ tmax=0.5, l_freq=8, h_freq=16, reject=None, flat=None,
+ baseline=None, preload=True, keep_ecg=False,
+ verbose=None):
"""Conveniently generate epochs around ECG artifact events
-
Parameters
----------
raw : instance of Raw
The raw data
ch_name : None | str
The name of the channel to use for ECG peak detection.
- If None (default), a synthetic ECG channel is created from
+ If None (default), ECG channel is used if present. If None and no
+ ECG channel is present, a synthetic ECG channel is created from
cross channel average. Synthetic channel can only be created from
'meg' channels.
event_id : int
The index to assign to found events
picks : array-like of int | None (default)
- Indices of channels to include (if None, all channels are used).
+ Indices of channels to include. If None, all channels are used.
tmin : float
Start time before event.
tmax : float
@@ -268,19 +270,20 @@ def create_ecg_epochs(raw, ch_name=None, event_id=999, picks=None,
Valid keys are 'grad' | 'mag' | 'eeg' | 'eog' | 'ecg', and values
are floats that set the minimum acceptable peak-to-peak amplitude.
If flat is None then no rejection is done.
- baseline : tuple or list of length 2, or None
+ baseline : tuple | list of length 2 | None
The time interval to apply rescaling / baseline correction.
If None do not apply it. If baseline is (a, b)
the interval is between "a (s)" and "b (s)".
If a is None the beginning of the data is used
and if b is None then b is set to the end of the interval.
- If baseline is equal ot (None, None) all the time
+ If baseline is equal to (None, None) all the time
interval is used. If None, no correction is applied.
preload : bool
Preload epochs or not.
keep_ecg : bool
- When ECG is synthetically created (after picking),
- should it be added to the epochs? Defaults to False.
+ When ECG is synthetically created (after picking), should it be added
+ to the epochs? Must be False when synthetic channel is not used.
+ Defaults to False.
verbose : bool, str, int, or None
If not None, override default verbose level (see mne.verbose).
@@ -289,16 +292,13 @@ def create_ecg_epochs(raw, ch_name=None, event_id=999, picks=None,
ecg_epochs : instance of Epochs
Data epoched around ECG r-peaks.
"""
- not_has_ecg = 'ecg' not in raw and ch_name is None
- if not_has_ecg:
- ecg, times = _make_ecg(raw, None, None, verbose)
+ has_ecg = 'ecg' in raw or ch_name is not None
events, _, _, ecg = find_ecg_events(
raw, ch_name=ch_name, event_id=event_id, l_freq=l_freq, h_freq=h_freq,
- return_ecg=True,
- verbose=verbose)
+ return_ecg=True, verbose=verbose)
- if not_has_ecg:
+ if not has_ecg:
ecg_raw = RawArray(
ecg[None],
create_info(ch_names=['ECG-SYN'],
@@ -309,23 +309,19 @@ def create_ecg_epochs(raw, ch_name=None, event_id=999, picks=None,
ecg_raw.info[k] = v
raw.add_channels([ecg_raw])
- if picks is None and not keep_ecg:
- picks = pick_types(raw.info, meg=True, eeg=True, ecg=False,
- ref_meg=False)
- elif picks is None and keep_ecg and not_has_ecg:
- picks = pick_types(raw.info, meg=True, eeg=True, ecg=True,
- ref_meg=False)
- elif keep_ecg and not_has_ecg:
- picks_extra = pick_types(raw.info, meg=False, eeg=False, ecg=True,
- ref_meg=False)
- picks = np.concatenate([picks, picks_extra])
-
+ if keep_ecg:
+ if has_ecg:
+ raise ValueError('keep_ecg can be True only if the ECG channel is '
+ 'created synthetically.')
+ else:
+ picks = np.append(picks, raw.ch_names.index('ECG-SYN'))
# create epochs around ECG events and baseline (important)
ecg_epochs = Epochs(raw, events=events, event_id=event_id,
- tmin=tmin, tmax=tmax, proj=False,
+ tmin=tmin, tmax=tmax, proj=False, flat=flat,
picks=picks, reject=reject, baseline=baseline,
- verbose=verbose, preload=preload)
- if ecg is not None:
+ verbose=verbose, preload=preload, add_eeg_ref=False)
+
+ if not has_ecg:
raw.drop_channels(['ECG-SYN'])
return ecg_epochs
diff --git a/mne/preprocessing/eog.py b/mne/preprocessing/eog.py
index cd96bdc..4e0b8c1 100644
--- a/mne/preprocessing/eog.py
+++ b/mne/preprocessing/eog.py
@@ -70,16 +70,19 @@ def _find_eog_events(eog, event_id, l_freq, h_freq, sampling_rate, first_samp,
# filtering to remove dc offset so that we know which is blink and saccades
fmax = np.minimum(45, sampling_rate / 2.0 - 0.75) # protect Nyquist
- filteog = np.array([band_pass_filter(x, sampling_rate, 2, fmax,
- filter_length=filter_length)
- for x in eog])
+ filteog = np.array([band_pass_filter(
+ x, sampling_rate, 2, fmax, filter_length=filter_length,
+ l_trans_bandwidth=0.5, h_trans_bandwidth=0.5, phase='zero-double',
+ fir_window='hann') for x in eog])
temp = np.sqrt(np.sum(filteog ** 2, axis=1))
indexmax = np.argmax(temp)
# easier to detect peaks with filtering.
- filteog = band_pass_filter(eog[indexmax], sampling_rate, l_freq, h_freq,
- filter_length=filter_length)
+ filteog = band_pass_filter(
+ eog[indexmax], sampling_rate, l_freq, h_freq,
+ filter_length=filter_length, l_trans_bandwidth=0.5,
+ h_trans_bandwidth=0.5, phase='zero-double', fir_window='hann')
# detecting eog blinks and generating event file
@@ -205,5 +208,5 @@ def create_eog_epochs(raw, ch_name=None, event_id=998, picks=None,
eog_epochs = Epochs(raw, events=events, event_id=event_id,
tmin=tmin, tmax=tmax, proj=False, reject=reject,
flat=flat, picks=picks, baseline=baseline,
- preload=preload)
+ preload=preload, add_eeg_ref=False)
return eog_epochs
diff --git a/mne/preprocessing/ica.py b/mne/preprocessing/ica.py
index 76113a4..6d080bc 100644
--- a/mne/preprocessing/ica.py
+++ b/mne/preprocessing/ica.py
@@ -35,6 +35,7 @@ from ..io.base import _BaseRaw
from ..epochs import _BaseEpochs
from ..viz import (plot_ica_components, plot_ica_scores,
plot_ica_sources, plot_ica_overlay)
+from ..viz.ica import plot_ica_properties
from ..viz.utils import (_prepare_trellis, tight_layout, plt_show,
_setup_vmin_vmax)
from ..viz.topomap import (_prepare_topo_plot, _check_outlines,
@@ -44,12 +45,14 @@ from ..channels.channels import _contains_ch_type, ContainsMixin
from ..io.write import start_file, end_file, write_id
from ..utils import (check_version, logger, check_fname, verbose,
_reject_data_segments, check_random_state,
- _get_fast_dot, compute_corr, _check_copy_dep)
+ _get_fast_dot, compute_corr, _get_inst_data,
+ copy_function_doc_to_method_doc)
from ..fixes import _get_args
from ..filter import band_pass_filter
from .bads import find_outliers
from .ctps_ import ctps
from ..externals.six import string_types, text_type
+from ..io.pick import channel_type
__all__ = ['ICA', 'ica_find_ecg_events', 'ica_find_eog_events',
@@ -88,6 +91,27 @@ def get_score_funcs():
return score_funcs
+def _check_for_unsupported_ica_channels(picks, info):
+ """Check for channels in picks that are not considered
+ valid channels. Accepted channels are the data channels
+ ('seeg','ecog','eeg', 'hbo', 'hbr', 'mag', and 'grad') and 'eog'.
+ This prevents the program from crashing without
+ feedback when a bad channel is provided to ICA whitening.
+ """
+ if picks is None:
+ return
+ elif len(picks) == 0:
+ raise ValueError('No channels provided to ICA')
+ types = _DATA_CH_TYPES_SPLIT + ['eog']
+ chs = list(set([channel_type(info, j) for j in picks]))
+ check = all([ch in types for ch in chs])
+ if not check:
+ raise ValueError('Invalid channel type(s) passed for ICA.\n'
+ 'Only the following channels are supported {0}\n'
+ 'Following types were passed {1}\n'
+ .format(types, chs))
+
+
class ICA(ContainsMixin):
"""M/EEG signal decomposition using Independent Component Analysis (ICA)
@@ -164,7 +188,7 @@ class ICA(ContainsMixin):
The number of components used for PCA dimensionality reduction.
verbose : bool, str, int, or None
See above.
- ``pca_components_` : ndarray
+ ``pca_components_`` : ndarray
If fit, the PCA components
``pca_mean_`` : ndarray
If fit, the mean vector used to center the data before doing the PCA.
@@ -184,9 +208,9 @@ class ICA(ContainsMixin):
again. To dump this 'artifact memory' say: ica.exclude = []
info : None | instance of Info
The measurement info copied from the object fitted.
- `n_samples_` : int
+ ``n_samples_`` : int
the number of samples used on fit.
- `labels_` : dict
+ ``labels_`` : dict
A dictionary of independent component indices, grouped by types of
independent components. This attribute is set by some of the artifact
detection functions.
@@ -298,7 +322,8 @@ class ICA(ContainsMixin):
within ``start`` and ``stop`` are used.
reject : dict | None
Rejection parameters based on peak-to-peak amplitude.
- Valid keys are 'grad', 'mag', 'eeg', 'seeg', 'ecog', 'eog', 'ecg'.
+ Valid keys are 'grad', 'mag', 'eeg', 'seeg', 'ecog', 'eog', 'ecg',
+ 'hbo', 'hbr'.
If reject is None then no rejection is done. Example::
reject = dict(grad=4000e-13, # T / m (gradiometers)
@@ -310,7 +335,8 @@ class ICA(ContainsMixin):
It only applies if `inst` is of type Raw.
flat : dict | None
Rejection parameters based on flatness of signal.
- Valid keys are 'grad', 'mag', 'eeg', 'seeg', 'ecog', 'eog', 'ecg'.
+ Valid keys are 'grad', 'mag', 'eeg', 'seeg', 'ecog', 'eog', 'ecg',
+ 'hbo', 'hbr'.
Values are floats that set the minimum acceptable peak-to-peak
amplitude. If flat is None then no rejection is done.
It only applies if `inst` is of type Raw.
@@ -326,13 +352,21 @@ class ICA(ContainsMixin):
self : instance of ICA
Returns the modified instance.
"""
- if isinstance(inst, _BaseRaw):
- self._fit_raw(inst, picks, start, stop, decim, reject, flat,
- tstep, verbose)
- elif isinstance(inst, _BaseEpochs):
- self._fit_epochs(inst, picks, decim, verbose)
+ if isinstance(inst, _BaseRaw) or isinstance(inst, _BaseEpochs):
+ _check_for_unsupported_ica_channels(picks, inst.info)
+ if isinstance(inst, _BaseRaw):
+ self._fit_raw(inst, picks, start, stop, decim, reject, flat,
+ tstep, verbose)
+ elif isinstance(inst, _BaseEpochs):
+ self._fit_epochs(inst, picks, decim, verbose)
else:
raise ValueError('Data input must be of Raw or Epochs type')
+
+ # sort ICA components by explained variance
+ var = _ica_explained_variance(self, inst)
+ var_ord = var.argsort()[::-1]
+ _sort_components(self, var_ord, copy=False)
+
return self
def _reset(self):
@@ -440,7 +474,7 @@ class ICA(ContainsMixin):
# Scale (z-score) the data by channel type
info = pick_info(info, picks)
pre_whitener = np.empty([len(data), 1])
- for ch_type in _DATA_CH_TYPES_SPLIT:
+ for ch_type in _DATA_CH_TYPES_SPLIT + ['eog']:
if _contains_ch_type(info, ch_type):
if ch_type == 'seeg':
this_picks = pick_types(info, meg=False, seeg=True)
@@ -448,8 +482,16 @@ class ICA(ContainsMixin):
this_picks = pick_types(info, meg=False, ecog=True)
elif ch_type == 'eeg':
this_picks = pick_types(info, meg=False, eeg=True)
- else:
+ elif ch_type in ('mag', 'grad'):
this_picks = pick_types(info, meg=ch_type)
+ elif ch_type == 'eog':
+ this_picks = pick_types(info, meg=False, eog=True)
+ elif ch_type in ('hbo', 'hbr'):
+ this_picks = pick_types(info, meg=False, fnirs=ch_type)
+ else:
+ raise RuntimeError('Should not be reached.'
+ 'Unsupported channel {0}'
+ .format(ch_type))
pre_whitener[this_picks] = np.std(data[this_picks])
data /= pre_whitener
elif not has_pre_whitener and self.noise_cov is not None:
@@ -467,13 +509,19 @@ class ICA(ContainsMixin):
def _fit(self, data, max_pca_components, fit_type):
"""Aux function """
- from sklearn.decomposition import RandomizedPCA
random_state = check_random_state(self.random_state)
- # XXX fix copy==True later. Bug in sklearn, see PR #2273
- pca = RandomizedPCA(n_components=max_pca_components, whiten=True,
- copy=True, random_state=random_state)
+ if not check_version('sklearn', '0.18'):
+ from sklearn.decomposition import RandomizedPCA
+ # XXX fix copy==True later. Bug in sklearn, see PR #2273
+ pca = RandomizedPCA(n_components=max_pca_components, whiten=True,
+ copy=True, random_state=random_state)
+
+ else:
+ from sklearn.decomposition import PCA
+ pca = PCA(n_components=max_pca_components, copy=True, whiten=True,
+ svd_solver='randomized', random_state=random_state)
if isinstance(self.n_components, float):
# compute full feature variance before doing PCA
@@ -508,9 +556,12 @@ class ICA(ContainsMixin):
# the things to store for PCA
self.pca_mean_ = pca.mean_
self.pca_components_ = pca.components_
- # unwhiten pca components and put scaling in unmixintg matrix later.
self.pca_explained_variance_ = exp_var = pca.explained_variance_
- self.pca_components_ *= np.sqrt(exp_var[:, None])
+ if not check_version('sklearn', '0.18'):
+ # unwhiten pca components and put scaling in unmixing matrix later.
+ # RandomizedPCA applies the whitening to the components
+ # but not the new PCA class.
+ self.pca_components_ *= np.sqrt(exp_var[:, None])
del pca
# update number of components
self.n_components_ = sel.stop
@@ -963,7 +1014,7 @@ class ICA(ContainsMixin):
raise ValueError('Method "%s" not supported.' % method)
# sort indices by scores
ecg_idx = ecg_idx[np.abs(scores[ecg_idx]).argsort()[::-1]]
- if not hasattr(self, 'labels_'):
+ if not hasattr(self, 'labels_') or self.labels_ is None:
self.labels_ = dict()
self.labels_['ecg'] = list(ecg_idx)
self.labels_['ecg/%s' % ch_name] = list(ecg_idx)
@@ -987,7 +1038,7 @@ class ICA(ContainsMixin):
Object to compute sources from.
ch_name : str
The name of the channel to use for EOG peak detection.
- The argument is mandatory if the dataset contains no ECG
+ The argument is mandatory if the dataset contains no EOG
channels.
threshold : int | float
The value above which a feature is classified as outlier.
@@ -1033,7 +1084,7 @@ class ICA(ContainsMixin):
if inst.ch_names != self.ch_names:
inst = inst.copy().pick_channels(self.ch_names)
- if not hasattr(self, 'labels_'):
+ if not hasattr(self, 'labels_') or self.labels_ is None:
self.labels_ = dict()
for ii, (eog_ch, target) in enumerate(zip(eog_chs, targets)):
@@ -1065,9 +1116,8 @@ class ICA(ContainsMixin):
return self.labels_['eog'], scores
- def apply(self, inst, include=None, exclude=None,
- n_pca_components=None, start=None, stop=None,
- copy=None):
+ def apply(self, inst, include=None, exclude=None, n_pca_components=None,
+ start=None, stop=None):
"""Remove selected components from the signal.
Given the unmixing matrix, transform data,
@@ -1095,12 +1145,7 @@ class ICA(ContainsMixin):
stop : int | float | None
Last sample to not include. If float, data will be interpreted as
time in seconds. If None, data will be used to the last sample.
- copy : bool
- This parameter has been deprecated and will be removed in 0.13.
- Use inst.copy() instead.
- Whether to return a new instance or modify in place.
"""
- inst = _check_copy_dep(inst, copy)
if isinstance(inst, _BaseRaw):
out = self._apply_raw(raw=inst, include=include,
exclude=exclude,
@@ -1297,217 +1342,48 @@ class ICA(ContainsMixin):
"""
return deepcopy(self)
+ @copy_function_doc_to_method_doc(plot_ica_components)
def plot_components(self, picks=None, ch_type=None, res=64, layout=None,
vmin=None, vmax=None, cmap='RdBu_r', sensors=True,
colorbar=False, title=None, show=True, outlines='head',
- contours=6, image_interp='bilinear', head_pos=None):
- """Project unmixing matrix on interpolated sensor topography.
-
- Parameters
- ----------
- picks : int | array-like | None
- The indices of the sources to be plotted.
- If None all are plotted in batches of 20.
- ch_type : 'mag' | 'grad' | 'planar1' | 'planar2' | 'eeg' | None
- The channel type to plot. For 'grad', the gradiometers are
- collected in pairs and the RMS for each pair is plotted.
- If None, then first available channel type from order given
- above is used. Defaults to None.
- res : int
- The resolution of the topomap image (n pixels along each side).
- layout : None | Layout
- Layout instance specifying sensor positions (does not need to
- be specified for Neuromag data). If possible, the correct layout is
- inferred from the data.
- vmin : float | callable
- The value specfying the lower bound of the color range.
- If None, and vmax is None, -vmax is used. Else np.min(data).
- If callable, the output equals vmin(data).
- vmax : float | callable
- The value specfying the upper bound of the color range.
- If None, the maximum absolute value is used. If vmin is None,
- but vmax is not, defaults to np.min(data).
- If callable, the output equals vmax(data).
- cmap : matplotlib colormap
- Colormap.
- sensors : bool | str
- Add markers for sensor locations to the plot. Accepts matplotlib
- plot format string (e.g., 'r+' for red plusses). If True, a circle
- will be used (via .add_artist). Defaults to True.
- colorbar : bool
- Plot a colorbar.
- title : str | None
- Title to use.
- show : bool
- Call pyplot.show() at the end.
- outlines : 'head' | 'skirt' | dict | None
- The outlines to be drawn. If 'head', the default head scheme will
- be drawn. If 'skirt' the head scheme will be drawn, but sensors are
- allowed to be plotted outside of the head circle. If dict, each key
- refers to a tuple of x and y positions, the values in 'mask_pos'
- will serve as image mask, and the 'autoshrink' (bool) field will
- trigger automated shrinking of the positions due to points outside
- the outline. Alternatively, a matplotlib patch object can be passed
- for advanced masking options, either directly or as a function that
- returns patches (required for multi-axis plots). If None, nothing
- will be drawn. Defaults to 'head'.
- contours : int | False | None
- The number of contour lines to draw. If 0, no contours will
- be drawn.
- image_interp : str
- The image interpolation to be used. All matplotlib options are
- accepted.
- head_pos : dict | None
- If None (default), the sensors are positioned such that they span
- the head circle. If dict, can have entries 'center' (tuple) and
- 'scale' (tuple) for what the center and scale of the head should be
- relative to the electrode locations.
-
- Returns
- -------
- fig : instance of matplotlib.pyplot.Figure
- The figure object.
- """
- return plot_ica_components(self, picks=picks,
- ch_type=ch_type,
- res=res, layout=layout, vmax=vmax,
- cmap=cmap,
- sensors=sensors, colorbar=colorbar,
- title=title, show=show,
+ contours=6, image_interp='bilinear', head_pos=None,
+ inst=None):
+ return plot_ica_components(self, picks=picks, ch_type=ch_type,
+ res=res, layout=layout, vmin=vmin,
+ vmax=vmax, cmap=cmap, sensors=sensors,
+ colorbar=colorbar, title=title, show=show,
outlines=outlines, contours=contours,
image_interp=image_interp,
- head_pos=head_pos)
-
+ head_pos=head_pos, inst=inst)
+
+ @copy_function_doc_to_method_doc(plot_ica_properties)
+ def plot_properties(self, inst, picks=None, axes=None, dB=True,
+ plot_std=True, topomap_args=None, image_args=None,
+ psd_args=None, figsize=None, show=True):
+ return plot_ica_properties(self, inst, picks=picks, axes=axes,
+ dB=dB, plot_std=plot_std,
+ topomap_args=topomap_args,
+ image_args=image_args, psd_args=psd_args,
+ figsize=figsize, show=show)
+
+ @copy_function_doc_to_method_doc(plot_ica_sources)
def plot_sources(self, inst, picks=None, exclude=None, start=None,
stop=None, title=None, show=True, block=False):
- """Plot estimated latent sources given the unmixing matrix.
-
- Typical usecases:
-
- 1. plot evolution of latent sources over time based on (Raw input)
- 2. plot latent source around event related time windows (Epochs input)
- 3. plot time-locking in ICA space (Evoked input)
-
-
- Parameters
- ----------
- inst : instance of mne.io.Raw, mne.Epochs, mne.Evoked
- The object to plot the sources from.
- picks : ndarray | None.
- The components to be displayed. If None, plot will show the
- sources in the order as fitted.
- exclude : array_like of int
- The components marked for exclusion. If None (default), ICA.exclude
- will be used.
- start : int
- X-axis start index. If None from the beginning.
- stop : int
- X-axis stop index. If None to the end.
- title : str | None
- The figure title. If None a default is provided.
- show : bool
- If True, all open plots will be shown.
- block : bool
- Whether to halt program execution until the figure is closed.
- Useful for interactive selection of components in raw and epoch
- plotter. For evoked, this parameter has no effect. Defaults to
- False.
-
- Returns
- -------
- fig : instance of pyplot.Figure
- The figure.
-
- Notes
- -----
- For raw and epoch instances, it is possible to select components for
- exclusion by clicking on the line. The selected components are added to
- ``ica.exclude`` on close. The independent components can be viewed as
- topographies by clicking on the component name on the left of of the
- main axes. The topography view tries to infer the correct electrode
- layout from the data. This should work at least for Neuromag data.
-
- .. versionadded:: 0.10.0
- """
-
return plot_ica_sources(self, inst=inst, picks=picks, exclude=exclude,
- title=title, start=start, stop=stop, show=show,
+ start=start, stop=stop, title=title, show=show,
block=block)
+ @copy_function_doc_to_method_doc(plot_ica_scores)
def plot_scores(self, scores, exclude=None, labels=None, axhline=None,
title='ICA component scores', figsize=(12, 6),
show=True):
- """Plot scores related to detected components.
-
- Use this function to assess how well your score describes outlier
- sources and how well you were detecting them.
-
- Parameters
- ----------
- scores : array_like of float, shape (n ica components,) | list of array
- Scores based on arbitrary metric to characterize ICA components.
- exclude : array_like of int
- The components marked for exclusion. If None (default), ICA.exclude
- will be used.
- labels : str | list | 'ecg' | 'eog' | None
- The labels to consider for the axes tests. Defaults to None.
- If list, should match the outer shape of `scores`.
- If 'ecg' or 'eog', the ``labels_`` attributes will be looked up.
- Note that '/' is used internally for sublabels specifying ECG and
- EOG channels.
- axhline : float
- Draw horizontal line to e.g. visualize rejection threshold.
- title : str
- The figure title.
- figsize : tuple of int
- The figure size. Defaults to (12, 6).
- show : bool
- If True, all open plots will be shown.
-
- Returns
- -------
- fig : instance of matplotlib.pyplot.Figure
- The figure object.
- """
return plot_ica_scores(
ica=self, scores=scores, exclude=exclude, labels=labels,
axhline=axhline, title=title, figsize=figsize, show=show)
+ @copy_function_doc_to_method_doc(plot_ica_overlay)
def plot_overlay(self, inst, exclude=None, picks=None, start=None,
stop=None, title=None, show=True):
- """Overlay of raw and cleaned signals given the unmixing matrix.
-
- This method helps visualizing signal quality and artifact rejection.
-
- Parameters
- ----------
- inst : instance of mne.io.Raw or mne.Evoked
- The signals to be compared given the ICA solution. If Raw input,
- The raw data are displayed before and after cleaning. In a second
- panel the cross channel average will be displayed. Since dipolar
- sources will be canceled out this display is sensitive to
- artifacts. If evoked input, butterfly plots for clean and raw
- signals will be superimposed.
- exclude : array_like of int
- The components marked for exclusion. If None (default), ICA.exclude
- will be used.
- picks : array-like of int | None (default)
- Indices of channels to include (if None, all channels
- are used that were included on fitting).
- start : int
- X-axis start index. If None from the beginning.
- stop : int
- X-axis stop index. If None to the end.
- title : str
- The figure title.
- show : bool
- If True, all open plots will be shown.
-
- Returns
- -------
- fig : instance of pyplot.Figure
- The figure.
- """
return plot_ica_overlay(self, inst=inst, exclude=exclude, picks=picks,
start=start, stop=stop, title=title, show=show)
@@ -1759,6 +1635,68 @@ def _find_sources(sources, target, score_func):
return scores
+def _ica_explained_variance(ica, inst, normalize=False):
+ """Checks variance accounted for by each component in supplied data.
+
+ Parameters
+ ----------
+ ica : ICA
+ Instance of `mne.preprocessing.ICA`.
+ inst : Raw | Epochs | Evoked
+ Data to explain with ICA. Instance of Raw, Epochs or Evoked.
+ normalize : bool
+ Whether to normalize the variance.
+
+ Returns
+ -------
+ var : array
+ Variance explained by each component.
+ """
+ # check if ica is ICA and whether inst is Raw or Epochs
+ if not isinstance(ica, ICA):
+ raise TypeError('first argument must be an instance of ICA.')
+ if not isinstance(inst, (_BaseRaw, _BaseEpochs, Evoked)):
+ raise TypeError('second argument must an instance of either Raw, '
+ 'Epochs or Evoked.')
+
+ source_data = _get_inst_data(ica.get_sources(inst))
+
+ # if epochs - reshape to channels x timesamples
+ if isinstance(inst, _BaseEpochs):
+ n_epochs, n_chan, n_samp = source_data.shape
+ source_data = source_data.transpose(1, 0, 2).reshape(
+ (n_chan, n_epochs * n_samp))
+
+ n_chan, n_samp = source_data.shape
+ var = np.sum(ica.mixing_matrix_**2, axis=0) * np.sum(
+ source_data**2, axis=1) / (n_chan * n_samp - 1)
+ if normalize:
+ var /= var.sum()
+ return var
+
+
+def _sort_components(ica, order, copy=True):
+ """Change the order of components in ica solution."""
+ assert ica.n_components_ == len(order)
+ if copy:
+ ica = ica.copy()
+
+ # reorder components
+ ica.mixing_matrix_ = ica.mixing_matrix_[:, order]
+ ica.unmixing_matrix_ = ica.unmixing_matrix_[order, :]
+
+ # reorder labels, excludes etc.
+ if isinstance(order, np.ndarray):
+ order = list(order)
+ if ica.exclude:
+ ica.exclude = [order.index(ic) for ic in ica.exclude]
+ if hasattr(ica, 'labels_'):
+ for k in ica.labels_.keys():
+ ica.labels_[k] = [order.index(ic) for ic in ica.labels_[k]]
+
+ return ica
+
+
def _serialize(dict_, outer_sep=';', inner_sep=':'):
"""Aux function"""
s = []
@@ -1834,7 +1772,8 @@ def _write_ica(fid, ica):
# samples on fit
n_samples = getattr(ica, 'n_samples_', None)
ica_misc = {'n_samples_': (None if n_samples is None else int(n_samples)),
- 'labels_': getattr(ica, 'labels_', None)}
+ 'labels_': getattr(ica, 'labels_', None),
+ 'method': getattr(ica, 'method', None)}
write_string(fid, FIFF.FIFF_MNE_ICA_INTERFACE_PARAMS,
_serialize(ica_init))
@@ -1969,6 +1908,8 @@ def read_ica(fname):
ica.n_samples_ = ica_misc['n_samples_']
if 'labels_' in ica_misc:
ica.labels_ = ica_misc['labels_']
+ if 'method' in ica_misc:
+ ica.method = ica_misc['method']
logger.info('Ready.')
@@ -2361,28 +2302,19 @@ def corrmap(icas, template, threshold="auto", label=None, ch_type="eeg",
original Corrmap)
Defaults to "auto".
label : None | str
- If not None, categorised ICs are stored in a dictionary "labels_" under
- the given name. Preexisting entries will be appended to (excluding
- repeats), not overwritten. If None, a dry run is performed and
- the supplied ICs are not changed.
+ If not None, categorised ICs are stored in a dictionary ``labels_``
+ under the given name. Preexisting entries will be appended to
+ (excluding repeats), not overwritten. If None, a dry run is performed
+ and the supplied ICs are not changed.
ch_type : 'mag' | 'grad' | 'planar1' | 'planar2' | 'eeg'
- The channel type to plot. Defaults to 'eeg'.
+ The channel type to plot. Defaults to 'eeg'.
plot : bool
Should constructed template and selected maps be plotted? Defaults
to True.
show : bool
Show figures if True.
- layout : None | Layout | list of Layout
- Layout instance specifying sensor positions (does not need to be
- specified for Neuromag data). Or a list of Layout if projections
- are from different sensor types.
- cmap : None | matplotlib colormap
- Colormap for the plot. If ``None``, defaults to 'Reds_r' for norm data,
- otherwise to 'RdBu_r'.
- sensors : bool | str
- Add markers for sensor locations to the plot. Accepts matplotlib plot
- format string (e.g., 'r+' for red plusses). If True, a circle will be
- used (via .add_artist). Defaults to True.
+ verbose : bool, str, int, or None
+ If not None, override default verbose level (see mne.verbose).
outlines : 'head' | dict | None
The outlines to be drawn. If 'head', a head scheme will be drawn. If
dict, each key refers to a tuple of x and y positions. The values in
@@ -2392,10 +2324,19 @@ def corrmap(icas, template, threshold="auto", label=None, ch_type="eeg",
outline. Moreover, a matplotlib patch object can be passed for
advanced masking options, either directly or as a function that returns
patches (required for multi-axis plots).
+ layout : None | Layout | list of Layout
+ Layout instance specifying sensor positions (does not need to be
+ specified for Neuromag data). Or a list of Layout if projections
+ are from different sensor types.
+ sensors : bool | str
+ Add markers for sensor locations to the plot. Accepts matplotlib plot
+ format string (e.g., 'r+' for red plusses). If True, a circle will be
+ used (via .add_artist). Defaults to True.
contours : int | False | None
The number of contour lines to draw. If 0, no contours will be drawn.
- verbose : bool, str, int, or None
- If not None, override default verbose level (see mne.verbose).
+ cmap : None | matplotlib colormap
+ Colormap for the plot. If ``None``, defaults to 'Reds_r' for norm data,
+ otherwise to 'RdBu_r'.
Returns
-------
diff --git a/mne/preprocessing/infomax_.py b/mne/preprocessing/infomax_.py
index 7deb657..8b5b328 100644
--- a/mne/preprocessing/infomax_.py
+++ b/mne/preprocessing/infomax_.py
@@ -13,94 +13,98 @@ from ..utils import logger, verbose, check_random_state, random_permutation
@verbose
def infomax(data, weights=None, l_rate=None, block=None, w_change=1e-12,
- anneal_deg=60., anneal_step=0.9, extended=False, n_subgauss=1,
- kurt_size=6000, ext_blocks=1, max_iter=200,
- random_state=None, blowup=1e4, blowup_fac=0.5, n_small_angle=20,
- use_bias=True, verbose=None):
- """Run the (extended) Infomax ICA decomposition on raw data
-
- based on the publications of Bell & Sejnowski 1995 (Infomax)
- and Lee, Girolami & Sejnowski, 1999 (extended Infomax)
+ anneal_deg=60., anneal_step=0.9, extended=True, n_subgauss=1,
+ kurt_size=6000, ext_blocks=1, max_iter=200, random_state=None,
+ blowup=1e4, blowup_fac=0.5, n_small_angle=20, use_bias=True,
+ verbose=None):
+ """Run (extended) Infomax ICA decomposition on raw data.
Parameters
----------
data : np.ndarray, shape (n_samples, n_features)
- The data to unmix.
- w_init : np.ndarray, shape (n_features, n_features)
- The initialized unmixing matrix. Defaults to None. If None, the
- identity matrix is used.
+ The whitened data to unmix.
+ weights : np.ndarray, shape (n_features, n_features)
+ The initialized unmixing matrix.
+ Defaults to None, which means the identity matrix is used.
l_rate : float
This quantity indicates the relative size of the change in weights.
- Note. Smaller learining rates will slow down the procedure.
- Defaults to 0.010d / alog(n_features ^ 2.0)
+ .. note:: Smaller learning rates will slow down the ICA procedure.
+ Defaults to 0.01 / log(n_features ** 2).
block : int
- The block size of randomly chosen data segment.
- Defaults to floor(sqrt(n_times / 3d))
+ The block size of randomly chosen data segments.
+ Defaults to floor(sqrt(n_times / 3.)).
w_change : float
The change at which to stop iteration. Defaults to 1e-12.
anneal_deg : float
- The angle at which (in degree) the learning rate will be reduced.
- Defaults to 60.0
+ The angle (in degrees) at which the learning rate will be reduced.
+ Defaults to 60.0.
anneal_step : float
The factor by which the learning rate will be reduced once
``anneal_deg`` is exceeded:
l_rate *= anneal_step
- Defaults to 0.9
+ Defaults to 0.9.
extended : bool
- Wheather to use the extended infomax algorithm or not. Defaults to
- True.
+ Whether to use the extended Infomax algorithm or not.
+ Defaults to True.
n_subgauss : int
The number of subgaussian components. Only considered for extended
- Infomax.
+ Infomax. Defaults to 1.
kurt_size : int
The window size for kurtosis estimation. Only considered for extended
- Infomax.
+ Infomax. Defaults to 6000.
ext_blocks : int
- Only considered for extended Infomax.
- If positive, it denotes the number of blocks after which to recompute
- the Kurtosis, which is used to estimate the signs of the sources.
- In this case the number of sub-gaussian sources is automatically
- determined.
+ Only considered for extended Infomax. If positive, denotes the number
+ of blocks after which to recompute the kurtosis, which is used to
+ estimate the signs of the sources. In this case, the number of
+ sub-gaussian sources is automatically determined.
If negative, the number of sub-gaussian sources to be used is fixed
- and equal to n_subgauss. In this case the Kurtosis is not estimated.
+ and equal to n_subgauss. In this case, the kurtosis is not estimated.
+ Defaults to 1.
max_iter : int
The maximum number of iterations. Defaults to 200.
random_state : int | np.random.RandomState
- If random_state is an int, use random_state as seed of the random
- number generator.
- If random_state is already a np.random.RandomState instance, use
- random_state as random number generator.
+ If random_state is an int, use random_state to seed the random number
+ generator. If random_state is already a np.random.RandomState instance,
+ use random_state as random number generator.
blowup : float
The maximum difference allowed between two successive estimations of
- the unmixing matrix. Defaults to 1e4
+ the unmixing matrix. Defaults to 10000.
blowup_fac : float
- The factor by which the learning rate will be reduced if the
- difference between two successive estimations of the
- unmixing matrix exceededs ``blowup``:
+ The factor by which the learning rate will be reduced if the difference
+ between two successive estimations of the unmixing matrix exceededs
+ ``blowup``:
l_rate *= blowup_fac
- Defaults to 0.5
+ Defaults to 0.5.
n_small_angle : int | None
The maximum number of allowed steps in which the angle between two
successive estimations of the unmixing matrix is less than
- ``anneal_deg``.
- If None, this parameter is not taken into account to stop the
- iterations.
- Defaults to 20
+ ``anneal_deg``. If None, this parameter is not taken into account to
+ stop the iterations.
+ Defaults to 20.
use_bias : bool
This quantity indicates if the bias should be computed.
- Defaults to True
+ Defaults to True.
verbose : bool, str, int, or None
- If not None, override default verbose level (see mne.verbose).
+ If not None, override default verbosity level (see mne.verbose).
Returns
-------
- unmixing_matrix : np.ndarray of float, shape (n_features, n_features)
+ unmixing_matrix : np.ndarray, shape (n_features, n_features)
The linear unmixing operator.
+
+ References
+ ----------
+ [1] A. J. Bell, T. J. Sejnowski. An information-maximization approach to
+ blind separation and blind deconvolution. Neural Computation, 7(6),
+ 1129-1159, 1995.
+ [2] T. W. Lee, M. Girolami, T. J. Sejnowski. Independent component analysis
+ using an extended infomax algorithm for mixed subgaussian and
+ supergaussian sources. Neural Computation, 11(2), 417-441, 1999.
"""
from scipy.stats import kurtosis
rng = check_random_state(random_state)
- # define some default parameter
+ # define some default parameters
max_weight = 1e8
restart_fac = 0.9
min_l_rate = 1e-10
@@ -116,26 +120,25 @@ def infomax(data, weights=None, l_rate=None, block=None, w_change=1e-12,
n_samples, n_features = data.shape
n_features_square = n_features ** 2
- # check input parameter
- # heuristic default - may need adjustment for
- # large or tiny data sets
+ # check input parameters
+ # heuristic default - may need adjustment for large or tiny data sets
if l_rate is None:
l_rate = 0.01 / math.log(n_features ** 2.0)
if block is None:
block = int(math.floor(math.sqrt(n_samples / 3.0)))
- logger.info('computing%sInfomax ICA' % ' Extended ' if extended is True
- else ' ')
+ logger.info('computing%sInfomax ICA' % ' Extended ' if extended else ' ')
- # collect parameter
+ # collect parameters
nblock = n_samples // block
lastt = (nblock - 1) * block + 1
# initialize training
if weights is None:
- # initialize weights as identity matrix
weights = np.identity(n_features, dtype=np.float64)
+ else:
+ weights = weights.T
BI = block * np.identity(n_features, dtype=np.float64)
bias = np.zeros((n_features, 1), dtype=np.float64)
@@ -150,7 +153,7 @@ def infomax(data, weights=None, l_rate=None, block=None, w_change=1e-12,
initial_ext_blocks = ext_blocks # save the initial value in case of reset
# for extended Infomax
- if extended is True:
+ if extended:
signs = np.ones(n_features)
for k in range(n_subgauss):
@@ -173,7 +176,7 @@ def infomax(data, weights=None, l_rate=None, block=None, w_change=1e-12,
u = np.dot(data[permute[t:t + block], :], weights)
u += np.dot(bias, onesrow).T
- if extended is True:
+ if extended:
# extended ICA update
y = np.tanh(u)
weights += l_rate * np.dot(weights,
@@ -206,10 +209,8 @@ def infomax(data, weights=None, l_rate=None, block=None, w_change=1e-12,
break
# ICA kurtosis estimation
- if extended is True:
-
+ if extended:
if ext_blocks > 0 and blockno % ext_blocks == 0:
-
if kurt_size < n_samples:
rp = np.floor(rng.uniform(0, 1, kurt_size) *
(n_samples - 1))
@@ -239,8 +240,7 @@ def infomax(data, weights=None, l_rate=None, block=None, w_change=1e-12,
ext_blocks = np.fix(ext_blocks * signcount_step)
signcount = 0
- # here we continue after the for
- # loop over the ICA training blocks
+ # here we continue after the for loop over the ICA training blocks
# if weights in bounds:
if not wts_blowup:
oldwtchange = weights - oldweights
@@ -262,10 +262,10 @@ def infomax(data, weights=None, l_rate=None, block=None, w_change=1e-12,
oldweights = weights.copy()
if angledelta > anneal_deg:
l_rate *= anneal_step # anneal learning rate
- # accumulate angledelta until anneal_deg reached l_rates
+ # accumulate angledelta until anneal_deg reaches l_rate
olddelta = delta
oldchange = change
- count_small_angle = 0 # reset count when angle delta is large
+ count_small_angle = 0 # reset count when angledelta is large
else:
if step == 1: # on first step only
olddelta = delta # initialize
@@ -282,8 +282,7 @@ def infomax(data, weights=None, l_rate=None, block=None, w_change=1e-12,
elif change > blowup:
l_rate *= blowup_fac
- # restart if weights blow up
- # (for lowering l_rate)
+ # restart if weights blow up (for lowering l_rate)
else:
step = 0 # start again
wts_blowup = 0 # re-initialize variables
diff --git a/mne/preprocessing/maxfilter.py b/mne/preprocessing/maxfilter.py
index c4d955b..6b0fb63 100644
--- a/mne/preprocessing/maxfilter.py
+++ b/mne/preprocessing/maxfilter.py
@@ -9,7 +9,7 @@ import os
from ..bem import fit_sphere_to_headshape
-from ..io import Raw
+from ..io import read_raw_fif
from ..utils import logger, verbose, warn
from ..externals.six.moves import map
@@ -136,7 +136,7 @@ def apply_maxfilter(in_fname, out_fname, origin=None, frame='device',
# determine the head origin if necessary
if origin is None:
logger.info('Estimating head origin from headshape points..')
- raw = Raw(in_fname)
+ raw = read_raw_fif(in_fname, add_eeg_ref=False)
r, o_head, o_dev = fit_sphere_to_headshape(raw.info, units='mm')
raw.close()
logger.info('[done]')
diff --git a/mne/preprocessing/maxwell.py b/mne/preprocessing/maxwell.py
index c2c72b4..6071485 100644
--- a/mne/preprocessing/maxwell.py
+++ b/mne/preprocessing/maxwell.py
@@ -6,6 +6,7 @@
# License: BSD (3-clause)
+from functools import partial
from math import factorial
from os import path as op
@@ -17,7 +18,7 @@ from ..bem import _check_origin
from ..chpi import quat_to_rot, rot_to_quat
from ..transforms import (_str_to_frame, _get_trans, Transform, apply_trans,
_find_vector_rotation)
-from ..forward import _concatenate_coils, _prep_meg_channels
+from ..forward import _concatenate_coils, _prep_meg_channels, _create_meg_coils
from ..surface import _normalize_vectors
from ..io.constants import FIFF
from ..io.proc_history import _read_ctc
@@ -25,7 +26,7 @@ from ..io.write import _generate_meas_id, _date_now
from ..io import _loc_to_coil_trans, _BaseRaw
from ..io.pick import pick_types, pick_info, pick_channels
from ..utils import verbose, logger, _clean_names, warn, _time_mask
-from ..fixes import _get_args, partial
+from ..fixes import _get_args, _safe_svd
from ..externals.six import string_types
from ..channels.channels import _get_T1T2_mag_inds
@@ -40,7 +41,8 @@ def maxwell_filter(raw, origin='auto', int_order=8, ext_order=3,
calibration=None, cross_talk=None, st_duration=None,
st_correlation=0.98, coord_frame='head', destination=None,
regularize='in', ignore_ref=False, bad_condition='error',
- head_pos=None, st_fixed=True, st_only=False, verbose=None):
+ head_pos=None, st_fixed=True, st_only=False, mag_scale=100.,
+ verbose=None):
"""Apply Maxwell filter to data using multipole moments
.. warning:: Automatic bad channel detection is not currently implemented.
@@ -136,6 +138,16 @@ def maxwell_filter(raw, origin='auto', int_order=8, ext_order=3,
.. versionadded:: 0.12
+ mag_scale : float | str
+ The magenetometer scale-factor used to bring the magnetometers
+ to approximately the same order of magnitude as the gradiometers
+ (default 100.), as they have different units (T vs T/m).
+ Can be ``'auto'`` to use the reciprocal of the physical distance
+ between the gradiometer pickup loops (e.g., 0.0168 m yields
+ 59.5 for VectorView).
+
+ .. versionadded:: 0.13
+
verbose : bool, str, int, or None
If not None, override default verbose level (see mne.verbose)
@@ -247,11 +259,12 @@ def maxwell_filter(raw, origin='auto', int_order=8, ext_order=3,
recon_trans = _check_destination(destination, raw.info, head_frame)
if st_duration is not None:
st_duration = float(st_duration)
- if not 0. < st_duration <= raw.times[-1]:
+ if not 0. < st_duration <= raw.times[-1] + 1. / raw.info['sfreq']:
raise ValueError('st_duration (%0.1fs) must be between 0 and the '
'duration of the data (%0.1fs).'
% (st_duration, raw.times[-1]))
st_correlation = float(st_correlation)
+ st_duration = int(round(st_duration * raw.info['sfreq']))
if not 0. < st_correlation <= 1:
raise ValueError('st_correlation must be between 0. and 1.')
if not isinstance(bad_condition, string_types) or \
@@ -279,9 +292,13 @@ def maxwell_filter(raw, origin='auto', int_order=8, ext_order=3,
raw, add_channels=add_channels)
del raw
_remove_meg_projs(raw_sss) # remove MEG projectors, they won't apply now
- info, times = raw_sss.info, raw_sss.times
- meg_picks, mag_picks, grad_picks, good_picks, coil_scale, mag_or_fine = \
- _get_mf_picks(info, int_order, ext_order, ignore_ref, mag_scale=100.)
+ info = raw_sss.info
+ meg_picks, mag_picks, grad_picks, good_picks, mag_or_fine = \
+ _get_mf_picks(info, int_order, ext_order, ignore_ref)
+
+ # Magnetometers are scaled to improve numerical stability
+ coil_scale, mag_scale = _get_coil_scale(
+ meg_picks, mag_picks, grad_picks, mag_scale, info)
#
# Fine calibration processing (load fine cal and overwrite sensor geometry)
@@ -289,11 +306,11 @@ def maxwell_filter(raw, origin='auto', int_order=8, ext_order=3,
sss_cal = dict()
if calibration is not None:
calibration, sss_cal = _update_sensor_geometry(info, calibration,
- head_frame, ignore_ref)
+ ignore_ref)
mag_or_fine.fill(True) # all channels now have some mag-type data
# Determine/check the origin of the expansion
- origin = _check_origin(origin, raw_sss.info, coord_frame, disp=True)
+ origin = _check_origin(origin, info, coord_frame, disp=True)
origin.setflags(write=False)
n_in, n_out = _get_n_moments([int_order, ext_order])
@@ -312,7 +329,9 @@ def maxwell_filter(raw, origin='auto', int_order=8, ext_order=3,
if len(missing) > 0:
warn('Not all cross-talk channels in raw:\n%s' % missing)
ctc_picks = pick_channels(ctc_chs,
- [info['ch_names'][c] for c in good_picks])
+ [info['ch_names'][c]
+ for c in meg_picks[good_picks]])
+ assert len(ctc_picks) == len(good_picks) # otherwise we errored
ctc = sss_ctc['decoupler'][ctc_picks][:, ctc_picks]
# I have no idea why, but MF transposes this for storage..
sss_ctc['decoupler'] = sss_ctc['decoupler'].T.tocsc()
@@ -322,8 +341,7 @@ def maxwell_filter(raw, origin='auto', int_order=8, ext_order=3,
#
# Translate to destination frame (always use non-fine-cal bases)
#
- exp = dict(origin=origin, int_order=int_order, ext_order=0,
- head_frame=head_frame)
+ exp = dict(origin=origin, int_order=int_order, ext_order=0)
all_coils = _prep_mf_coils(info, ignore_ref)
S_recon = _trans_sss_basis(exp, all_coils, recon_trans, coil_scale)
exp['ext_order'] = ext_order
@@ -341,29 +359,33 @@ def maxwell_filter(raw, origin='auto', int_order=8, ext_order=3,
# Reconstruct raw file object with spatiotemporal processed data
max_st = dict()
if st_duration is not None:
- max_st.update(job=10, subspcorr=st_correlation, buflen=st_duration)
+ max_st.update(job=10, subspcorr=st_correlation,
+ buflen=st_duration / info['sfreq'])
logger.info(' Processing data using tSSS with st_duration=%s'
- % st_duration)
+ % max_st['buflen'])
st_when = 'before' if st_fixed else 'after' # relative to movecomp
else:
- st_duration = min(raw_sss.times[-1], 10.) # chunk size
+ # st_duration from here on will act like the chunk size
+ st_duration = max(int(round(10. * info['sfreq'])), 1)
st_correlation = None
st_when = 'never'
+ st_duration = min(len(raw_sss.times), st_duration)
del st_fixed
- # Generate time points to break up data into windows
- chunk_times = np.arange(times[0], times[-1], st_duration)
- read_lims = raw_sss.time_as_index(chunk_times)
- len_last_buf = raw_sss.times[-1] - raw_sss.times[read_lims[-1]]
- if len_last_buf == st_duration:
+ # Generate time points to break up data into equal-length windows
+ read_lims = np.arange(0, len(raw_sss.times) + 1, st_duration)
+ if len(read_lims) == 1:
read_lims = np.concatenate([read_lims, [len(raw_sss.times)]])
- else:
- # len_last_buf < st_dur so fold it into the previous buffer
+ if read_lims[-1] != len(raw_sss.times):
read_lims[-1] = len(raw_sss.times)
- if st_correlation is not None:
+ # len_last_buf < st_dur so fold it into the previous buffer
+ if st_correlation is not None and len(read_lims) > 2:
logger.info(' Spatiotemporal window did not fit evenly into '
'raw object. The final %0.2f seconds were lumped '
- 'onto the previous window.' % len_last_buf)
+ 'onto the previous window.'
+ % ((read_lims[-1] - read_lims[-2]) / info['sfreq'],))
+ assert len(read_lims) >= 2
+ assert read_lims[0] == 0 and read_lims[-1] == len(raw_sss.times)
#
# Do the heavy lifting
@@ -373,7 +395,7 @@ def maxwell_filter(raw, origin='auto', int_order=8, ext_order=3,
# (and transform pos[1] to times)
head_pos[1] = raw_sss.time_as_index(head_pos[1], use_rounding=True)
# Compute the first bit of pos_data for cHPI reporting
- if raw_sss.info['dev_head_t'] is not None and head_pos[0] is not None:
+ if info['dev_head_t'] is not None and head_pos[0] is not None:
this_pos_quat = np.concatenate([
rot_to_quat(info['dev_head_t']['trans'][:3, :3]),
info['dev_head_t']['trans'][:3, 3],
@@ -385,19 +407,24 @@ def maxwell_filter(raw, origin='auto', int_order=8, ext_order=3,
cal=calibration, regularize=regularize,
exp=exp, ignore_ref=ignore_ref, coil_scale=coil_scale,
grad_picks=grad_picks, mag_picks=mag_picks, good_picks=good_picks,
- mag_or_fine=mag_or_fine, bad_condition=bad_condition)
+ mag_or_fine=mag_or_fine, bad_condition=bad_condition,
+ mag_scale=mag_scale)
S_decomp, pS_decomp, reg_moments, n_use_in = _get_this_decomp_trans(
- info['dev_head_t'])
+ info['dev_head_t'], t=0.)
reg_moments_0 = reg_moments.copy()
# Loop through buffer windows of data
+ n_sig = int(np.floor(np.log10(max(len(read_lims), 0)))) + 1
pl = 's' if len(read_lims) != 2 else ''
logger.info(' Processing %s data chunk%s of (at least) %0.1f sec'
- % (len(read_lims) - 1, pl, st_duration))
- for start, stop in zip(read_lims[:-1], read_lims[1:]):
- t_str = '% 8.2f - % 8.2f sec' % tuple(raw_sss.times[[start, stop - 1]])
+ % (len(read_lims) - 1, pl, st_duration / info['sfreq']))
+ for ii, (start, stop) in enumerate(zip(read_lims[:-1], read_lims[1:])):
+ rel_times = raw_sss.times[start:stop]
+ t_str = '%8.3f - %8.3f sec' % tuple(rel_times[[0, -1]])
+ t_str += ('(#%d/%d)'
+ % (ii + 1, len(read_lims) - 1)).rjust(2 * n_sig + 5)
# Get original data
- orig_data = raw_sss._data[good_picks, start:stop]
+ orig_data = raw_sss._data[meg_picks[good_picks], start:stop]
# This could just be np.empty if not st_only, but shouldn't be slow
# this way so might as well just always take the original data
out_meg_data = raw_sss._data[meg_picks, start:stop]
@@ -422,7 +449,7 @@ def maxwell_filter(raw, origin='auto', int_order=8, ext_order=3,
if avg_trans is not None:
# if doing movecomp
S_decomp_st, pS_decomp_st, _, n_use_in_st = \
- _get_this_decomp_trans(avg_trans, verbose=False)
+ _get_this_decomp_trans(avg_trans, t=rel_times[0])
else:
S_decomp_st, pS_decomp_st = S_decomp, pS_decomp
n_use_in_st = n_use_in
@@ -446,7 +473,7 @@ def maxwell_filter(raw, origin='auto', int_order=8, ext_order=3,
# previous interval)
if trans is not None:
S_decomp, pS_decomp, reg_moments, n_use_in = \
- _get_this_decomp_trans(trans, verbose=False)
+ _get_this_decomp_trans(trans, t=rel_times[rel_start])
# Determine multipole moments for this interval
mm_in = np.dot(pS_decomp[:n_use_in],
@@ -477,7 +504,7 @@ def maxwell_filter(raw, origin='auto', int_order=8, ext_order=3,
if st_when == 'after':
_do_tSSS(out_meg_data, orig_in_data, resid, st_correlation,
n_positions, t_str)
- else:
+ elif st_when == 'never' and head_pos[0] is not None:
pl = 's' if n_positions > 1 else ''
logger.info(' Used % 2d head position%s for %s'
% (n_positions, pl, t_str))
@@ -493,6 +520,37 @@ def maxwell_filter(raw, origin='auto', int_order=8, ext_order=3,
return raw_sss
+def _get_coil_scale(meg_picks, mag_picks, grad_picks, mag_scale, info):
+ """Helper to get the magnetometer scale factor"""
+ if isinstance(mag_scale, string_types):
+ if mag_scale != 'auto':
+ raise ValueError('mag_scale must be a float or "auto", got "%s"'
+ % mag_scale)
+ if len(mag_picks) in (0, len(meg_picks)):
+ mag_scale = 100. # only one coil type, doesn't matter
+ logger.info(' Setting mag_scale=%0.2f because only one '
+ 'coil type is present' % mag_scale)
+ else:
+ # Find our physical distance between gradiometer pickup loops
+ # ("base line")
+ coils = _create_meg_coils(pick_info(info, meg_picks)['chs'],
+ 'accurate')
+ grad_base = set(coils[pick]['base'] for pick in grad_picks)
+ if len(grad_base) != 1 or list(grad_base)[0] <= 0:
+ raise RuntimeError('Could not automatically determine '
+ 'mag_scale, could not find one '
+ 'proper gradiometer distance from: %s'
+ % list(grad_base))
+ grad_base = list(grad_base)[0]
+ mag_scale = 1. / grad_base
+ logger.info(' Setting mag_scale=%0.2f based on gradiometer '
+ 'distance %0.2f mm' % (mag_scale, 1000 * grad_base))
+ mag_scale = float(mag_scale)
+ coil_scale = np.ones((len(meg_picks), 1))
+ coil_scale[mag_picks] = mag_scale
+ return coil_scale, mag_scale
+
+
def _remove_meg_projs(inst):
"""Helper to remove inplace existing MEG projectors (assumes inactive)"""
meg_picks = pick_types(inst.info, meg=True, exclude=[])
@@ -620,10 +678,10 @@ def _do_tSSS(clean_data, orig_in_data, resid, st_correlation,
np.asarray_chkfinite(resid)
t_proj = _overlap_projector(orig_in_data, resid, st_correlation)
# Apply projector according to Eq. 12 in [2]_
- msg = (' Projecting % 2d intersecting tSSS components '
+ msg = (' Projecting %2d intersecting tSSS components '
'for %s' % (t_proj.shape[1], t_str))
if n_positions > 1:
- msg += ' (across % 2d positions)' % n_positions
+ msg += ' (across %2d positions)' % n_positions
logger.info(msg)
clean_data -= np.dot(np.dot(clean_data, t_proj), t_proj.T)
@@ -693,6 +751,11 @@ def _check_pos(pos, head_frame, raw, st_fixed, sfreq):
raise ValueError('Head position time points must be greater than '
'first sample offset, but found %0.4f < %0.4f'
% (t[0], t_off))
+ max_dist = np.sqrt(np.sum(pos[:, 4:7] ** 2, axis=1)).max()
+ if max_dist > 1.:
+ warn('Found a distance greater than 1 m (%0.3g m) from the device '
+ 'origin, positions may be invalid and Maxwell filtering could '
+ 'fail' % (max_dist,))
dev_head_ts = np.zeros((len(t), 4, 4))
dev_head_ts[:, 3, 3] = 1.
dev_head_ts[:, :3, 3] = pos[:, 4:7]
@@ -701,33 +764,24 @@ def _check_pos(pos, head_frame, raw, st_fixed, sfreq):
return pos
- at verbose
def _get_decomp(trans, all_coils, cal, regularize, exp, ignore_ref,
coil_scale, grad_picks, mag_picks, good_picks, mag_or_fine,
- bad_condition, verbose=None):
- """Helper to get a decomposition matrix"""
+ bad_condition, t, mag_scale):
+ """Helper to get a decomposition matrix and pseudoinverse matrices"""
#
# Fine calibration processing (point-like magnetometers and calib. coeffs)
#
- S_decomp = _trans_sss_basis(exp, all_coils, trans, coil_scale)
- if cal is not None:
- # Compute point-like mags to incorporate gradiometer imbalance
- cal['grad_cals'] = _sss_basis_point(exp, trans, cal, ignore_ref)
- # Add point like magnetometer data to bases.
- S_decomp[grad_picks, :] += cal['grad_cals']
- # Scale magnetometers by calibration coefficient
- S_decomp[mag_picks, :] /= cal['mag_cals']
- # We need to be careful about KIT gradiometers
- S_decomp = S_decomp[good_picks]
+ S_decomp = _get_s_decomp(exp, all_coils, trans, coil_scale, cal,
+ ignore_ref, grad_picks, mag_picks, good_picks,
+ mag_scale)
#
# Regularization
#
- reg_moments, n_use_in = _regularize(regularize, exp, S_decomp, mag_or_fine)
- S_decomp = S_decomp.take(reg_moments, axis=1)
+ S_decomp, pS_decomp, sing, reg_moments, n_use_in = _regularize(
+ regularize, exp, S_decomp, mag_or_fine, t=t)
# Pseudo-inverse of total multipolar moment basis set (Part of Eq. 37)
- pS_decomp, sing = _col_norm_pinv(S_decomp.copy())
cond = sing[0] / sing[-1]
logger.debug(' Decomposition matrix condition: %0.1f' % cond)
if bad_condition != 'ignore' and cond >= 1000.:
@@ -743,13 +797,31 @@ def _get_decomp(trans, all_coils, cal, regularize, exp, ignore_ref,
return S_decomp, pS_decomp, reg_moments, n_use_in
-def _regularize(regularize, exp, S_decomp, mag_or_fine):
+def _get_s_decomp(exp, all_coils, trans, coil_scale, cal, ignore_ref,
+ grad_picks, mag_picks, good_picks, mag_scale):
+ """Helper to get S_decomp"""
+ S_decomp = _trans_sss_basis(exp, all_coils, trans, coil_scale)
+ if cal is not None:
+ # Compute point-like mags to incorporate gradiometer imbalance
+ grad_cals = _sss_basis_point(exp, trans, cal, ignore_ref, mag_scale)
+ # Add point like magnetometer data to bases.
+ S_decomp[grad_picks, :] += grad_cals
+ # Scale magnetometers by calibration coefficient
+ S_decomp[mag_picks, :] /= cal['mag_cals']
+ # We need to be careful about KIT gradiometers
+ S_decomp = S_decomp[good_picks]
+ return S_decomp
+
+
+ at verbose
+def _regularize(regularize, exp, S_decomp, mag_or_fine, t, verbose=None):
"""Regularize a decomposition matrix"""
# ALWAYS regularize the out components according to norm, since
# gradiometer-only setups (e.g., KIT) can have zero first-order
# components
int_order, ext_order = exp['int_order'], exp['ext_order']
n_in, n_out = _get_n_moments([int_order, ext_order])
+ t_str = '%8.3f' % t
if regularize is not None: # regularize='in'
logger.info(' Computing regularization')
in_removes, out_removes = _regularize_in(
@@ -762,15 +834,18 @@ def _regularize(regularize, exp, S_decomp, mag_or_fine):
out_removes)
n_use_in = len(reg_in_moments)
n_use_out = len(reg_out_moments)
- if regularize is not None or n_use_out != n_out:
- logger.info(' Using %s/%s inside and %s/%s outside harmonic '
- 'components' % (n_use_in, n_in, n_use_out, n_out))
reg_moments = np.concatenate((reg_in_moments, reg_out_moments))
- return reg_moments, n_use_in
+ S_decomp = S_decomp.take(reg_moments, axis=1)
+ pS_decomp, sing = _col_norm_pinv(S_decomp.copy())
+ if regularize is not None or n_use_out != n_out:
+ logger.info(' Using %s/%s harmonic components for %s '
+ '(%s/%s in, %s/%s out)'
+ % (n_use_in + n_use_out, n_in + n_out, t_str,
+ n_use_in, n_in, n_use_out, n_out))
+ return S_decomp, pS_decomp, sing, reg_moments, n_use_in
-def _get_mf_picks(info, int_order, ext_order, ignore_ref=False,
- mag_scale=100.):
+def _get_mf_picks(info, int_order, ext_order, ignore_ref=False):
"""Helper to pick types for Maxwell filtering"""
# Check for T1/T2 mag types
mag_inds_T1T2 = _get_T1T2_mag_inds(info)
@@ -798,16 +873,13 @@ def _get_mf_picks(info, int_order, ext_order, ignore_ref=False,
ref_meg = False if ignore_ref else 'grad'
grad_picks = pick_types(meg_info, meg='grad', ref_meg=ref_meg, exclude=[])
assert len(mag_picks) + len(grad_picks) == len(meg_info['ch_names'])
- # Magnetometers are scaled by 100 to improve numerical stability
- coil_scale = np.ones((len(meg_picks), 1))
- coil_scale[mag_picks] = 100.
# Determine which are magnetometers for external basis purposes
mag_or_fine = np.zeros(len(meg_picks), bool)
mag_or_fine[mag_picks] = True
# KIT gradiometers are marked as having units T, not T/M (argh)
# We need a separate variable for this because KIT grads should be
# treated mostly like magnetometers (e.g., scaled by 100) for reg
- mag_or_fine[np.array([ch['coil_type'] == FIFF.FIFFV_COIL_KIT_GRAD
+ mag_or_fine[np.array([ch['coil_type'] & 0xFFFF == FIFF.FIFFV_COIL_KIT_GRAD
for ch in meg_info['chs']], bool)] = False
msg = (' Processing %s gradiometers and %s magnetometers'
% (len(grad_picks), len(mag_picks)))
@@ -815,8 +887,7 @@ def _get_mf_picks(info, int_order, ext_order, ignore_ref=False,
if n_kit > 0:
msg += ' (of which %s are actually KIT gradiometers)' % n_kit
logger.info(msg)
- return (meg_picks, mag_picks, grad_picks, good_picks, coil_scale,
- mag_or_fine)
+ return meg_picks, mag_picks, grad_picks, good_picks, mag_or_fine
def _check_regularize(regularize):
@@ -830,14 +901,11 @@ def _check_usable(inst):
"""Helper to ensure our data are clean"""
if inst.proj:
raise RuntimeError('Projectors cannot be applied to data.')
- if hasattr(inst, 'comp'):
- if inst.comp is not None:
- raise RuntimeError('Maxwell filter cannot be done on compensated '
- 'channels.')
- else:
- if len(inst.info['comps']) > 0: # more conservative check
- raise RuntimeError('Maxwell filter cannot be done on data that '
- 'might have been compensated.')
+ current_comp = inst.compensation_grade
+ if current_comp not in (0, None):
+ raise RuntimeError('Maxwell filter cannot be done on compensated '
+ 'channels, but data have been compensated with '
+ 'grade %s.' % current_comp)
def _col_norm_pinv(x):
@@ -976,7 +1044,7 @@ def _sss_basis_basic(exp, coils, mag_scale=100., method='standard'):
S_in = S_tot[:, :n_in]
S_out = S_tot[:, n_in:]
coil_scale = np.ones((len(coils), 1))
- coil_scale[_get_mag_mask(coils)] = 100.
+ coil_scale[_get_mag_mask(coils)] = mag_scale
# Compute internal/external basis vectors (exclude degree 0; L/RHS Eq. 5)
for degree in range(1, max(int_order, ext_order) + 1):
@@ -1540,8 +1608,7 @@ if 'check_finite' in _get_args(linalg.svd):
def _orth_overwrite(A):
"""Helper to create a slightly more efficient 'orth'"""
# adapted from scipy/linalg/decomp_svd.py
- u, s = linalg.svd(A, overwrite_a=True, full_matrices=False,
- **check_disable)[:2]
+ u, s = _safe_svd(A, full_matrices=False, **check_disable)[:2]
M, N = A.shape
eps = np.finfo(float).eps
tol = max(M, N) * np.amax(s) * eps
@@ -1618,45 +1685,63 @@ def _read_fine_cal(fine_cal):
return cal_chs, cal_ch_numbers
-def _update_sensor_geometry(info, fine_cal, head_frame, ignore_ref):
+def _update_sensor_geometry(info, fine_cal, ignore_ref):
"""Helper to replace sensor geometry information and reorder cal_chs"""
+ from ._fine_cal import read_fine_calibration
logger.info(' Using fine calibration %s' % op.basename(fine_cal))
- cal_chs, cal_ch_numbers = _read_fine_cal(fine_cal)
-
- # Check that we ended up with correct channels
- meg_info = pick_info(info, pick_types(info, meg=True, exclude=[]))
- clean_meg_names = _clean_names(meg_info['ch_names'],
- remove_whitespace=True)
- cal_names = [c['ch_name'] for c in cal_chs]
- order = pick_channels(cal_names, clean_meg_names)
- if meg_info['nchan'] != len(order):
- raise RuntimeError('Not all MEG channels found in fine calibration '
- 'file, missing:\n%s'
- % sorted(list(set(clean_meg_names) -
- set(cal_names))))
- # ensure they're ordered like our data
- cal_chs = [cal_chs[ii] for ii in order]
+ fine_cal = read_fine_calibration(fine_cal) # filename -> dict
+ ch_names = _clean_names(info['ch_names'], remove_whitespace=True)
+ info_order = pick_channels(ch_names, fine_cal['ch_names'])
+ meg_picks = pick_types(info, meg=True, exclude=[])
+ if len(set(info_order) - set(meg_picks)) != 0:
+ # this should never happen
+ raise RuntimeError('Found channels in cal file that are not marked '
+ 'as MEG channels in the data file')
+ if len(info_order) != len(meg_picks):
+ raise RuntimeError(
+ 'Not all MEG channels found in fine calibration file, missing:\n%s'
+ % sorted(list(set(ch_names[pick] for pick in meg_picks) -
+ set(fine_cal['ch_names']))))
+ rev_order = np.argsort(info_order)
+ rev_grad = rev_order[np.in1d(meg_picks,
+ pick_types(info, meg='grad', exclude=()))]
+ rev_mag = rev_order[np.in1d(meg_picks,
+ pick_types(info, meg='mag', exclude=()))]
+
+ # Determine gradiometer imbalances and magnetometer calibrations
+ grad_imbalances = np.array([fine_cal['imb_cals'][ri] for ri in rev_grad]).T
+ if grad_imbalances.shape[0] not in [1, 3]:
+ raise ValueError('Must have 1 (x) or 3 (x, y, z) point-like ' +
+ 'magnetometers. Currently have %i' %
+ grad_imbalances.shape[0])
+ mag_cals = np.array([fine_cal['imb_cals'][ri] for ri in rev_mag])
+ del rev_order, rev_grad, rev_mag
+ # Now let's actually construct our point-like adjustment coils for grads
+ grad_coilsets = _get_grad_point_coilsets(
+ info, n_types=len(grad_imbalances), ignore_ref=ignore_ref)
+ calibration = dict(grad_imbalances=grad_imbalances,
+ grad_coilsets=grad_coilsets, mag_cals=mag_cals)
# Replace sensor locations (and track differences) for fine calibration
- ang_shift = np.zeros((len(cal_chs), 3))
+ ang_shift = np.zeros((len(fine_cal['ch_names']), 3))
used = np.zeros(len(info['chs']), bool)
cal_corrs = list()
- coil_types = list()
- grad_picks = pick_types(meg_info, meg='grad')
+ cal_chans = list()
+ grad_picks = pick_types(info, meg='grad', exclude=())
adjust_logged = False
- clean_info_names = _clean_names(info['ch_names'], remove_whitespace=True)
- for ci, cal_ch in enumerate(cal_chs):
- idx = clean_info_names.index(cal_ch['ch_name'])
- assert not used[idx]
- used[idx] = True
- info_ch = info['chs'][idx]
- coil_types.append(info_ch['coil_type'])
+ for ci, info_idx in enumerate(info_order):
+ assert ch_names[info_idx] == fine_cal['ch_names'][ci]
+ assert not used[info_idx]
+ used[info_idx] = True
+ info_ch = info['chs'][info_idx]
+ ch_num = int(fine_cal['ch_names'][ci].lstrip('MEG').lstrip('0'))
+ cal_chans.append([ch_num, info_ch['coil_type']])
# Some .dat files might only rotate EZ, so we must check first that
# EX and EY are orthogonal to EZ. If not, we find the rotation between
# the original and fine-cal ez, and rotate EX and EY accordingly:
ch_coil_rot = _loc_to_coil_trans(info_ch['loc'])[:3, :3]
- cal_loc = cal_ch['loc'].copy()
+ cal_loc = fine_cal['locs'][ci].copy()
cal_coil_rot = _loc_to_coil_trans(cal_loc)[:3, :3]
if np.max([np.abs(np.dot(cal_coil_rot[:, ii], cal_coil_rot[:, 2]))
for ii in range(2)]) > 1e-6: # X or Y not orthogonal
@@ -1669,51 +1754,34 @@ def _update_sensor_geometry(info, fine_cal, head_frame, ignore_ref):
cal_loc[3:] = np.dot(this_trans, ch_coil_rot).T.ravel()
# calculate shift angle
- v1 = _loc_to_coil_trans(cal_ch['loc'])[:3, :3]
+ v1 = _loc_to_coil_trans(cal_loc)[:3, :3]
_normalize_vectors(v1)
v2 = _loc_to_coil_trans(info_ch['loc'])[:3, :3]
_normalize_vectors(v2)
ang_shift[ci] = np.sum(v1 * v2, axis=0)
- if idx in grad_picks:
- extra = [1., cal_ch['calib_coeff'][0]]
+ if info_idx in grad_picks:
+ extra = [1., fine_cal['imb_cals'][ci][0]]
else:
- extra = [cal_ch['calib_coeff'][0], 0.]
+ extra = [fine_cal['imb_cals'][ci][0], 0.]
cal_corrs.append(np.concatenate([extra, cal_loc]))
# Adjust channel normal orientations with those from fine calibration
# Channel positions are not changed
info_ch['loc'][3:] = cal_loc[3:]
- assert (info_ch['coord_frame'] == cal_ch['coord_frame'] ==
- FIFF.FIFFV_COORD_DEVICE)
- cal_chans = [[sc, ct] for sc, ct in zip(cal_ch_numbers, coil_types)]
+ assert (info_ch['coord_frame'] == FIFF.FIFFV_COORD_DEVICE)
+ assert used[meg_picks].all()
+ assert not used[np.setdiff1d(np.arange(len(used)), meg_picks)].any()
+ # This gets written to the Info struct
sss_cal = dict(cal_corrs=np.array(cal_corrs),
cal_chans=np.array(cal_chans))
+ # Log quantification of sensor changes
# Deal with numerical precision giving absolute vals slightly more than 1.
np.clip(ang_shift, -1., 1., ang_shift)
np.rad2deg(np.arccos(ang_shift), ang_shift) # Convert to degrees
-
- # Log quantification of sensor changes
logger.info(' Adjusted coil positions by (μ ± σ): '
'%0.1f° ± %0.1f° (max: %0.1f°)' %
(np.mean(ang_shift), np.std(ang_shift),
np.max(np.abs(ang_shift))))
-
- # Determine gradiometer imbalances and magnetometer calibrations
- grad_picks = pick_types(info, meg='grad', exclude=[])
- mag_picks = pick_types(info, meg='mag', exclude=[])
- grad_imbalances = np.array([cal_chs[ii]['calib_coeff']
- for ii in grad_picks]).T
- if grad_imbalances.shape[0] not in [1, 3]:
- raise ValueError('Must have 1 (x) or 3 (x, y, z) point-like ' +
- 'magnetometers. Currently have %i' %
- grad_imbalances.shape[0])
- mag_cals = np.array([cal_chs[ii]['calib_coeff'] for ii in mag_picks])
-
- # Now let's actually construct our point-like adjustment coils for grads
- grad_coilsets = _get_grad_point_coilsets(
- info, n_types=len(grad_imbalances), ignore_ref=ignore_ref)
- calibration = dict(grad_imbalances=grad_imbalances,
- grad_coilsets=grad_coilsets, mag_cals=mag_cals)
return calibration, sss_cal
@@ -1845,10 +1913,10 @@ def _regularize_in(int_order, ext_order, S_decomp, mag_or_fine):
'Removing in component %s: l=%s, m=%+0.0f'
% (tuple(eigs[ii]) + (eigs[ii, 0] / eigs[ii, 1],
ri, degrees[ri], orders[ri])))
- logger.info(' Resulting information: %0.1f bits/sample '
- '(%0.1f%% of peak %0.1f)'
- % (I_tots[lim_idx], 100 * I_tots[lim_idx] / max_info,
- max_info))
+ logger.debug(' Resulting information: %0.1f bits/sample '
+ '(%0.1f%% of peak %0.1f)'
+ % (I_tots[lim_idx], 100 * I_tots[lim_idx] / max_info,
+ max_info))
return in_removes, out_removes
@@ -1901,6 +1969,7 @@ def _trans_sss_basis(exp, all_coils, trans=None, coil_scale=100.):
if trans is not None:
if not isinstance(trans, Transform):
trans = Transform('meg', 'head', trans)
+ assert not np.isnan(trans['trans']).any()
all_coils = (apply_trans(trans, all_coils[0]),
apply_trans(trans, all_coils[1], move=False),
) + all_coils[2:]
diff --git a/mne/preprocessing/ssp.py b/mne/preprocessing/ssp.py
index 758e55b..ccaf05b 100644
--- a/mne/preprocessing/ssp.py
+++ b/mne/preprocessing/ssp.py
@@ -180,10 +180,13 @@ def _compute_exg_proj(mode, raw, raw_event, tmin, tmax,
picks = pick_types(my_info, meg=True, eeg=True, eog=True, ref_meg=False,
exclude='bads')
raw.filter(l_freq, h_freq, picks=picks, filter_length=filter_length,
- n_jobs=n_jobs, method=filter_method, iir_params=iir_params)
+ n_jobs=n_jobs, method=filter_method, iir_params=iir_params,
+ l_trans_bandwidth=0.5, h_trans_bandwidth=0.5,
+ phase='zero-double')
epochs = Epochs(raw, events, None, tmin, tmax, baseline=None, preload=True,
- picks=picks, reject=reject, flat=flat, proj=True)
+ picks=picks, reject=reject, flat=flat, proj=True,
+ add_eeg_ref=False)
epochs.drop_bad()
if epochs.events.shape[0] < 1:
diff --git a/mne/preprocessing/tests/test_ecg.py b/mne/preprocessing/tests/test_ecg.py
index 92b6602..165b9db 100644
--- a/mne/preprocessing/tests/test_ecg.py
+++ b/mne/preprocessing/tests/test_ecg.py
@@ -1,10 +1,12 @@
import os.path as op
+import warnings
from nose.tools import assert_true, assert_equal
-from mne.io import Raw
+from mne.io import read_raw_fif
from mne import pick_types
from mne.preprocessing.ecg import find_ecg_events, create_ecg_epochs
+from mne.utils import run_tests_if_main
data_path = op.join(op.dirname(__file__), '..', '..', 'io', 'tests', 'data')
raw_fname = op.join(data_path, 'test_raw.fif')
@@ -13,15 +15,15 @@ proj_fname = op.join(data_path, 'test-proj.fif')
def test_find_ecg():
- """Test find ECG peaks"""
- raw = Raw(raw_fname)
+ """Test find ECG peaks."""
+ raw = read_raw_fif(raw_fname, add_eeg_ref=False)
# once with mag-trick
# once with characteristic channel
for ch_name in ['MEG 1531', None]:
events, ch_ECG, average_pulse, ecg = find_ecg_events(
- raw, event_id=999, ch_name=None, return_ecg=True)
- assert_equal(len(raw.times), len(ecg))
+ raw, event_id=999, ch_name=ch_name, return_ecg=True)
+ assert_equal(raw.n_times, ecg.shape[-1])
n_events = len(events)
_, times = raw[0, :]
assert_true(55 < average_pulse < 60)
@@ -42,3 +44,15 @@ def test_find_ecg():
eog=False, ecg=True, emg=False, ref_meg=False,
exclude='bads')
assert_true(len(picks) == 1)
+
+ ecg_epochs = create_ecg_epochs(raw, ch_name='MEG 2641')
+ assert_true('MEG 2641' in ecg_epochs.ch_names)
+
+ # test with user provided ecg channel
+ raw.info['projs'] = list()
+ with warnings.catch_warnings(record=True) as w:
+ raw.set_channel_types({'MEG 2641': 'ecg'})
+ assert_true(len(w) == 1 and 'unit for channel' in str(w[0].message))
+ create_ecg_epochs(raw)
+
+run_tests_if_main()
diff --git a/mne/preprocessing/tests/test_eeglab_infomax.py b/mne/preprocessing/tests/test_eeglab_infomax.py
index 0711441..90ffcd8 100644
--- a/mne/preprocessing/tests/test_eeglab_infomax.py
+++ b/mne/preprocessing/tests/test_eeglab_infomax.py
@@ -1,27 +1,28 @@
+import os.path as op
+import warnings
+
import numpy as np
+from numpy.testing import assert_almost_equal
-from scipy.linalg import svd
+from scipy.linalg import svd, pinv
+import scipy.io as sio
-from mne.io import Raw
+from mne.io import read_raw_fif
from mne import pick_types
-
-import scipy.io as sio
-from scipy.linalg import pinv
from mne.preprocessing.infomax_ import infomax
-from numpy.testing import assert_almost_equal
-from mne.utils import random_permutation
+from mne.utils import random_permutation, slow_test
from mne.datasets import testing
-import os.path as op
base_dir = op.join(op.dirname(__file__), 'data')
def generate_data_for_comparing_against_eeglab_infomax(ch_type, random_state):
+ """Generate data."""
data_dir = op.join(testing.data_path(download=False), 'MEG', 'sample')
raw_fname = op.join(data_dir, 'sample_audvis_trunc_raw.fif')
- raw = Raw(raw_fname, preload=True)
+ raw = read_raw_fif(raw_fname, preload=True, add_eeg_ref=False)
if ch_type == 'eeg':
picks = pick_types(raw.info, meg=False, eeg=True, exclude='bads')
@@ -34,7 +35,13 @@ def generate_data_for_comparing_against_eeglab_infomax(ch_type, random_state):
idx_perm = random_permutation(picks.shape[0], random_state)
picks = picks[idx_perm[:number_of_channels_to_use]]
- raw.filter(1, 45, n_jobs=2)
+ with warnings.catch_warnings(record=True): # deprecated params
+ raw.filter(1, 45, picks=picks)
+ # Eventually we will need to add these, but for now having none of
+ # them is a nice deprecation sanity check.
+ # filter_length='10s',
+ # l_trans_bandwidth=0.5, h_trans_bandwidth=0.5,
+ # phase='zero-double', fir_window='hann') # use the old way
X = raw[picks, :][0][:, ::20]
# Subtract the mean
@@ -55,150 +62,120 @@ def generate_data_for_comparing_against_eeglab_infomax(ch_type, random_state):
return Y
+ at slow_test
@testing.requires_testing_data
def test_mne_python_vs_eeglab():
- """ Test eeglab vs mne_python infomax code.
- """
+ """ Test eeglab vs mne_python infomax code."""
random_state = 42
- methods = ['infomax', 'infomax', 'extended_infomax', 'extended_infomax']
- list_ch_types = ['eeg', 'mag', 'eeg', 'mag']
-
- for method, ch_type in zip(methods, list_ch_types):
-
- if method == 'infomax':
- if ch_type == 'eeg':
- eeglab_results_file = 'eeglab_infomax_results_eeg_data.mat'
- elif ch_type == 'mag':
- eeglab_results_file = 'eeglab_infomax_results_meg_data.mat'
-
- elif method == 'extended_infomax':
-
- if ch_type == 'eeg':
- eeglab_results_file = ('eeglab_extended_infomax_results_eeg_'
- 'data.mat')
- elif ch_type == 'mag':
- eeglab_results_file = ('eeglab_extended_infomax_results_meg_'
- 'data.mat')
-
- Y = generate_data_for_comparing_against_eeglab_infomax(ch_type,
- random_state)
- N = Y.shape[0]
- T = Y.shape[1]
-
- # For comparasion against eeglab, make sure the following
- # parameters have the same value in mne_python and eeglab:
- #
- # - starting point
- # - random state
- # - learning rate
- # - block size
- # - blowup parameter
- # - blowup_fac parameter
- # - tolerance for stopping the algorithm
- # - number of iterations
- # - anneal_step parameter
- #
- # Notes:
- # * By default, eeglab whiten the data using the "sphering transform"
- # instead of pca. The mne_python infomax code does not
- # whiten the data. To make sure both mne_python and eeglab starts
- # from the same point (i.e., the same matrix), we need to make sure
- # to whiten the data outside, and pass these whiten data to
- # mne_python and eeglab. Finally, we need to tell eeglab that
- # the input data is already whiten, this can be done by calling
- # eeglab with the following syntax:
- #
- # % Run infomax
- # [unmixing,sphere,meanvar,bias,signs,lrates,sources,y] = ...
- # runica( Y, 'sphering', 'none');
- #
- # % Run extended infomax
- # [unmixing,sphere,meanvar,bias,signs,lrates,sources,y] = ...
- # runica( Y, 'sphering', 'none', 'extended', 1);
- #
- # By calling eeglab using the former code, we are using its default
- # parameters, which are specified below in the section
- # "EEGLAB default parameters".
- #
- # * eeglab does not expose a parameter for fixing the random state.
- # Therefore, to accomplish this, we need to edit the runica.m
- # file located at /path_to_eeglab/functions/sigprocfunc/runica.m
- #
- # i) Comment the line related with the random number generator
- # (line 812).
- # ii) Then, add the following line just below line 812:
- # rng(42); %use 42 as random seed.
- #
- # * eeglab does not have the parameter "n_small_angle",
- # so we need to disable it for making a fair comparison.
- #
- # * Finally, we need to take the unmixing matrix estimated by the
- # mne_python infomax implementation and order the components
- # in the same way that eeglab does. This is done below in the section
- # "Order the components in the same way that eeglab does".
-
- ###############################################################
- # EEGLAB default parameters
- ###############################################################
- l_rate_eeglab = 0.00065 / np.log(N)
- block_eeglab = int(np.ceil(np.min([5 * np.log(T), 0.3 * T])))
- blowup_eeglab = 1e9
- blowup_fac_eeglab = 0.8
- max_iter_eeglab = 512
-
- if method == 'infomax':
- anneal_step_eeglab = 0.9
- use_extended = False
-
- elif method == 'extended_infomax':
- anneal_step_eeglab = 0.98
- use_extended = True
-
- if N > 32:
- w_change_eeglab = 1e-7
- else:
- w_change_eeglab = 1e-6
- ###############################################################
-
- # Call mne_python infomax version using the following sintax
- # to obtain the same result than eeglab version
- unmixing = infomax(Y.T, extended=use_extended,
- random_state=random_state,
- max_iter=max_iter_eeglab,
- l_rate=l_rate_eeglab,
- block=block_eeglab,
- w_change=w_change_eeglab,
- blowup=blowup_eeglab,
- blowup_fac=blowup_fac_eeglab,
- n_small_angle=None,
- anneal_step=anneal_step_eeglab
- )
-
- #######################################################################
- # Order the components in the same way that eeglab does
- #######################################################################
-
- sources = np.dot(unmixing, Y)
- mixing = pinv(unmixing)
-
- mvar = np.sum(mixing ** 2, axis=0) * \
- np.sum(sources ** 2, axis=1) / (N * T - 1)
- windex = np.argsort(mvar)[::-1]
-
- unmixing_ordered = unmixing[windex, :]
- #######################################################################
-
- #######################################################################
- # Load the eeglab results, then compare the unmixing matrices estimated
- # by mne_python and eeglab. To make the comparison use the
- # \ell_inf norm:
- # ||unmixing_mne_python - unmixing_eeglab||_inf
- #######################################################################
-
- eeglab_data = sio.loadmat(op.join(base_dir, eeglab_results_file))
- unmixing_eeglab = eeglab_data['unmixing_eeglab']
-
- maximum_difference = np.max(np.abs(unmixing_ordered - unmixing_eeglab))
-
- assert_almost_equal(maximum_difference, 1e-12, decimal=10)
+ methods = ['infomax', 'extended_infomax']
+ ch_types = ['eeg', 'mag']
+ for ch_type in ch_types:
+ Y = generate_data_for_comparing_against_eeglab_infomax(
+ ch_type, random_state)
+ N, T = Y.shape
+ for method in methods:
+ eeglab_results_file = ('eeglab_%s_results_%s_data.mat'
+ % (method,
+ dict(eeg='eeg', mag='meg')[ch_type]))
+
+ # For comparasion against eeglab, make sure the following
+ # parameters have the same value in mne_python and eeglab:
+ #
+ # - starting point
+ # - random state
+ # - learning rate
+ # - block size
+ # - blowup parameter
+ # - blowup_fac parameter
+ # - tolerance for stopping the algorithm
+ # - number of iterations
+ # - anneal_step parameter
+ #
+ # Notes:
+ # * By default, eeglab whiten the data using "sphering transform"
+ # instead of pca. The mne_python infomax code does not
+ # whiten the data. To make sure both mne_python and eeglab starts
+ # from the same point (i.e., the same matrix), we need to make
+ # sure to whiten the data outside, and pass these whiten data to
+ # mne_python and eeglab. Finally, we need to tell eeglab that
+ # the input data is already whiten, this can be done by calling
+ # eeglab with the following syntax:
+ #
+ # % Run infomax
+ # [unmixing,sphere,meanvar,bias,signs,lrates,sources,y] = ...
+ # runica( Y, 'sphering', 'none');
+ #
+ # % Run extended infomax
+ # [unmixing,sphere,meanvar,bias,signs,lrates,sources,y] = ...
+ # runica( Y, 'sphering', 'none', 'extended', 1);
+ #
+ # By calling eeglab using the former code, we are using its
+ # default parameters, which are specified below in the section
+ # "EEGLAB default parameters".
+ #
+ # * eeglab does not expose a parameter for fixing the random state.
+ # Therefore, to accomplish this, we need to edit the runica.m
+ # file located at /path_to_eeglab/functions/sigprocfunc/runica.m
+ #
+ # i) Comment the line related with the random number generator
+ # (line 812).
+ # ii) Then, add the following line just below line 812:
+ # rng(42); %use 42 as random seed.
+ #
+ # * eeglab does not have the parameter "n_small_angle",
+ # so we need to disable it for making a fair comparison.
+ #
+ # * Finally, we need to take the unmixing matrix estimated by the
+ # mne_python infomax implementation and order the components
+ # in the same way that eeglab does. This is done below in the
+ # section "Order the components in the same way that eeglab does"
+
+ # EEGLAB default parameters
+ l_rate_eeglab = 0.00065 / np.log(N)
+ block_eeglab = int(np.ceil(np.min([5 * np.log(T), 0.3 * T])))
+ blowup_eeglab = 1e9
+ blowup_fac_eeglab = 0.8
+ max_iter_eeglab = 512
+
+ if method == 'infomax':
+ anneal_step_eeglab = 0.9
+ use_extended = False
+
+ elif method == 'extended_infomax':
+ anneal_step_eeglab = 0.98
+ use_extended = True
+
+ w_change_eeglab = 1e-7 if N > 32 else 1e-6
+
+ # Call mne_python infomax version using the following sintax
+ # to obtain the same result than eeglab version
+ unmixing = infomax(
+ Y.T, extended=use_extended, random_state=random_state,
+ max_iter=max_iter_eeglab, l_rate=l_rate_eeglab,
+ block=block_eeglab, w_change=w_change_eeglab,
+ blowup=blowup_eeglab, blowup_fac=blowup_fac_eeglab,
+ n_small_angle=None, anneal_step=anneal_step_eeglab)
+
+ # Order the components in the same way that eeglab does
+ sources = np.dot(unmixing, Y)
+ mixing = pinv(unmixing)
+
+ mvar = np.sum(mixing ** 2, axis=0) * \
+ np.sum(sources ** 2, axis=1) / (N * T - 1)
+ windex = np.argsort(mvar)[::-1]
+
+ unmixing_ordered = unmixing[windex, :]
+
+ # Load the eeglab results, then compare the unmixing matrices
+ # estimated by mne_python and eeglab. To make the comparison use
+ # the \ell_inf norm:
+ # ||unmixing_mne_python - unmixing_eeglab||_inf
+
+ eeglab_data = sio.loadmat(op.join(base_dir, eeglab_results_file))
+ unmixing_eeglab = eeglab_data['unmixing_eeglab']
+
+ maximum_difference = np.max(np.abs(unmixing_ordered -
+ unmixing_eeglab))
+
+ assert_almost_equal(maximum_difference, 1e-12, decimal=10)
diff --git a/mne/preprocessing/tests/test_eog.py b/mne/preprocessing/tests/test_eog.py
index 97220dd..eb7afa3 100644
--- a/mne/preprocessing/tests/test_eog.py
+++ b/mne/preprocessing/tests/test_eog.py
@@ -1,7 +1,7 @@
import os.path as op
from nose.tools import assert_true
-from mne.io import Raw
+from mne.io import read_raw_fif
from mne.preprocessing.eog import find_eog_events
data_path = op.join(op.dirname(__file__), '..', '..', 'io', 'tests', 'data')
@@ -11,8 +11,8 @@ proj_fname = op.join(data_path, 'test-proj.fif')
def test_find_eog():
- """Test find EOG peaks"""
- raw = Raw(raw_fname)
+ """Test find EOG peaks."""
+ raw = read_raw_fif(raw_fname, add_eeg_ref=False)
events = find_eog_events(raw)
n_events = len(events)
assert_true(n_events == 4)
diff --git a/mne/preprocessing/tests/test_fine_cal.py b/mne/preprocessing/tests/test_fine_cal.py
new file mode 100644
index 0000000..fe27aea
--- /dev/null
+++ b/mne/preprocessing/tests/test_fine_cal.py
@@ -0,0 +1,40 @@
+# Author: Mark Wronkiewicz <wronk at uw.edu>
+#
+# License: BSD (3-clause)
+
+import os.path as op
+import warnings
+
+from mne.datasets import testing
+from mne.preprocessing._fine_cal import (read_fine_calibration,
+ write_fine_calibration)
+from mne.utils import _TempDir, object_hash, run_tests_if_main
+from nose.tools import assert_equal
+
+warnings.simplefilter('always') # Always throw warnings
+
+# Define fine calibration filepaths
+data_path = testing.data_path(download=False)
+fine_cal_fname = op.join(data_path, 'SSS', 'sss_cal_3053.dat')
+fine_cal_fname_3d = op.join(data_path, 'SSS', 'sss_cal_3053_3d.dat')
+
+
+ at testing.requires_testing_data
+def test_read_write_fine_cal():
+ """Test round trip reading/writing of fine calibration .dat file"""
+ temp_dir = _TempDir()
+ temp_fname = op.join(temp_dir, 'fine_cal_temp.dat')
+
+ for fname in [fine_cal_fname, fine_cal_fname_3d]:
+ # Load fine calibration file
+ fine_cal_dict = read_fine_calibration(fname)
+
+ # Save temp version of fine calibration file
+ write_fine_calibration(temp_fname, fine_cal_dict)
+ fine_cal_dict_reload = read_fine_calibration(temp_fname)
+
+ # Load temp version of fine calibration file and compare hashes
+ assert_equal(object_hash(fine_cal_dict),
+ object_hash(fine_cal_dict_reload))
+
+run_tests_if_main()
diff --git a/mne/preprocessing/tests/test_ica.py b/mne/preprocessing/tests/test_ica.py
index fd8df37..a15e3e2 100644
--- a/mne/preprocessing/tests/test_ica.py
+++ b/mne/preprocessing/tests/test_ica.py
@@ -9,19 +9,22 @@ import os
import os.path as op
import warnings
-from nose.tools import assert_true, assert_raises, assert_equal
+from nose.tools import assert_true, assert_raises, assert_equal, assert_false
import numpy as np
from numpy.testing import (assert_array_almost_equal, assert_array_equal,
assert_allclose)
from scipy import stats
from itertools import product
-from mne import Epochs, read_events, pick_types
+from mne import Epochs, read_events, pick_types, create_info, EpochsArray
from mne.cov import read_cov
from mne.preprocessing import (ICA, ica_find_ecg_events, ica_find_eog_events,
read_ica, run_ica)
-from mne.preprocessing.ica import get_score_funcs, corrmap, _get_ica_map
-from mne.io import Raw, Info
+from mne.preprocessing.ica import (get_score_funcs, corrmap, _get_ica_map,
+ _ica_explained_variance, _sort_components)
+from mne.io import read_raw_fif, Info, RawArray
+from mne.io.meas_info import _kind_dict
+from mne.io.pick import _DATA_CH_TYPES_SPLIT
from mne.tests.common import assert_naming
from mne.utils import (catch_logging, _TempDir, requires_sklearn, slow_test,
run_tests_if_main)
@@ -36,7 +39,6 @@ warnings.simplefilter('always') # enable b/c these tests throw warnings
data_dir = op.join(op.dirname(__file__), '..', '..', 'io', 'tests', 'data')
raw_fname = op.join(data_dir, 'test_raw.fif')
event_name = op.join(data_dir, 'test-eve.fif')
-evoked_nf_name = op.join(data_dir, 'test-nf-ave.fif')
test_cov_name = op.join(data_dir, 'test-cov.fif')
event_id, tmin, tmax = 1, -0.2, 0.2
@@ -52,16 +54,16 @@ except:
@requires_sklearn
def test_ica_full_data_recovery():
- """Test recovery of full data when no source is rejected"""
+ """Test recovery of full data when no source is rejected."""
# Most basic recovery
- raw = Raw(raw_fname).crop(0.5, stop, copy=False)
- raw.load_data()
+ raw = read_raw_fif(raw_fname, add_eeg_ref=False)
+ raw.crop(0.5, stop, copy=False).load_data()
events = read_events(event_name)
picks = pick_types(raw.info, meg=True, stim=False, ecg=False,
eog=False, exclude='bads')[:10]
with warnings.catch_warnings(record=True): # bad proj
epochs = Epochs(raw, events[:4], event_id, tmin, tmax, picks=picks,
- baseline=(None, 0), preload=True)
+ baseline=(None, 0), preload=True, add_eeg_ref=False)
evoked = epochs.average()
n_channels = 5
data = raw._data[:n_channels].copy()
@@ -111,10 +113,10 @@ def test_ica_full_data_recovery():
@requires_sklearn
def test_ica_rank_reduction():
- """Test recovery ICA rank reduction"""
+ """Test recovery ICA rank reduction."""
# Most basic recovery
- raw = Raw(raw_fname).crop(0.5, stop, copy=False)
- raw.load_data()
+ raw = read_raw_fif(raw_fname, add_eeg_ref=False)
+ raw.crop(0.5, stop, copy=False).load_data()
picks = pick_types(raw.info, meg=True, stim=False, ecg=False,
eog=False, exclude='bads')[:10]
n_components = 5
@@ -140,9 +142,9 @@ def test_ica_rank_reduction():
@requires_sklearn
def test_ica_reset():
- """Test ICA resetting"""
- raw = Raw(raw_fname).crop(0.5, stop, copy=False)
- raw.load_data()
+ """Test ICA resetting."""
+ raw = read_raw_fif(raw_fname, add_eeg_ref=False)
+ raw.crop(0.5, stop, copy=False).load_data()
picks = pick_types(raw.info, meg=True, stim=False, ecg=False,
eog=False, exclude='bads')[:10]
@@ -168,9 +170,9 @@ def test_ica_reset():
@requires_sklearn
def test_ica_core():
- """Test ICA on raw and epochs"""
- raw = Raw(raw_fname).crop(1.5, stop, copy=False)
- raw.load_data()
+ """Test ICA on raw and epochs."""
+ raw = read_raw_fif(raw_fname, add_eeg_ref=False)
+ raw.crop(1.5, stop, copy=False).load_data()
picks = pick_types(raw.info, meg=True, stim=False, ecg=False,
eog=False, exclude='bads')
# XXX. The None cases helped revealing bugs but are time consuming.
@@ -179,7 +181,7 @@ def test_ica_core():
picks = pick_types(raw.info, meg=True, stim=False, ecg=False,
eog=False, exclude='bads')
epochs = Epochs(raw, events[:4], event_id, tmin, tmax, picks=picks,
- baseline=(None, 0), preload=True)
+ baseline=(None, 0), preload=True, add_eeg_ref=False)
noise_cov = [None, test_cov]
# removed None cases to speed up...
n_components = [2, 1.0] # for future dbg add cases
@@ -271,19 +273,19 @@ def test_ica_core():
@slow_test
@requires_sklearn
def test_ica_additional():
- """Test additional ICA functionality"""
+ """Test additional ICA functionality."""
tempdir = _TempDir()
stop2 = 500
- raw = Raw(raw_fname).crop(1.5, stop, copy=False)
- raw.load_data()
- picks = pick_types(raw.info, meg=True, stim=False, ecg=False,
- eog=False, exclude='bads')
+ raw = read_raw_fif(raw_fname, add_eeg_ref=False)
+ raw.crop(1.5, stop, copy=False).load_data()
+ # XXX This breaks the tests :(
+ # raw.info['bads'] = [raw.ch_names[1]]
test_cov = read_cov(test_cov_name)
events = read_events(event_name)
picks = pick_types(raw.info, meg=True, stim=False, ecg=False,
eog=False, exclude='bads')
epochs = Epochs(raw, events[:4], event_id, tmin, tmax, picks=picks,
- baseline=(None, 0), preload=True)
+ baseline=(None, 0), preload=True, add_eeg_ref=False)
# test if n_components=None works
with warnings.catch_warnings(record=True):
ica = ICA(n_components=None,
@@ -294,7 +296,7 @@ def test_ica_additional():
picks2 = pick_types(raw.info, meg=True, stim=False, ecg=False,
eog=True, exclude='bads')
epochs_eog = Epochs(raw, events[:4], event_id, tmin, tmax, picks=picks2,
- baseline=(None, 0), preload=True)
+ baseline=(None, 0), preload=True, add_eeg_ref=False)
test_cov2 = test_cov.copy()
ica = ICA(noise_cov=test_cov2, n_components=3, max_pca_components=4,
@@ -350,6 +352,15 @@ def test_ica_additional():
with warnings.catch_warnings(record=True):
ica.fit(raw, picks=None, decim=3)
assert_true(ica.n_components_ == 4)
+ ica_var = _ica_explained_variance(ica, raw, normalize=True)
+ assert_true(np.all(ica_var[:-1] >= ica_var[1:]))
+
+ # test ica sorting
+ ica.exclude = [0]
+ ica.labels_ = dict(blink=[0], think=[1])
+ ica_sorted = _sort_components(ica, [3, 2, 1, 0], copy=True)
+ assert_equal(ica_sorted.exclude, [3])
+ assert_equal(ica_sorted.labels_, dict(blink=[3], think=[2]))
# epochs extraction from raw fit
assert_raises(RuntimeError, ica.get_sources, epochs)
@@ -367,7 +378,7 @@ def test_ica_additional():
assert_true(sources.shape[1] == ica.n_components_)
for exclude in [[], [0]]:
- ica.exclude = [0]
+ ica.exclude = exclude
ica.labels_ = {'foo': [0]}
ica.save(test_ica_fname)
ica_read = read_ica(test_ica_fname)
@@ -387,18 +398,24 @@ def test_ica_additional():
# test filtering
d1 = ica_raw._data[0].copy()
- with warnings.catch_warnings(record=True): # dB warning
- ica_raw.filter(4, 20)
+ ica_raw.filter(4, 20, l_trans_bandwidth='auto',
+ h_trans_bandwidth='auto', filter_length='auto',
+ phase='zero', fir_window='hamming')
+ assert_equal(ica_raw.info['lowpass'], 20.)
+ assert_equal(ica_raw.info['highpass'], 4.)
assert_true((d1 != ica_raw._data[0]).any())
d1 = ica_raw._data[0].copy()
- with warnings.catch_warnings(record=True): # dB warning
- ica_raw.notch_filter([10])
+ ica_raw.notch_filter([10], filter_length='auto', trans_bandwidth=10,
+ phase='zero', fir_window='hamming')
assert_true((d1 != ica_raw._data[0]).any())
ica.n_pca_components = 2
+ ica.method = 'fake'
ica.save(test_ica_fname)
ica_read = read_ica(test_ica_fname)
assert_true(ica.n_pca_components == ica_read.n_pca_components)
+ assert_equal(ica.method, ica_read.method)
+ assert_equal(ica.labels_, ica_read.labels_)
# check type consistency
attrs = ('mixing_matrix_ unmixing_matrix_ pca_components_ '
@@ -460,6 +477,11 @@ def test_ica_additional():
assert_equal(len(scores), ica.n_components_)
idx, scores = ica.find_bads_ecg(raw, method='correlation')
assert_equal(len(scores), ica.n_components_)
+
+ idx, scores = ica.find_bads_eog(raw)
+ assert_equal(len(scores), ica.n_components_)
+
+ ica.labels_ = None
idx, scores = ica.find_bads_ecg(epochs, method='ctps')
assert_equal(len(scores), ica.n_components_)
assert_raises(ValueError, ica.find_bads_ecg, epochs.average(),
@@ -467,8 +489,6 @@ def test_ica_additional():
assert_raises(ValueError, ica.find_bads_ecg, raw,
method='crazy-coupling')
- idx, scores = ica.find_bads_eog(raw)
- assert_equal(len(scores), ica.n_components_)
raw.info['chs'][raw.ch_names.index('EOG 061') - 1]['kind'] = 202
idx, scores = ica.find_bads_eog(raw)
assert_true(isinstance(scores, list))
@@ -517,7 +537,7 @@ def test_ica_additional():
test_ica_fname = op.join(op.abspath(op.curdir), 'test-ica_raw.fif')
ica.n_components = np.int32(ica.n_components)
ica_raw.save(test_ica_fname, overwrite=True)
- ica_raw2 = Raw(test_ica_fname, preload=True)
+ ica_raw2 = read_raw_fif(test_ica_fname, preload=True, add_eeg_ref=False)
assert_allclose(ica_raw._data, ica_raw2._data, rtol=1e-5, atol=1e-4)
ica_raw2.close()
os.remove(test_ica_fname)
@@ -541,8 +561,9 @@ def test_ica_additional():
@requires_sklearn
def test_run_ica():
- """Test run_ica function"""
- raw = Raw(raw_fname, preload=True).crop(1.5, stop, copy=False)
+ """Test run_ica function."""
+ raw = read_raw_fif(raw_fname, add_eeg_ref=False)
+ raw.crop(1.5, stop, copy=False).load_data()
params = []
params += [(None, -1, slice(2), [0, 1])] # varicance, kurtosis idx
params += [(None, 'MEG 1531')] # ECG / EOG channel params
@@ -556,9 +577,9 @@ def test_run_ica():
@requires_sklearn
def test_ica_reject_buffer():
- """Test ICA data raw buffer rejection"""
- raw = Raw(raw_fname).crop(1.5, stop, copy=False)
- raw.load_data()
+ """Test ICA data raw buffer rejection."""
+ raw = read_raw_fif(raw_fname, add_eeg_ref=False)
+ raw.crop(1.5, stop, copy=False).load_data()
picks = pick_types(raw.info, meg=True, stim=False, ecg=False,
eog=False, exclude='bads')
ica = ICA(n_components=3, max_pca_components=4, n_pca_components=4)
@@ -574,9 +595,9 @@ def test_ica_reject_buffer():
@requires_sklearn
def test_ica_twice():
- """Test running ICA twice"""
- raw = Raw(raw_fname).crop(1.5, stop, copy=False)
- raw.load_data()
+ """Test running ICA twice."""
+ raw = read_raw_fif(raw_fname, add_eeg_ref=False)
+ raw.crop(1.5, stop, copy=False).load_data()
picks = pick_types(raw.info, meg='grad', exclude='bads')
n_components = 0.9
max_pca_components = None
@@ -597,10 +618,67 @@ def test_ica_twice():
@requires_sklearn
def test_fit_params():
- """Test fit_params for ICA"""
+ """Test fit_params for ICA."""
assert_raises(ValueError, ICA, fit_params=dict(extended=True))
fit_params = {}
ICA(fit_params=fit_params) # test no side effects
assert_equal(fit_params, {})
+
+ at requires_sklearn
+def test_bad_channels():
+ """Test exception when unsupported channels are used."""
+ chs = [i for i in _kind_dict]
+ data_chs = _DATA_CH_TYPES_SPLIT + ['eog']
+ chs_bad = list(set(chs) - set(data_chs))
+ info = create_info(len(chs), 500, chs)
+ data = np.random.rand(len(chs), 50)
+ raw = RawArray(data, info)
+ data = np.random.rand(100, len(chs), 50)
+ epochs = EpochsArray(data, info)
+
+ n_components = 0.9
+ ica = ICA(n_components=n_components, method='fastica')
+
+ for inst in [raw, epochs]:
+ for ch in chs_bad:
+ # Test case for only bad channels
+ picks_bad1 = pick_types(inst.info, meg=False,
+ **{str(ch): True})
+ # Test case for good and bad channels
+ picks_bad2 = pick_types(inst.info, meg=True,
+ **{str(ch): True})
+ assert_raises(ValueError, ica.fit, inst, picks=picks_bad1)
+ assert_raises(ValueError, ica.fit, inst, picks=picks_bad2)
+ assert_raises(ValueError, ica.fit, inst, picks=[])
+
+
+ at requires_sklearn
+def test_eog_channel():
+ """Test that EOG channel is included when performing ICA."""
+ raw = read_raw_fif(raw_fname, preload=True, add_eeg_ref=False)
+ events = read_events(event_name)
+ picks = pick_types(raw.info, meg=True, stim=True, ecg=False,
+ eog=True, exclude='bads')
+ epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks,
+ baseline=(None, 0), preload=True,
+ add_eeg_ref=False)
+ n_components = 0.9
+ ica = ICA(n_components=n_components, method='fastica')
+ # Test case for MEG and EOG data. Should have EOG channel
+ for inst in [raw, epochs]:
+ picks1a = pick_types(inst.info, meg=True, stim=False, ecg=False,
+ eog=False, exclude='bads')[:4]
+ picks1b = pick_types(inst.info, meg=False, stim=False, ecg=False,
+ eog=True, exclude='bads')
+ picks1 = np.append(picks1a, picks1b)
+ ica.fit(inst, picks=picks1)
+ assert_true(any('EOG' in ch for ch in ica.ch_names))
+ # Test case for MEG data. Should have no EOG channel
+ for inst in [raw, epochs]:
+ picks1 = pick_types(inst.info, meg=True, stim=False, ecg=False,
+ eog=False, exclude='bads')[:5]
+ ica.fit(inst, picks=picks1)
+ assert_false(any('EOG' in ch for ch in ica.ch_names))
+
run_tests_if_main()
diff --git a/mne/preprocessing/tests/test_infomax.py b/mne/preprocessing/tests/test_infomax.py
index d8d9a72..7491d02 100644
--- a/mne/preprocessing/tests/test_infomax.py
+++ b/mne/preprocessing/tests/test_infomax.py
@@ -14,7 +14,7 @@ from scipy import stats
from scipy import linalg
from mne.preprocessing.infomax_ import infomax
-from mne.utils import requires_sklearn, run_tests_if_main
+from mne.utils import requires_sklearn, run_tests_if_main, check_version
def center_and_norm(x, axis=-1):
@@ -37,7 +37,7 @@ def center_and_norm(x, axis=-1):
def test_infomax_blowup():
""" Test the infomax algorithm blowup condition
"""
- from sklearn.decomposition import RandomizedPCA
+
# scipy.stats uses the global RNG:
np.random.seed(0)
n_samples = 100
@@ -56,7 +56,7 @@ def test_infomax_blowup():
center_and_norm(m)
- X = RandomizedPCA(n_components=2, whiten=True).fit_transform(m.T)
+ X = _get_pca().fit_transform(m.T)
k_ = infomax(X, extended=True, l_rate=0.1)
s_ = np.dot(k_, X.T)
@@ -78,7 +78,6 @@ def test_infomax_blowup():
def test_infomax_simple():
""" Test the infomax algorithm on very simple data.
"""
- from sklearn.decomposition import RandomizedPCA
rng = np.random.RandomState(0)
# scipy.stats uses the global RNG:
np.random.seed(0)
@@ -102,7 +101,7 @@ def test_infomax_simple():
algos = [True, False]
for algo in algos:
- X = RandomizedPCA(n_components=2, whiten=True).fit_transform(m.T)
+ X = _get_pca().fit_transform(m.T)
k_ = infomax(X, extended=algo)
s_ = np.dot(k_, X.T)
@@ -124,12 +123,24 @@ def test_infomax_simple():
assert_almost_equal(np.dot(s2_, s2) / n_samples, 1, decimal=1)
+def test_infomax_weights_ini():
+ """ Test the infomax algorithm when user provides an initial weights matrix.
+ """
+
+ X = np.random.random((3, 100))
+ weights = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.float64)
+
+ w1 = infomax(X, max_iter=0, weights=weights, extended=True)
+ w2 = infomax(X, max_iter=0, weights=weights, extended=False)
+
+ assert_almost_equal(w1, weights)
+ assert_almost_equal(w2, weights)
+
+
@requires_sklearn
def test_non_square_infomax():
""" Test non-square infomax
"""
- from sklearn.decomposition import RandomizedPCA
-
rng = np.random.RandomState(0)
n_samples = 200
@@ -151,9 +162,8 @@ def test_non_square_infomax():
m += 0.1 * rng.randn(n_observed, n_samples)
center_and_norm(m)
- pca = RandomizedPCA(n_components=2, whiten=True, random_state=rng)
m = m.T
- m = pca.fit_transform(m)
+ m = _get_pca(rng).fit_transform(m)
# we need extended since input signals are sub-gaussian
unmixing_ = infomax(m, random_state=rng, extended=True)
s_ = np.dot(unmixing_, m.T)
@@ -176,4 +186,16 @@ def test_non_square_infomax():
assert_almost_equal(np.dot(s1_, s1) / n_samples, 1, decimal=2)
assert_almost_equal(np.dot(s2_, s2) / n_samples, 1, decimal=2)
+
+def _get_pca(rng=None):
+ if not check_version('sklearn', '0.18'):
+ from sklearn.decomposition import RandomizedPCA
+ return RandomizedPCA(n_components=2, whiten=True,
+ random_state=rng)
+ else:
+ from sklearn.decomposition import PCA
+ return PCA(n_components=2, whiten=True, svd_solver='randomized',
+ random_state=rng)
+
+
run_tests_if_main()
diff --git a/mne/preprocessing/tests/test_maxwell.py b/mne/preprocessing/tests/test_maxwell.py
index 38d9c5c..04df920 100644
--- a/mne/preprocessing/tests/test_maxwell.py
+++ b/mne/preprocessing/tests/test_maxwell.py
@@ -17,14 +17,15 @@ from mne.chpi import read_head_pos, filter_chpi
from mne.forward import _prep_meg_channels
from mne.cov import _estimate_rank_meeg_cov
from mne.datasets import testing
-from mne.io import Raw, proc_history, read_info, read_raw_bti, read_raw_kit
+from mne.io import (read_raw_fif, proc_history, read_info, read_raw_bti,
+ read_raw_kit, _BaseRaw)
from mne.preprocessing.maxwell import (
maxwell_filter, _get_n_moments, _sss_basis_basic, _sh_complex_to_real,
_sh_real_to_complex, _sh_negate, _bases_complex_to_real, _trans_sss_basis,
_bases_real_to_complex, _sph_harm, _prep_mf_coils)
from mne.tests.common import assert_meg_snr
from mne.utils import (_TempDir, run_tests_if_main, slow_test, catch_logging,
- requires_version, object_diff)
+ requires_version, object_diff, buggy_mkl_svd)
from mne.externals.six import PY3
warnings.simplefilter('always') # Always throw warnings
@@ -76,6 +77,20 @@ ctc_mgh_fname = op.join(sss_path, 'ct_sparse_mgh.fif')
sample_fname = op.join(data_path, 'MEG', 'sample',
'sample_audvis_trunc_raw.fif')
+triux_path = op.join(data_path, 'SSS', 'TRIUX')
+tri_fname = op.join(triux_path, 'triux_bmlhus_erm_raw.fif')
+tri_sss_fname = op.join(triux_path, 'triux_bmlhus_erm_raw_sss.fif')
+tri_sss_reg_fname = op.join(triux_path, 'triux_bmlhus_erm_regIn_raw_sss.fif')
+tri_sss_st4_fname = op.join(triux_path, 'triux_bmlhus_erm_st4_raw_sss.fif')
+tri_sss_ctc_fname = op.join(triux_path, 'triux_bmlhus_erm_ctc_raw_sss.fif')
+tri_sss_cal_fname = op.join(triux_path, 'triux_bmlhus_erm_cal_raw_sss.fif')
+tri_sss_ctc_cal_fname = op.join(
+ triux_path, 'triux_bmlhus_erm_ctc_cal_raw_sss.fif')
+tri_sss_ctc_cal_reg_in_fname = op.join(
+ triux_path, 'triux_bmlhus_erm_ctc_cal_regIn_raw_sss.fif')
+tri_ctc_fname = op.join(triux_path, 'ct_sparse_BMLHUS.fif')
+tri_cal_fname = op.join(triux_path, 'sss_cal_BMLHUS.dat')
+
io_dir = op.join(op.dirname(__file__), '..', '..', 'io')
fname_ctf_raw = op.join(io_dir, 'tests', 'data', 'test_ctf_comp_raw.fif')
@@ -95,20 +110,26 @@ bads = ['MEG0912', 'MEG1722', 'MEG2213', 'MEG0132', 'MEG1312', 'MEG0432',
def _assert_n_free(raw_sss, lower, upper=None):
- """Helper to check the DOF"""
+ """Check the DOF."""
upper = lower if upper is None else upper
n_free = raw_sss.info['proc_history'][0]['max_info']['sss_info']['nfree']
assert_true(lower <= n_free <= upper,
'nfree fail: %s <= %s <= %s' % (lower, n_free, upper))
+def read_crop(fname, lims=(0, None)):
+ """Read and crop."""
+ return read_raw_fif(fname, allow_maxshield='yes',
+ add_eeg_ref=False).crop(*lims)
+
+
@slow_test
@testing.requires_testing_data
def test_movement_compensation():
- """Test movement compensation"""
+ """Test movement compensation."""
temp_dir = _TempDir()
lims = (0, 4, False)
- raw = Raw(raw_fname, allow_maxshield='yes', preload=True).crop(*lims)
+ raw = read_crop(raw_fname, lims).load_data()
head_pos = read_head_pos(pos_fname)
#
@@ -116,20 +137,20 @@ def test_movement_compensation():
#
raw_sss = maxwell_filter(raw, head_pos=head_pos, origin=mf_head_origin,
regularize=None, bad_condition='ignore')
- assert_meg_snr(raw_sss, Raw(sss_movecomp_fname).crop(*lims),
+ assert_meg_snr(raw_sss, read_crop(sss_movecomp_fname, lims),
4.6, 12.4, chpi_med_tol=58)
# IO
temp_fname = op.join(temp_dir, 'test_raw_sss.fif')
raw_sss.save(temp_fname)
- raw_sss = Raw(temp_fname)
- assert_meg_snr(raw_sss, Raw(sss_movecomp_fname).crop(*lims),
+ raw_sss = read_crop(temp_fname)
+ assert_meg_snr(raw_sss, read_crop(sss_movecomp_fname, lims),
4.6, 12.4, chpi_med_tol=58)
#
# Movement compensation, regularization, no tSSS
#
raw_sss = maxwell_filter(raw, head_pos=head_pos, origin=mf_head_origin)
- assert_meg_snr(raw_sss, Raw(sss_movecomp_reg_in_fname).crop(*lims),
+ assert_meg_snr(raw_sss, read_crop(sss_movecomp_reg_in_fname, lims),
0.5, 1.9, chpi_med_tol=121)
#
@@ -143,25 +164,25 @@ def test_movement_compensation():
assert_equal(len(w), 1)
assert_true('is untested' in str(w[0].message))
# Neither match is particularly good because our algorithm actually differs
- assert_meg_snr(raw_sss_mv, Raw(sss_movecomp_reg_in_st4s_fname).crop(*lims),
+ assert_meg_snr(raw_sss_mv, read_crop(sss_movecomp_reg_in_st4s_fname, lims),
0.6, 1.3)
tSSS_fname = op.join(sss_path, 'test_move_anon_st4s_raw_sss.fif')
- assert_meg_snr(raw_sss_mv, Raw(tSSS_fname).crop(*lims),
+ assert_meg_snr(raw_sss_mv, read_crop(tSSS_fname, lims),
0.6, 1.0, chpi_med_tol=None)
- assert_meg_snr(Raw(sss_movecomp_reg_in_st4s_fname), Raw(tSSS_fname),
- 0.8, 1.0, chpi_med_tol=None)
+ assert_meg_snr(read_crop(sss_movecomp_reg_in_st4s_fname),
+ read_crop(tSSS_fname), 0.8, 1.0, chpi_med_tol=None)
#
# Movement compensation, regularization, tSSS at the beginning
#
raw_sss_mc = maxwell_filter(raw_nohpi, head_pos=head_pos, st_duration=4.,
origin=mf_head_origin)
- assert_meg_snr(raw_sss_mc, Raw(tSSS_fname).crop(*lims),
+ assert_meg_snr(raw_sss_mc, read_crop(tSSS_fname, lims),
0.6, 1.0, chpi_med_tol=None)
assert_meg_snr(raw_sss_mc, raw_sss_mv, 0.6, 1.4)
# some degenerate cases
- raw_erm = Raw(erm_fname, allow_maxshield='yes')
+ raw_erm = read_crop(erm_fname)
assert_raises(ValueError, maxwell_filter, raw_erm, coord_frame='meg',
head_pos=head_pos) # can't do ERM file
assert_raises(ValueError, maxwell_filter, raw,
@@ -171,7 +192,15 @@ def test_movement_compensation():
head_pos_bad = head_pos.copy()
head_pos_bad[0, 0] = raw.first_samp / raw.info['sfreq'] - 1e-2
assert_raises(ValueError, maxwell_filter, raw, head_pos=head_pos_bad)
+
+ head_pos_bad = head_pos.copy()
+ head_pos_bad[0, 4] = 1. # off by more than 1 m
+ with warnings.catch_warnings(record=True) as w:
+ maxwell_filter(raw, head_pos=head_pos_bad, bad_condition='ignore')
+ assert_true(any('greater than 1 m' in str(ww.message) for ww in w))
+
# make sure numerical error doesn't screw it up, though
+ head_pos_bad = head_pos.copy()
head_pos_bad[0, 0] = raw.first_samp / raw.info['sfreq'] - 5e-4
raw_sss_tweak = maxwell_filter(raw, head_pos=head_pos_bad,
origin=mf_head_origin)
@@ -180,8 +209,7 @@ def test_movement_compensation():
@slow_test
def test_other_systems():
- """Test Maxwell filtering on KIT, BTI, and CTF files
- """
+ """Test Maxwell filtering on KIT, BTI, and CTF files."""
# KIT
kit_dir = op.join(io_dir, 'kit', 'tests', 'data')
sqd_path = op.join(kit_dir, 'test.sqd')
@@ -193,6 +221,9 @@ def test_other_systems():
assert_raises(RuntimeError, maxwell_filter, raw_kit)
raw_sss = maxwell_filter(raw_kit, origin=(0., 0., 0.04), ignore_ref=True)
_assert_n_free(raw_sss, 65, 65)
+ raw_sss_auto = maxwell_filter(raw_kit, origin=(0., 0., 0.04),
+ ignore_ref=True, mag_scale='auto')
+ assert_allclose(raw_sss._data, raw_sss_auto._data)
# XXX this KIT origin fit is terrible! Eventually we should get a
# corrected HSP file with proper coverage
with warnings.catch_warnings(record=True):
@@ -232,22 +263,33 @@ def test_other_systems():
bti_hs = op.join(bti_dir, 'test_hs_linux')
with warnings.catch_warnings(record=True): # weght table
raw_bti = read_raw_bti(bti_pdf, bti_config, bti_hs, preload=False)
+ picks = pick_types(raw_bti.info, meg='mag', exclude=())
+ power = np.sqrt(np.sum(raw_bti[picks][0] ** 2))
raw_sss = maxwell_filter(raw_bti)
_assert_n_free(raw_sss, 70)
+ _assert_shielding(raw_sss, power, 0.5)
+ raw_sss_auto = maxwell_filter(raw_bti, mag_scale='auto', verbose=True)
+ _assert_shielding(raw_sss_auto, power, 0.7)
# CTF
- raw_ctf = Raw(fname_ctf_raw, compensation=2)
+ raw_ctf = read_crop(fname_ctf_raw)
+ assert_equal(raw_ctf.compensation_grade, 3)
assert_raises(RuntimeError, maxwell_filter, raw_ctf) # compensated
- raw_ctf = Raw(fname_ctf_raw)
+ raw_ctf.apply_gradient_compensation(0)
assert_raises(ValueError, maxwell_filter, raw_ctf) # cannot fit headshape
raw_sss = maxwell_filter(raw_ctf, origin=(0., 0., 0.04))
_assert_n_free(raw_sss, 68)
+ _assert_shielding(raw_sss, raw_ctf, 1.8)
raw_sss = maxwell_filter(raw_ctf, origin=(0., 0., 0.04), ignore_ref=True)
_assert_n_free(raw_sss, 70)
+ _assert_shielding(raw_sss, raw_ctf, 12)
+ raw_sss_auto = maxwell_filter(raw_ctf, origin=(0., 0., 0.04),
+ ignore_ref=True, mag_scale='auto')
+ assert_allclose(raw_sss._data, raw_sss_auto._data)
def test_spherical_harmonics():
- """Test spherical harmonic functions"""
+ """Test spherical harmonic functions."""
from scipy.special import sph_harm
az, pol = np.meshgrid(np.linspace(0, 2 * np.pi, 30),
np.linspace(0, np.pi, 20))
@@ -267,7 +309,7 @@ def test_spherical_harmonics():
def test_spherical_conversions():
- """Test spherical harmonic conversions"""
+ """Test spherical harmonic conversions."""
# Test our real<->complex conversion functions
az, pol = np.meshgrid(np.linspace(0, 2 * np.pi, 30),
np.linspace(0, np.pi, 20))
@@ -286,7 +328,7 @@ def test_spherical_conversions():
@testing.requires_testing_data
def test_multipolar_bases():
- """Test multipolar moment basis calculation using sensor information"""
+ """Test multipolar moment basis calculation using sensor information."""
from scipy.io import loadmat
# Test our basis calculations
info = read_info(raw_fname)
@@ -346,11 +388,11 @@ def test_multipolar_bases():
@testing.requires_testing_data
def test_basic():
- """Test Maxwell filter basic version"""
+ """Test Maxwell filter basic version."""
# Load testing data (raw, SSS std origin, SSS non-standard origin)
- raw = Raw(raw_fname, allow_maxshield='yes').crop(0., 1., copy=False)
- raw_err = Raw(raw_fname, proj=True, allow_maxshield='yes')
- raw_erm = Raw(erm_fname, allow_maxshield='yes')
+ raw = read_crop(raw_fname, (0., 1.))
+ raw_err = read_crop(raw_fname).apply_proj()
+ raw_erm = read_crop(erm_fname)
assert_raises(RuntimeError, maxwell_filter, raw_err)
assert_raises(TypeError, maxwell_filter, 1.) # not a raw
assert_raises(ValueError, maxwell_filter, raw, int_order=20) # too many
@@ -368,7 +410,7 @@ def test_basic():
bad_condition='ignore')
assert_equal(len(raw_sss.info['projs']), 1) # avg EEG
assert_equal(raw_sss.info['projs'][0]['desc'], 'Average EEG reference')
- assert_meg_snr(raw_sss, Raw(sss_std_fname), 200., 1000.)
+ assert_meg_snr(raw_sss, read_crop(sss_std_fname), 200., 1000.)
py_cal = raw_sss.info['proc_history'][0]['max_info']['sss_cal']
assert_equal(len(py_cal), 0)
py_ctc = raw_sss.info['proc_history'][0]['max_info']['sss_ctc']
@@ -380,10 +422,10 @@ def test_basic():
# Test SSS computation at non-standard head origin
raw_sss = maxwell_filter(raw, origin=[0., 0.02, 0.02], regularize=None,
bad_condition='ignore')
- assert_meg_snr(raw_sss, Raw(sss_nonstd_fname), 250., 700.)
+ assert_meg_snr(raw_sss, read_crop(sss_nonstd_fname), 250., 700.)
# Test SSS computation at device origin
- sss_erm_std = Raw(sss_erm_std_fname)
+ sss_erm_std = read_crop(sss_erm_std_fname)
raw_sss = maxwell_filter(raw_erm, coord_frame='meg',
origin=mf_meg_origin, regularize=None,
bad_condition='ignore')
@@ -399,18 +441,15 @@ def test_basic():
proc_history._get_sss_rank(sss_info))
# Degenerate cases
- raw_bad = raw.copy()
- raw_bad.comp = True
- assert_raises(RuntimeError, maxwell_filter, raw_bad)
- del raw_bad
assert_raises(ValueError, maxwell_filter, raw, coord_frame='foo')
assert_raises(ValueError, maxwell_filter, raw, origin='foo')
assert_raises(ValueError, maxwell_filter, raw, origin=[0] * 4)
+ assert_raises(ValueError, maxwell_filter, raw, mag_scale='foo')
@testing.requires_testing_data
def test_maxwell_filter_additional():
- """Test processing of Maxwell filtered data"""
+ """Test processing of Maxwell filtered data."""
# TODO: Future tests integrate with mne/io/tests/test_proc_history
@@ -422,7 +461,7 @@ def test_maxwell_filter_additional():
raw_fname = op.join(data_path, 'SSS', file_name + '_raw.fif')
# Use 2.0 seconds of data to get stable cov. estimate
- raw = Raw(raw_fname, allow_maxshield='yes').crop(0., 2., copy=False)
+ raw = read_crop(raw_fname, (0., 2.))
# Get MEG channels, compute Maxwell filtered data
raw.load_data()
@@ -435,7 +474,7 @@ def test_maxwell_filter_additional():
tempdir = _TempDir()
test_outname = op.join(tempdir, 'test_raw_sss.fif')
raw_sss.save(test_outname)
- raw_sss_loaded = Raw(test_outname, preload=True)
+ raw_sss_loaded = read_crop(test_outname).load_data()
# Some numerical imprecision since save uses 'single' fmt
assert_allclose(raw_sss_loaded[:][0], raw_sss[:][0],
@@ -457,20 +496,20 @@ def test_maxwell_filter_additional():
@slow_test
@testing.requires_testing_data
def test_bads_reconstruction():
- """Test Maxwell filter reconstruction of bad channels"""
- raw = Raw(raw_fname, allow_maxshield='yes').crop(0., 1.)
+ """Test Maxwell filter reconstruction of bad channels."""
+ raw = read_crop(raw_fname, (0., 1.))
raw.info['bads'] = bads
raw_sss = maxwell_filter(raw, origin=mf_head_origin, regularize=None,
bad_condition='ignore')
- assert_meg_snr(raw_sss, Raw(sss_bad_recon_fname), 300.)
+ assert_meg_snr(raw_sss, read_crop(sss_bad_recon_fname), 300.)
@requires_svd_convergence
@testing.requires_testing_data
def test_spatiotemporal_maxwell():
- """Test Maxwell filter (tSSS) spatiotemporal processing"""
+ """Test Maxwell filter (tSSS) spatiotemporal processing."""
# Load raw testing data
- raw = Raw(raw_fname, allow_maxshield='yes')
+ raw = read_crop(raw_fname)
# Test that window is less than length of data
assert_raises(ValueError, maxwell_filter, raw, st_duration=1000.)
@@ -483,7 +522,7 @@ def test_spatiotemporal_maxwell():
# Load tSSS data depending on st_duration and get data
tSSS_fname = op.join(sss_path,
'test_move_anon_st%0ds_raw_sss.fif' % st_duration)
- tsss_bench = Raw(tSSS_fname)
+ tsss_bench = read_crop(tSSS_fname)
# Because Elekta's tSSS sometimes(!) lumps the tail window of data
# onto the previous buffer if it's shorter than st_duration, we have to
# crop the data here to compensate for Elekta's tSSS behavior.
@@ -522,10 +561,9 @@ def test_spatiotemporal_maxwell():
@requires_svd_convergence
@testing.requires_testing_data
def test_spatiotemporal_only():
- """Test tSSS-only processing"""
+ """Test tSSS-only processing."""
# Load raw testing data
- raw = Raw(raw_fname,
- allow_maxshield='yes').crop(0, 2, copy=False).load_data()
+ raw = read_crop(raw_fname, (0, 2)).load_data()
picks = pick_types(raw.info, meg='mag', exclude=())
power = np.sqrt(np.sum(raw[picks][0] ** 2))
# basics
@@ -574,11 +612,11 @@ def test_spatiotemporal_only():
@testing.requires_testing_data
def test_fine_calibration():
- """Test Maxwell filter fine calibration"""
+ """Test Maxwell filter fine calibration."""
# Load testing data (raw, SSS std origin, SSS non-standard origin)
- raw = Raw(raw_fname, allow_maxshield='yes').crop(0., 1., copy=False)
- sss_fine_cal = Raw(sss_fine_cal_fname)
+ raw = read_crop(raw_fname, (0., 1.))
+ sss_fine_cal = read_crop(sss_fine_cal_fname)
# Test 1D SSS fine calibration
raw_sss = maxwell_filter(raw, calibration=fine_cal_fname,
@@ -601,7 +639,7 @@ def test_fine_calibration():
origin=mf_head_origin, regularize=None,
bad_condition='ignore')
assert_meg_snr(raw_sss_3D, sss_fine_cal, 1.0, 6.)
- raw_ctf = Raw(fname_ctf_raw)
+ raw_ctf = read_crop(fname_ctf_raw).apply_gradient_compensation(0)
assert_raises(RuntimeError, maxwell_filter, raw_ctf, origin=(0., 0., 0.04),
calibration=fine_cal_fname)
@@ -609,7 +647,7 @@ def test_fine_calibration():
@slow_test
@testing.requires_testing_data
def test_regularization():
- """Test Maxwell filter regularization"""
+ """Test Maxwell filter regularization."""
# Load testing data (raw, SSS std origin, SSS non-standard origin)
min_tols = (100., 2.6, 1.0)
med_tols = (1000., 21.4, 3.7)
@@ -620,8 +658,8 @@ def test_regularization():
sss_samp_reg_in_fname)
comp_tols = [0, 1, 4]
for ii, rf in enumerate(raw_fnames):
- raw = Raw(rf, allow_maxshield='yes').crop(0., 1.)
- sss_reg_in = Raw(sss_fnames[ii])
+ raw = read_crop(rf, (0., 1.))
+ sss_reg_in = read_crop(sss_fnames[ii])
# Test "in" regularization
raw_sss = maxwell_filter(raw, coord_frame=coord_frames[ii],
@@ -629,27 +667,32 @@ def test_regularization():
assert_meg_snr(raw_sss, sss_reg_in, min_tols[ii], med_tols[ii], msg=rf)
# check components match
- py_info = raw_sss.info['proc_history'][0]['max_info']['sss_info']
- assert_true(py_info is not None)
- assert_true(len(py_info) > 0)
- mf_info = sss_reg_in.info['proc_history'][0]['max_info']['sss_info']
- n_in = None
- for inf in py_info, mf_info:
- if n_in is None:
- n_in = _get_n_moments(inf['in_order'])
- else:
- assert_equal(n_in, _get_n_moments(inf['in_order']))
- assert_equal(inf['components'][:n_in].sum(), inf['nfree'])
- assert_allclose(py_info['nfree'], mf_info['nfree'],
- atol=comp_tols[ii], err_msg=rf)
+ _check_reg_match(raw_sss, sss_reg_in, comp_tols[ii])
+
+
+def _check_reg_match(sss_py, sss_mf, comp_tol):
+ """Helper to check regularization."""
+ info_py = sss_py.info['proc_history'][0]['max_info']['sss_info']
+ assert_true(info_py is not None)
+ assert_true(len(info_py) > 0)
+ info_mf = sss_mf.info['proc_history'][0]['max_info']['sss_info']
+ n_in = None
+ for inf in (info_py, info_mf):
+ if n_in is None:
+ n_in = _get_n_moments(inf['in_order'])
+ else:
+ assert_equal(n_in, _get_n_moments(inf['in_order']))
+ assert_equal(inf['components'][:n_in].sum(), inf['nfree'])
+ assert_allclose(info_py['nfree'], info_mf['nfree'],
+ atol=comp_tol, err_msg=sss_py._filenames[0])
@testing.requires_testing_data
def test_cross_talk():
- """Test Maxwell filter cross-talk cancellation"""
- raw = Raw(raw_fname, allow_maxshield='yes').crop(0., 1., copy=False)
+ """Test Maxwell filter cross-talk cancellation."""
+ raw = read_crop(raw_fname, (0., 1.))
raw.info['bads'] = bads
- sss_ctc = Raw(sss_ctc_fname)
+ sss_ctc = read_crop(sss_ctc_fname)
raw_sss = maxwell_filter(raw, cross_talk=ctc_fname,
origin=mf_head_origin, regularize=None,
bad_condition='ignore')
@@ -661,7 +704,7 @@ def test_cross_talk():
mf_ctc = sss_ctc.info['proc_history'][0]['max_info']['sss_ctc']
del mf_ctc['block_id'] # we don't write this
assert_equal(object_diff(py_ctc, mf_ctc), '')
- raw_ctf = Raw(fname_ctf_raw)
+ raw_ctf = read_crop(fname_ctf_raw).apply_gradient_compensation(0)
assert_raises(ValueError, maxwell_filter, raw_ctf) # cannot fit headshape
raw_sss = maxwell_filter(raw_ctf, origin=(0., 0., 0.04))
_assert_n_free(raw_sss, 68)
@@ -681,13 +724,13 @@ def test_cross_talk():
@testing.requires_testing_data
def test_head_translation():
- """Test Maxwell filter head translation"""
- raw = Raw(raw_fname, allow_maxshield='yes').crop(0., 1., copy=False)
+ """Test Maxwell filter head translation."""
+ raw = read_crop(raw_fname, (0., 1.))
# First try with an unchanged destination
raw_sss = maxwell_filter(raw, destination=raw_fname,
origin=mf_head_origin, regularize=None,
bad_condition='ignore')
- assert_meg_snr(raw_sss, Raw(sss_std_fname).crop(0., 1.), 200.)
+ assert_meg_snr(raw_sss, read_crop(sss_std_fname, (0., 1.)), 200.)
# Now with default
with warnings.catch_warnings(record=True):
with catch_logging() as log:
@@ -695,7 +738,7 @@ def test_head_translation():
origin=mf_head_origin, regularize=None,
bad_condition='ignore', verbose='warning')
assert_true('over 25 mm' in log.getvalue())
- assert_meg_snr(raw_sss, Raw(sss_trans_default_fname), 125.)
+ assert_meg_snr(raw_sss, read_crop(sss_trans_default_fname), 125.)
destination = np.eye(4)
destination[2, 3] = 0.04
assert_allclose(raw_sss.info['dev_head_t']['trans'], destination)
@@ -706,7 +749,7 @@ def test_head_translation():
origin=mf_head_origin, regularize=None,
bad_condition='ignore', verbose='warning')
assert_true('= 25.6 mm' in log.getvalue())
- assert_meg_snr(raw_sss, Raw(sss_trans_sample_fname), 350.)
+ assert_meg_snr(raw_sss, read_crop(sss_trans_sample_fname), 350.)
assert_allclose(raw_sss.info['dev_head_t']['trans'],
read_info(sample_fname)['dev_head_t']['trans'])
# Degenerate cases
@@ -720,8 +763,12 @@ def test_head_translation():
# http://ieeexplore.ieee.org/xpl/articleDetails.jsp?arnumber=1495874
def _assert_shielding(raw_sss, erm_power, shielding_factor, meg='mag'):
- """Helper to assert a minimum shielding factor using empty-room power"""
- picks = pick_types(raw_sss.info, meg=meg)
+ """Helper to assert a minimum shielding factor using empty-room power."""
+ picks = pick_types(raw_sss.info, meg=meg, ref_meg=False)
+ if isinstance(erm_power, _BaseRaw):
+ picks_erm = pick_types(raw_sss.info, meg=meg, ref_meg=False)
+ assert_allclose(picks, picks_erm)
+ erm_power = np.sqrt((erm_power[picks_erm][0] ** 2).sum())
sss_power = raw_sss[picks][0].ravel()
sss_power = np.sqrt(np.sum(sss_power * sss_power))
factor = erm_power / sss_power
@@ -729,30 +776,48 @@ def _assert_shielding(raw_sss, erm_power, shielding_factor, meg='mag'):
'Shielding factor %0.3f < %0.3f' % (factor, shielding_factor))
+ at buggy_mkl_svd
@slow_test
@requires_svd_convergence
@testing.requires_testing_data
def test_shielding_factor():
- """Test Maxwell filter shielding factor using empty room"""
- raw_erm = Raw(erm_fname, allow_maxshield='yes', preload=True)
+ """Test Maxwell filter shielding factor using empty room."""
+ raw_erm = read_crop(erm_fname).load_data()
picks = pick_types(raw_erm.info, meg='mag')
erm_power = raw_erm[picks][0]
erm_power = np.sqrt(np.sum(erm_power * erm_power))
+ erm_power_grad = raw_erm[pick_types(raw_erm.info, meg='grad')][0]
+ erm_power_grad = np.sqrt(np.sum(erm_power * erm_power))
# Vanilla SSS (second value would be for meg=True instead of meg='mag')
- _assert_shielding(Raw(sss_erm_std_fname), erm_power, 10) # 1.5)
+ _assert_shielding(read_crop(sss_erm_std_fname), erm_power, 10) # 1.5)
raw_sss = maxwell_filter(raw_erm, coord_frame='meg', regularize=None)
_assert_shielding(raw_sss, erm_power, 12) # 1.5)
+ _assert_shielding(raw_sss, erm_power_grad, 0.45, 'grad') # 1.5)
+
+ # Using different mag_scale values
+ raw_sss = maxwell_filter(raw_erm, coord_frame='meg', regularize=None,
+ mag_scale='auto')
+ _assert_shielding(raw_sss, erm_power, 12)
+ _assert_shielding(raw_sss, erm_power_grad, 0.48, 'grad')
+ raw_sss = maxwell_filter(raw_erm, coord_frame='meg', regularize=None,
+ mag_scale=1.) # not a good choice
+ _assert_shielding(raw_sss, erm_power, 7.3)
+ _assert_shielding(raw_sss, erm_power_grad, 0.2, 'grad')
+ raw_sss = maxwell_filter(raw_erm, coord_frame='meg', regularize=None,
+ mag_scale=1000., bad_condition='ignore')
+ _assert_shielding(raw_sss, erm_power, 4.0)
+ _assert_shielding(raw_sss, erm_power_grad, 0.1, 'grad')
# Fine cal
- _assert_shielding(Raw(sss_erm_fine_cal_fname), erm_power, 12) # 2.0)
+ _assert_shielding(read_crop(sss_erm_fine_cal_fname), erm_power, 12) # 2.0)
raw_sss = maxwell_filter(raw_erm, coord_frame='meg', regularize=None,
origin=mf_meg_origin,
calibration=fine_cal_fname)
_assert_shielding(raw_sss, erm_power, 12) # 2.0)
# Crosstalk
- _assert_shielding(Raw(sss_erm_ctc_fname), erm_power, 12) # 2.1)
+ _assert_shielding(read_crop(sss_erm_ctc_fname), erm_power, 12) # 2.1)
raw_sss = maxwell_filter(raw_erm, coord_frame='meg', regularize=None,
origin=mf_meg_origin,
cross_talk=ctc_fname)
@@ -766,7 +831,7 @@ def test_shielding_factor():
_assert_shielding(raw_sss, erm_power, 13) # 2.2)
# tSSS
- _assert_shielding(Raw(sss_erm_st_fname), erm_power, 37) # 5.8)
+ _assert_shielding(read_crop(sss_erm_st_fname), erm_power, 37) # 5.8)
raw_sss = maxwell_filter(raw_erm, coord_frame='meg', regularize=None,
origin=mf_meg_origin, st_duration=1.)
_assert_shielding(raw_sss, erm_power, 37) # 5.8)
@@ -784,7 +849,7 @@ def test_shielding_factor():
_assert_shielding(raw_sss, erm_power, 38) # 5.98)
# Fine cal + Crosstalk + tSSS
- _assert_shielding(Raw(sss_erm_st1FineCalCrossTalk_fname),
+ _assert_shielding(read_crop(sss_erm_st1FineCalCrossTalk_fname),
erm_power, 39) # 6.07)
raw_sss = maxwell_filter(raw_erm, coord_frame='meg', regularize=None,
calibration=fine_cal_fname, origin=mf_meg_origin,
@@ -792,8 +857,8 @@ def test_shielding_factor():
_assert_shielding(raw_sss, erm_power, 39) # 6.05)
# Fine cal + Crosstalk + tSSS + Reg-in
- _assert_shielding(Raw(sss_erm_st1FineCalCrossTalkRegIn_fname), erm_power,
- 57) # 6.97)
+ _assert_shielding(read_crop(sss_erm_st1FineCalCrossTalkRegIn_fname),
+ erm_power, 57) # 6.97)
raw_sss = maxwell_filter(raw_erm, calibration=fine_cal_fname,
cross_talk=ctc_fname, st_duration=1.,
origin=mf_meg_origin,
@@ -803,6 +868,13 @@ def test_shielding_factor():
cross_talk=ctc_fname, st_duration=1.,
coord_frame='meg', regularize='in')
_assert_shielding(raw_sss, erm_power, 58) # 7.0)
+ _assert_shielding(raw_sss, erm_power_grad, 1.6, 'grad')
+ raw_sss = maxwell_filter(raw_erm, calibration=fine_cal_fname,
+ cross_talk=ctc_fname, st_duration=1.,
+ coord_frame='meg', regularize='in',
+ mag_scale='auto')
+ _assert_shielding(raw_sss, erm_power, 51)
+ _assert_shielding(raw_sss, erm_power_grad, 1.5, 'grad')
raw_sss = maxwell_filter(raw_erm, calibration=fine_cal_fname_3d,
cross_talk=ctc_fname, st_duration=1.,
coord_frame='meg', regularize='in')
@@ -827,7 +899,7 @@ def test_shielding_factor():
@requires_svd_convergence
@testing.requires_testing_data
def test_all():
- """Test maxwell filter using all options"""
+ """Test maxwell filter using all options."""
raw_fnames = (raw_fname, raw_fname, erm_fname, sample_fname)
sss_fnames = (sss_st1FineCalCrossTalkRegIn_fname,
sss_st1FineCalCrossTalkRegInTransSample_fname,
@@ -840,7 +912,7 @@ def test_all():
coord_frames = ('head', 'head', 'meg', 'head')
ctcs = (ctc_fname, ctc_fname, ctc_fname, ctc_mgh_fname)
mins = (3.5, 3.5, 1.2, 0.9)
- meds = (10.9, 10.4, 3.2, 6.)
+ meds = (10.8, 10.4, 3.2, 6.)
st_durs = (1., 1., 1., None)
destinations = (None, sample_fname, None, None)
origins = (mf_head_origin,
@@ -848,13 +920,55 @@ def test_all():
mf_meg_origin,
mf_head_origin)
for ii, rf in enumerate(raw_fnames):
- raw = Raw(rf, allow_maxshield='yes').crop(0., 1.)
+ raw = read_crop(rf, (0., 1.))
with warnings.catch_warnings(record=True): # head fit off-center
sss_py = maxwell_filter(
raw, calibration=fine_cals[ii], cross_talk=ctcs[ii],
st_duration=st_durs[ii], coord_frame=coord_frames[ii],
destination=destinations[ii], origin=origins[ii])
- sss_mf = Raw(sss_fnames[ii])
+ sss_mf = read_crop(sss_fnames[ii])
assert_meg_snr(sss_py, sss_mf, mins[ii], meds[ii], msg=rf)
+
+ at slow_test
+ at requires_svd_convergence
+ at testing.requires_testing_data
+def test_triux():
+ """Test TRIUX system support."""
+ raw = read_crop(tri_fname, (0, 0.999))
+ raw.fix_mag_coil_types()
+ # standard
+ sss_py = maxwell_filter(raw, coord_frame='meg', regularize=None)
+ assert_meg_snr(sss_py, read_crop(tri_sss_fname), 37, 700)
+ # cross-talk
+ sss_py = maxwell_filter(raw, coord_frame='meg', regularize=None,
+ cross_talk=tri_ctc_fname)
+ assert_meg_snr(sss_py, read_crop(tri_sss_ctc_fname), 35, 700)
+ # fine cal
+ sss_py = maxwell_filter(raw, coord_frame='meg', regularize=None,
+ calibration=tri_cal_fname)
+ assert_meg_snr(sss_py, read_crop(tri_sss_cal_fname), 31, 360)
+ # ctc+cal
+ sss_py = maxwell_filter(raw, coord_frame='meg', regularize=None,
+ calibration=tri_cal_fname,
+ cross_talk=tri_ctc_fname)
+ assert_meg_snr(sss_py, read_crop(tri_sss_ctc_cal_fname), 31, 350)
+ # regularization
+ sss_py = maxwell_filter(raw, coord_frame='meg', regularize='in')
+ sss_mf = read_crop(tri_sss_reg_fname)
+ assert_meg_snr(sss_py, sss_mf, 0.6, 9)
+ _check_reg_match(sss_py, sss_mf, 1)
+ # all three
+ sss_py = maxwell_filter(raw, coord_frame='meg', regularize='in',
+ calibration=tri_cal_fname,
+ cross_talk=tri_ctc_fname)
+ sss_mf = read_crop(tri_sss_ctc_cal_reg_in_fname)
+ assert_meg_snr(sss_py, sss_mf, 0.6, 9)
+ _check_reg_match(sss_py, sss_mf, 1)
+ # tSSS
+ raw = read_crop(tri_fname).fix_mag_coil_types()
+ sss_py = maxwell_filter(raw, coord_frame='meg', regularize=None,
+ st_duration=4., verbose=True)
+ assert_meg_snr(sss_py, read_crop(tri_sss_st4_fname), 700., 1600)
+
run_tests_if_main()
diff --git a/mne/preprocessing/tests/test_ssp.py b/mne/preprocessing/tests/test_ssp.py
index 8a6f3a7..0ec478c 100644
--- a/mne/preprocessing/tests/test_ssp.py
+++ b/mne/preprocessing/tests/test_ssp.py
@@ -5,7 +5,7 @@ from nose.tools import assert_true, assert_equal
from numpy.testing import assert_array_almost_equal
import numpy as np
-from mne.io import Raw
+from mne.io import read_raw_fif
from mne.io.proj import make_projector, activate_proj
from mne.preprocessing.ssp import compute_proj_ecg, compute_proj_eog
from mne.utils import run_tests_if_main
@@ -19,8 +19,8 @@ eog_times = np.array([0.5, 2.3, 3.6, 14.5])
def test_compute_proj_ecg():
- """Test computation of ECG SSP projectors"""
- raw = Raw(raw_fname).crop(0, 10, copy=False)
+ """Test computation of ECG SSP projectors."""
+ raw = read_raw_fif(raw_fname, add_eeg_ref=False).crop(0, 10, copy=False)
raw.load_data()
for average in [False, True]:
# For speed, let's not filter here (must also not reject then)
@@ -29,7 +29,8 @@ def test_compute_proj_ecg():
average=average, avg_ref=True,
no_proj=True, l_freq=None,
h_freq=None, reject=None,
- tmax=dur_use, qrs_threshold=0.5)
+ tmax=dur_use, qrs_threshold=0.5,
+ filter_length=6000)
assert_true(len(projs) == 7)
# heart rate at least 0.5 Hz, but less than 3 Hz
assert_true(events.shape[0] > 0.5 * dur_use and
@@ -55,13 +56,13 @@ def test_compute_proj_ecg():
average=average, avg_ref=True,
no_proj=True, l_freq=None,
h_freq=None, tmax=dur_use)
- assert_equal(len(w), 1)
+ assert_true(len(w) >= 1)
assert_equal(projs, None)
def test_compute_proj_eog():
- """Test computation of EOG SSP projectors"""
- raw = Raw(raw_fname).crop(0, 10, copy=False)
+ """Test computation of EOG SSP projectors."""
+ raw = read_raw_fif(raw_fname, add_eeg_ref=False).crop(0, 10, copy=False)
raw.load_data()
for average in [False, True]:
n_projs_init = len(raw.info['projs'])
@@ -69,7 +70,8 @@ def test_compute_proj_eog():
bads=['MEG 2443'], average=average,
avg_ref=True, no_proj=False,
l_freq=None, h_freq=None,
- reject=None, tmax=dur_use)
+ reject=None, tmax=dur_use,
+ filter_length=6000)
assert_true(len(projs) == (7 + n_projs_init))
assert_true(np.abs(events.shape[0] -
np.sum(np.less(eog_times, dur_use))) <= 1)
@@ -94,26 +96,26 @@ def test_compute_proj_eog():
avg_ref=True, no_proj=False,
l_freq=None, h_freq=None,
tmax=dur_use)
- assert_equal(len(w), 1)
+ assert_true(len(w) >= 1)
assert_equal(projs, None)
def test_compute_proj_parallel():
- """Test computation of ExG projectors using parallelization"""
- raw_0 = Raw(raw_fname).crop(0, 10, copy=False)
+ """Test computation of ExG projectors using parallelization."""
+ raw_0 = read_raw_fif(raw_fname, add_eeg_ref=False).crop(0, 10, copy=False)
raw_0.load_data()
raw = raw_0.copy()
projs, _ = compute_proj_eog(raw, n_mag=2, n_grad=2, n_eeg=2,
bads=['MEG 2443'], average=False,
avg_ref=True, no_proj=False, n_jobs=1,
l_freq=None, h_freq=None, reject=None,
- tmax=dur_use)
+ tmax=dur_use, filter_length=6000)
raw_2 = raw_0.copy()
projs_2, _ = compute_proj_eog(raw_2, n_mag=2, n_grad=2, n_eeg=2,
bads=['MEG 2443'], average=False,
avg_ref=True, no_proj=False, n_jobs=2,
l_freq=None, h_freq=None, reject=None,
- tmax=dur_use)
+ tmax=dur_use, filter_length=6000)
projs = activate_proj(projs)
projs_2 = activate_proj(projs_2)
projs, _, _ = make_projector(projs, raw_2.info['ch_names'],
diff --git a/mne/preprocessing/tests/test_stim.py b/mne/preprocessing/tests/test_stim.py
index eb290c4..0ce802e 100644
--- a/mne/preprocessing/tests/test_stim.py
+++ b/mne/preprocessing/tests/test_stim.py
@@ -8,7 +8,7 @@ import numpy as np
from numpy.testing import assert_array_almost_equal
from nose.tools import assert_true, assert_raises
-from mne.io import Raw
+from mne.io import read_raw_fif
from mne.io.pick import pick_types
from mne.event import read_events
from mne.epochs import Epochs
@@ -20,20 +20,20 @@ event_fname = op.join(data_path, 'test-eve.fif')
def test_fix_stim_artifact():
- """Test fix stim artifact"""
+ """Test fix stim artifact."""
events = read_events(event_fname)
- raw = Raw(raw_fname, preload=False)
+ raw = read_raw_fif(raw_fname, add_eeg_ref=False)
assert_raises(RuntimeError, fix_stim_artifact, raw)
- raw = Raw(raw_fname, preload=True)
+ raw = read_raw_fif(raw_fname, add_eeg_ref=False, preload=True)
# use window before stimulus in epochs
tmin, tmax, event_id = -0.2, 0.5, 1
picks = pick_types(raw.info, meg=True, eeg=True,
eog=True, stim=False, exclude='bads')
epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks,
- preload=True, reject=None)
+ preload=True, reject=None, add_eeg_ref=False)
e_start = int(np.ceil(epochs.info['sfreq'] * epochs.tmin))
tmin, tmax = -0.045, -0.015
tmin_samp = int(-0.035 * epochs.info['sfreq']) - e_start
@@ -72,7 +72,8 @@ def test_fix_stim_artifact():
# get epochs from raw with fixed data
tmin, tmax, event_id = -0.2, 0.5, 1
epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks,
- preload=True, reject=None, baseline=None)
+ preload=True, reject=None, baseline=None,
+ add_eeg_ref=False)
e_start = int(np.ceil(epochs.info['sfreq'] * epochs.tmin))
tmin_samp = int(-0.035 * epochs.info['sfreq']) - e_start
tmax_samp = int(-0.015 * epochs.info['sfreq']) - e_start
diff --git a/mne/preprocessing/tests/test_xdawn.py b/mne/preprocessing/tests/test_xdawn.py
index 6f46134..458a9b2 100644
--- a/mne/preprocessing/tests/test_xdawn.py
+++ b/mne/preprocessing/tests/test_xdawn.py
@@ -1,28 +1,29 @@
# Authors: Alexandre Barachant <alexandre.barachant at gmail.com>
+# Jean-Remi King <jeanremi.king at gmail.com>
#
# License: BSD (3-clause)
import numpy as np
import os.path as op
-from nose.tools import (assert_equal, assert_raises)
-from numpy.testing import assert_array_equal
-from mne import (io, Epochs, read_events, pick_types,
- compute_raw_covariance)
+from nose.tools import assert_equal, assert_raises, assert_true
+from numpy.testing import assert_array_equal, assert_array_almost_equal
+from mne import Epochs, read_events, pick_types, compute_raw_covariance
+from mne.io import read_raw_fif
from mne.utils import requires_sklearn, run_tests_if_main
-from mne.preprocessing.xdawn import Xdawn
+from mne.preprocessing.xdawn import Xdawn, _XdawnTransformer
base_dir = op.join(op.dirname(__file__), '..', '..', 'io', 'tests', 'data')
raw_fname = op.join(base_dir, 'test_raw.fif')
event_name = op.join(base_dir, 'test-eve.fif')
-evoked_nf_name = op.join(base_dir, 'test-nf-ave.fif')
tmin, tmax = -0.1, 0.2
event_id = dict(cond2=2, cond3=3)
def _get_data():
- raw = io.read_raw_fif(raw_fname, add_eeg_ref=False, verbose=False,
- preload=True)
+ """Get data."""
+ raw = read_raw_fif(raw_fname, add_eeg_ref=False, verbose=False,
+ preload=True)
events = read_events(event_name)
picks = pick_types(raw.info, meg=False, eeg=True, stim=False,
ecg=False, eog=False,
@@ -30,56 +31,58 @@ def _get_data():
return raw, events, picks
-def test_xdawn_init():
+def test_xdawn():
"""Test init of xdawn."""
- # init xdawn with good parameters
+ # Init xdawn with good parameters
Xdawn(n_components=2, correct_overlap='auto', signal_cov=None, reg=None)
- # init xdawn with bad parameters
+ # Init xdawn with bad parameters
assert_raises(ValueError, Xdawn, correct_overlap=42)
def test_xdawn_fit():
"""Test Xdawn fit."""
- # get data
+ # Get data
raw, events, picks = _get_data()
epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks,
- preload=True, baseline=None, verbose=False)
+ preload=True, baseline=None, verbose=False,
+ add_eeg_ref=False)
# =========== Basic Fit test =================
- # test base xdawn
- xd = Xdawn(n_components=2, correct_overlap='auto',
- signal_cov=None, reg=None)
+ # Test base xdawn
+ xd = Xdawn(n_components=2, correct_overlap='auto')
xd.fit(epochs)
- # with this parameters, the overlapp correction must be False
- assert_equal(xd.correct_overlap, False)
- # no overlapp correction should give averaged evoked
+ # With these parameters, the overlap correction must be False
+ assert_equal(xd.correct_overlap_, False)
+ # No overlap correction should give averaged evoked
evoked = epochs['cond2'].average()
assert_array_equal(evoked.data, xd.evokeds_['cond2'].data)
# ========== with signal cov provided ====================
- # provide covariance object
+ # Provide covariance object
signal_cov = compute_raw_covariance(raw, picks=picks)
xd = Xdawn(n_components=2, correct_overlap=False,
- signal_cov=signal_cov, reg=None)
+ signal_cov=signal_cov)
xd.fit(epochs)
- # provide ndarray
+ # Provide ndarray
signal_cov = np.eye(len(picks))
xd = Xdawn(n_components=2, correct_overlap=False,
- signal_cov=signal_cov, reg=None)
+ signal_cov=signal_cov)
xd.fit(epochs)
- # provide ndarray of bad shape
+ # Provide ndarray of bad shape
signal_cov = np.eye(len(picks) - 1)
xd = Xdawn(n_components=2, correct_overlap=False,
- signal_cov=signal_cov, reg=None)
+ signal_cov=signal_cov)
assert_raises(ValueError, xd.fit, epochs)
- # provide another type
+ # Provide another type
signal_cov = 42
xd = Xdawn(n_components=2, correct_overlap=False,
- signal_cov=signal_cov, reg=None)
+ signal_cov=signal_cov)
assert_raises(ValueError, xd.fit, epochs)
- # fit with baseline correction and ovverlapp correction should throw an
+ # Fit with baseline correction and overlap correction should throw an
# error
+ # XXX This is a buggy test, the epochs here don't overlap
epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks,
- preload=True, baseline=(None, 0), verbose=False)
+ preload=True, baseline=(None, 0), verbose=False,
+ add_eeg_ref=False)
xd = Xdawn(n_components=2, correct_overlap=True)
assert_raises(ValueError, xd.fit, epochs)
@@ -87,60 +90,135 @@ def test_xdawn_fit():
def test_xdawn_apply_transform():
"""Test Xdawn apply and transform."""
- # get data
+ # Get data
raw, events, picks = _get_data()
- epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks,
- preload=True, baseline=None, verbose=False)
+ raw.pick_types(eeg=True, meg=False)
+ epochs = Epochs(raw, events, event_id, tmin, tmax, proj=False,
+ add_eeg_ref=False, preload=True, baseline=None,
+ verbose=False)
n_components = 2
# Fit Xdawn
- xd = Xdawn(n_components=n_components, correct_overlap='auto')
+ xd = Xdawn(n_components=n_components, correct_overlap=False)
xd.fit(epochs)
- # apply on raw
- xd.apply(raw)
- # apply on epochs
- xd.apply(epochs)
- # apply on evoked
- xd.apply(epochs.average())
- # apply on other thing should raise an error
+ # Apply on different types of instances
+ for inst in [raw, epochs.average(), epochs]:
+ denoise = xd.apply(inst)
+ # Apply on other thing should raise an error
assert_raises(ValueError, xd.apply, 42)
- # transform on epochs
+ # Transform on epochs
xd.transform(epochs)
- # transform on ndarray
+ # Transform on ndarray
xd.transform(epochs._data)
- # transform on someting else
+ # Transform on someting else
assert_raises(ValueError, xd.transform, 42)
+ # Check numerical results with shuffled epochs
+ np.random.seed(0) # random makes unstable linalg
+ idx = np.arange(len(epochs))
+ np.random.shuffle(idx)
+ xd.fit(epochs[idx])
+ denoise_shfl = xd.apply(epochs)
+ assert_array_almost_equal(denoise['cond2']._data,
+ denoise_shfl['cond2']._data)
+
@requires_sklearn
def test_xdawn_regularization():
"""Test Xdawn with regularization."""
- # get data
+ # Get data
raw, events, picks = _get_data()
epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks,
- preload=True, baseline=None, verbose=False)
-
- # test xdawn with overlap correction
- xd = Xdawn(n_components=2, correct_overlap=True,
- signal_cov=None, reg=0.1)
- xd.fit(epochs)
- # ========== with cov regularization ====================
- # ledoit-wolf
- xd = Xdawn(n_components=2, correct_overlap=False,
- signal_cov=np.eye(len(picks)), reg='ledoit_wolf')
- xd.fit(epochs)
- # oas
- xd = Xdawn(n_components=2, correct_overlap=False,
- signal_cov=np.eye(len(picks)), reg='oas')
- xd.fit(epochs)
- # with shrinkage
- xd = Xdawn(n_components=2, correct_overlap=False,
- signal_cov=np.eye(len(picks)), reg=0.1)
+ preload=True, baseline=None, verbose=False,
+ add_eeg_ref=False)
+
+ # Test with overlapping events.
+ # modify events to simulate one overlap
+ events = epochs.events
+ sel = np.where(events[:, 2] == 2)[0][:2]
+ modified_event = events[sel[0]]
+ modified_event[0] += 1
+ epochs.events[sel[1]] = modified_event
+ # Fit and check that overlap was found and applied
+ xd = Xdawn(n_components=2, correct_overlap='auto', reg='oas')
xd.fit(epochs)
- # with bad shrinkage
+ assert_equal(xd.correct_overlap_, True)
+ evoked = epochs['cond2'].average()
+ assert_true(np.sum(np.abs(evoked.data - xd.evokeds_['cond2'].data)))
+
+ # With covariance regularization
+ for reg in [.1, 0.1, 'ledoit_wolf', 'oas']:
+ xd = Xdawn(n_components=2, correct_overlap=False,
+ signal_cov=np.eye(len(picks)), reg=reg)
+ xd.fit(epochs)
+ # With bad shrinkage
xd = Xdawn(n_components=2, correct_overlap=False,
signal_cov=np.eye(len(picks)), reg=2)
assert_raises(ValueError, xd.fit, epochs)
+
+ at requires_sklearn
+def test_XdawnTransformer():
+ """Test _XdawnTransformer."""
+ # Get data
+ raw, events, picks = _get_data()
+ epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks,
+ preload=True, baseline=None, verbose=False,
+ add_eeg_ref=False)
+ X = epochs._data
+ y = epochs.events[:, -1]
+ # Fit
+ xdt = _XdawnTransformer()
+ xdt.fit(X, y)
+ assert_raises(ValueError, xdt.fit, X, y[1:])
+ assert_raises(ValueError, xdt.fit, 'foo')
+
+ # Provide covariance object
+ signal_cov = compute_raw_covariance(raw, picks=picks)
+ xdt = _XdawnTransformer(signal_cov=signal_cov)
+ xdt.fit(X, y)
+ # Provide ndarray
+ signal_cov = np.eye(len(picks))
+ xdt = _XdawnTransformer(signal_cov=signal_cov)
+ xdt.fit(X, y)
+ # Provide ndarray of bad shape
+ signal_cov = np.eye(len(picks) - 1)
+ xdt = _XdawnTransformer(signal_cov=signal_cov)
+ assert_raises(ValueError, xdt.fit, X, y)
+ # Provide another type
+ signal_cov = 42
+ xdt = _XdawnTransformer(signal_cov=signal_cov)
+ assert_raises(ValueError, xdt.fit, X, y)
+
+ # Fit with y as None
+ xdt = _XdawnTransformer()
+ xdt.fit(X)
+
+ # Compare xdawn and _XdawnTransformer
+ xd = Xdawn(correct_overlap=False)
+ xd.fit(epochs)
+
+ xdt = _XdawnTransformer()
+ xdt.fit(X, y)
+ assert_array_almost_equal(xd.filters_['cond2'][:, :2],
+ xdt.filters_.reshape(2, 2, 8)[0].T)
+
+ # Transform testing
+ xdt.transform(X[1:, ...]) # different number of epochs
+ xdt.transform(X[:, :, 1:]) # different number of time
+ assert_raises(ValueError, xdt.transform, X[:, 1:, :])
+ Xt = xdt.transform(X)
+ assert_raises(ValueError, xdt.transform, 42)
+
+ # Inverse transform testing
+ Xinv = xdt.inverse_transform(Xt)
+ assert_equal(Xinv.shape, X.shape)
+ xdt.inverse_transform(Xt[1:, ...])
+ xdt.inverse_transform(Xt[:, :, 1:])
+ # should raise an error if not correct number of components
+ assert_raises(ValueError, xdt.inverse_transform, Xt[:, 1:, :])
+ assert_raises(ValueError, xdt.inverse_transform, 42)
+
+
run_tests_if_main()
diff --git a/mne/preprocessing/xdawn.py b/mne/preprocessing/xdawn.py
index 2e7ad00..eaf9d38 100644
--- a/mne/preprocessing/xdawn.py
+++ b/mne/preprocessing/xdawn.py
@@ -1,160 +1,362 @@
-"""Xdawn implementation."""
# Authors: Alexandre Barachant <alexandre.barachant at gmail.com>
+# Asish Panda <asishrocks95 at gmail.com>
+# Jean-Remi King <jeanremi.king at gmail.com>
#
# License: BSD (3-clause)
-import copy as cp
-
import numpy as np
+import copy as cp
from scipy import linalg
-
-from ..io.base import _BaseRaw
-from ..epochs import _BaseEpochs
-from .. import Covariance, EvokedArray, Evoked, EpochsArray
-from ..io.pick import pick_types
from .ica import _get_fast_dot
+from .. import EvokedArray, Evoked
+from ..cov import Covariance, _regularized_covariance
+from ..decoding import TransformerMixin, BaseEstimator
+from ..epochs import _BaseEpochs, EpochsArray
+from ..io import _BaseRaw
+from ..io.pick import _pick_data_channels
from ..utils import logger
-from ..decoding.mixin import TransformerMixin
-from ..cov import _regularized_covariance
-from ..channels.channels import ContainsMixin
+from ..externals.six import iteritems, itervalues
-def _least_square_evoked(data, events, event_id, tmin, tmax, sfreq):
- """Least square estimation of evoked response from data.
+def _construct_signal_from_epochs(epochs, events, sfreq, tmin):
+ """Reconstruct pseudo continuous signal from epochs."""
+ n_epochs, n_channels, n_times = epochs.shape
+ tmax = tmin + n_times / float(sfreq)
+ start = (np.min(events[:, 0]) + int(tmin * sfreq))
+ stop = (np.max(events[:, 0]) + int(tmax * sfreq) + 1)
+
+ n_samples = stop - start
+ n_epochs, n_channels, n_times = epochs.shape
+ events_pos = events[:, 0] - events[0, 0]
+
+ raw = np.zeros((n_channels, n_samples))
+ for idx in range(n_epochs):
+ onset = events_pos[idx]
+ offset = onset + n_times
+ raw[:, onset:offset] = epochs[idx]
+
+ return raw
+
+
+def _least_square_evoked(epochs_data, events, tmin, sfreq):
+ """Least square estimation of evoked response from epochs data.
Parameters
----------
- data : ndarray, shape (n_channels, n_times)
- The data to estimates evoked
- events : ndarray, shape (n_events, 3)
+ epochs_data : array, shape (n_channels, n_times)
+ The epochs data to estimate evoked.
+ events : array, shape (n_events, 3)
The events typically returned by the read_events function.
If some events don't match the events of interest as specified
by event_id, they will be ignored.
- event_id : dict
- The id of the events to consider
tmin : float
Start time before event.
- tmax : float
- End time after event.
sfreq : float
Sampling frequency.
Returns
-------
- evokeds_data : dict of ndarray
- A dict of evoked data for each event type in event_id.
- toeplitz : dict of ndarray
- A dict of toeplitz matrix for each event type in event_id.
+ evokeds : array, shape (n_class, n_components, n_times)
+ An concatenated array of evoked data for each event type.
+ toeplitz : array, shape (n_class * n_components, n_channels)
+ An concatenated array of toeplitz matrix for each event type.
"""
- nmin = int(tmin * sfreq)
- nmax = int(tmax * sfreq)
-
- window = nmax - nmin
- n_samples = data.shape[1]
- toeplitz_mat = dict()
- full_toep = list()
- for eid in event_id:
+
+ n_epochs, n_channels, n_times = epochs_data.shape
+ tmax = tmin + n_times / float(sfreq)
+
+ # Deal with shuffled epochs
+ events = events.copy()
+ events[:, 0] -= events[0, 0] + int(tmin * sfreq)
+
+ # Contruct raw signal
+ raw = _construct_signal_from_epochs(epochs_data, events, sfreq, tmin)
+
+ # Compute the independent evoked responses per condition, while correcting
+ # for event overlaps.
+ n_min, n_max = int(tmin * sfreq), int(tmax * sfreq)
+ window = n_max - n_min
+ n_samples = raw.shape[1]
+ toeplitz = list()
+ classes = np.unique(events[:, 2])
+ for ii, this_class in enumerate(classes):
# select events by type
- ix_ev = events[:, -1] == event_id[eid]
+ sel = events[:, 2] == this_class
# build toeplitz matrix
trig = np.zeros((n_samples, 1))
- ix_trig = (events[ix_ev, 0]) + nmin
+ ix_trig = (events[sel, 0]) + n_min
trig[ix_trig] = 1
- toep_mat = linalg.toeplitz(trig[0:window], trig)
- toeplitz_mat[eid] = toep_mat
- full_toep.append(toep_mat)
+ toeplitz.append(linalg.toeplitz(trig[0:window], trig))
# Concatenate toeplitz
- full_toep = np.concatenate(full_toep)
+ toeplitz = np.array(toeplitz)
+ X = np.concatenate(toeplitz)
# least square estimation
- predictor = np.dot(linalg.pinv(np.dot(full_toep, full_toep.T)), full_toep)
- all_evokeds = np.dot(predictor, data.T)
- all_evokeds = np.vsplit(all_evokeds, len(event_id))
+ predictor = np.dot(linalg.pinv(np.dot(X, X.T)), X)
+ evokeds = np.dot(predictor, raw.T)
+ evokeds = np.transpose(np.vsplit(evokeds, len(classes)), (0, 2, 1))
+ return evokeds, toeplitz
- # parse evoked response
- evoked_data = dict()
- for idx, eid in enumerate(event_id):
- evoked_data[eid] = all_evokeds[idx].T
- return evoked_data, toeplitz_mat
+def _fit_xdawn(epochs_data, y, n_components, reg=None, signal_cov=None,
+ events=None, tmin=0., sfreq=1.):
+ """Fit filters and coefs using Xdawn Algorithm.
+ Xdawn is a spatial filtering method designed to improve the signal
+ to signal + noise ratio (SSNR) of the event related responses. Xdawn was
+ originally designed for P300 evoked potential by enhancing the target
+ response with respect to the non-target response. This implementation is a
+ generalization to any type of event related response.
-def _check_overlapp(epochs):
- """check if events are overlapped."""
- isi = np.diff(epochs.events[:, 0])
- window = int((epochs.tmax - epochs.tmin) * epochs.info['sfreq'])
- # Events are overlapped if the minimal inter-stimulus interval is smaller
- # than the time window.
- return isi.min() < window
+ Parameters
+ ----------
+ epochs_data : array, shape (n_epochs, n_channels, n_times)
+ The epochs data.
+ y : array, shape (n_epochs)
+ The epochs class.
+ n_components : int (default 2)
+ The number of components to decompose the signals signals.
+ reg : float | str | None (default None)
+ If not None, allow regularization for covariance estimation
+ if float, shrinkage covariance is used (0 <= shrinkage <= 1).
+ if str, optimal shrinkage using Ledoit-Wolf Shrinkage ('ledoit_wolf')
+ or Oracle Approximating Shrinkage ('oas').
+ signal_cov : None | Covariance | array, shape (n_channels, n_channels)
+ The signal covariance used for whitening of the data.
+ if None, the covariance is estimated from the epochs signal.
+ events : array, shape (n_epochs, 3)
+ The epochs events, used to correct for epochs overlap.
+ tmin : float
+ Epochs starting time. Only used if events is passed to correct for
+ epochs overlap.
+ sfreq : float
+ Sampling frequency. Only used if events is passed to correct for
+ epochs overlap.
+ Returns
+ -------
+ filters : array, shape (n_channels, n_channels)
+ The Xdawn components used to decompose the data for each event type.
+ patterns : array, shape (n_channels, n_channels)
+ The Xdawn patterns used to restore the signals for each event type.
+ evokeds : array, shape (n_class, n_components, n_times)
+ The independent evoked responses per condition.
-def _construct_signal_from_epochs(epochs):
- """Reconstruct pseudo continuous signal from epochs."""
- start = (np.min(epochs.events[:, 0]) +
- int(epochs.tmin * epochs.info['sfreq']))
- stop = (np.max(epochs.events[:, 0]) +
- int(epochs.tmax * epochs.info['sfreq']) + 1)
+ References
+ ----------
+ [1] Rivet, B., Souloumiac, A., Attina, V., & Gibert, G. (2009). xDAWN
+ algorithm to enhance evoked potentials: application to brain-computer
+ interface. Biomedical Engineering, IEEE Transactions on, 56(8), 2035-2043.
+ [2] Rivet, B., Cecotti, H., Souloumiac, A., Maby, E., & Mattout, J. (2011,
+ August). Theoretical analysis of xDAWN algorithm: application to an
+ efficient sensor selection in a P300 BCI. In Signal Processing Conference,
+ 2011 19th European (pp. 1382-1386). IEEE.
- n_samples = stop - start
- epochs_data = epochs.get_data()
- n_epochs, n_channels, n_times = epochs_data.shape
- events_pos = epochs.events[:, 0] - epochs.events[0, 0]
- data = np.zeros((n_channels, n_samples))
- for idx in range(n_epochs):
- onset = events_pos[idx]
- offset = onset + n_times
- data[:, onset:offset] = epochs_data[idx]
+ See Also
+ --------
+ CSP
+ XDawn
+ """
+ n_epochs, n_channels, n_times = epochs_data.shape
- return data
+ classes = np.unique(y)
+
+ # Retrieve or compute whitening covariance
+ if signal_cov is None:
+ signal_cov = _regularized_covariance(np.hstack(epochs_data), reg)
+ elif isinstance(signal_cov, Covariance):
+ signal_cov = signal_cov.data
+ if not isinstance(signal_cov, np.ndarray) or (
+ not np.array_equal(signal_cov.shape,
+ np.tile(epochs_data.shape[1], 2))):
+ raise ValueError('signal_cov must be None, a covariance instance, '
+ 'or an array of shape (n_chans, n_chans)')
+
+ # Get prototype events
+ if events is not None:
+ evokeds, toeplitzs = _least_square_evoked(
+ epochs_data, events, tmin, sfreq)
+ else:
+ evokeds, toeplitzs = list(), list()
+ for c in classes:
+ # Prototyped response for each class
+ evokeds.append(np.mean(epochs_data[y == c, :, :], axis=0))
+ toeplitzs.append(1.)
+
+ filters = list()
+ patterns = list()
+ for evo, toeplitz in zip(evokeds, toeplitzs):
+ # Estimate covariance matrix of the prototype response
+ evo = np.dot(evo, toeplitz)
+ evo_cov = np.matrix(_regularized_covariance(evo, reg))
+
+ # Fit spatial filters
+ evals, evecs = linalg.eigh(evo_cov, signal_cov)
+ evecs = evecs[:, np.argsort(evals)[::-1]] # sort eigenvectors
+ evecs /= np.apply_along_axis(np.linalg.norm, 0, evecs)
+ _patterns = np.linalg.pinv(evecs.T)
+ filters.append(evecs[:, :n_components].T)
+ patterns.append(_patterns[:, :n_components].T)
+
+ filters = np.concatenate(filters, axis=0)
+ patterns = np.concatenate(patterns, axis=0)
+ evokeds = np.array(evokeds)
+ return filters, patterns, evokeds
+
+
+class _XdawnTransformer(BaseEstimator, TransformerMixin):
+ """Implementation of the Xdawn Algorithm compatible with scikit-learn.
+ Xdawn is a spatial filtering method designed to improve the signal
+ to signal + noise ratio (SSNR) of the event related responses. Xdawn was
+ originally designed for P300 evoked potential by enhancing the target
+ response with respect to the non-target response. This implementation is a
+ generalization to any type of event related response.
-def least_square_evoked(epochs, return_toeplitz=False):
- """Least square estimation of evoked response from a Epochs instance.
+ .. note:: _XdawnTransformer does not correct for epochs overlap. To correct
+ overlaps see ``Xdawn``.
Parameters
----------
- epochs : Epochs instance
- An instance of Epochs.
- return_toeplitz : bool (default False)
- If true, compute the toeplitz matrix.
+ n_components : int (default 2)
+ The number of components to decompose the signals.
+ reg : float | str | None (default None)
+ If not None, allow regularization for covariance estimation
+ if float, shrinkage covariance is used (0 <= shrinkage <= 1).
+ if str, optimal shrinkage using Ledoit-Wolf Shrinkage ('ledoit_wolf')
+ or Oracle Approximating Shrinkage ('oas').
+ signal_cov : None | Covariance | array, shape (n_channels, n_channels)
+ The signal covariance used for whitening of the data.
+ if None, the covariance is estimated from the epochs signal.
- Returns
- -------
- evokeds : dict of evoked instance
- An dict of evoked instance for each event type in epochs.event_id.
- toeplitz : dict of ndarray
- If return_toeplitz is true, return the toeplitz matrix for each event
- type in epochs.event_id.
+ Attributes
+ ----------
+ classes_ : array, shape (n_classes)
+ The event indices of the classes.
+ filters_ : array, shape (n_channels, n_channels)
+ The Xdawn components used to decompose the data for each event type.
+ patterns_ : array, shape (n_channels, n_channels)
+ The Xdawn patterns used to restore the signals for each event type.
+
+ References
+ ----------
+ [1] Rivet, B., Souloumiac, A., Attina, V., & Gibert, G. (2009). xDAWN
+ algorithm to enhance evoked potentials: application to brain-computer
+ interface. Biomedical Engineering, IEEE Transactions on, 56(8), 2035-2043.
+ [2] Rivet, B., Cecotti, H., Souloumiac, A., Maby, E., & Mattout, J. (2011,
+ August). Theoretical analysis of xDAWN algorithm: application to an
+ efficient sensor selection in a P300 BCI. In Signal Processing Conference,
+ 2011 19th European (pp. 1382-1386). IEEE.
+
+ See Also
+ --------
+ Xdawn
+ CSD
"""
- if not isinstance(epochs, _BaseEpochs):
- raise ValueError('epochs must be an instance of `mne.Epochs`')
- events = epochs.events.copy()
- events[:, 0] -= events[0, 0] + int(epochs.tmin * epochs.info['sfreq'])
- data = _construct_signal_from_epochs(epochs)
- evoked_data, toeplitz = _least_square_evoked(data, events, epochs.event_id,
- tmin=epochs.tmin,
- tmax=epochs.tmax,
- sfreq=epochs.info['sfreq'])
- evokeds = dict()
- info = cp.deepcopy(epochs.info)
- for name, data in evoked_data.items():
- n_events = len(events[events[:, 2] == epochs.event_id[name]])
- evoked = EvokedArray(data, info, tmin=epochs.tmin,
- comment=name, nave=n_events)
- evokeds[name] = evoked
+ def __init__(self, n_components=2, reg=None, signal_cov=None):
+ """Init."""
+ self.n_components = n_components
+ self.signal_cov = signal_cov
+ self.reg = reg
- if return_toeplitz:
- return evokeds, toeplitz
+ def fit(self, X, y=None):
+ """Fit Xdawn spatial filters.
- return evokeds
+ Parameters
+ ----------
+ X : array, shape (n_epochs, n_channels, n_samples)
+ The target data.
+ y : array, shape (n_epochs,) | None
+ The target labels. If None, Xdawn fit on the average evoked.
+ Returns
+ -------
+ self : Xdawn instance
+ The Xdawn instance.
+ """
+ X, y = self._check_Xy(X, y)
+
+ # Main function
+ self.classes_ = np.unique(y)
+ self.filters_, self.patterns_, _ = _fit_xdawn(
+ X, y, n_components=self.n_components, reg=self.reg,
+ signal_cov=self.signal_cov)
+ return self
+
+ def transform(self, X):
+ """Transform data with spatial filters.
+
+ Parameters
+ ----------
+ X : array, shape (n_epochs, n_channels, n_samples)
+ The target data.
+
+ Returns
+ -------
+ X : array, shape (n_epochs, n_components * n_classes, n_samples)
+ The transformed data.
+ """
+ X, _ = self._check_Xy(X)
+
+ # Check size
+ if self.filters_.shape[1] != X.shape[1]:
+ raise ValueError('X must have %i channels, got %i instead.' % (
+ self.filters_.shape[1], X.shape[1]))
-class Xdawn(TransformerMixin, ContainsMixin):
+ # Transform
+ X = np.dot(self.filters_, X)
+ X = X.transpose((1, 0, 2))
+ return X
+
+ def inverse_transform(self, X):
+ """Remove selected components from the signal.
+ Given the unmixing matrix, transform data, zero out components,
+ and inverse transform the data. This procedure will reconstruct
+ the signals from which the dynamics described by the excluded
+ components is subtracted.
+
+ Parameters
+ ----------
+ X : array, shape (n_epochs, n_components * n_classes, n_times)
+ The transformed data.
+
+ Returns
+ -------
+ X : array, shape (n_epochs, n_channels * n_classes, n_times)
+ The inverse transform data.
+ """
+ # Check size
+ X, _ = self._check_Xy(X)
+ n_components, n_channels = self.patterns_.shape
+ n_epochs, n_comp, n_times = X.shape
+ if n_comp != (self.n_components * len(self.classes_)):
+ raise ValueError('X must have %i components, got %i instead' % (
+ self.n_components * len(self.classes_), n_comp))
+
+ # Transform
+ fast_dot = _get_fast_dot()
+ return fast_dot(self.patterns_.T, X).transpose(1, 0, 2)
+
+ def _check_Xy(self, X, y=None):
+ """Check X and y types and dimensions."""
+ # Check data
+ if not isinstance(X, np.ndarray) or X.ndim != 3:
+ raise ValueError('X must be an array of shape (n_epochs, '
+ 'n_channels, n_samples).')
+ if y is None:
+ y = np.ones(len(X))
+ y = np.asarray(y)
+ if len(X) != len(y):
+ raise ValueError('X and y must have the same length')
+ return X, y
+
+
+class Xdawn(_XdawnTransformer):
"""Implementation of the Xdawn Algorithm.
Xdawn is a spatial filtering method designed to improve the signal
@@ -166,14 +368,14 @@ class Xdawn(TransformerMixin, ContainsMixin):
Parameters
----------
n_components : int (default 2)
- The number of components to decompose M/EEG signals.
+ The number of components to decompose the signals.
signal_cov : None | Covariance | ndarray, shape (n_channels, n_channels)
(default None). The signal covariance used for whitening of the data.
if None, the covariance is estimated from the epochs signal.
correct_overlap : 'auto' or bool (default 'auto')
- Apply correction for overlaped ERP for the estimation of evokeds
- responses. if 'auto', the overlapp correction is chosen in function
- of the events in epochs.events.
+ Compute the independent evoked responses per condition, while
+ correcting for event overlaps if any. If 'auto', then
+ overlapp_correction = True if the events do overlap.
reg : float | str | None (default None)
if not None, allow regularization for covariance estimation
if float, shrinkage covariance is used (0 <= shrinkage <= 1).
@@ -182,14 +384,18 @@ class Xdawn(TransformerMixin, ContainsMixin):
Attributes
----------
- filters_ : dict of ndarray
+ ``filters_`` : dict of ndarray
If fit, the Xdawn components used to decompose the data for each event
type, else empty.
- patterns_ : dict of ndarray
- If fit, the Xdawn patterns used to restore M/EEG signals for each event
+ ``patterns_`` : dict of ndarray
+ If fit, the Xdawn patterns used to restore the signals for each event
type, else empty.
- evokeds_ : dict of evoked instance
+ ``evokeds_`` : dict of evoked instance
If fit, the evoked response for each event type.
+ ``event_id_`` : dict of event id
+ The event id.
+ ``correct_overlap_``: bool
+ Whether overlap correction was applied.
Notes
-----
@@ -197,7 +403,6 @@ class Xdawn(TransformerMixin, ContainsMixin):
See Also
--------
- ICA
CSP
References
@@ -211,17 +416,11 @@ class Xdawn(TransformerMixin, ContainsMixin):
efficient sensor selection in a P300 BCI. In Signal Processing Conference,
2011 19th European (pp. 1382-1386). IEEE.
"""
-
def __init__(self, n_components=2, signal_cov=None, correct_overlap='auto',
reg=None):
- """init xdawn."""
- self.n_components = n_components
- self.signal_cov = signal_cov
- self.reg = reg
- self.filters_ = dict()
- self.patterns_ = dict()
- self.evokeds_ = dict()
-
+ """Init."""
+ super(Xdawn, self).__init__(n_components=n_components,
+ signal_cov=signal_cov, reg=reg)
if correct_overlap not in ['auto', True, False]:
raise ValueError('correct_overlap must be a bool or "auto"')
self.correct_overlap = correct_overlap
@@ -232,72 +431,65 @@ class Xdawn(TransformerMixin, ContainsMixin):
Parameters
----------
epochs : Epochs object
- An instance of Epoch on which Xdawn filters will be trained.
+ An instance of Epoch on which Xdawn filters will be fitted.
y : ndarray | None (default None)
- Not used, here for compatibility with decoding API.
+ If None, used epochs.events[:, 2].
Returns
-------
self : Xdawn instance
The Xdawn instance.
"""
- if self.correct_overlap == 'auto':
- self.correct_overlap = _check_overlapp(epochs)
-
- # Extract signal covariance
- if self.signal_cov is None:
- if self.correct_overlap:
- sig_data = _construct_signal_from_epochs(epochs)
- else:
- sig_data = np.hstack(epochs.get_data())
- self.signal_cov_ = _regularized_covariance(sig_data, self.reg)
- elif isinstance(self.signal_cov, Covariance):
- self.signal_cov_ = self.signal_cov.data
- elif isinstance(self.signal_cov, np.ndarray):
- self.signal_cov_ = self.signal_cov
- else:
- raise ValueError('signal_cov must be None, a covariance instance '
- 'or a ndarray')
-
- # estimates evoked covariance
- self.evokeds_cov_ = dict()
- if self.correct_overlap:
- if epochs.baseline is not None:
- raise ValueError('Baseline correction must be None if overlap '
- 'correction activated')
- evokeds, toeplitz = least_square_evoked(epochs,
- return_toeplitz=True)
- else:
- evokeds = dict()
- toeplitz = dict()
- for eid in epochs.event_id:
- evokeds[eid] = epochs[eid].average()
- toeplitz[eid] = 1.0
- self.evokeds_ = evokeds
-
- for eid in epochs.event_id:
- data = np.dot(evokeds[eid].data, toeplitz[eid])
- self.evokeds_cov_[eid] = _regularized_covariance(data, self.reg)
-
- # estimates spatial filters
- for eid in epochs.event_id:
-
- if self.signal_cov_.shape != self.evokeds_cov_[eid].shape:
- raise ValueError('Size of signal cov must be the same as the'
- ' number of channels in epochs')
-
- evals, evecs = linalg.eigh(self.evokeds_cov_[eid],
- self.signal_cov_)
- evecs = evecs[:, np.argsort(evals)[::-1]] # sort eigenvectors
- evecs /= np.sqrt(np.sum(evecs ** 2, axis=0))
-
- self.filters_[eid] = evecs
- self.patterns_[eid] = linalg.inv(evecs.T)
-
- # store some values
- self.ch_names = epochs.ch_names
- self.exclude = list(range(self.n_components, len(self.ch_names)))
- self.event_id = epochs.event_id
+ # Check data
+ if not isinstance(epochs, _BaseEpochs):
+ raise ValueError('epochs must be an Epochs object.')
+ X = epochs.get_data()
+ X = X[:, _pick_data_channels(epochs.info), :]
+ y = epochs.events[:, 2] if y is None else y
+ self.event_id_ = epochs.event_id
+
+ # Check that no baseline was applied with correct overlap
+ correct_overlap = self.correct_overlap
+ if correct_overlap == 'auto':
+ # Events are overlapped if the minimal inter-stimulus
+ # interval is smaller than the time window.
+ isi = np.diff(np.sort(epochs.events[:, 0]))
+ window = int((epochs.tmax - epochs.tmin) * epochs.info['sfreq'])
+ correct_overlap = isi.min() < window
+
+ if epochs.baseline and correct_overlap:
+ raise ValueError('Cannot apply correct_overlap if epochs'
+ ' were baselined.')
+
+ events, tmin, sfreq = None, 0., 1.
+ if correct_overlap:
+ events = epochs.events
+ tmin = epochs.tmin
+ sfreq = epochs.info['sfreq']
+ self.correct_overlap_ = correct_overlap
+
+ # Note: In this original version of Xdawn we compute and keep all
+ # components. The selection comes at transform().
+ n_components = X.shape[1]
+
+ # Main fitting function
+ filters, patterns, evokeds = _fit_xdawn(
+ X, y, n_components=n_components, reg=self.reg,
+ signal_cov=self.signal_cov, events=events, tmin=tmin, sfreq=sfreq)
+
+ # Re-order filters and patterns according to event_id
+ filters = filters.reshape(-1, n_components, filters.shape[-1])
+ patterns = patterns.reshape(-1, n_components, patterns.shape[-1])
+ self.filters_, self.patterns_, self.evokeds_ = dict(), dict(), dict()
+ idx = np.argsort([value for _, value in iteritems(epochs.event_id)])
+ for eid, this_filter, this_pattern, this_evo in zip(
+ epochs.event_id, filters[idx], patterns[idx], evokeds[idx]):
+ self.filters_[eid] = this_filter.T
+ self.patterns_[eid] = this_pattern.T
+ n_events = len(epochs[eid])
+ evoked = EvokedArray(this_evo, epochs.info, tmin=epochs.tmin,
+ comment=eid, nave=n_events)
+ self.evokeds_[eid] = evoked
return self
def transform(self, epochs):
@@ -310,34 +502,28 @@ class Xdawn(TransformerMixin, ContainsMixin):
Returns
-------
- X : ndarray, shape (n_epochs, n_components * event_types, n_times)
+ X : ndarray, shape (n_epochs, n_components * n_event_types, n_times)
Spatially filtered signals.
"""
if isinstance(epochs, _BaseEpochs):
- data = epochs.get_data()
+ X = epochs.get_data()
elif isinstance(epochs, np.ndarray):
- data = epochs
+ X = epochs
else:
- raise ValueError('Data input must be of Epoch '
- 'type or numpy array')
-
- # create full matrix of spatial filter
- full_filters = list()
- for filt in self.filters_.values():
- full_filters.append(filt[:, 0:self.n_components])
- full_filters = np.concatenate(full_filters, axis=1)
+ raise ValueError('Data input must be of Epoch type or numpy array')
- # Apply spatial filters
- X = np.dot(full_filters.T, data)
- X = X.transpose((1, 0, 2))
- return X
+ filters = [filt[:self.n_components]
+ for filt in itervalues(self.filters_)]
+ filters = np.concatenate(filters, axis=0)
+ X = np.dot(filters, X)
+ return X.transpose((1, 0, 2))
def apply(self, inst, event_id=None, include=None, exclude=None):
"""Remove selected components from the signal.
Given the unmixing matrix, transform data,
zero out components, and inverse transform the data.
- This procedure will reconstruct M/EEG signals from which
+ This procedure will reconstruct the signals from which
the dynamics described by the excluded components is subtracted.
Parameters
@@ -363,28 +549,35 @@ class Xdawn(TransformerMixin, ContainsMixin):
event type in event_id.
"""
if event_id is None:
- event_id = self.event_id
+ event_id = self.event_id_
+
+ if not isinstance(inst, (_BaseRaw, _BaseEpochs, Evoked)):
+ raise ValueError('Data input must be Raw, Epochs or Evoked type')
+ picks = _pick_data_channels(inst.info)
+
+ # Define the components to keep
+ default_exclude = list(range(self.n_components, len(inst.ch_names)))
+ if exclude is None:
+ exclude = default_exclude
+ else:
+ exclude = list(set(list(default_exclude) + list(exclude)))
if isinstance(inst, _BaseRaw):
out = self._apply_raw(raw=inst, include=include, exclude=exclude,
- event_id=event_id)
+ event_id=event_id, picks=picks)
elif isinstance(inst, _BaseEpochs):
- out = self._apply_epochs(epochs=inst, include=include,
+ out = self._apply_epochs(epochs=inst, include=include, picks=picks,
exclude=exclude, event_id=event_id)
elif isinstance(inst, Evoked):
- out = self._apply_evoked(evoked=inst, include=include,
+ out = self._apply_evoked(evoked=inst, include=include, picks=picks,
exclude=exclude, event_id=event_id)
- else:
- raise ValueError('Data input must be Raw, Epochs or Evoked type')
return out
- def _apply_raw(self, raw, include, exclude, event_id):
+ def _apply_raw(self, raw, include, exclude, event_id, picks):
"""Aux method."""
if not raw.preload:
raise ValueError('Raw data must be preloaded to apply Xdawn')
- picks = pick_types(raw.info, meg=False, include=self.ch_names,
- exclude='bads')
raws = dict()
for eid in event_id:
data = raw[picks, :][0]
@@ -397,22 +590,12 @@ class Xdawn(TransformerMixin, ContainsMixin):
raws[eid] = raw_r
return raws
- def _apply_epochs(self, epochs, include, exclude, event_id):
+ def _apply_epochs(self, epochs, include, exclude, event_id, picks):
"""Aux method."""
if not epochs.preload:
raise ValueError('Epochs must be preloaded to apply Xdawn')
- picks = pick_types(epochs.info, meg=False, ref_meg=False,
- include=self.ch_names, exclude='bads')
-
# special case where epochs come picked but fit was 'unpicked'.
- if len(picks) != len(self.ch_names):
- raise RuntimeError('Epochs don\'t match fitted data: %i channels '
- 'fitted but %i channels supplied. \nPlease '
- 'provide Epochs compatible with '
- 'xdawn.ch_names' % (len(self.ch_names),
- len(picks)))
-
epochs_dict = dict()
data = np.hstack(epochs.get_data()[:, picks])
@@ -429,20 +612,8 @@ class Xdawn(TransformerMixin, ContainsMixin):
return epochs_dict
- def _apply_evoked(self, evoked, include, exclude, event_id):
+ def _apply_evoked(self, evoked, include, exclude, event_id, picks):
"""Aux method."""
- picks = pick_types(evoked.info, meg=False, ref_meg=False,
- include=self.ch_names,
- exclude='bads')
-
- # special case where evoked come picked but fit was 'unpicked'.
- if len(picks) != len(self.ch_names):
- raise RuntimeError('Evoked does not match fitted data: %i channels'
- ' fitted but %i channels supplied. \nPlease '
- 'provide an Evoked object that\'s compatible '
- 'with xdawn.ch_names' % (len(self.ch_names),
- len(picks)))
-
data = evoked.data[picks]
evokeds = dict()
@@ -459,22 +630,18 @@ class Xdawn(TransformerMixin, ContainsMixin):
def _pick_sources(self, data, include, exclude, eid):
"""Aux method."""
fast_dot = _get_fast_dot()
- if exclude is None:
- exclude = self.exclude
- else:
- exclude = list(set(list(self.exclude) + list(exclude)))
logger.info('Transforming to Xdawn space')
# Apply unmixing
sources = fast_dot(self.filters_[eid].T, data)
- if include not in (None, []):
+ if include not in (None, list()):
mask = np.ones(len(sources), dtype=np.bool)
mask[np.unique(include)] = False
sources[mask] = 0.
logger.info('Zeroing out %i Xdawn components' % mask.sum())
- elif exclude not in (None, []):
+ elif exclude not in (None, list()):
exclude_ = np.unique(exclude)
sources[exclude_] = 0.
logger.info('Zeroing out %i Xdawn components' % len(exclude_))
@@ -482,3 +649,9 @@ class Xdawn(TransformerMixin, ContainsMixin):
data = fast_dot(self.patterns_[eid], sources)
return data
+
+ def inverse_transform(self):
+ """Not implemented, see Xdawn.apply() instead.
+ """
+ # Exists because of _XdawnTransformer
+ raise NotImplementedError('See Xdawn.apply()')
diff --git a/mne/proj.py b/mne/proj.py
index 42082c6..44141c1 100644
--- a/mne/proj.py
+++ b/mne/proj.py
@@ -253,7 +253,7 @@ def compute_proj_raw(raw, start=0, stop=None, duration=1, n_grad=2, n_mag=2,
picks=pick_types(raw.info, meg=True, eeg=True,
eog=True, ecg=True, emg=True,
exclude='bads'),
- reject=reject, flat=flat)
+ reject=reject, flat=flat, add_eeg_ref=False)
data = _compute_cov_epochs(epochs, n_jobs)
info = epochs.info
if not stop:
@@ -285,7 +285,7 @@ def sensitivity_map(fwd, projs=None, ch_type='grad', mode='fixed', exclude=[],
Parameters
----------
- fwd : dict
+ fwd : Forward
The forward operator.
projs : list
List of projection vectors.
diff --git a/mne/realtime/epochs.py b/mne/realtime/epochs.py
index 785a8c3..9e5110b 100644
--- a/mne/realtime/epochs.py
+++ b/mne/realtime/epochs.py
@@ -100,20 +100,21 @@ class RtEpochs(_BaseEpochs):
(will yield equivalent results but be slower).
add_eeg_ref : bool
If True, an EEG average reference will be added (unless one
- already exists).
+ already exists). The default value of True in 0.13 will change to
+ False in 0.14, and the parameter will be removed in 0.15. Use
+ :func:`mne.set_eeg_reference` instead.
isi_max : float
The maximmum time in seconds between epochs. If no epoch
arrives in the next isi_max seconds the RtEpochs stops.
find_events : dict
The arguments to the real-time `find_events` method as a dictionary.
If `find_events` is None, then default values are used.
- Valid keys are 'output' | 'consecutive' | 'min_duration' | 'mask'.
Example (also default values)::
find_events = dict(output='onset', consecutive='increasing',
- min_duration=0, mask=0)
+ min_duration=0, mask=0, mask_type='not_and')
- See mne.find_events for detailed explanation of these options.
+ See :func:`mne.find_events` for detailed explanation of these options.
verbose : bool, str, int, or None
If not None, override default verbose level (see mne.verbose).
Defaults to client.verbose.
@@ -136,8 +137,7 @@ class RtEpochs(_BaseEpochs):
sleep_time=0.1, baseline=(None, 0), picks=None,
name='Unknown', reject=None, flat=None, proj=True,
decim=1, reject_tmin=None, reject_tmax=None, detrend=None,
- add_eeg_ref=True, isi_max=2., find_events=None, verbose=None):
-
+ add_eeg_ref=None, isi_max=2., find_events=None, verbose=None):
info = client.get_measurement_info()
# the measurement info of the data as we receive it
@@ -169,7 +169,8 @@ class RtEpochs(_BaseEpochs):
# find_events default options
self._find_events_kwargs = dict(output='onset',
consecutive='increasing',
- min_duration=0, mask=0)
+ min_duration=0, mask=0,
+ mask_type='not_and')
# update default options if dictionary is provided
if find_events is not None:
self._find_events_kwargs.update(find_events)
diff --git a/mne/realtime/fieldtrip_client.py b/mne/realtime/fieldtrip_client.py
index 3787da3..878f81e 100644
--- a/mne/realtime/fieldtrip_client.py
+++ b/mne/realtime/fieldtrip_client.py
@@ -259,6 +259,8 @@ class FieldTripClient(object):
info = self.info
if picks is not None:
info = pick_info(info, picks)
+ else:
+ picks = range(info['nchan'])
epoch = EpochsArray(data[picks][np.newaxis], info, events)
return epoch
diff --git a/mne/realtime/stim_server_client.py b/mne/realtime/stim_server_client.py
index f06cf0d..21d16c8 100644
--- a/mne/realtime/stim_server_client.py
+++ b/mne/realtime/stim_server_client.py
@@ -88,8 +88,6 @@ class StimServer(object):
Parameters
----------
- ip : str
- IP address of the host where StimServer is running.
port : int
The port to which the stimulation server must bind to.
n_clients : int
@@ -100,10 +98,10 @@ class StimServer(object):
StimClient
"""
- def __init__(self, ip='localhost', port=4218, n_clients=1):
+ def __init__(self, port=4218, n_clients=1):
# Start a threaded TCP server, binding to localhost on specified port
- self._data = _ThreadedTCPServer((ip, port),
+ self._data = _ThreadedTCPServer(('', port),
_TriggerHandler, self)
self.n_clients = n_clients
@@ -246,8 +244,6 @@ class StimClient(object):
@verbose
def __init__(self, host, port=4218, timeout=5.0, verbose=None):
- self._host = host
- self._port = port
try:
logger.info("Setting up client socket")
diff --git a/mne/realtime/tests/test_fieldtrip_client.py b/mne/realtime/tests/test_fieldtrip_client.py
index c17a4a5..bd4eb89 100644
--- a/mne/realtime/tests/test_fieldtrip_client.py
+++ b/mne/realtime/tests/test_fieldtrip_client.py
@@ -74,6 +74,9 @@ def test_fieldtrip_client():
epoch2 = rt_client.get_data_as_epoch(n_samples=5, picks=picks)
n_channels2, n_samples2 = epoch2.get_data().shape[1:]
+ # case of picks=None
+ epoch = rt_client.get_data_as_epoch(n_samples=5)
+
assert_true(tmin_samp2 > tmin_samp1)
assert_true(len(w) >= 1)
assert_equal(n_samples, 5)
diff --git a/mne/realtime/tests/test_mockclient.py b/mne/realtime/tests/test_mockclient.py
index 3e7b158..5bd3217 100644
--- a/mne/realtime/tests/test_mockclient.py
+++ b/mne/realtime/tests/test_mockclient.py
@@ -18,19 +18,21 @@ events = read_events(event_name)
def test_mockclient():
"""Test the RtMockClient."""
- raw = mne.io.read_raw_fif(raw_fname, preload=True, verbose=False)
+ raw = mne.io.read_raw_fif(raw_fname, preload=True, verbose=False,
+ add_eeg_ref=False)
picks = mne.pick_types(raw.info, meg='grad', eeg=False, eog=True,
stim=True, exclude=raw.info['bads'])
event_id, tmin, tmax = 1, -0.2, 0.5
epochs = Epochs(raw, events[:7], event_id=event_id, tmin=tmin, tmax=tmax,
- picks=picks, baseline=(None, 0), preload=True)
+ picks=picks, baseline=(None, 0), preload=True,
+ add_eeg_ref=False)
data = epochs.get_data()
rt_client = MockRtClient(raw)
rt_epochs = RtEpochs(rt_client, event_id, tmin, tmax, picks=picks,
- isi_max=0.5)
+ isi_max=0.5, add_eeg_ref=False)
rt_epochs.start()
rt_client.send_data(rt_epochs, picks, tmin=0, tmax=10, buffer_size=1000)
@@ -44,14 +46,15 @@ def test_mockclient():
def test_get_event_data():
"""Test emulation of realtime data stream."""
- raw = mne.io.read_raw_fif(raw_fname, preload=True, verbose=False)
+ raw = mne.io.read_raw_fif(raw_fname, preload=True, verbose=False,
+ add_eeg_ref=False)
picks = mne.pick_types(raw.info, meg='grad', eeg=False, eog=True,
stim=True, exclude=raw.info['bads'])
event_id, tmin, tmax = 2, -0.1, 0.3
epochs = Epochs(raw, events, event_id=event_id,
tmin=tmin, tmax=tmax, picks=picks, baseline=None,
- preload=True, proj=False)
+ preload=True, proj=False, add_eeg_ref=False)
data = epochs.get_data()[0, :, :]
@@ -66,7 +69,8 @@ def test_get_event_data():
def test_find_events():
"""Test find_events in rt_epochs."""
- raw = mne.io.read_raw_fif(raw_fname, preload=True, verbose=False)
+ raw = mne.io.read_raw_fif(raw_fname, preload=True, verbose=False,
+ add_eeg_ref=False)
picks = mne.pick_types(raw.info, meg='grad', eeg=False, eog=True,
stim=True, exclude=raw.info['bads'])
@@ -94,7 +98,7 @@ def test_find_events():
rt_client = MockRtClient(raw)
rt_epochs = RtEpochs(rt_client, event_id, tmin, tmax, picks=picks,
stim_channel='STI 014', isi_max=0.5,
- find_events=find_events)
+ find_events=find_events, add_eeg_ref=False)
rt_client.send_data(rt_epochs, picks, tmin=0, tmax=10, buffer_size=1000)
rt_epochs.start()
events = [5, 6]
@@ -107,7 +111,7 @@ def test_find_events():
rt_client = MockRtClient(raw)
rt_epochs = RtEpochs(rt_client, event_id, tmin, tmax, picks=picks,
stim_channel='STI 014', isi_max=0.5,
- find_events=find_events)
+ find_events=find_events, add_eeg_ref=False)
rt_client.send_data(rt_epochs, picks, tmin=0, tmax=10, buffer_size=1000)
rt_epochs.start()
events = [5, 6, 5, 6]
@@ -120,7 +124,7 @@ def test_find_events():
rt_client = MockRtClient(raw)
rt_epochs = RtEpochs(rt_client, event_id, tmin, tmax, picks=picks,
stim_channel='STI 014', isi_max=0.5,
- find_events=find_events)
+ find_events=find_events, add_eeg_ref=False)
rt_client.send_data(rt_epochs, picks, tmin=0, tmax=10, buffer_size=1000)
rt_epochs.start()
events = [5]
@@ -133,7 +137,7 @@ def test_find_events():
rt_client = MockRtClient(raw)
rt_epochs = RtEpochs(rt_client, event_id, tmin, tmax, picks=picks,
stim_channel='STI 014', isi_max=0.5,
- find_events=find_events)
+ find_events=find_events, add_eeg_ref=False)
rt_client.send_data(rt_epochs, picks, tmin=0, tmax=10, buffer_size=1000)
rt_epochs.start()
events = [5, 6, 5, 0, 6, 0]
diff --git a/mne/realtime/tests/test_stim_client_server.py b/mne/realtime/tests/test_stim_client_server.py
index b0e5835..b63c8b7 100644
--- a/mne/realtime/tests/test_stim_client_server.py
+++ b/mne/realtime/tests/test_stim_client_server.py
@@ -36,7 +36,7 @@ def test_connection():
thread1.start()
thread2.start()
- with StimServer('localhost', port=4218, n_clients=2) as stim_server:
+ with StimServer(port=4218, n_clients=2) as stim_server:
_server = stim_server
stim_server.start(timeout=10.0) # don't allow test to hang
@@ -54,7 +54,7 @@ def test_connection():
assert_equal(trig1, trig2)
# test timeout for stim_server
- with StimServer('localhost', port=4218) as stim_server:
+ with StimServer(port=4218) as stim_server:
assert_raises(StopIteration, stim_server.start, 0.1)
diff --git a/mne/report.py b/mne/report.py
index 1c6378f..0861ab3 100644
--- a/mne/report.py
+++ b/mne/report.py
@@ -53,10 +53,9 @@ def _fig_to_img(function=None, fig=None, image_format='png',
from matplotlib.figure import Figure
if not isinstance(fig, Figure) and function is None:
from scipy.misc import imread
- mayavi = None
+ mlab = None
try:
- from mayavi import mlab # noqa, mlab imported
- import mayavi
+ from mayavi import mlab # noqa
except: # on some systems importing Mayavi raises SystemExit (!)
warn('Could not import mayavi. Trying to render'
'`mayavi.core.scene.Scene` figure instances'
@@ -70,7 +69,7 @@ def _fig_to_img(function=None, fig=None, image_format='png',
else: # Testing mode
img = np.zeros((2, 2, 3))
- mayavi.mlab.close(fig)
+ mlab.close(fig)
fig = plt.figure()
plt.imshow(img)
plt.axis('off')
@@ -174,6 +173,16 @@ def _is_bad_fname(fname):
return ''
+def _get_fname(fname):
+ """Get fname without -#-"""
+ if '-#-' in fname:
+ fname = fname.split('-#-')[0]
+ else:
+ fname = op.basename(fname)
+ fname = ' ... %s' % fname
+ return fname
+
+
def _get_toc_property(fname):
"""Auxiliary function to assign class names to TOC
list elements to allow toggling with buttons.
@@ -814,6 +823,25 @@ class Report(object):
self._init_render() # Initialize the renderer
+ def __repr__(self):
+ """Print useful info about report."""
+ s = '<Report | %d items' % len(self.fnames)
+ if self.title is not None:
+ s += ' | %s' % self.title
+ fnames = [_get_fname(f) for f in self.fnames]
+ if len(self.fnames) > 4:
+ s += '\n%s' % '\n'.join(fnames[:2])
+ s += '\n ...\n'
+ s += '\n'.join(fnames[-2:])
+ elif len(self.fnames) > 0:
+ s += '\n%s' % '\n'.join(fnames)
+ s += '\n>'
+ return s
+
+ def __len__(self):
+ """The number of items in report."""
+ return len(self.fnames)
+
def _get_id(self):
"""Get id of plot.
"""
@@ -1355,7 +1383,7 @@ class Report(object):
html.append(this_html)
fnames.append(fname)
sectionlabels.append(sectionlabel)
- logger.info('\t... %s' % fname[-20:])
+ logger.info(_get_fname(fname))
color = _is_bad_fname(fname)
div_klass, tooltip, text = _get_toc_property(fname)
diff --git a/mne/selection.py b/mne/selection.py
index 22d8148..f3e3e9d 100644
--- a/mne/selection.py
+++ b/mne/selection.py
@@ -6,9 +6,16 @@
from os import path
+import numpy as np
+
from .io.meas_info import Info
-from . import pick_types
-from .utils import logger, verbose
+from .io.pick import _pick_data_channels, pick_types
+from .utils import logger, verbose, _get_stim_channel
+
+_SELECTIONS = ['Vertex', 'Left-temporal', 'Right-temporal', 'Left-parietal',
+ 'Right-parietal', 'Left-occipital', 'Right-occipital',
+ 'Left-frontal', 'Right-frontal']
+_EEG_SELECTIONS = ['EEG 1-32', 'EEG 33-64', 'EEG 65-96', 'EEG 97-128']
@verbose
@@ -114,3 +121,68 @@ def read_selection(name, fname=None, info=None, verbose=None):
if spacing == 'new': # "new" or "old" by now, "old" is default
sel = [s.replace('MEG ', 'MEG') for s in sel]
return sel
+
+
+def _divide_to_regions(info, add_stim=True):
+ """Divides channels to regions by positions."""
+ from scipy.stats import zscore
+ picks = _pick_data_channels(info, exclude=[])
+ chs_in_lobe = len(picks) // 4
+ pos = np.array([ch['loc'][:3] for ch in info['chs']])
+ x, y, z = pos.T
+
+ frontal = picks[np.argsort(y[picks])[-chs_in_lobe:]]
+ picks = np.setdiff1d(picks, frontal)
+
+ occipital = picks[np.argsort(y[picks])[:chs_in_lobe]]
+ picks = np.setdiff1d(picks, occipital)
+
+ temporal = picks[np.argsort(z[picks])[:chs_in_lobe]]
+ picks = np.setdiff1d(picks, temporal)
+
+ lt, rt = _divide_side(temporal, x)
+ lf, rf = _divide_side(frontal, x)
+ lo, ro = _divide_side(occipital, x)
+ lp, rp = _divide_side(picks, x) # Parietal lobe from the remaining picks.
+
+ # Because of the way the sides are divided, there may be outliers in the
+ # temporal lobes. Here we switch the sides for these outliers. For other
+ # lobes it is not a big problem because of the vicinity of the lobes.
+ zs = np.abs(zscore(x[rt]))
+ outliers = np.array(rt)[np.where(zs > 2.)[0]]
+ rt = list(np.setdiff1d(rt, outliers))
+
+ zs = np.abs(zscore(x[lt]))
+ outliers = np.append(outliers, (np.array(lt)[np.where(zs > 2.)[0]]))
+ lt = list(np.setdiff1d(lt, outliers))
+
+ l_mean = np.mean(x[lt])
+ r_mean = np.mean(x[rt])
+ for outlier in outliers:
+ if abs(l_mean - x[outlier]) < abs(r_mean - x[outlier]):
+ lt.append(outlier)
+ else:
+ rt.append(outlier)
+
+ if add_stim:
+ stim_ch = _get_stim_channel(None, info, raise_error=False)
+ if len(stim_ch) > 0:
+ for region in [lf, rf, lo, ro, lp, rp, lt, rt]:
+ region.append(info['ch_names'].index(stim_ch[0]))
+ return {'Left-frontal': lf, 'Right-frontal': rf, 'Left-parietal': lp,
+ 'Right-parietal': rp, 'Left-occipital': lo, 'Right-occipital': ro,
+ 'Left-temporal': lt, 'Right-temporal': rt}
+
+
+def _divide_side(lobe, x):
+ """Helper for making a separation between left and right lobe evenly."""
+ lobe = np.asarray(lobe)
+ median = np.median(x[lobe])
+
+ left = lobe[np.where(x[lobe] < median)[0]]
+ right = lobe[np.where(x[lobe] > median)[0]]
+ medians = np.where(x[lobe] == median)[0]
+
+ left = np.sort(np.concatenate([left, lobe[medians[1::2]]]))
+ right = np.sort(np.concatenate([right, lobe[medians[::2]]]))
+ return list(left), list(right)
diff --git a/mne/simulation/evoked.py b/mne/simulation/evoked.py
index a88f448..b137222 100644
--- a/mne/simulation/evoked.py
+++ b/mne/simulation/evoked.py
@@ -18,9 +18,16 @@ def simulate_evoked(fwd, stc, info, cov, snr=3., tmin=None, tmax=None,
iir_filter=None, random_state=None, verbose=None):
"""Generate noisy evoked data
+ .. note:: No projections from ``info`` will be present in the
+ output ``evoked``. You can use e.g.
+ :func:`evoked.add_proj <mne.Evoked.add_proj>` or
+ :func:`evoked.add_eeg_average_proj
+ <mne.Evoked.add_eeg_average_proj>`
+ to add them afterward as necessary.
+
Parameters
----------
- fwd : dict
+ fwd : Forward
a forward solution.
stc : SourceEstimate object
The source time courses.
@@ -49,6 +56,12 @@ def simulate_evoked(fwd, stc, info, cov, snr=3., tmin=None, tmax=None,
evoked : Evoked object
The simulated evoked data
+ See Also
+ --------
+ simulate_raw
+ simulate_stc
+ simulate_sparse_stc
+
Notes
-----
.. versionadded:: 0.10.0
@@ -56,8 +69,8 @@ def simulate_evoked(fwd, stc, info, cov, snr=3., tmin=None, tmax=None,
evoked = apply_forward(fwd, stc, info)
if snr < np.inf:
noise = simulate_noise_evoked(evoked, cov, iir_filter, random_state)
- evoked_noise = add_noise_evoked(evoked, noise, snr,
- tmin=tmin, tmax=tmax)
+ evoked_noise = add_noise_evoked(evoked, noise, snr, tmin=tmin,
+ tmax=tmax)
else:
evoked_noise = evoked
return evoked_noise
@@ -98,6 +111,12 @@ def _generate_noise(info, cov, iir_filter, random_state, n_samples, zi=None):
"""Helper to create spatially colored and temporally IIR-filtered noise"""
from scipy.signal import lfilter
noise_cov = pick_channels_cov(cov, include=info['ch_names'], exclude=[])
+ if set(info['ch_names']) != set(noise_cov.ch_names):
+ raise ValueError('Evoked and covariance channel names are not '
+ 'identical. Cannot generate the noise matrix. '
+ 'Channels missing in covariance %s.' %
+ np.setdiff1d(info['ch_names'], noise_cov.ch_names))
+
rng = check_random_state(random_state)
c = np.diag(noise_cov.data) if noise_cov['diag'] else noise_cov.data
mu_channels = np.zeros(len(c))
diff --git a/mne/simulation/raw.py b/mne/simulation/raw.py
index 43a4817..bd16ad4 100644
--- a/mne/simulation/raw.py
+++ b/mne/simulation/raw.py
@@ -44,7 +44,10 @@ def simulate_raw(raw, stc, trans, src, bem, cov='simple',
blink=False, ecg=False, chpi=False, head_pos=None,
mindist=1.0, interp='cos2', iir_filter=None, n_jobs=1,
random_state=None, verbose=None):
- """Simulate raw data with head movements
+ """Simulate raw data
+
+ Head movements can optionally be simulated using the ``head_pos``
+ parameter.
Parameters
----------
@@ -114,6 +117,9 @@ def simulate_raw(raw, stc, trans, src, bem, cov='simple',
See Also
--------
read_head_pos
+ simulate_evoked
+ simulate_stc
+ simalute_sparse_stc
Notes
-----
diff --git a/mne/simulation/source.py b/mne/simulation/source.py
index 20ff6b9..90b6588 100644
--- a/mne/simulation/source.py
+++ b/mne/simulation/source.py
@@ -9,10 +9,13 @@ import numpy as np
from ..source_estimate import SourceEstimate, VolSourceEstimate
from ..source_space import _ensure_src
from ..utils import check_random_state, warn
+
+from ..externals.six import string_types
from ..externals.six.moves import zip
-def select_source_in_label(src, label, random_state=None):
+def select_source_in_label(src, label, random_state=None, location='random',
+ subject=None, subjects_dir=None, surf='sphere'):
"""Select source positions using a label
Parameters
@@ -23,6 +26,34 @@ def select_source_in_label(src, label, random_state=None):
the label (read with mne.read_label)
random_state : None | int | np.random.RandomState
To specify the random generator state.
+ location : str
+ The label location to choose. Can be 'random' (default) or 'center'
+ to use :func:`mne.Label.center_of_mass` (restricting to vertices
+ both in the label and in the source space). Note that for 'center'
+ mode the label values are used as weights.
+
+ .. versionadded:: 0.13
+
+ subject : string | None
+ The subject the label is defined for.
+ Only used with ``location='center'``.
+
+ .. versionadded:: 0.13
+
+ subjects_dir : str, or None
+ Path to the SUBJECTS_DIR. If None, the path is obtained by using
+ the environment variable SUBJECTS_DIR.
+ Only used with ``location='center'``.
+
+ .. versionadded:: 0.13
+
+ surf : str
+ The surface to use for Euclidean distance center of mass
+ finding. The default here is "sphere", which finds the center
+ of mass on the spherical surface to help avoid potential issues
+ with cortical folding.
+
+ .. versionadded:: 0.13
Returns
-------
@@ -33,29 +64,39 @@ def select_source_in_label(src, label, random_state=None):
"""
lh_vertno = list()
rh_vertno = list()
+ if not isinstance(location, string_types) or \
+ location not in ('random', 'center'):
+ raise ValueError('location must be "random" or "center", got %s'
+ % (location,))
rng = check_random_state(random_state)
-
if label.hemi == 'lh':
- src_sel_lh = np.intersect1d(src[0]['vertno'], label.vertices)
- idx_select = rng.randint(0, len(src_sel_lh), 1)
- lh_vertno.append(src_sel_lh[idx_select][0])
+ vertno = lh_vertno
+ hemi_idx = 0
else:
- src_sel_rh = np.intersect1d(src[1]['vertno'], label.vertices)
- idx_select = rng.randint(0, len(src_sel_rh), 1)
- rh_vertno.append(src_sel_rh[idx_select][0])
-
+ vertno = rh_vertno
+ hemi_idx = 1
+ src_sel = np.intersect1d(src[hemi_idx]['vertno'], label.vertices)
+ if location == 'random':
+ idx = src_sel[rng.randint(0, len(src_sel), 1)[0]]
+ else: # 'center'
+ idx = label.center_of_mass(
+ subject, restrict_vertices=src_sel, subjects_dir=subjects_dir,
+ surf=surf)
+ vertno.append(idx)
return lh_vertno, rh_vertno
def simulate_sparse_stc(src, n_dipoles, times,
data_fun=lambda t: 1e-7 * np.sin(20 * np.pi * t),
- labels=None, random_state=None):
+ labels=None, random_state=None, location='random',
+ subject=None, subjects_dir=None, surf='sphere'):
"""Generate sparse (n_dipoles) sources time courses from data_fun
- This function randomly selects n_dipoles vertices in the whole cortex
- or one single vertex in each label if labels is not None. It uses data_fun
- to generate waveforms for each vertex.
+ This function randomly selects ``n_dipoles`` vertices in the whole
+ cortex or one single vertex (randomly in or in the center of) each
+ label if ``labels is not None``. It uses ``data_fun`` to generate
+ waveforms for each vertex.
Parameters
----------
@@ -74,18 +115,57 @@ def simulate_sparse_stc(src, n_dipoles, times,
The labels. The default is None, otherwise its size must be n_dipoles.
random_state : None | int | np.random.RandomState
To specify the random generator state.
+ location : str
+ The label location to choose. Can be 'random' (default) or 'center'
+ to use :func:`mne.Label.center_of_mass`. Note that for 'center'
+ mode the label values are used as weights.
+
+ .. versionadded:: 0.13
+
+ subject : string | None
+ The subject the label is defined for.
+ Only used with ``location='center'``.
+
+ .. versionadded:: 0.13
+
+ subjects_dir : str, or None
+ Path to the SUBJECTS_DIR. If None, the path is obtained by using
+ the environment variable SUBJECTS_DIR.
+ Only used with ``location='center'``.
+
+ .. versionadded:: 0.13
+
+ surf : str
+ The surface to use for Euclidean distance center of mass
+ finding. The default here is "sphere", which finds the center
+ of mass on the spherical surface to help avoid potential issues
+ with cortical folding.
+
+ .. versionadded:: 0.13
Returns
-------
stc : SourceEstimate
The generated source time courses.
+ See Also
+ --------
+ simulate_raw
+ simulate_evoked
+ simulate_stc
+
Notes
-----
.. versionadded:: 0.10.0
"""
rng = check_random_state(random_state)
src = _ensure_src(src, verbose=False)
+ subject_src = src[0].get('subject_his_id')
+ if subject is None:
+ subject = subject_src
+ elif subject_src is not None and subject != subject_src:
+ raise ValueError('subject argument (%s) did not match the source '
+ 'space subject_his_id (%s)' % (subject, subject_src))
data = np.zeros((n_dipoles, len(times)))
for i_dip in range(n_dipoles):
data[i_dip, :] = data_fun(times)
@@ -109,7 +189,8 @@ def simulate_sparse_stc(src, n_dipoles, times,
lh_data = [np.empty((0, data.shape[1]))]
rh_data = [np.empty((0, data.shape[1]))]
for i, label in enumerate(labels):
- lh_vertno, rh_vertno = select_source_in_label(src, label, rng)
+ lh_vertno, rh_vertno = select_source_in_label(
+ src, label, rng, location, subject, subjects_dir, surf)
vertno[0] += lh_vertno
vertno[1] += rh_vertno
if len(lh_vertno) != 0:
@@ -131,7 +212,7 @@ def simulate_sparse_stc(src, n_dipoles, times,
tmin, tstep = times[0], np.diff(times[:2])[0]
assert datas.shape == data.shape
cls = SourceEstimate if len(vs) == 2 else VolSourceEstimate
- stc = cls(datas, vertices=vs, tmin=tmin, tstep=tstep)
+ stc = cls(datas, vertices=vs, tmin=tmin, tstep=tstep, subject=subject)
return stc
@@ -141,20 +222,9 @@ def simulate_stc(src, labels, stc_data, tmin, tstep, value_fun=None):
This function generates a source estimate with extended sources by
filling the labels with the waveforms given in stc_data.
- By default, the vertices within a label are assigned the same waveform.
- The waveforms can be scaled for each vertex by using the label values
- and value_fun. E.g.,
-
- # create a source label where the values are the distance from the center
- labels = circular_source_labels('sample', 0, 10, 0)
-
- # sources with decaying strength (x will be the distance from the center)
- fun = lambda x: exp(- x / 10)
- stc = generate_stc(fwd, labels, stc_data, tmin, tstep, fun)
-
Parameters
----------
- src : list of dict
+ src : instance of SourceSpaces
The source space
labels : list of Labels
The labels
@@ -164,13 +234,21 @@ def simulate_stc(src, labels, stc_data, tmin, tstep, value_fun=None):
The beginning of the timeseries
tstep : float
The time step (1 / sampling frequency)
- value_fun : function
- Function to apply to the label values
+ value_fun : function | None
+ Function to apply to the label values to obtain the waveform
+ scaling for each vertex in the label. If None (default), uniform
+ scaling is used.
Returns
-------
stc : SourceEstimate
The generated source time courses.
+
+ See Also
+ --------
+ simulate_raw
+ simulate_evoked
+ simulate_sparse_stc
"""
if len(labels) != len(stc_data):
raise ValueError('labels and stc_data must have the same length')
@@ -201,6 +279,11 @@ def simulate_stc(src, labels, stc_data, tmin, tstep, value_fun=None):
elif len(vertno[idx]) == 1:
vertno[idx] = vertno[idx][0]
vertno = [np.array(v) for v in vertno]
+ for v, hemi in zip(vertno, ('left', 'right')):
+ d = len(v) - len(np.unique(v))
+ if d > 0:
+ raise RuntimeError('Labels had %s overlaps in the %s hemisphere, '
+ 'they must be non-overlapping' % (d, hemi))
# the data is in the order left, right
data = list()
@@ -216,5 +299,7 @@ def simulate_stc(src, labels, stc_data, tmin, tstep, value_fun=None):
data = np.concatenate(data)
- stc = SourceEstimate(data, vertices=vertno, tmin=tmin, tstep=tstep)
+ subject = src[0].get('subject_his_id')
+ stc = SourceEstimate(data, vertices=vertno, tmin=tmin, tstep=tstep,
+ subject=subject)
return stc
diff --git a/mne/simulation/tests/test_evoked.py b/mne/simulation/tests/test_evoked.py
index 262a670..00418f9 100644
--- a/mne/simulation/tests/test_evoked.py
+++ b/mne/simulation/tests/test_evoked.py
@@ -5,7 +5,8 @@
import os.path as op
import numpy as np
-from numpy.testing import assert_array_almost_equal, assert_array_equal
+from numpy.testing import (assert_array_almost_equal, assert_array_equal,
+ assert_almost_equal)
from nose.tools import assert_true, assert_raises
import warnings
@@ -13,7 +14,7 @@ from mne.datasets import testing
from mne import read_forward_solution
from mne.simulation import simulate_sparse_stc, simulate_evoked
from mne import read_cov
-from mne.io import Raw
+from mne.io import read_raw_fif
from mne import pick_types_forward, read_evokeds
from mne.utils import run_tests_if_main
@@ -32,9 +33,9 @@ cov_fname = op.join(op.dirname(__file__), '..', '..', 'io', 'tests',
@testing.requires_testing_data
def test_simulate_evoked():
- """ Test simulation of evoked data """
+ """Test simulation of evoked data."""
- raw = Raw(raw_fname)
+ raw = read_raw_fif(raw_fname, add_eeg_ref=False)
fwd = read_forward_solution(fwd_fname, force_fixed=True)
fwd = pick_types_forward(fwd, meg=True, eeg=True, exclude=raw.info['bads'])
cov = read_cov(cov_fname)
@@ -64,6 +65,7 @@ def test_simulate_evoked():
stc_bad = stc.copy()
mv = np.max(fwd['src'][0]['vertno'][fwd['src'][0]['inuse']])
stc_bad.vertices[0][0] = mv + 1
+
assert_raises(RuntimeError, simulate_evoked, fwd, stc_bad,
evoked_template.info, cov, snr, tmin=0.0, tmax=0.2)
evoked_1 = simulate_evoked(fwd, stc, evoked_template.info, cov, np.inf,
@@ -72,4 +74,22 @@ def test_simulate_evoked():
tmin=0.0, tmax=0.2)
assert_array_equal(evoked_1.data, evoked_2.data)
+ # test snr definition in dB
+ evoked_noise = simulate_evoked(fwd, stc, evoked_template.info, cov,
+ snr=snr, tmin=None, tmax=None,
+ iir_filter=None)
+ evoked_clean = simulate_evoked(fwd, stc, evoked_template.info, cov,
+ snr=np.inf, tmin=None, tmax=None,
+ iir_filter=None)
+ noise = evoked_noise.data - evoked_clean.data
+
+ empirical_snr = 10 * np.log10(np.mean((evoked_clean.data ** 2).ravel()) /
+ np.mean((noise ** 2).ravel()))
+
+ assert_almost_equal(snr, empirical_snr, decimal=5)
+
+ cov['names'] = cov.ch_names[:-2] # Error channels are different.
+ assert_raises(ValueError, simulate_evoked, fwd, stc, evoked_template.info,
+ cov, snr=3., tmin=None, tmax=None, iir_filter=None)
+
run_tests_if_main()
diff --git a/mne/simulation/tests/test_raw.py b/mne/simulation/tests/test_raw.py
index 45f35b1..0b1a310 100644
--- a/mne/simulation/tests/test_raw.py
+++ b/mne/simulation/tests/test_raw.py
@@ -10,19 +10,20 @@ from copy import deepcopy
import numpy as np
from numpy.testing import assert_allclose, assert_array_equal
-from nose.tools import assert_true, assert_raises
+from nose.tools import assert_true, assert_raises, assert_equal
from mne import (read_source_spaces, pick_types, read_trans, read_cov,
- make_sphere_model, create_info, setup_volume_source_space)
+ make_sphere_model, create_info, setup_volume_source_space,
+ find_events, Epochs, fit_dipole, transform_surface_to,
+ make_ad_hoc_cov, SourceEstimate, setup_source_space)
from mne.chpi import (_calculate_chpi_positions, read_head_pos,
_get_hpi_info, head_pos_to_trans_rot_t)
from mne.tests.test_chpi import _compare_positions
from mne.datasets import testing
from mne.simulation import simulate_sparse_stc, simulate_raw
-from mne.io import Raw, RawArray
+from mne.io import read_raw_fif, RawArray
from mne.time_frequency import psd_welch
from mne.utils import _TempDir, run_tests_if_main, requires_version, slow_test
-from mne.fixes import isclose
warnings.simplefilter('always')
@@ -33,7 +34,8 @@ cov_fname = op.join(data_path, 'MEG', 'sample',
'sample_audvis_trunc-cov.fif')
trans_fname = op.join(data_path, 'MEG', 'sample',
'sample_audvis_trunc-trans.fif')
-bem_path = op.join(data_path, 'subjects', 'sample', 'bem')
+subjects_dir = op.join(data_path, 'subjects')
+bem_path = op.join(subjects_dir, 'sample', 'bem')
src_fname = op.join(bem_path, 'sample-oct-2-src.fif')
bem_fname = op.join(bem_path, 'sample-320-320-320-bem-sol.fif')
@@ -42,7 +44,7 @@ pos_fname = op.join(data_path, 'SSS', 'test_move_anon_raw_subsampled.pos')
def _make_stc(raw, src):
- """Helper to make a STC"""
+ """Helper to make a STC."""
seed = 42
sfreq = raw.info['sfreq'] # Hz
tstep = 1. / sfreq
@@ -53,13 +55,15 @@ def _make_stc(raw, src):
def _get_data():
- """Helper to get some starting data"""
+ """Helper to get some starting data."""
# raw with ECG channel
- raw = Raw(raw_fname).crop(0., 5.0, copy=False).load_data()
+ raw = read_raw_fif(raw_fname, add_eeg_ref=False)
+ raw.crop(0., 5.0, copy=False).load_data()
data_picks = pick_types(raw.info, meg=True, eeg=True)
other_picks = pick_types(raw.info, meg=False, stim=True, eog=True)
picks = np.sort(np.concatenate((data_picks[::16], other_picks)))
raw = raw.pick_channels([raw.ch_names[p] for p in picks])
+ raw.info.normalize_proj()
ecg = RawArray(np.zeros((1, len(raw.times))),
create_info(['ECG 063'], raw.info['sfreq'], 'ecg'))
for key in ('dev_head_t', 'buffer_size_sec', 'highpass', 'lowpass',
@@ -76,7 +80,7 @@ def _get_data():
@testing.requires_testing_data
def test_simulate_raw_sphere():
- """Test simulation of raw data with sphere model"""
+ """Test simulation of raw data with sphere model."""
seed = 42
raw, src, stc, trans, sphere = _get_data()
assert_true(len(pick_types(raw.info, meg=False, ecg=True)) == 1)
@@ -107,7 +111,8 @@ def test_simulate_raw_sphere():
test_outname = op.join(tempdir, 'sim_test_raw.fif')
raw_sim.save(test_outname)
- raw_sim_loaded = Raw(test_outname, preload=True, proj=False)
+ raw_sim_loaded = read_raw_fif(test_outname, preload=True, proj=False,
+ add_eeg_ref=False)
assert_allclose(raw_sim_loaded[:][0], raw_sim[:][0], rtol=1e-6, atol=1e-20)
del raw_sim, raw_sim_2
# with no cov (no noise) but with artifacts, most time periods should match
@@ -122,11 +127,11 @@ def test_simulate_raw_sphere():
picks = np.arange(len(raw.ch_names))
diff_picks = pick_types(raw.info, meg=False, ecg=ecg, eog=eog)
these_picks = np.setdiff1d(picks, diff_picks)
- close = isclose(raw_sim_3[these_picks][0],
- raw_sim_4[these_picks][0], atol=1e-20)
+ close = np.isclose(raw_sim_3[these_picks][0],
+ raw_sim_4[these_picks][0], atol=1e-20)
assert_true(np.mean(close) > 0.7)
- far = ~isclose(raw_sim_3[diff_picks][0],
- raw_sim_4[diff_picks][0], atol=1e-20)
+ far = ~np.isclose(raw_sim_3[diff_picks][0],
+ raw_sim_4[diff_picks][0], atol=1e-20)
assert_true(np.mean(far) > 0.99)
del raw_sim_3, raw_sim_4
@@ -191,13 +196,15 @@ def test_simulate_raw_sphere():
@testing.requires_testing_data
def test_simulate_raw_bem():
- """Test simulation of raw data with BEM"""
- seed = 42
+ """Test simulation of raw data with BEM."""
raw, src, stc, trans, sphere = _get_data()
- raw_sim_sph = simulate_raw(raw, stc, trans, src, sphere, cov=None,
- ecg=True, blink=True, random_state=seed)
+ src = setup_source_space('sample', None, 'oct1', subjects_dir=subjects_dir)
+ # use different / more complete STC here
+ vertices = [s['vertno'] for s in src]
+ stc = SourceEstimate(np.eye(sum(len(v) for v in vertices)), vertices,
+ 0, 1. / raw.info['sfreq'])
+ raw_sim_sph = simulate_raw(raw, stc, trans, src, sphere, cov=None)
raw_sim_bem = simulate_raw(raw, stc, trans, src, bem_fname, cov=None,
- ecg=True, blink=True, random_state=seed,
n_jobs=2)
# some components (especially radial) might not match that well,
# so just make sure that most components have high correlation
@@ -206,7 +213,25 @@ def test_simulate_raw_bem():
n_ch = len(picks)
corr = np.corrcoef(raw_sim_sph[picks][0], raw_sim_bem[picks][0])
assert_array_equal(corr.shape, (2 * n_ch, 2 * n_ch))
- assert_true(np.median(np.diag(corr[:n_ch, -n_ch:])) > 0.9)
+ assert_true(np.median(np.diag(corr[:n_ch, -n_ch:])) > 0.65)
+ # do some round-trip localization
+ for s in src:
+ transform_surface_to(s, 'head', trans)
+ locs = np.concatenate([s['rr'][s['vertno']] for s in src])
+ tmax = (len(locs) - 1) / raw.info['sfreq']
+ cov = make_ad_hoc_cov(raw.info)
+ # The tolerance for the BEM is surprisingly high (28) but I get the same
+ # result when using MNE-C and Xfit, even when using a proper 5120 BEM :(
+ for use_raw, bem, tol in ((raw_sim_sph, sphere, 1),
+ (raw_sim_bem, bem_fname, 28)):
+ events = find_events(use_raw, 'STI 014')
+ assert_equal(len(locs), 12) # oct1 count
+ evoked = Epochs(use_raw, events, 1, 0, tmax, baseline=None,
+ add_eeg_ref=False).average()
+ assert_equal(len(evoked.times), len(locs))
+ fits = fit_dipole(evoked, cov, bem, trans, min_dist=1.)[0].pos
+ diffs = np.sqrt(np.sum((locs - fits) ** 2, axis=-1)) * 1000
+ assert_true(np.median(diffs) < tol)
@slow_test
@@ -214,8 +239,9 @@ def test_simulate_raw_bem():
@requires_version('scipy', '0.12')
@testing.requires_testing_data
def test_simulate_raw_chpi():
- """Test simulation of raw data with cHPI"""
- raw = Raw(raw_chpi_fname, allow_maxshield='yes')
+ """Test simulation of raw data with cHPI."""
+ raw = read_raw_fif(raw_chpi_fname, allow_maxshield='yes',
+ add_eeg_ref=False)
sphere = make_sphere_model('auto', 'auto', raw.info)
# make sparse spherical source space
sphere_vol = tuple(sphere['r0'] * 1000.) + (sphere.radius * 1000.,)
diff --git a/mne/simulation/tests/test_source.py b/mne/simulation/tests/test_source.py
index ee6eb84..ddd8973 100644
--- a/mne/simulation/tests/test_source.py
+++ b/mne/simulation/tests/test_source.py
@@ -2,7 +2,7 @@ import os.path as op
import numpy as np
from numpy.testing import assert_array_almost_equal, assert_array_equal
-from nose.tools import assert_true
+from nose.tools import assert_true, assert_raises, assert_equal
from mne.datasets import testing
from mne import read_label, read_forward_solution, pick_types_forward
@@ -17,6 +17,7 @@ fname_fwd = op.join(data_path, 'MEG', 'sample',
label_names = ['Aud-lh', 'Aud-rh', 'Vis-rh']
label_names_single_hemi = ['Aud-rh', 'Vis-rh']
+subjects_dir = op.join(data_path, 'subjects')
def read_forward_solution_meg(*args, **kwargs):
@@ -46,6 +47,7 @@ def test_simulate_stc():
stc_data = np.ones((len(labels), n_times))
stc = simulate_stc(fwd['src'], mylabels, stc_data, tmin, tstep)
+ assert_equal(stc.subject, 'sample')
for label in labels:
if label.hemi == 'lh':
@@ -84,6 +86,15 @@ def test_simulate_stc():
res = ((2. * i) ** 2.) * np.ones((len(idx), n_times))
assert_array_almost_equal(stc.data[idx], res)
+ # degenerate conditions
+ label_subset = mylabels[:2]
+ data_subset = stc_data[:2]
+ stc = simulate_stc(fwd['src'], label_subset, data_subset, tmin, tstep, fun)
+ assert_raises(ValueError, simulate_stc, fwd['src'],
+ label_subset, data_subset[:-1], tmin, tstep, fun)
+ assert_raises(RuntimeError, simulate_stc, fwd['src'], label_subset * 2,
+ np.concatenate([data_subset] * 2, axis=0), tmin, tstep, fun)
+
@testing.requires_testing_data
def test_simulate_sparse_stc():
@@ -97,18 +108,40 @@ def test_simulate_sparse_stc():
tstep = 1e-3
times = np.arange(n_times, dtype=np.float) * tstep + tmin
- stc_1 = simulate_sparse_stc(fwd['src'], len(labels), times,
- labels=labels, random_state=0)
-
- assert_true(stc_1.data.shape[0] == len(labels))
- assert_true(stc_1.data.shape[1] == n_times)
-
- # make sure we get the same result when using the same seed
- stc_2 = simulate_sparse_stc(fwd['src'], len(labels), times,
- labels=labels, random_state=0)
-
- assert_array_equal(stc_1.lh_vertno, stc_2.lh_vertno)
- assert_array_equal(stc_1.rh_vertno, stc_2.rh_vertno)
+ assert_raises(ValueError, simulate_sparse_stc, fwd['src'], len(labels),
+ times, labels=labels, location='center', subject='sample',
+ subjects_dir=subjects_dir) # no non-zero values
+ for label in labels:
+ label.values.fill(1.)
+ for location in ('random', 'center'):
+ random_state = 0 if location == 'random' else None
+ stc_1 = simulate_sparse_stc(fwd['src'], len(labels), times,
+ labels=labels, random_state=random_state,
+ location=location,
+ subjects_dir=subjects_dir)
+ assert_equal(stc_1.subject, 'sample')
+
+ assert_true(stc_1.data.shape[0] == len(labels))
+ assert_true(stc_1.data.shape[1] == n_times)
+
+ # make sure we get the same result when using the same seed
+ stc_2 = simulate_sparse_stc(fwd['src'], len(labels), times,
+ labels=labels, random_state=random_state,
+ location=location,
+ subjects_dir=subjects_dir)
+
+ assert_array_equal(stc_1.lh_vertno, stc_2.lh_vertno)
+ assert_array_equal(stc_1.rh_vertno, stc_2.rh_vertno)
+ # Degenerate cases
+ assert_raises(ValueError, simulate_sparse_stc, fwd['src'], len(labels),
+ times, labels=labels, location='center', subject='foo',
+ subjects_dir=subjects_dir) # wrong subject
+ del fwd['src'][0]['subject_his_id']
+ assert_raises(ValueError, simulate_sparse_stc, fwd['src'], len(labels),
+ times, labels=labels, location='center',
+ subjects_dir=subjects_dir) # no subject
+ assert_raises(ValueError, simulate_sparse_stc, fwd['src'], len(labels),
+ times, labels=labels, location='foo') # bad location
@testing.requires_testing_data
diff --git a/mne/source_estimate.py b/mne/source_estimate.py
index 8429f72..ae75df9 100644
--- a/mne/source_estimate.py
+++ b/mne/source_estimate.py
@@ -6,13 +6,13 @@
# License: BSD (3-clause)
import copy
-import os
+import os.path as op
from math import ceil
import warnings
import numpy as np
from scipy import linalg, sparse
-from scipy.sparse import coo_matrix
+from scipy.sparse import coo_matrix, block_diag as sparse_block_diag
from .filter import resample
from .evoked import _get_peak
@@ -20,14 +20,14 @@ from .parallel import parallel_func
from .surface import (read_surface, _get_ico_surface, read_morph_map,
_compute_nearest, mesh_edges)
from .source_space import (_ensure_src, _get_morph_src_reordering,
- _ensure_src_subject)
+ _ensure_src_subject, SourceSpaces)
from .utils import (get_subjects_dir, _check_subject, logger, verbose,
- _time_mask, warn as warn_)
+ _time_mask, warn as warn_, copy_function_doc_to_method_doc)
from .viz import plot_source_estimates
-from .fixes import in1d, sparse_block_diag
from .io.base import ToDataFrameMixin, TimeMixin
-from .externals.six.moves import zip
+
from .externals.six import string_types
+from .externals.six.moves import zip
from .externals.h5io import read_hdf5, write_hdf5
@@ -247,7 +247,7 @@ def read_source_estimate(fname, subject=None):
# make sure corresponding file(s) can be found
ftype = None
- if os.path.exists(fname):
+ if op.exists(fname):
if fname.endswith('-vl.stc') or fname.endswith('-vol.stc') or \
fname.endswith('-vl.w') or fname.endswith('-vol.w'):
ftype = 'volume'
@@ -276,11 +276,11 @@ def read_source_estimate(fname, subject=None):
raise RuntimeError('Unknown extension for file %s' % fname_arg)
if ftype is not 'volume':
- stc_exist = [os.path.exists(f)
+ stc_exist = [op.exists(f)
for f in [fname + '-rh.stc', fname + '-lh.stc']]
- w_exist = [os.path.exists(f)
+ w_exist = [op.exists(f)
for f in [fname + '-rh.w', fname + '-lh.w']]
- h5_exist = os.path.exists(fname + '-stc.h5')
+ h5_exist = op.exists(fname + '-stc.h5')
if all(stc_exist) and (ftype is not 'w'):
ftype = 'surface'
elif all(w_exist):
@@ -505,7 +505,7 @@ class _BaseSourceEstimate(ToDataFrameMixin, TimeMixin):
return self # return self for chaining methods
@verbose
- def resample(self, sfreq, npad=None, window='boxcar', n_jobs=1,
+ def resample(self, sfreq, npad='auto', window='boxcar', n_jobs=1,
verbose=None):
"""Resample data
@@ -532,11 +532,6 @@ class _BaseSourceEstimate(ToDataFrameMixin, TimeMixin):
Note that the sample rate of the original data is inferred from tstep.
"""
- if npad is None:
- npad = 100
- warn_('npad is currently taken to be 100, but will be changed to '
- '"auto" in 0.12. Please set the value explicitly.',
- DeprecationWarning)
# resampling in sensor instead of source space gives a somewhat
# different result, so we don't allow it
self._remove_kernel_sens_data_()
@@ -747,8 +742,8 @@ class _BaseSourceEstimate(ToDataFrameMixin, TimeMixin):
----------
func : callable
The transform to be applied, including parameters (see, e.g.,
- `mne.fixes.partial`). The first parameter of the function is the
- input data. The first return value is the transformed data,
+ :func:`functools.partial`). The first parameter of the function is
+ the input data. The first return value is the transformed data,
remaining outputs are ignored. The first dimension of the
transformed data has to be the same as the first dimension of the
input data.
@@ -823,8 +818,8 @@ class _BaseSourceEstimate(ToDataFrameMixin, TimeMixin):
----------
func : callable
The transform to be applied, including parameters (see, e.g.,
- mne.fixes.partial). The first parameter of the function is the
- input data. The first two dimensions of the transformed data
+ :func:`functools.partial`). The first parameter of the function is
+ the input data. The first two dimensions of the transformed data
should be (i) vertices and (ii) time. Transforms which yield 3D
output (e.g. time-frequency transforms) are valid, so long as the
first two dimensions are vertices and time. In this case, the
@@ -915,6 +910,32 @@ class _BaseSourceEstimate(ToDataFrameMixin, TimeMixin):
return stcs
+def _center_of_mass(vertices, values, hemi, surf, subject, subjects_dir,
+ restrict_vertices):
+ """Helper to find the center of mass on a surface"""
+ if (values == 0).all() or (values < 0).any():
+ raise ValueError('All values must be non-negative and at least one '
+ 'must be non-zero, cannot compute COM')
+ subjects_dir = get_subjects_dir(subjects_dir, raise_error=True)
+ surf = read_surface(op.join(subjects_dir, subject, 'surf',
+ hemi + '.' + surf))
+ if restrict_vertices is True:
+ restrict_vertices = vertices
+ elif restrict_vertices is False:
+ restrict_vertices = np.arange(surf[0].shape[0])
+ elif isinstance(restrict_vertices, SourceSpaces):
+ idx = 1 if restrict_vertices.kind == 'surface' and hemi == 'rh' else 0
+ restrict_vertices = restrict_vertices[idx]['vertno']
+ else:
+ restrict_vertices = np.array(restrict_vertices, int)
+ pos = surf[0][vertices, :].T
+ c_o_m = np.sum(pos * values, axis=1) / np.sum(values)
+ vertex = np.argmin(np.sqrt(np.mean((surf[0][restrict_vertices, :] -
+ c_o_m) ** 2, axis=1)))
+ vertex = restrict_vertices[vertex]
+ return vertex
+
+
class SourceEstimate(_BaseSourceEstimate):
"""Container for surface source estimates
@@ -1047,7 +1068,7 @@ class SourceEstimate(_BaseSourceEstimate):
stc_vertices = self.vertices[1]
# find index of the Label's vertices
- idx = np.nonzero(in1d(stc_vertices, label.vertices))[0]
+ idx = np.nonzero(np.in1d(stc_vertices, label.vertices))[0]
# find output vertices
vertices = stc_vertices[idx]
@@ -1199,15 +1220,19 @@ class SourceEstimate(_BaseSourceEstimate):
return label_tc
def center_of_mass(self, subject=None, hemi=None, restrict_vertices=False,
- subjects_dir=None):
- """Return the vertex on a given surface that is at the center of mass
- of the activity in stc. Note that all activity must occur in a single
- hemisphere, otherwise an error is returned. The "mass" of each point in
- space for computing the spatial center of mass is computed by summing
- across time, and vice-versa for each point in time in computing the
- temporal center of mass. This is useful for quantifying spatio-temporal
- cluster locations, especially when combined with the function
- mne.source_space.vertex_to_mni().
+ subjects_dir=None, surf='sphere'):
+ """Compute the center of mass of activity
+
+ This function computes the spatial center of mass on the surface
+ as well as the temporal center of mass as in [1]_.
+
+ .. note:: All activity must occur in a single hemisphere, otherwise
+ an error is raised. The "mass" of each point in space for
+ computing the spatial center of mass is computed by summing
+ across time, and vice-versa for each point in time in
+ computing the temporal center of mass. This is useful for
+ quantifying spatio-temporal cluster locations, especially
+ when combined with :func:`mne.source_space.vertex_to_mni`.
Parameters
----------
@@ -1218,14 +1243,25 @@ class SourceEstimate(_BaseSourceEstimate):
hemisphere. If None, one of the hemispheres must be all zeroes,
and the center of mass will be calculated for the other
hemisphere (useful for getting COM for clusters).
- restrict_vertices : bool, or array of int
+ restrict_vertices : bool | array of int | instance of SourceSpaces
If True, returned vertex will be one from stc. Otherwise, it could
be any vertex from surf. If an array of int, the returned vertex
- will come from that array. For most accuruate estimates, do not
- restrict vertices.
+ will come from that array. If instance of SourceSpaces (as of
+ 0.13), the returned vertex will be from the given source space.
+ For most accuruate estimates, do not restrict vertices.
subjects_dir : str, or None
Path to the SUBJECTS_DIR. If None, the path is obtained by using
the environment variable SUBJECTS_DIR.
+ surf : str
+ The surface to use for Euclidean distance center of mass
+ finding. The default here is "sphere", which finds the center
+ of mass on the spherical surface to help avoid potential issues
+ with cortical folding.
+
+ See Also
+ --------
+ Label.center_of_mass
+ vertex_to_mni
Returns
-------
@@ -1240,12 +1276,16 @@ class SourceEstimate(_BaseSourceEstimate):
Time of the temporal center of mass (weighted by the sum across
source vertices).
- References:
- Used in Larson and Lee, "The cortical dynamics underlying effective
- switching of auditory spatial attention", NeuroImage 2012.
+ References
+ ----------
+ .. [1] Larson and Lee, "The cortical dynamics underlying effective
+ switching of auditory spatial attention", NeuroImage 2012.
"""
+ if not isinstance(surf, string_types):
+ raise TypeError('surf must be a string, got %s' % (type(surf),))
subject = _check_subject(self.subject, subject)
-
+ if np.any(self.data < 0):
+ raise ValueError('Cannot compute COM with negative values')
values = np.sum(self.data, axis=1) # sum across time
vert_inds = [np.arange(len(self.vertices[0])),
np.arange(len(self.vertices[1])) + len(self.vertices[0])]
@@ -1257,115 +1297,27 @@ class SourceEstimate(_BaseSourceEstimate):
hemi = hemi[0]
if hemi not in [0, 1]:
raise ValueError('hemi must be 0 or 1')
-
- subjects_dir = get_subjects_dir(subjects_dir, raise_error=True)
-
- values = values[vert_inds[hemi]]
-
- hemis = ['lh', 'rh']
- surf = os.path.join(subjects_dir, subject, 'surf',
- hemis[hemi] + '.sphere')
-
- if isinstance(surf, string_types): # read in surface
- surf = read_surface(surf)
-
- if restrict_vertices is False:
- restrict_vertices = np.arange(surf[0].shape[0])
- elif restrict_vertices is True:
- restrict_vertices = self.vertices[hemi]
-
- if np.any(self.data < 0):
- raise ValueError('Cannot compute COM with negative values')
-
- pos = surf[0][self.vertices[hemi], :].T
- c_o_m = np.sum(pos * values, axis=1) / np.sum(values)
-
- # Find the vertex closest to the COM
- vertex = np.argmin(np.sqrt(np.mean((surf[0][restrict_vertices, :] -
- c_o_m) ** 2, axis=1)))
- vertex = restrict_vertices[vertex]
-
+ vertices = self.vertices[hemi]
+ values = values[vert_inds[hemi]] # left or right
+ del vert_inds
+ vertex = _center_of_mass(
+ vertices, values, hemi=['lh', 'rh'][hemi], surf=surf,
+ subject=subject, subjects_dir=subjects_dir,
+ restrict_vertices=restrict_vertices)
# do time center of mass by using the values across space
masses = np.sum(self.data, axis=0).astype(float)
t_ind = np.sum(masses * np.arange(self.shape[1])) / np.sum(masses)
t = self.tmin + self.tstep * t_ind
return vertex, hemi, t
+ @copy_function_doc_to_method_doc(plot_source_estimates)
def plot(self, subject=None, surface='inflated', hemi='lh',
- colormap='auto', time_label='time=%0.2f ms',
+ colormap='auto', time_label='auto',
smoothing_steps=10, transparent=None, alpha=1.0,
time_viewer=False, config_opts=None, subjects_dir=None,
- figure=None, views='lat', colorbar=True, clim='auto'):
- """Plot SourceEstimates with PySurfer
-
- Note: PySurfer currently needs the SUBJECTS_DIR environment variable,
- which will automatically be set by this function. Plotting multiple
- SourceEstimates with different values for subjects_dir will cause
- PySurfer to use the wrong FreeSurfer surfaces when using methods of
- the returned Brain object. It is therefore recommended to set the
- SUBJECTS_DIR environment variable or always use the same value for
- subjects_dir (within the same Python session).
-
- Parameters
- ----------
- subject : str | None
- The subject name corresponding to FreeSurfer environment
- variable SUBJECT. If None stc.subject will be used. If that
- is None, the environment will be used.
- surface : str
- The type of surface (inflated, white etc.).
- hemi : str, 'lh' | 'rh' | 'split' | 'both'
- The hemisphere to display.
- colormap : str | np.ndarray of float, shape(n_colors, 3 | 4)
- Name of colormap to use or a custom look up table. If array, must
- be (n x 3) or (n x 4) array for with RGB or RGBA values between
- 0 and 255. If 'auto', either 'hot' or 'mne' will be chosen
- based on whether 'lims' or 'pos_lims' are specified in `clim`.
- time_label : str
- How to print info about the time instant visualized.
- smoothing_steps : int
- The amount of smoothing.
- transparent : bool | None
- If True, use a linear transparency between fmin and fmid.
- None will choose automatically based on colormap type.
- alpha : float
- Alpha value to apply globally to the overlay.
- time_viewer : bool
- Display time viewer GUI.
- config_opts : dict
- Keyword arguments for Brain initialization.
- See pysurfer.viz.Brain.
- subjects_dir : str
- The path to the FreeSurfer subjects reconstructions.
- It corresponds to FreeSurfer environment variable SUBJECTS_DIR.
- figure : instance of mayavi.core.scene.Scene | None
- If None, the last figure will be cleaned and a new figure will
- be created.
- views : str | list
- View to use. See surfer.Brain().
- colorbar : bool
- If True, display colorbar on scene.
- clim : str | dict
- Colorbar properties specification. If 'auto', set clim
- automatically based on data percentiles. If dict, should contain:
-
- kind : str
- Flag to specify type of limits. 'value' or 'percent'.
- lims : list | np.ndarray | tuple of float, 3 elements
- Note: Only use this if 'colormap' is not 'mne'.
- Left, middle, and right bound for colormap.
- pos_lims : list | np.ndarray | tuple of float, 3 elements
- Note: Only use this if 'colormap' is 'mne'.
- Left, middle, and right bound for colormap. Positive values
- will be mirrored directly across zero during colormap
- construction to obtain negative control points.
-
-
- Returns
- -------
- brain : Brain
- A instance of surfer.viz.Brain from PySurfer.
- """
+ figure=None, views='lat', colorbar=True, clim='auto',
+ cortex="classic", size=800, background="black",
+ foreground="white", initial_time=None, time_unit=None):
brain = plot_source_estimates(self, subject, surface=surface,
hemi=hemi, colormap=colormap,
time_label=time_label,
@@ -1375,7 +1327,11 @@ class SourceEstimate(_BaseSourceEstimate):
config_opts=config_opts,
subjects_dir=subjects_dir, figure=figure,
views=views, colorbar=colorbar,
- clim=clim)
+ clim=clim, cortex=cortex, size=size,
+ background=background,
+ foreground=foreground,
+ initial_time=initial_time,
+ time_unit=time_unit)
return brain
@verbose
@@ -1786,7 +1742,7 @@ class MixedSourceEstimate(_BaseSourceEstimate):
colormap='auto', time_label='time=%02.f ms',
smoothing_steps=10,
transparent=None, alpha=1.0, time_viewer=False,
- config_opts={}, subjects_dir=None, figure=None,
+ config_opts=None, subjects_dir=None, figure=None,
views='lat', colorbar=True, clim='auto'):
"""Plot surface source estimates with PySurfer
@@ -1998,8 +1954,8 @@ def _morph_mult(data, e, use_sparse, idx_use_data, idx_use_out=None):
def _get_subject_sphere_tris(subject, subjects_dir):
- spheres = [os.path.join(subjects_dir, subject, 'surf',
- xh + '.sphere.reg') for xh in ['lh', 'rh']]
+ spheres = [op.join(subjects_dir, subject, 'surf',
+ xh + '.sphere.reg') for xh in ['lh', 'rh']]
tris = [read_surface(s)[1] for s in spheres]
return tris
@@ -2224,7 +2180,7 @@ def grade_to_vertices(subject, grade, subjects_dir=None, n_jobs=1,
----------
subject : str
Name of the subject
- grade : int
+ grade : int | list
Resolution of the icosahedral mesh (typically 5). If None, all
vertices will be used (potentially filling the surface). If a list,
then values will be morphed to the set of vertices specified in
@@ -2251,8 +2207,8 @@ def grade_to_vertices(subject, grade, subjects_dir=None, n_jobs=1,
return [np.arange(10242), np.arange(10242)]
subjects_dir = get_subjects_dir(subjects_dir, raise_error=True)
- spheres_to = [os.path.join(subjects_dir, subject, 'surf',
- xh + '.sphere.reg') for xh in ['lh', 'rh']]
+ spheres_to = [op.join(subjects_dir, subject, 'surf',
+ xh + '.sphere.reg') for xh in ['lh', 'rh']]
lhs, rhs = [read_surface(s)[0] for s in spheres_to]
if grade is not None: # fill a subset of vertices
@@ -2276,6 +2232,14 @@ def grade_to_vertices(subject, grade, subjects_dir=None, n_jobs=1,
for xhs in [lhs, rhs])
# Make sure the vertices are ordered
vertices = [np.sort(verts) for verts in vertices]
+ for verts in vertices:
+ if (np.diff(verts) == 0).any():
+ raise ValueError(
+ 'Cannot use icosahedral grade %s with subject %s, '
+ 'mapping %s vertices onto the high-resolution mesh '
+ 'yields repeated vertices, use a lower grade or a '
+ 'list of vertices from an existing source space'
+ % (grade, subject, len(verts)))
else: # potentially fill the surface
vertices = [np.arange(lhs.shape[0]), np.arange(rhs.shape[0])]
@@ -2363,7 +2327,7 @@ def spatio_temporal_src_connectivity(src, n_times, dist=None, verbose=None):
connectivity = spatio_temporal_tris_connectivity(tris, n_times)
# deal with source space only using a subset of vertices
- masks = [in1d(u, s['vertno']) for s, u in zip(src, used_verts)]
+ masks = [np.in1d(u, s['vertno']) for s, u in zip(src, used_verts)]
if sum(u.size for u in used_verts) != connectivity.shape[0] / n_times:
raise ValueError('Used vertices do not match connectivity shape')
if [np.sum(m) for m in masks] != [len(s['vertno']) for s in src]:
diff --git a/mne/source_space.py b/mne/source_space.py
index 60ad8fe..440f536 100644
--- a/mne/source_space.py
+++ b/mne/source_space.py
@@ -3,11 +3,14 @@
#
# License: BSD (3-clause)
-import numpy as np
+from copy import deepcopy
+from functools import partial
+from gzip import GzipFile
import os
import os.path as op
+
+import numpy as np
from scipy import sparse, linalg
-from copy import deepcopy
from .io.constants import FIFF
from .io.tree import dir_tree_find
@@ -27,7 +30,6 @@ from .surface import (read_surface, _create_surf_spacing, _get_ico_surface,
from .utils import (get_subjects_dir, run_subprocess, has_freesurfer,
has_nibabel, check_fname, logger, verbose,
check_version, _get_call_line, warn)
-from .fixes import in1d, partial, gzip_open, meshgrid
from .parallel import parallel_func, check_n_jobs
from .transforms import (invert_transform, apply_trans, _print_coord_trans,
combine_transforms, _get_trans,
@@ -53,6 +55,13 @@ def _get_lut_id(lut, label, use_lut):
return lut['id'][mask]
+_src_kind_dict = {
+ 'vol': 'volume',
+ 'surf': 'surface',
+ 'discrete': 'discrete',
+}
+
+
class SourceSpaces(list):
"""Represent a list of source space
@@ -84,24 +93,26 @@ class SourceSpaces(list):
ss_repr = []
for ss in self:
ss_type = ss['type']
+ r = _src_kind_dict[ss_type]
if ss_type == 'vol':
if 'seg_name' in ss:
- r = ("'vol' (%s), n_used=%i"
- % (ss['seg_name'], ss['nuse']))
+ r += " (%s)" % (ss['seg_name'],)
else:
- r = ("'vol', shape=%s, n_used=%i"
- % (repr(ss['shape']), ss['nuse']))
+ r += ", shape=%s" % (ss['shape'],)
elif ss_type == 'surf':
- r = "'surf', n_vertices=%i, n_used=%i" % (ss['np'], ss['nuse'])
- else:
- r = "%r" % ss_type
- coord_frame = ss['coord_frame']
- if isinstance(coord_frame, np.ndarray):
- coord_frame = coord_frame[0]
- r += ', coordinate_frame=%s' % _coord_frame_name(coord_frame)
+ r += (" (%s), n_vertices=%i" % (_get_hemi(ss)[0], ss['np']))
+ r += (', n_used=%i, coordinate_frame=%s'
+ % (ss['nuse'], _coord_frame_name(int(ss['coord_frame']))))
ss_repr.append('<%s>' % r)
- ss_repr = ', '.join(ss_repr)
- return "<SourceSpaces: [{ss}]>".format(ss=ss_repr)
+ return "<SourceSpaces: [%s]>" % ', '.join(ss_repr)
+
+ @property
+ def kind(self):
+ """The kind of source space (surface, volume, discrete)"""
+ ss_types = list(set([ss['type'] for ss in self]))
+ if len(ss_types) != 1:
+ return 'combined'
+ return _src_kind_dict[ss_types[0]]
def __add__(self, other):
return SourceSpaces(list.__add__(self, other))
@@ -1174,7 +1185,7 @@ def _read_talxfm(subject, subjects_dir, mode=None, verbose=None):
def setup_source_space(subject, fname=True, spacing='oct6', surface='white',
overwrite=False, subjects_dir=None, add_dist=True,
n_jobs=1, verbose=None):
- """Setup a source space with subsampling
+ """Setup a bilater hemisphere surface-based source space with subsampling
Parameters
----------
@@ -1206,6 +1217,10 @@ def setup_source_space(subject, fname=True, spacing='oct6', surface='white',
-------
src : list
The source space for each hemisphere.
+
+ See Also
+ --------
+ setup_volume_source_space
"""
cmd = ('setup_source_space(%s, fname=%s, spacing=%s, surface=%s, '
'overwrite=%s, subjects_dir=%s, add_dist=%s, verbose=%s)'
@@ -1364,7 +1379,7 @@ def setup_volume_source_space(subject, fname=None, pos=5.0, mri=None,
surface : str | dict | None
Define source space bounds using a FreeSurfer surface file. Can
also be a dictionary with entries `'rr'` and `'tris'`, such as
- those returned by `read_surface()`.
+ those returned by :func:`mne.read_surface`.
mindist : float
Exclude points closer than this distance (mm) to the bounding surface.
exclude : float
@@ -1389,6 +1404,10 @@ def setup_volume_source_space(subject, fname=None, pos=5.0, mri=None,
compatibility reasons, as most functions expect source spaces
to be provided as lists).
+ See Also
+ --------
+ setup_source_space
+
Notes
-----
To create a discrete source space, `pos` must be a dict, 'mri' must be
@@ -1626,9 +1645,9 @@ def _make_volume_source_space(surf, grid, exclude, mindist, mri=None,
ncol = ns[1]
nplane = nrow * ncol
# x varies fastest, then y, then z (can use unravel to do this)
- rr = meshgrid(np.arange(minn[2], maxn[2] + 1),
- np.arange(minn[1], maxn[1] + 1),
- np.arange(minn[0], maxn[0] + 1), indexing='ij')
+ rr = np.meshgrid(np.arange(minn[2], maxn[2] + 1),
+ np.arange(minn[1], maxn[1] + 1),
+ np.arange(minn[0], maxn[0] + 1), indexing='ij')
x, y, z = rr[2].ravel(), rr[1].ravel(), rr[0].ravel()
rr = np.array([x * grid, y * grid, z * grid]).T
sp = dict(np=npts, nn=np.zeros((npts, 3)), rr=rr,
@@ -1786,7 +1805,7 @@ def _make_volume_source_space(surf, grid, exclude, mindist, mri=None,
old_shape = neigh.shape
neigh = neigh.ravel()
checks = np.where(neigh >= 0)[0]
- removes = np.logical_not(in1d(checks, vertno))
+ removes = np.logical_not(np.in1d(checks, vertno))
neigh[checks[removes]] = -1
neigh.shape = old_shape
neigh = neigh.T
@@ -1826,7 +1845,7 @@ def _get_mgz_header(fname):
('delta', '>f4', (3,)), ('Mdc', '>f4', (3, 3)),
('Pxyz_c', '>f4', (3,))]
header_dtype = np.dtype(header_dtd)
- with gzip_open(fname, 'rb') as fid:
+ with GzipFile(fname, 'rb') as fid:
hdr_str = fid.read(header_dtype.itemsize)
header = np.ndarray(shape=(), dtype=header_dtype,
buffer=hdr_str)
@@ -2479,11 +2498,11 @@ def _get_morph_src_reordering(vertices, src_from, subject_from, subject_to,
# some are omitted during fwd calc), so we must do some indexing magic:
# From all vertices, a subset could be chosen by fwd calc:
- used_vertices = in1d(full_mapping, vertices[ii])
+ used_vertices = np.in1d(full_mapping, vertices[ii])
from_vertices.append(src_from[ii]['vertno'][used_vertices])
remaining_mapping = full_mapping[used_vertices]
if not np.array_equal(np.sort(remaining_mapping), vertices[ii]) or \
- not in1d(vertices[ii], full_mapping).all():
+ not np.in1d(vertices[ii], full_mapping).all():
raise RuntimeError('Could not map vertices, perhaps the wrong '
'subject "%s" was provided?' % subject_from)
diff --git a/mne/stats/cluster_level.py b/mne/stats/cluster_level.py
index 620dbea..0e9d787 100755
--- a/mne/stats/cluster_level.py
+++ b/mne/stats/cluster_level.py
@@ -17,7 +17,6 @@ from scipy import sparse
from .parametric import f_oneway
from ..parallel import parallel_func, check_n_jobs
from ..utils import split_list, logger, verbose, ProgressBar, warn
-from ..fixes import in1d, unravel_index
from ..source_estimate import SourceEstimate
@@ -42,8 +41,8 @@ def _get_clusters_spatial(s, neighbors):
ind = t_inds[icount - 1]
# look across other vertices
buddies = np.where(r)[0]
- buddies = buddies[in1d(s[buddies], neighbors[s[ind]],
- assume_unique=True)]
+ buddies = buddies[np.in1d(s[buddies], neighbors[s[ind]],
+ assume_unique=True)]
t_inds += buddies.tolist()
r[buddies] = False
icount += 1
@@ -152,8 +151,8 @@ def _get_clusters_st_multistep(keepers, neighbors, max_step=1):
# look at current time point across other vertices
buddies = inds[t_border[t[ind]]:t_border[t[ind] + 1]]
buddies = buddies[r[buddies]]
- buddies = buddies[in1d(s[buddies], neighbors[s[ind]],
- assume_unique=True)]
+ buddies = buddies[np.in1d(s[buddies], neighbors[s[ind]],
+ assume_unique=True)]
buddies = np.concatenate((selves, buddies))
t_inds += buddies.tolist()
r[buddies] = False
@@ -176,7 +175,7 @@ def _get_clusters_st(x_in, neighbors, max_step=1):
cl_goods = np.where(x_in)[0]
if len(cl_goods) > 0:
keepers = [np.array([], dtype=int)] * n_times
- row, col = unravel_index(cl_goods, (n_times, n_src))
+ row, col = np.unravel_index(cl_goods, (n_times, n_src))
if isinstance(row, int):
row = [row]
col = [col]
@@ -1490,7 +1489,7 @@ def _reshape_clusters(clusters, sample_shape):
if clusters[0].dtype == bool: # format of mask
clusters = [c.reshape(sample_shape) for c in clusters]
else: # format of indices
- clusters = [unravel_index(c, sample_shape) for c in clusters]
+ clusters = [np.unravel_index(c, sample_shape) for c in clusters]
return clusters
@@ -1520,6 +1519,10 @@ def summarize_clusters_stc(clu, p_thresh=0.05, tstep=1e-3, tmin=0,
Returns
-------
out : instance of SourceEstimate
+ A summary of the clusters. The first time point in this SourceEstimate
+ object is the summation of all the clusters. Subsequent time points
+ contain each individual cluster. The magnitude of the activity
+ corresponds to the length the cluster spans in time (in samples).
"""
if vertices is None:
vertices = [np.arange(10242), np.arange(10242)]
diff --git a/mne/stats/parametric.py b/mne/stats/parametric.py
index 37a8e5b..49acff1 100644
--- a/mne/stats/parametric.py
+++ b/mne/stats/parametric.py
@@ -9,7 +9,6 @@ from functools import reduce
from string import ascii_uppercase
from ..externals.six import string_types
-from ..fixes import matrix_rank
# The following function is a rewriting of scipy.stats.f_oneway
# Contrary to the scipy.stats.f_oneway implementation it does not
@@ -180,7 +179,7 @@ def _iter_contrasts(n_subjects, factor_levels, effect_picks):
for i_contrast in range(1, n_factors):
this_contrast = contrast_idx[(n_factors - 1) - i_contrast]
c_ = np.kron(c_, sc[i_contrast][this_contrast])
- df1 = matrix_rank(c_)
+ df1 = np.linalg.matrix_rank(c_)
df2 = df1 * (n_subjects - 1)
yield c_, df1, df2
@@ -198,11 +197,13 @@ def f_threshold_mway_rm(n_subjects, factor_levels, effects='A*B',
effects : str
A string denoting the effect to be returned. The following
mapping is currently supported:
- 'A': main effect of A
- 'B': main effect of B
- 'A:B': interaction effect
- 'A+B': both main effects
- 'A*B': all three effects
+
+ * ``'A'``: main effect of A
+ * ``'B'``: main effect of B
+ * ``'A:B'``: interaction effect
+ * ``'A+B'``: both main effects
+ * ``'A*B'``: all three effects
+
pvalue : float
The p-value to be thresholded.
diff --git a/mne/stats/regression.py b/mne/stats/regression.py
index e747250..80377bc 100644
--- a/mne/stats/regression.py
+++ b/mne/stats/regression.py
@@ -18,7 +18,6 @@ from ..epochs import _BaseEpochs
from ..evoked import Evoked, EvokedArray
from ..utils import logger, _reject_data_segments, warn
from ..io.pick import pick_types, pick_info
-from ..fixes import in1d
def linear_regression(inst, design_matrix, names=None):
@@ -342,7 +341,7 @@ def _prepare_rerp_preds(n_samples, sfreq, events, event_id=None, tmin=-.1,
ids = ([event_id[cond]]
if isinstance(event_id[cond], int)
else event_id[cond])
- onsets = -(events[in1d(events[:, 2], ids), 0] + tmin_)
+ onsets = -(events[np.in1d(events[:, 2], ids), 0] + tmin_)
values = np.ones((len(onsets), n_lags))
else: # for predictors from covariates, e.g. continuous ones
diff --git a/mne/stats/tests/test_cluster_level.py b/mne/stats/tests/test_cluster_level.py
index 6fb3cda..06b5bb7 100644
--- a/mne/stats/tests/test_cluster_level.py
+++ b/mne/stats/tests/test_cluster_level.py
@@ -1,11 +1,13 @@
+from functools import partial
import os
+import warnings
+
import numpy as np
+from scipy import sparse, linalg, stats
from numpy.testing import (assert_equal, assert_array_equal,
assert_array_almost_equal)
from nose.tools import assert_true, assert_raises
-from scipy import sparse, linalg, stats
-from mne.fixes import partial
-import warnings
+
from mne.parallel import _force_serial
from mne.stats.cluster_level import (permutation_cluster_test,
permutation_cluster_1samp_test,
diff --git a/mne/stats/tests/test_regression.py b/mne/stats/tests/test_regression.py
index d9bafce..5b2f9a2 100644
--- a/mne/stats/tests/test_regression.py
+++ b/mne/stats/tests/test_regression.py
@@ -31,16 +31,15 @@ event_fname = data_path + '/MEG/sample/sample_audvis_trunc_raw-eve.fif'
@testing.requires_testing_data
def test_regression():
- """Test Ordinary Least Squares Regression
- """
+ """Test Ordinary Least Squares Regression."""
tmin, tmax = -0.2, 0.5
event_id = dict(aud_l=1, aud_r=2)
# Setup for reading the raw data
- raw = mne.io.read_raw_fif(raw_fname)
+ raw = mne.io.read_raw_fif(raw_fname, add_eeg_ref=False)
events = mne.read_events(event_fname)[:10]
epochs = mne.Epochs(raw, events, event_id, tmin, tmax, proj=True,
- baseline=(None, 0))
+ baseline=(None, 0), add_eeg_ref=False)
picks = np.arange(len(epochs.ch_names))
evoked = epochs.average(picks=picks)
design_matrix = epochs.events[:, 1:].astype(np.float64)
@@ -87,10 +86,10 @@ def test_regression():
@testing.requires_testing_data
def test_continuous_regression_no_overlap():
- """Test regression without overlap correction, on real data"""
+ """Test regression without overlap correction, on real data."""
tmin, tmax = -.1, .5
- raw = mne.io.read_raw_fif(raw_fname, preload=True)
+ raw = mne.io.read_raw_fif(raw_fname, preload=True, add_eeg_ref=False)
raw.apply_proj()
events = mne.read_events(event_fname)
event_id = dict(audio_l=1, audio_r=2)
@@ -98,7 +97,7 @@ def test_continuous_regression_no_overlap():
raw = raw.pick_channels(raw.ch_names[:2])
epochs = mne.Epochs(raw, events, event_id, tmin, tmax,
- baseline=None, reject=None)
+ baseline=None, reject=None, add_eeg_ref=False)
revokeds = linear_regression_raw(raw, events, event_id,
tmin=tmin, tmax=tmax,
@@ -122,7 +121,7 @@ def test_continuous_regression_no_overlap():
def test_continuous_regression_with_overlap():
- """Test regression with overlap correction"""
+ """Test regression with overlap correction."""
signal = np.zeros(100000)
times = [1000, 2500, 3000, 5000, 5250, 7000, 7250, 8000]
events = np.zeros((len(times), 3), int)
diff --git a/mne/surface.py b/mne/surface.py
index 45a2abe..32564ad 100644
--- a/mne/surface.py
+++ b/mne/surface.py
@@ -9,6 +9,7 @@ from os import path as op
import sys
from struct import pack
from glob import glob
+from distutils.version import LooseVersion
import numpy as np
from scipy.sparse import coo_matrix, csr_matrix, eye as speye
@@ -18,13 +19,13 @@ from .io.constants import FIFF
from .io.open import fiff_open
from .io.tree import dir_tree_find
from .io.tag import find_tag
-from .io.write import (write_int, start_file, end_block,
- start_block, end_file, write_string,
- write_float_sparse_rcs)
+from .io.write import (write_int, start_file, end_block, start_block, end_file,
+ write_string, write_float_sparse_rcs)
from .channels.channels import _get_meg_system
from .transforms import transform_surface_to
from .utils import logger, verbose, get_subjects_dir, warn
from .externals.six import string_types
+from .fixes import _read_volume_info, _serialize_volume_info
###############################################################################
@@ -406,13 +407,29 @@ def read_curvature(filepath):
@verbose
-def read_surface(fname, verbose=None):
+def read_surface(fname, read_metadata=False, verbose=None):
"""Load a Freesurfer surface mesh in triangular format
Parameters
----------
fname : str
The name of the file containing the surface.
+ read_metadata : bool
+ Read metadata as key-value pairs.
+ Valid keys:
+
+ * 'head' : array of int
+ * 'valid' : str
+ * 'filename' : str
+ * 'volume' : array of int, shape (3,)
+ * 'voxelsize' : array of float, shape (3,)
+ * 'xras' : array of float, shape (3,)
+ * 'yras' : array of float, shape (3,)
+ * 'zras' : array of float, shape (3,)
+ * 'cras' : array of float, shape (3,)
+
+ .. versionadded:: 0.13.0
+
verbose : bool, str, int, or None
If not None, override default verbose level (see mne.verbose).
@@ -421,13 +438,25 @@ def read_surface(fname, verbose=None):
rr : array, shape=(n_vertices, 3)
Coordinate points.
tris : int array, shape=(n_faces, 3)
- Triangulation (each line contains indexes for three points which
+ Triangulation (each line contains indices for three points which
together form a face).
+ volume_info : dict-like
+ If read_metadata is true, key-value pairs found in the geometry file.
See Also
--------
write_surface
+ read_tri
"""
+ try:
+ import nibabel as nib
+ has_nibabel = True
+ except ImportError:
+ has_nibabel = False
+ if has_nibabel and LooseVersion(nib.__version__) > LooseVersion('2.1.0'):
+ return nib.freesurfer.read_geometry(fname, read_metadata=read_metadata)
+
+ volume_info = dict()
TRIANGLE_MAGIC = 16777214
QUAD_MAGIC = 16777215
NEW_QUAD_MAGIC = 16777213
@@ -462,6 +491,8 @@ def read_surface(fname, verbose=None):
fnum = np.fromfile(fobj, ">i4", 1)[0]
coords = np.fromfile(fobj, ">f4", vnum * 3).reshape(vnum, 3)
faces = np.fromfile(fobj, ">i4", fnum * 3).reshape(fnum, 3)
+ if read_metadata:
+ volume_info = _read_volume_info(fobj)
else:
raise ValueError("%s does not appear to be a Freesurfer surface"
% fname)
@@ -469,19 +500,25 @@ def read_surface(fname, verbose=None):
% (create_stamp.strip(), len(coords), len(faces)))
coords = coords.astype(np.float) # XXX: due to mayavi bug on mac 32bits
- return coords, faces
+
+ ret = (coords, faces)
+ if read_metadata:
+ if len(volume_info) == 0:
+ warn('No volume information contained in the file')
+ ret += (volume_info,)
+ return ret
@verbose
-def _read_surface_geom(fname, patch_stats=True, norm_rr=False, verbose=None):
+def _read_surface_geom(fname, patch_stats=True, norm_rr=False,
+ read_metadata=False, verbose=None):
"""Load the surface as dict, optionally add the geometry information"""
# based on mne_load_surface_geom() in mne_surface_io.c
if isinstance(fname, string_types):
- rr, tris = read_surface(fname) # mne_read_triangle_file()
- nvert = len(rr)
- ntri = len(tris)
- s = dict(rr=rr, tris=tris, use_tris=tris, ntri=ntri,
- np=nvert)
+ ret = read_surface(fname, read_metadata=read_metadata)
+ nvert = len(ret[0])
+ ntri = len(ret[1])
+ s = dict(rr=ret[0], tris=ret[1], use_tris=ret[1], ntri=ntri, np=nvert)
elif isinstance(fname, dict):
s = fname
else:
@@ -490,6 +527,8 @@ def _read_surface_geom(fname, patch_stats=True, norm_rr=False, verbose=None):
s = _complete_surface_info(s)
if norm_rr is True:
_normalize_vectors(s['rr'])
+ if read_metadata:
+ return s, ret[2]
return s
@@ -671,7 +710,7 @@ def _create_surf_spacing(surf, hemi, subject, stype, sval, ico_surf,
return surf
-def write_surface(fname, coords, faces, create_stamp=''):
+def write_surface(fname, coords, faces, create_stamp='', volume_info=None):
"""Write a triangular Freesurfer surface mesh
Accepts the same data format as is returned by read_surface().
@@ -683,16 +722,42 @@ def write_surface(fname, coords, faces, create_stamp=''):
coords : array, shape=(n_vertices, 3)
Coordinate points.
faces : int array, shape=(n_faces, 3)
- Triangulation (each line contains indexes for three points which
+ Triangulation (each line contains indices for three points which
together form a face).
create_stamp : str
Comment that is written to the beginning of the file. Can not contain
line breaks.
+ volume_info : dict-like or None
+ Key-value pairs to encode at the end of the file.
+ Valid keys:
+
+ * 'head' : array of int
+ * 'valid' : str
+ * 'filename' : str
+ * 'volume' : array of int, shape (3,)
+ * 'voxelsize' : array of float, shape (3,)
+ * 'xras' : array of float, shape (3,)
+ * 'yras' : array of float, shape (3,)
+ * 'zras' : array of float, shape (3,)
+ * 'cras' : array of float, shape (3,)
+
+ .. versionadded:: 0.13.0
See Also
--------
read_surface
+ read_tri
"""
+ try:
+ import nibabel as nib
+ has_nibabel = True
+ except ImportError:
+ has_nibabel = False
+ if has_nibabel and LooseVersion(nib.__version__) > LooseVersion('2.1.0'):
+ nib.freesurfer.io.write_geometry(fname, coords, faces,
+ create_stamp=create_stamp,
+ volume_info=volume_info)
+ return
if len(create_stamp.splitlines()) > 1:
raise ValueError("create_stamp can only contain one line")
@@ -707,6 +772,10 @@ def write_surface(fname, coords, faces, create_stamp=''):
fid.write(np.array(coords, dtype='>f4').tostring())
fid.write(np.array(faces, dtype='>i4').tostring())
+ # Add volume info, if given
+ if volume_info is not None and len(volume_info) > 0:
+ fid.write(_serialize_volume_info(volume_info))
+
###############################################################################
# Decimation
@@ -717,6 +786,7 @@ def _decimate_surface(points, triangles, reduction):
os.environ['ETS_TOOLKIT'] = 'null'
try:
from tvtk.api import tvtk
+ from tvtk.common import configure_input
except ImportError:
raise ValueError('This function requires the TVTK package to be '
'installed')
@@ -724,7 +794,8 @@ def _decimate_surface(points, triangles, reduction):
raise ValueError('The triangles refer to undefined points. '
'Please check your mesh.')
src = tvtk.PolyData(points=points, polys=triangles)
- decimate = tvtk.QuadricDecimation(input=src, target_reduction=reduction)
+ decimate = tvtk.QuadricDecimation(target_reduction=reduction)
+ configure_input(decimate, src)
decimate.update()
out = decimate.output
tris = out.polys.to_array()
@@ -1111,3 +1182,61 @@ def mesh_dist(tris, vert):
axis=1))
dist_matrix = csr_matrix((dist, (edges.row, edges.col)), shape=edges.shape)
return dist_matrix
+
+
+ at verbose
+def read_tri(fname_in, swap=False, verbose=None):
+ """Function for reading triangle definitions from an ascii file.
+
+ Parameters
+ ----------
+ fname_in : str
+ Path to surface ASCII file (ending with '.tri').
+ swap : bool
+ Assume the ASCII file vertex ordering is clockwise instead of
+ counterclockwise.
+ verbose : bool, str, int, or None
+ If not None, override default verbose level (see mne.verbose).
+
+ Returns
+ -------
+ rr : array, shape=(n_vertices, 3)
+ Coordinate points.
+ tris : int array, shape=(n_faces, 3)
+ Triangulation (each line contains indices for three points which
+ together form a face).
+
+ Notes
+ -----
+ .. versionadded:: 0.13.0
+
+ See Also
+ --------
+ read_surface
+ write_surface
+ """
+ with open(fname_in, "r") as fid:
+ lines = fid.readlines()
+ n_nodes = int(lines[0])
+ n_tris = int(lines[n_nodes + 1])
+ n_items = len(lines[1].split())
+ if n_items in [3, 6, 14, 17]:
+ inds = range(3)
+ elif n_items in [4, 7]:
+ inds = range(1, 4)
+ else:
+ raise IOError('Unrecognized format of data.')
+ rr = np.array([np.array([float(v) for v in l.split()])[inds]
+ for l in lines[1:n_nodes + 1]])
+ tris = np.array([np.array([int(v) for v in l.split()])[inds]
+ for l in lines[n_nodes + 2:n_nodes + 2 + n_tris]])
+ if swap:
+ tris[:, [2, 1]] = tris[:, [1, 2]]
+ tris -= 1
+ logger.info('Loaded surface from %s with %s nodes and %s triangles.' %
+ (fname_in, n_nodes, n_tris))
+ if n_items in [3, 4]:
+ logger.info('Node normals were not included in the source file.')
+ else:
+ warn('Node normals were not read.')
+ return (rr, tris)
diff --git a/mne/tests/test_annotations.py b/mne/tests/test_annotations.py
index 0c00613..62ff1fa 100644
--- a/mne/tests/test_annotations.py
+++ b/mne/tests/test_annotations.py
@@ -4,13 +4,14 @@
from datetime import datetime
from nose.tools import assert_raises
-from numpy.testing import assert_array_equal
+from numpy.testing import assert_array_equal, assert_array_almost_equal
import os.path as op
import numpy as np
+from mne import create_info
from mne.utils import run_tests_if_main
-from mne.io import Raw, concatenate_raws
+from mne.io import read_raw_fif, RawArray, concatenate_raws
from mne.annotations import Annotations
from mne.datasets import testing
@@ -21,9 +22,9 @@ fif_fname = op.join(data_dir, 'sample_audvis_trunc_raw.fif')
@testing.requires_testing_data
def test_annotations():
"""Test annotation class."""
- raw = Raw(fif_fname)
+ raw = read_raw_fif(fif_fname, add_eeg_ref=False)
onset = np.array(range(10))
- duration = np.ones(10) + raw.first_samp
+ duration = np.ones(10)
description = np.repeat('test', 10)
dt = datetime.utcnow()
meas_date = raw.info['meas_date']
@@ -36,23 +37,31 @@ def test_annotations():
assert_raises(ValueError, Annotations, onset, [duration, 1], description)
# Test combining annotations with concatenate_raws
- annot = Annotations(onset, duration, description, dt)
- sfreq = raw.info['sfreq']
raw2 = raw.copy()
+ orig_time = (meas_date[0] + meas_date[1] * 0.000001 +
+ raw2.first_samp / raw2.info['sfreq'])
+ annot = Annotations(onset, duration, description, orig_time)
raw2.annotations = annot
+ assert_array_equal(raw2.annotations.onset, onset)
concatenate_raws([raw, raw2])
- assert_array_equal(annot.onset, raw.annotations.onset)
+ assert_array_almost_equal(onset + 20., raw.annotations.onset, decimal=2)
assert_array_equal(annot.duration, raw.annotations.duration)
+ assert_array_equal(raw.annotations.description, np.repeat('test', 10))
+
+ # Test combining with RawArray and orig_times
+ data = np.random.randn(2, 1000) * 10e-12
+ sfreq = 100.
+ info = create_info(ch_names=['MEG1', 'MEG2'], ch_types=['grad'] * 2,
+ sfreq=sfreq)
+ info['meas_date'] = 0
+ raws = []
+ for i, fs in enumerate([1000, 100, 12]):
+ raw = RawArray(data.copy(), info, first_samp=fs)
+ ants = Annotations([1., 2.], [.5, .5], 'x', fs / sfreq)
+ raw.annotations = ants
+ raws.append(raw)
+ raw = concatenate_raws(raws)
+ assert_array_equal(raw.annotations.onset, [1., 2., 11., 12., 21., 22.])
- raw2.annotations = Annotations(onset, duration * 2, description, None)
- last_samp = raw.last_samp - 1
- concatenate_raws([raw, raw2])
- onsets = np.concatenate([onset,
- onset + (last_samp - raw.first_samp) / sfreq])
- assert_array_equal(raw.annotations.onset, onsets)
- assert_array_equal(raw.annotations.onset[:10], onset)
- assert_array_equal(raw.annotations.duration[:10], duration)
- assert_array_equal(raw.annotations.duration[10:], duration * 2)
- assert_array_equal(raw.annotations.description, np.repeat('test', 20))
run_tests_if_main()
diff --git a/mne/tests/test_bem.py b/mne/tests/test_bem.py
index 44a730d..dce68f0 100644
--- a/mne/tests/test_bem.py
+++ b/mne/tests/test_bem.py
@@ -3,7 +3,9 @@
# License: BSD 3 clause
from copy import deepcopy
+from os import remove
import os.path as op
+from shutil import copy
import warnings
import numpy as np
@@ -17,12 +19,17 @@ from mne.preprocessing.maxfilter import fit_sphere_to_headshape
from mne.io.constants import FIFF
from mne.transforms import translation
from mne.datasets import testing
-from mne.utils import run_tests_if_main, _TempDir, slow_test, catch_logging
+from mne.utils import (run_tests_if_main, _TempDir, slow_test, catch_logging,
+ requires_freesurfer)
from mne.bem import (_ico_downsample, _get_ico_map, _order_surfaces,
_assert_complete_surface, _assert_inside,
- _check_surface_size, _bem_find_surface)
+ _check_surface_size, _bem_find_surface, make_flash_bem)
+from mne.surface import read_surface
from mne.io import read_info
+import matplotlib
+matplotlib.use('Agg') # for testing don't use X server
+
warnings.simplefilter('always')
fname_raw = op.join(op.dirname(__file__), '..', 'io', 'tests', 'data',
@@ -66,8 +73,7 @@ def _compare_bem_solutions(sol_a, sol_b):
@testing.requires_testing_data
def test_io_bem():
- """Test reading and writing of bem surfaces and solutions
- """
+ """Test reading and writing of bem surfaces and solutions"""
tempdir = _TempDir()
temp_bem = op.join(tempdir, 'temp-bem.fif')
assert_raises(ValueError, read_bem_surfaces, fname_raw)
@@ -237,11 +243,8 @@ def test_fit_sphere_to_headshape():
info = Info(dig=dig, dev_head_t=dev_head_t)
# Degenerate conditions
- with warnings.catch_warnings(record=True) as w:
- assert_raises(ValueError, fit_sphere_to_headshape, info,
- dig_kinds=(FIFF.FIFFV_POINT_HPI,))
- assert_equal(len(w), 1)
- assert_true(w[0].category == DeprecationWarning)
+ assert_raises(ValueError, fit_sphere_to_headshape, info,
+ dig_kinds=(FIFF.FIFFV_POINT_HPI,))
assert_raises(ValueError, fit_sphere_to_headshape, info,
dig_kinds='foo', units='m')
info['dig'][0]['coord_frame'] = FIFF.FIFFV_COORD_DEVICE
@@ -336,4 +339,41 @@ def test_fit_sphere_to_headshape():
assert_raises(TypeError, fit_sphere_to_headshape, 1, units='m')
+ at requires_freesurfer
+ at testing.requires_testing_data
+def test_make_flash_bem():
+ """Test computing bem from flash images."""
+ import matplotlib.pyplot as plt
+ tmp = _TempDir()
+ bemdir = op.join(subjects_dir, 'sample', 'bem')
+ flash_path = op.join(subjects_dir, 'sample', 'mri', 'flash')
+
+ for surf in ('inner_skull', 'outer_skull', 'outer_skin'):
+ copy(op.join(bemdir, surf + '.surf'), tmp)
+ copy(op.join(bemdir, surf + '.tri'), tmp)
+ copy(op.join(bemdir, 'inner_skull_tmp.tri'), tmp)
+ copy(op.join(bemdir, 'outer_skin_from_testing.surf'), tmp)
+
+ # This function deletes the tri files at the end.
+ try:
+ make_flash_bem('sample', overwrite=True, subjects_dir=subjects_dir,
+ flash_path=flash_path)
+ for surf in ('inner_skull', 'outer_skull', 'outer_skin'):
+ coords, faces = read_surface(op.join(bemdir, surf + '.surf'))
+ surf = 'outer_skin_from_testing' if surf == 'outer_skin' else surf
+ coords_c, faces_c = read_surface(op.join(tmp, surf + '.surf'))
+ assert_equal(0, faces.min())
+ assert_equal(coords.shape[0], faces.max() + 1)
+ assert_allclose(coords, coords_c)
+ assert_allclose(faces, faces_c)
+ finally:
+ for surf in ('inner_skull', 'outer_skull', 'outer_skin'):
+ remove(op.join(bemdir, surf + '.surf')) # delete symlinks
+ copy(op.join(tmp, surf + '.tri'), bemdir) # return deleted tri
+ copy(op.join(tmp, surf + '.surf'), bemdir) # return moved surf
+ copy(op.join(tmp, 'inner_skull_tmp.tri'), bemdir)
+ copy(op.join(tmp, 'outer_skin_from_testing.surf'), bemdir)
+ plt.close('all')
+
+
run_tests_if_main()
diff --git a/mne/tests/test_chpi.py b/mne/tests/test_chpi.py
index 8dff103..a6ab8c5 100644
--- a/mne/tests/test_chpi.py
+++ b/mne/tests/test_chpi.py
@@ -8,13 +8,13 @@ from numpy.testing import assert_allclose
from nose.tools import assert_raises, assert_equal, assert_true
import warnings
-from mne.io import Raw
+from mne.io import read_raw_fif
from mne.io.constants import FIFF
-from mne.chpi import (get_chpi_positions, _calculate_chpi_positions,
+from mne.chpi import (_calculate_chpi_positions,
head_pos_to_trans_rot_t, read_head_pos,
- write_head_pos, filter_chpi)
+ write_head_pos, filter_chpi, _get_hpi_info)
from mne.fixes import assert_raises_regex
-from mne.transforms import rot_to_quat, quat_to_rot, _angle_between_quats
+from mne.transforms import rot_to_quat, _angle_between_quats
from mne.utils import (run_tests_if_main, _TempDir, slow_test, catch_logging,
requires_version)
from mne.datasets import testing
@@ -36,8 +36,47 @@ warnings.simplefilter('always')
@testing.requires_testing_data
+def test_chpi_adjust():
+ """Test cHPI logging and adjustment."""
+ raw = read_raw_fif(chpi_fif_fname, allow_maxshield='yes',
+ add_eeg_ref=False)
+ with catch_logging() as log:
+ _get_hpi_info(raw.info, adjust=True, verbose='debug')
+
+ # Ran MaxFilter (with -list, -v, -movecomp, etc.), and got:
+ msg = ['HPIFIT: 5 coils digitized in order 5 1 4 3 2',
+ 'HPIFIT: 3 coils accepted: 1 2 4',
+ 'Hpi coil moments (3 5):',
+ '2.08542e-15 -1.52486e-15 -1.53484e-15',
+ '2.14516e-15 2.09608e-15 7.30303e-16',
+ '-3.2318e-16 -4.25666e-16 2.69997e-15',
+ '5.21717e-16 1.28406e-15 1.95335e-15',
+ '1.21199e-15 -1.25801e-19 1.18321e-15',
+ 'HPIFIT errors: 0.3, 0.3, 5.3, 0.4, 3.2 mm.',
+ 'HPI consistency of isotrak and hpifit is OK.',
+ 'HP fitting limits: err = 5.0 mm, gval = 0.980.',
+ 'Using 5 HPI coils: 83 143 203 263 323 Hz', # actually came earlier
+ ]
+
+ log = log.getvalue().splitlines()
+ assert_true(set(log) == set(msg), '\n' + '\n'.join(set(msg) - set(log)))
+
+ # Then took the raw file, did this:
+ raw.info['dig'][5]['r'][2] += 1.
+ # And checked the result in MaxFilter, which changed the logging as:
+ msg = msg[:8] + [
+ 'HPIFIT errors: 0.3, 0.3, 5.3, 999.7, 3.2 mm.',
+ 'Note: HPI coil 3 isotrak is adjusted by 5.3 mm!',
+ 'Note: HPI coil 5 isotrak is adjusted by 3.2 mm!'] + msg[-2:]
+ with catch_logging() as log:
+ _get_hpi_info(raw.info, adjust=True, verbose='debug')
+ log = log.getvalue().splitlines()
+ assert_true(set(log) == set(msg), '\n' + '\n'.join(set(msg) - set(log)))
+
+
+ at testing.requires_testing_data
def test_read_write_head_pos():
- """Test reading and writing head position quaternion parameters"""
+ """Test reading and writing head position quaternion parameters."""
tempdir = _TempDir()
temp_name = op.join(tempdir, 'temp.pos')
# This isn't a 100% valid quat matrix but it should be okay for tests
@@ -56,48 +95,23 @@ def test_read_write_head_pos():
assert_raises(IOError, read_head_pos, temp_name + 'foo')
-def test_get_chpi():
- """Test CHPI position computation
- """
- with warnings.catch_warnings(record=True): # deprecation
- trans0, rot0, _, quat0 = get_chpi_positions(hp_fname, return_quat=True)
- assert_allclose(rot0[0], quat_to_rot(quat0[0]))
- trans0, rot0 = trans0[:-1], rot0[:-1]
- raw = Raw(hp_fif_fname)
- with warnings.catch_warnings(record=True): # deprecation
- out = get_chpi_positions(raw)
- trans1, rot1, t1 = out
- trans1, rot1 = trans1[2:], rot1[2:]
- # these will not be exact because they don't use equiv. time points
- assert_allclose(trans0, trans1, atol=1e-5, rtol=1e-1)
- assert_allclose(rot0, rot1, atol=1e-6, rtol=1e-1)
- # run through input checking
- raw_no_chpi = Raw(test_fif_fname)
- with warnings.catch_warnings(record=True): # deprecation
- assert_raises(TypeError, get_chpi_positions, 1)
- assert_raises(ValueError, get_chpi_positions, hp_fname, [1])
- assert_raises(RuntimeError, get_chpi_positions, raw_no_chpi)
- assert_raises(ValueError, get_chpi_positions, raw, t_step='foo')
- assert_raises(IOError, get_chpi_positions, 'foo')
-
-
@testing.requires_testing_data
def test_hpi_info():
- """Test getting HPI info
- """
+ """Test getting HPI info."""
tempdir = _TempDir()
temp_name = op.join(tempdir, 'temp_raw.fif')
for fname in (chpi_fif_fname, sss_fif_fname):
- raw = Raw(fname, allow_maxshield='yes')
+ raw = read_raw_fif(fname, allow_maxshield='yes', add_eeg_ref=False)
assert_true(len(raw.info['hpi_subsystem']) > 0)
raw.save(temp_name, overwrite=True)
- raw_2 = Raw(temp_name, allow_maxshield='yes')
+ raw_2 = read_raw_fif(temp_name, allow_maxshield='yes',
+ add_eeg_ref=False)
assert_equal(len(raw_2.info['hpi_subsystem']),
len(raw.info['hpi_subsystem']))
def _compare_positions(a, b, max_dist=0.003, max_angle=5.):
- """Compare estimated cHPI positions"""
+ """Compare estimated cHPI positions."""
from scipy.interpolate import interp1d
trans, rot, t = a
trans_est, rot_est, t_est = b
@@ -133,17 +147,17 @@ def _compare_positions(a, b, max_dist=0.003, max_angle=5.):
@requires_version('scipy', '0.11')
@requires_version('numpy', '1.7')
def test_calculate_chpi_positions():
- """Test calculation of cHPI positions
- """
+ """Test calculation of cHPI positions."""
trans, rot, t = head_pos_to_trans_rot_t(read_head_pos(pos_fname))
- raw = Raw(chpi_fif_fname, allow_maxshield='yes', preload=True)
+ raw = read_raw_fif(chpi_fif_fname, allow_maxshield='yes', preload=True,
+ add_eeg_ref=False)
t -= raw.first_samp / raw.info['sfreq']
quats = _calculate_chpi_positions(raw, verbose='debug')
trans_est, rot_est, t_est = head_pos_to_trans_rot_t(quats)
_compare_positions((trans, rot, t), (trans_est, rot_est, t_est), 0.003)
# degenerate conditions
- raw_no_chpi = Raw(test_fif_fname)
+ raw_no_chpi = read_raw_fif(test_fif_fname, add_eeg_ref=False)
assert_raises(RuntimeError, _calculate_chpi_positions, raw_no_chpi)
raw_bad = raw.copy()
for d in raw_bad.info['dig']:
@@ -160,8 +174,7 @@ def test_calculate_chpi_positions():
with catch_logging() as log_file:
_calculate_chpi_positions(raw_bad, verbose=True)
# ignore HPI info header and [done] footer
- for line in log_file.getvalue().strip().split('\n')[4:-1]:
- assert_true('0/5 good' in line)
+ assert_true('0/5 good' in log_file.getvalue().strip().split('\n')[-2])
# half the rate cuts off cHPI coils
with warnings.catch_warnings(record=True): # uint cast suggestion
@@ -172,28 +185,32 @@ def test_calculate_chpi_positions():
@testing.requires_testing_data
def test_chpi_subtraction():
- """Test subtraction of cHPI signals"""
- raw = Raw(chpi_fif_fname, allow_maxshield='yes', preload=True)
+ """Test subtraction of cHPI signals."""
+ raw = read_raw_fif(chpi_fif_fname, allow_maxshield='yes', preload=True,
+ add_eeg_ref=False)
+ raw.info['bads'] = ['MEG0111']
with catch_logging() as log:
filter_chpi(raw, include_line=False, verbose=True)
assert_true('5 cHPI' in log.getvalue())
# MaxFilter doesn't do quite as well as our algorithm with the last bit
raw.crop(0, 16, copy=False)
# remove cHPI status chans
- raw_c = Raw(sss_hpisubt_fname).crop(0, 16, copy=False).load_data()
+ raw_c = read_raw_fif(sss_hpisubt_fname,
+ add_eeg_ref=False).crop(0, 16, copy=False).load_data()
raw_c.pick_types(
meg=True, eeg=True, eog=True, ecg=True, stim=True, misc=True)
assert_meg_snr(raw, raw_c, 143, 624)
# Degenerate cases
- raw_nohpi = Raw(test_fif_fname, preload=True)
+ raw_nohpi = read_raw_fif(test_fif_fname, preload=True, add_eeg_ref=False)
assert_raises(RuntimeError, filter_chpi, raw_nohpi)
# When MaxFliter downsamples, like::
# $ maxfilter -nosss -ds 2 -f test_move_anon_raw.fif \
# -o test_move_anon_ds2_raw.fif
# it can strip out some values of info, which we emulate here:
- raw = Raw(chpi_fif_fname, allow_maxshield='yes')
+ raw = read_raw_fif(chpi_fif_fname, allow_maxshield='yes',
+ add_eeg_ref=False)
with warnings.catch_warnings(record=True): # uint cast suggestion
raw = raw.crop(0, 1).load_data().resample(600., npad='auto')
raw.info['buffer_size_sec'] = np.float64(2.)
diff --git a/mne/tests/test_cov.py b/mne/tests/test_cov.py
index 4ec2601..7eef394 100644
--- a/mne/tests/test_cov.py
+++ b/mne/tests/test_cov.py
@@ -23,7 +23,7 @@ from mne import (read_cov, write_cov, Epochs, merge_events,
compute_covariance, read_evokeds, compute_proj_raw,
pick_channels_cov, pick_channels, pick_types, pick_info,
make_ad_hoc_cov)
-from mne.io import Raw, RawArray
+from mne.io import read_raw_fif, RawArray, read_info
from mne.tests.common import assert_naming, assert_snr
from mne.utils import (_TempDir, slow_test, requires_sklearn_0_15,
run_tests_if_main)
@@ -42,8 +42,53 @@ erm_cov_fname = op.join(base_dir, 'test_erm-cov.fif')
hp_fif_fname = op.join(base_dir, 'test_chpi_raw_sss.fif')
+def test_cov_mismatch():
+ """Test estimation with MEG<->Head mismatch."""
+ raw = read_raw_fif(raw_fname, add_eeg_ref=False).crop(0, 5).load_data()
+ events = find_events(raw, stim_channel='STI 014')
+ raw.pick_channels(raw.ch_names[:5])
+ raw.add_proj([], remove_existing=True)
+ epochs = Epochs(raw, events, None, tmin=-0.2, tmax=0., preload=True,
+ add_eeg_ref=False)
+ for kind in ('shift', 'None'):
+ epochs_2 = epochs.copy()
+ # This should be fine
+ with warnings.catch_warnings(record=True) as w:
+ compute_covariance([epochs, epochs_2])
+ assert_equal(len(w), 0)
+ if kind == 'shift':
+ epochs_2.info['dev_head_t']['trans'][:3, 3] += 0.001
+ else: # None
+ epochs_2.info['dev_head_t'] = None
+ assert_raises(ValueError, compute_covariance, [epochs, epochs_2])
+ assert_equal(len(w), 0)
+ compute_covariance([epochs, epochs_2], on_mismatch='ignore')
+ assert_equal(len(w), 0)
+ compute_covariance([epochs, epochs_2], on_mismatch='warn')
+ assert_raises(ValueError, compute_covariance, epochs,
+ on_mismatch='x')
+ assert_true(any('transform mismatch' in str(ww.message) for ww in w))
+ # This should work
+ epochs.info['dev_head_t'] = None
+ epochs_2.info['dev_head_t'] = None
+ compute_covariance([epochs, epochs_2], method=None)
+
+
+def test_cov_order():
+ """Test covariance ordering."""
+ info = read_info(raw_fname)
+ # add MEG channel with low enough index number to affect EEG if
+ # order is incorrect
+ info['bads'] += ['MEG 0113']
+ ch_names = [info['ch_names'][pick]
+ for pick in pick_types(info, meg=False, eeg=True)]
+ cov = read_cov(cov_fname)
+ # no avg ref present warning
+ prepare_noise_cov(cov, info, ch_names, verbose='error')
+
+
def test_ad_hoc_cov():
- """Test ad hoc cov creation and I/O"""
+ """Test ad hoc cov creation and I/O."""
tempdir = _TempDir()
out_fname = op.join(tempdir, 'test-cov.fif')
evoked = read_evokeds(ave_fname)[0]
@@ -55,7 +100,7 @@ def test_ad_hoc_cov():
def test_io_cov():
- """Test IO for noise covariance matrices"""
+ """Test IO for noise covariance matrices."""
tempdir = _TempDir()
cov = read_cov(cov_fname)
cov['method'] = 'empirical'
@@ -95,15 +140,15 @@ def test_io_cov():
def test_cov_estimation_on_raw():
- """Test estimation from raw (typically empty room)"""
+ """Test estimation from raw (typically empty room)."""
tempdir = _TempDir()
- raw = Raw(raw_fname, preload=True)
+ raw = read_raw_fif(raw_fname, preload=True, add_eeg_ref=False)
cov_mne = read_cov(erm_cov_fname)
# The pure-string uses the more efficient numpy-based method, the
# the list gets triaged to compute_covariance (should be equivalent
# but use more memory)
- for method in ('empirical', ['empirical']):
+ for method in (None, ['empirical']): # None is cast to 'empirical'
cov = compute_raw_covariance(raw, tstep=None, method=method)
assert_equal(cov.ch_names, cov_mne.ch_names)
assert_equal(cov.nfree, cov_mne.nfree)
@@ -132,7 +177,8 @@ def test_cov_estimation_on_raw():
cov = compute_raw_covariance(raw_pick, picks=picks, method=method)
assert_snr(cov.data, cov_mne.data[picks][:, picks], 90) # cutoff samps
# make sure we get a warning with too short a segment
- raw_2 = Raw(raw_fname).crop(0, 1, copy=False)
+ raw_2 = read_raw_fif(raw_fname,
+ add_eeg_ref=False).crop(0, 1, copy=False)
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter('always')
cov = compute_raw_covariance(raw_2, method=method)
@@ -149,8 +195,8 @@ def test_cov_estimation_on_raw():
@slow_test
@requires_sklearn_0_15
def test_cov_estimation_on_raw_reg():
- """Test estimation from raw with regularization"""
- raw = Raw(raw_fname, preload=True)
+ """Test estimation from raw with regularization."""
+ raw = read_raw_fif(raw_fname, preload=True, add_eeg_ref=False)
raw.info['sfreq'] /= 10.
raw = RawArray(raw._data[:, ::10].copy(), raw.info) # decimate for speed
cov_mne = read_cov(erm_cov_fname)
@@ -164,9 +210,10 @@ def test_cov_estimation_on_raw_reg():
@slow_test
def test_cov_estimation_with_triggers():
- """Test estimation from raw with triggers"""
+ """Test estimation from raw with triggers."""
tempdir = _TempDir()
- raw = Raw(raw_fname, preload=False)
+ raw = read_raw_fif(raw_fname, preload=False, add_eeg_ref=False)
+ raw.set_eeg_reference()
events = find_events(raw, stim_channel='STI 014')
event_ids = [1, 2, 3, 4]
reject = dict(grad=10000e-13, mag=4e-12, eeg=80e-6, eog=150e-6)
@@ -175,7 +222,7 @@ def test_cov_estimation_with_triggers():
events_merged = merge_events(events, event_ids, 1234)
epochs = Epochs(raw, events_merged, 1234, tmin=-0.2, tmax=0,
baseline=(-0.2, -0.1), proj=True,
- reject=reject, preload=True)
+ reject=reject, preload=True, add_eeg_ref=False)
cov = compute_covariance(epochs, keep_sample_mean=True)
cov_mne = read_cov(cov_km_fname)
@@ -191,7 +238,8 @@ def test_cov_estimation_with_triggers():
# cov using a list of epochs and keep_sample_mean=True
epochs = [Epochs(raw, events, ev_id, tmin=-0.2, tmax=0,
- baseline=(-0.2, -0.1), proj=True, reject=reject)
+ baseline=(-0.2, -0.1), proj=True, reject=reject,
+ add_eeg_ref=False)
for ev_id in event_ids]
cov2 = compute_covariance(epochs, keep_sample_mean=True)
@@ -222,9 +270,11 @@ def test_cov_estimation_with_triggers():
# cov with list of epochs with different projectors
epochs = [Epochs(raw, events[:4], event_ids[0], tmin=-0.2, tmax=0,
- baseline=(-0.2, -0.1), proj=True, reject=reject),
+ baseline=(-0.2, -0.1), proj=True, reject=reject,
+ add_eeg_ref=False),
Epochs(raw, events[:4], event_ids[0], tmin=-0.2, tmax=0,
- baseline=(-0.2, -0.1), proj=False, reject=reject)]
+ baseline=(-0.2, -0.1), proj=False, reject=reject,
+ add_eeg_ref=False)]
# these should fail
assert_raises(ValueError, compute_covariance, epochs)
assert_raises(ValueError, compute_covariance, epochs, projs=None)
@@ -237,12 +287,13 @@ def test_cov_estimation_with_triggers():
# test new dict support
epochs = Epochs(raw, events, dict(a=1, b=2, c=3, d=4), tmin=-0.2, tmax=0,
- baseline=(-0.2, -0.1), proj=True, reject=reject)
+ baseline=(-0.2, -0.1), proj=True, reject=reject,
+ add_eeg_ref=False)
compute_covariance(epochs)
def test_arithmetic_cov():
- """Test arithmetic with noise covariance matrices"""
+ """Test arithmetic with noise covariance matrices."""
cov = read_cov(cov_fname)
cov_sum = cov + cov
assert_array_almost_equal(2 * cov.nfree, cov_sum.nfree)
@@ -256,8 +307,8 @@ def test_arithmetic_cov():
def test_regularize_cov():
- """Test cov regularization"""
- raw = Raw(raw_fname, preload=False)
+ """Test cov regularization."""
+ raw = read_raw_fif(raw_fname, preload=False, add_eeg_ref=False)
raw.info['bads'].append(raw.ch_names[0]) # test with bad channels
noise_cov = read_cov(cov_fname)
# Regularize noise cov
@@ -269,8 +320,8 @@ def test_regularize_cov():
assert_true(np.mean(noise_cov['data'] < reg_noise_cov['data']) < 0.08)
-def test_evoked_whiten():
- """Test whitening of evoked data"""
+def test_whiten_evoked():
+ """Test whitening of evoked data."""
evoked = read_evokeds(ave_fname, condition=0, baseline=(None, 0),
proj=True)
cov = read_cov(cov_fname)
@@ -289,10 +340,14 @@ def test_evoked_whiten():
assert_true(np.all(mean_baseline < 1.))
assert_true(np.all(mean_baseline > 0.2))
+ # degenerate
+ cov_bad = pick_channels_cov(cov, include=evoked.ch_names[:10])
+ assert_raises(RuntimeError, whiten_evoked, evoked, cov_bad, picks)
+
@slow_test
def test_rank():
- """Test cov rank estimation"""
+ """Test cov rank estimation."""
# Test that our rank estimation works properly on a simple case
evoked = read_evokeds(ave_fname, condition=0, baseline=(None, 0),
proj=False)
@@ -304,9 +359,9 @@ def test_rank():
assert_true((cov['eig'][1:] > 0).all()) # all else should be > 0
# Now do some more comprehensive tests
- raw_sample = Raw(raw_fname)
+ raw_sample = read_raw_fif(raw_fname, add_eeg_ref=False)
- raw_sss = Raw(hp_fif_fname)
+ raw_sss = read_raw_fif(hp_fif_fname, add_eeg_ref=False)
raw_sss.add_proj(compute_proj_raw(raw_sss))
cov_sample = compute_raw_covariance(raw_sample)
@@ -411,8 +466,7 @@ def test_cov_scaling():
@requires_sklearn_0_15
def test_auto_low_rank():
- """Test probabilistic low rank estimators"""
-
+ """Test probabilistic low rank estimators."""
n_samples, n_features, rank = 400, 20, 10
sigma = 0.1
@@ -460,9 +514,8 @@ def test_auto_low_rank():
@slow_test
@requires_sklearn_0_15
def test_compute_covariance_auto_reg():
- """Test automated regularization"""
-
- raw = Raw(raw_fname, preload=True)
+ """Test automated regularization."""
+ raw = read_raw_fif(raw_fname, preload=True, add_eeg_ref=False)
raw.resample(100, npad='auto') # much faster estimation
events = find_events(raw, stim_channel='STI 014')
event_ids = [1, 2, 3, 4]
@@ -476,7 +529,8 @@ def test_compute_covariance_auto_reg():
raw.info.normalize_proj()
epochs = Epochs(
raw, events_merged, 1234, tmin=-0.2, tmax=0,
- baseline=(-0.2, -0.1), proj=True, reject=reject, preload=True)
+ baseline=(-0.2, -0.1), proj=True, reject=reject, preload=True,
+ add_eeg_ref=False)
epochs = epochs.crop(None, 0)[:10]
method_params = dict(factor_analysis=dict(iter_n_components=[3]),
diff --git a/mne/tests/test_dipole.py b/mne/tests/test_dipole.py
index 6555e64..1db01a6 100644
--- a/mne/tests/test_dipole.py
+++ b/mne/tests/test_dipole.py
@@ -12,14 +12,15 @@ from mne import (read_dipole, read_forward_solution,
SourceEstimate, write_evokeds, fit_dipole,
transform_surface_to, make_sphere_model, pick_types,
pick_info, EvokedArray, read_source_spaces, make_ad_hoc_cov,
- make_forward_solution, Dipole, DipoleFixed)
+ make_forward_solution, Dipole, DipoleFixed, Epochs,
+ make_fixed_length_events)
from mne.simulation import simulate_evoked
from mne.datasets import testing
from mne.utils import (run_tests_if_main, _TempDir, slow_test, requires_mne,
run_subprocess)
from mne.proj import make_eeg_average_ref_proj
-from mne.io import Raw
+from mne.io import read_raw_fif, read_raw_ctf
from mne.surface import _compute_nearest
from mne.bem import _bem_find_surface, read_bem_solution
@@ -39,12 +40,15 @@ fname_trans = op.join(data_path, 'MEG', 'sample',
'sample_audvis_trunc-trans.fif')
fname_fwd = op.join(data_path, 'MEG', 'sample',
'sample_audvis_trunc-meg-eeg-oct-6-fwd.fif')
-fname_xfit_dip = op.join(data_path, 'misc', 'fam_115_LH.fif')
+fname_xfit_dip = op.join(data_path, 'dip', 'fixed_auto.fif')
+fname_xfit_dip_txt = op.join(data_path, 'dip', 'fixed_auto.dip')
+fname_xfit_seq_txt = op.join(data_path, 'dip', 'sequential.dip')
+fname_ctf = op.join(data_path, 'CTF', 'testdata_ctf_short.ds')
subjects_dir = op.join(data_path, 'subjects')
def _compare_dipoles(orig, new):
- """Compare dipole results for equivalence"""
+ """Compare dipole results for equivalence."""
assert_allclose(orig.times, new.times, atol=1e-3, err_msg='times')
assert_allclose(orig.pos, new.pos, err_msg='pos')
assert_allclose(orig.amplitude, new.amplitude, err_msg='amplitude')
@@ -54,6 +58,7 @@ def _compare_dipoles(orig, new):
def _check_dipole(dip, n_dipoles):
+ """Check dipole sizes."""
assert_equal(len(dip), n_dipoles)
assert_equal(dip.pos.shape, (n_dipoles, 3))
assert_equal(dip.ori.shape, (n_dipoles, 3))
@@ -63,7 +68,7 @@ def _check_dipole(dip, n_dipoles):
@testing.requires_testing_data
def test_io_dipoles():
- """Test IO for .dip files"""
+ """Test IO for .dip files."""
tempdir = _TempDir()
dipole = read_dipole(fname_dip)
print(dipole) # test repr
@@ -73,11 +78,27 @@ def test_io_dipoles():
_compare_dipoles(dipole, dipole_new)
+ at testing.requires_testing_data
+def test_dipole_fitting_ctf():
+ """Test dipole fitting with CTF data."""
+ raw_ctf = read_raw_ctf(fname_ctf).set_eeg_reference()
+ events = make_fixed_length_events(raw_ctf, 1)
+ evoked = Epochs(raw_ctf, events, 1, 0, 0, baseline=None,
+ add_eeg_ref=False).average()
+ cov = make_ad_hoc_cov(evoked.info)
+ sphere = make_sphere_model((0., 0., 0.))
+ # XXX Eventually we should do some better checks about accuracy, but
+ # for now our CTF phantom fitting tutorials will have to do
+ # (otherwise we need to add that to the testing dataset, which is
+ # a bit too big)
+ fit_dipole(evoked, cov, sphere)
+
+
@slow_test
@testing.requires_testing_data
@requires_mne
def test_dipole_fitting():
- """Test dipole fitting"""
+ """Test dipole fitting."""
amp = 10e-9
tempdir = _TempDir()
rng = np.random.RandomState(0)
@@ -160,13 +181,13 @@ def test_dipole_fitting():
@testing.requires_testing_data
def test_dipole_fitting_fixed():
- """Test dipole fitting with a fixed position"""
+ """Test dipole fitting with a fixed position."""
tpeak = 0.073
sphere = make_sphere_model(head_radius=0.1)
evoked = read_evokeds(fname_evo, baseline=(None, 0))[0]
- evoked.pick_types(meg=True, copy=False)
+ evoked.pick_types(meg=True)
t_idx = np.argmin(np.abs(tpeak - evoked.times))
- evoked_crop = evoked.copy().crop(tpeak, tpeak, copy=False)
+ evoked_crop = evoked.copy().crop(tpeak, tpeak)
assert_equal(len(evoked_crop.times), 1)
cov = read_cov(fname_cov)
dip_seq, resid = fit_dipole(evoked_crop, cov, sphere)
@@ -204,7 +225,7 @@ def test_dipole_fitting_fixed():
@testing.requires_testing_data
def test_len_index_dipoles():
- """Test len and indexing of Dipole objects"""
+ """Test len and indexing of Dipole objects."""
dipole = read_dipole(fname_dip)
d0 = dipole[0]
d1 = dipole[:1]
@@ -220,9 +241,9 @@ def test_len_index_dipoles():
@testing.requires_testing_data
def test_min_distance_fit_dipole():
- """Test dipole min_dist to inner_skull"""
+ """Test dipole min_dist to inner_skull."""
subject = 'sample'
- raw = Raw(fname_raw, preload=True)
+ raw = read_raw_fif(fname_raw, preload=True, add_eeg_ref=False)
# select eeg data
picks = pick_types(raw.info, meg=False, eeg=True, exclude='bads')
@@ -242,7 +263,8 @@ def test_min_distance_fit_dipole():
min_dist = 5. # distance in mm
- dip, residual = fit_dipole(evoked, cov, fname_bem, fname_trans,
+ bem = read_bem_solution(fname_bem)
+ dip, residual = fit_dipole(evoked, cov, bem, fname_trans,
min_dist=min_dist)
dist = _compute_depth(dip, fname_bem, fname_trans, subject, subjects_dir)
@@ -255,7 +277,7 @@ def test_min_distance_fit_dipole():
def _compute_depth(dip, fname_bem, fname_trans, subject, subjects_dir):
- """Compute dipole depth"""
+ """Compute dipole depth."""
trans = _get_trans(fname_trans)[0]
bem = read_bem_solution(fname_bem)
surf = _bem_find_surface(bem, 'inner_skull')
@@ -267,7 +289,7 @@ def _compute_depth(dip, fname_bem, fname_trans, subject, subjects_dir):
@testing.requires_testing_data
def test_accuracy():
- """Test dipole fitting to sub-mm accuracy"""
+ """Test dipole fitting to sub-mm accuracy."""
evoked = read_evokeds(fname_evo)[0].crop(0., 0.,)
evoked.pick_types(meg=True, eeg=False)
evoked.pick_channels([c for c in evoked.ch_names[::4]])
@@ -310,13 +332,21 @@ def test_accuracy():
@testing.requires_testing_data
def test_dipole_fixed():
- """Test reading a fixed-position dipole (from Xfit)"""
+ """Test reading a fixed-position dipole (from Xfit)."""
dip = read_dipole(fname_xfit_dip)
_check_roundtrip_fixed(dip)
+ with warnings.catch_warnings(record=True) as w: # unused fields
+ dip_txt = read_dipole(fname_xfit_dip_txt)
+ assert_true(any('extra fields' in str(ww.message) for ww in w))
+ assert_allclose(dip.info['chs'][0]['loc'][:3], dip_txt.pos[0])
+ assert_allclose(dip_txt.amplitude[0], 12.1e-9)
+ with warnings.catch_warnings(record=True): # unused fields
+ dip_txt_seq = read_dipole(fname_xfit_seq_txt)
+ assert_allclose(dip_txt_seq.gof, [27.3, 46.4, 43.7, 41., 37.3, 32.5])
def _check_roundtrip_fixed(dip):
- """Helper to test roundtrip IO for fixed dipoles"""
+ """Helper to test roundtrip IO for fixed dipoles."""
tempdir = _TempDir()
dip.save(op.join(tempdir, 'test-dip.fif.gz'))
dip_read = read_dipole(op.join(tempdir, 'test-dip.fif.gz'))
diff --git a/mne/tests/test_docstring_parameters.py b/mne/tests/test_docstring_parameters.py
index 5d1bb5c..d38fcdf 100644
--- a/mne/tests/test_docstring_parameters.py
+++ b/mne/tests/test_docstring_parameters.py
@@ -1,19 +1,16 @@
-# TODO inspect for Cython (see sagenb.misc.sageinspect)
from __future__ import print_function
-from nose.plugins.skip import SkipTest
from nose.tools import assert_true
-from os import path as op
import sys
import inspect
import warnings
-import imp
from pkgutil import walk_packages
from inspect import getsource
import mne
-from mne.utils import run_tests_if_main
+from mne.utils import (run_tests_if_main, _doc_special_members,
+ requires_numpydoc)
from mne.fixes import _get_args
public_modules = [
@@ -43,13 +40,6 @@ public_modules = [
'mne.viz',
]
-docscrape_path = op.join(op.dirname(__file__), '..', '..', 'doc', 'sphinxext',
- 'numpy_ext', 'docscrape.py')
-if op.isfile(docscrape_path):
- docscrape = imp.load_source('docscrape', docscrape_path)
-else:
- docscrape = None
-
def get_name(func):
parts = []
@@ -65,17 +55,17 @@ def get_name(func):
# functions to ignore args / docstring of
_docstring_ignores = [
'mne.io.write', # always ignore these
- 'mne.fixes._in1d', # fix function
- 'mne.epochs.average_movements', # deprecated pos param
+ 'mne.decoding.csp.CSP.fit', # deprecated epochs_data
+ 'mne.decoding.csp.CSP.transform' # deprecated epochs_data
]
_tab_ignores = [
- 'mne.channels.tests.test_montage', # demo data has a tab
]
def check_parameters_match(func, doc=None):
"""Helper to check docstring, returns list of incorrect results"""
+ from numpydoc import docscrape
incorrect = []
name_ = get_name(func)
if not name_.startswith('mne.') or name_.startswith('mne.externals'):
@@ -110,10 +100,10 @@ def check_parameters_match(func, doc=None):
return incorrect
+ at requires_numpydoc
def test_docstring_parameters():
"""Test module docsting formatting"""
- if docscrape is None:
- raise SkipTest('This must be run from the mne-python source directory')
+ from numpydoc import docscrape
incorrect = []
for name in public_modules:
module = __import__(name, globals())
@@ -121,7 +111,7 @@ def test_docstring_parameters():
module = getattr(module, submod)
classes = inspect.getmembers(module, inspect.isclass)
for cname, cls in classes:
- if cname.startswith('_'):
+ if cname.startswith('_') and cname not in _doc_special_members:
continue
with warnings.catch_warnings(record=True) as w:
cdoc = docscrape.ClassDoc(cls)
diff --git a/mne/tests/test_epochs.py b/mne/tests/test_epochs.py
index 9803d09..47259fa 100644
--- a/mne/tests/test_epochs.py
+++ b/mne/tests/test_epochs.py
@@ -19,17 +19,18 @@ import matplotlib
from mne import (Epochs, Annotations, read_events, pick_events, read_epochs,
equalize_channels, pick_types, pick_channels, read_evokeds,
- write_evokeds, create_info, make_fixed_length_events)
+ write_evokeds, create_info, make_fixed_length_events,
+ combine_evoked)
+from mne.baseline import rescale
from mne.preprocessing import maxwell_filter
from mne.epochs import (
bootstrap, equalize_epoch_counts, combine_event_ids, add_channels_epochs,
EpochsArray, concatenate_epochs, _BaseEpochs, average_movements)
from mne.utils import (_TempDir, requires_pandas, slow_test,
- clean_warning_registry, run_tests_if_main,
- requires_version)
+ run_tests_if_main, requires_version)
from mne.chpi import read_head_pos, head_pos_to_trans_rot_t
-from mne.io import RawArray, Raw
+from mne.io import RawArray, read_raw_fif
from mne.io.proj import _has_eeg_average_ref_proj
from mne.event import merge_events
from mne.io.constants import FIFF
@@ -59,7 +60,8 @@ rng = np.random.RandomState(42)
def _get_data(preload=False):
- raw = Raw(raw_fname, preload=preload, add_eeg_ref=False, proj=False)
+ """Get data."""
+ raw = read_raw_fif(raw_fname, preload=preload, add_eeg_ref=False)
events = read_events(event_name)
picks = pick_types(raw.info, meg=True, eeg=True, stim=True,
ecg=True, eog=True, include=['STI 014'],
@@ -69,18 +71,38 @@ def _get_data(preload=False):
reject = dict(grad=1000e-12, mag=4e-12, eeg=80e-6, eog=150e-6)
flat = dict(grad=1e-15, mag=1e-15)
-clean_warning_registry() # really clean warning stack
+
+def test_hierarchical():
+ """Test hierarchical access."""
+ raw, events, picks = _get_data()
+ event_id = {'a/1': 1, 'a/2': 2, 'b/1': 3, 'b/2': 4}
+ epochs = Epochs(raw, events, event_id, add_eeg_ref=False, preload=True)
+ epochs_a1 = epochs['a/1']
+ epochs_a2 = epochs['a/2']
+ epochs_b1 = epochs['b/1']
+ epochs_b2 = epochs['b/2']
+ epochs_a = epochs['a']
+ assert_equal(len(epochs_a), len(epochs_a1) + len(epochs_a2))
+ epochs_b = epochs['b']
+ assert_equal(len(epochs_b), len(epochs_b1) + len(epochs_b2))
+ epochs_1 = epochs['1']
+ assert_equal(len(epochs_1), len(epochs_a1) + len(epochs_b1))
+ epochs_2 = epochs['2']
+ assert_equal(len(epochs_2), len(epochs_a2) + len(epochs_b2))
+ epochs_all = epochs[('1', '2')]
+ assert_equal(len(epochs), len(epochs_all))
+ assert_array_equal(epochs.get_data(), epochs_all.get_data())
@slow_test
@testing.requires_testing_data
def test_average_movements():
- """Test movement averaging algorithm
- """
+ """Test movement averaging algorithm."""
# usable data
crop = 0., 10.
origin = (0., 0., 0.04)
- raw = Raw(fname_raw_move, allow_maxshield='yes')
+ raw = read_raw_fif(fname_raw_move, allow_maxshield='yes',
+ add_eeg_ref=False)
raw.info['bads'] += ['MEG2443'] # mark some bad MEG channel
raw.crop(*crop, copy=False).load_data()
raw.filter(None, 20, method='iir')
@@ -88,14 +110,14 @@ def test_average_movements():
picks = pick_types(raw.info, meg=True, eeg=True, stim=True,
ecg=True, eog=True, exclude=())
epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks, proj=False,
- preload=True)
+ preload=True, add_eeg_ref=False)
epochs_proj = Epochs(raw, events[:1], event_id, tmin, tmax, picks=picks,
- proj=True, preload=True)
+ proj=True, preload=True, add_eeg_ref=False)
raw_sss_stat = maxwell_filter(raw, origin=origin, regularize=None,
bad_condition='ignore')
del raw
epochs_sss_stat = Epochs(raw_sss_stat, events, event_id, tmin, tmax,
- picks=picks, proj=False)
+ picks=picks, proj=False, add_eeg_ref=False)
evoked_sss_stat = epochs_sss_stat.average()
del raw_sss_stat, epochs_sss_stat
head_pos = read_head_pos(fname_raw_move_pos)
@@ -106,11 +128,8 @@ def test_average_movements():
# SSS-based
assert_raises(TypeError, average_movements, epochs, None)
- with warnings.catch_warnings(record=True) as w:
- warnings.simplefilter('always') # deprecated param, pos -> head_pos
- evoked_move_non = average_movements(epochs, pos=head_pos,
- weight_all=False, origin=origin)
- assert_equal(len(w), 1)
+ evoked_move_non = average_movements(epochs, head_pos=head_pos,
+ weight_all=False, origin=origin)
evoked_move_all = average_movements(epochs, head_pos=head_pos,
weight_all=True, origin=origin)
evoked_stat_all = average_movements(epochs, head_pos=head_pos_stat,
@@ -126,15 +145,16 @@ def test_average_movements():
ev, evoked_std, 1., 1.)
meg_picks = pick_types(evoked_std.info, meg=True, exclude=())
assert_allclose(evoked_move_non.data[meg_picks],
- evoked_move_all.data[meg_picks])
+ evoked_move_all.data[meg_picks], atol=1e-20)
# compare to averaged movecomp version (should be fairly similar)
- raw_sss = Raw(fname_raw_movecomp_sss).crop(*crop, copy=False).load_data()
+ raw_sss = read_raw_fif(fname_raw_movecomp_sss, add_eeg_ref=False)
+ raw_sss.crop(*crop, copy=False).load_data()
raw_sss.filter(None, 20, method='iir')
picks_sss = pick_types(raw_sss.info, meg=True, eeg=True, stim=True,
ecg=True, eog=True, exclude=())
assert_array_equal(picks, picks_sss)
epochs_sss = Epochs(raw_sss, events, event_id, tmin, tmax,
- picks=picks_sss, proj=False)
+ picks=picks_sss, proj=False, add_eeg_ref=False)
evoked_sss = epochs_sss.average()
assert_equal(evoked_std.nave, evoked_sss.nave)
# this should break the non-MEG channels
@@ -155,7 +175,8 @@ def test_average_movements():
destination['trans'])
evoked_miss = average_movements(epochs, head_pos=head_pos[2:],
origin=origin, destination=destination)
- assert_allclose(evoked_miss.data, evoked_move_all.data)
+ assert_allclose(evoked_miss.data, evoked_move_all.data,
+ atol=1e-20)
assert_allclose(evoked_miss.info['dev_head_t']['trans'],
destination['trans'])
@@ -166,14 +187,10 @@ def test_average_movements():
assert_raises(TypeError, average_movements, 'foo', head_pos=head_pos)
assert_raises(RuntimeError, average_movements, epochs_proj,
head_pos=head_pos) # prj
- epochs.info['comps'].append([0])
- assert_raises(RuntimeError, average_movements, epochs, head_pos=head_pos)
- epochs.info['comps'].pop()
def test_reject():
- """Test epochs rejection
- """
+ """Test epochs rejection."""
raw, events, picks = _get_data()
# cull the list just to contain the relevant event
events = events[events[:, 2] == event_id, :]
@@ -182,19 +199,22 @@ def test_reject():
assert_raises(TypeError, pick_types, raw)
picks_meg = pick_types(raw.info, meg=True, eeg=False)
assert_raises(TypeError, Epochs, raw, events, event_id, tmin, tmax,
- picks=picks, preload=False, reject='foo')
+ picks=picks, preload=False, reject='foo',
+ add_eeg_ref=False)
assert_raises(ValueError, Epochs, raw, events, event_id, tmin, tmax,
- picks=picks_meg, preload=False, reject=dict(eeg=1.))
+ picks=picks_meg, preload=False, reject=dict(eeg=1.),
+ add_eeg_ref=False)
# this one is okay because it's not actually requesting rejection
Epochs(raw, events, event_id, tmin, tmax, picks=picks_meg,
- preload=False, reject=dict(eeg=np.inf))
+ preload=False, reject=dict(eeg=np.inf), add_eeg_ref=False)
for val in (None, -1): # protect against older MNE-C types
for kwarg in ('reject', 'flat'):
assert_raises(ValueError, Epochs, raw, events, event_id,
tmin, tmax, picks=picks_meg, preload=False,
- **{kwarg: dict(grad=val)})
+ add_eeg_ref=False, **{kwarg: dict(grad=val)})
assert_raises(KeyError, Epochs, raw, events, event_id, tmin, tmax,
- picks=picks, preload=False, reject=dict(foo=1.))
+ picks=picks, preload=False, reject=dict(foo=1.),
+ add_eeg_ref=False)
data_7 = dict()
keep_idx = [0, 1, 2]
@@ -202,7 +222,7 @@ def test_reject():
for proj in (True, False, 'delayed'):
# no rejection
epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks,
- preload=preload)
+ preload=preload, add_eeg_ref=False)
assert_raises(ValueError, epochs.drop_bad, reject='foo')
epochs.drop_bad()
assert_equal(len(epochs), len(events))
@@ -214,7 +234,7 @@ def test_reject():
# with rejection
epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks,
- reject=reject, preload=preload)
+ reject=reject, preload=preload, add_eeg_ref=False)
epochs.drop_bad()
assert_equal(len(epochs), len(events) - 4)
assert_array_equal(epochs.selection, selection)
@@ -223,7 +243,7 @@ def test_reject():
# rejection post-hoc
epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks,
- preload=preload)
+ preload=preload, add_eeg_ref=False)
epochs.drop_bad()
assert_equal(len(epochs), len(events))
assert_array_equal(epochs.get_data(), data_7[proj])
@@ -237,7 +257,8 @@ def test_reject():
# rejection twice
reject_part = dict(grad=1100e-12, mag=4e-12, eeg=80e-6, eog=150e-6)
epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks,
- reject=reject_part, preload=preload)
+ reject=reject_part, preload=preload,
+ add_eeg_ref=False)
epochs.drop_bad()
assert_equal(len(epochs), len(events) - 1)
epochs.drop_bad(reject)
@@ -258,17 +279,19 @@ def test_reject():
# rejection of subset of trials (ensure array ownership)
reject_part = dict(grad=1100e-12, mag=4e-12, eeg=80e-6, eog=150e-6)
epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks,
- reject=None, preload=preload)
+ reject=None, preload=preload,
+ add_eeg_ref=False)
epochs = epochs[:-1]
epochs.drop_bad(reject=reject)
assert_equal(len(epochs), len(events) - 4)
assert_array_equal(epochs.get_data(), data_7[proj][keep_idx])
# rejection on annotations
- raw.annotations = Annotations([events[0][0] / raw.info['sfreq']],
- [1], ['BAD'])
+ raw.annotations = Annotations([(events[0][0] - raw.first_samp) /
+ raw.info['sfreq']], [1], ['BAD'])
epochs = Epochs(raw, events, event_id, tmin, tmax, picks=[0],
- reject=None, preload=preload)
+ reject=None, preload=preload,
+ add_eeg_ref=False)
epochs.drop_bad()
assert_equal(len(events) - 1, len(epochs.events))
assert_equal(epochs.drop_log[0][0], 'BAD')
@@ -276,12 +299,11 @@ def test_reject():
def test_decim():
- """Test epochs decimation
- """
+ """Test epochs decimation."""
# First with EpochsArray
- n_epochs, n_channels, n_times = 5, 10, 20
dec_1, dec_2 = 2, 3
decim = dec_1 * dec_2
+ n_epochs, n_channels, n_times = 5, 10, 20
sfreq = 1000.
sfreq_new = sfreq / decim
data = rng.randn(n_epochs, n_channels, n_times)
@@ -298,19 +320,40 @@ def test_decim():
# Now let's do it with some real data
raw, events, picks = _get_data()
+ events = events[events[:, 2] == 1][:2]
+ raw.load_data().pick_channels([raw.ch_names[pick] for pick in picks[::30]])
+ raw.info.normalize_proj()
+ del picks
sfreq_new = raw.info['sfreq'] / decim
- raw.info['lowpass'] = sfreq_new / 4. # suppress aliasing warnings
- epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks,
- preload=False)
+ raw.info['lowpass'] = sfreq_new / 12. # suppress aliasing warnings
assert_raises(ValueError, epochs.decimate, -1)
assert_raises(ValueError, epochs.decimate, 2, offset=-1)
assert_raises(ValueError, epochs.decimate, 2, offset=2)
+ for this_offset in range(decim):
+ epochs = Epochs(raw, events, event_id,
+ tmin=-this_offset / raw.info['sfreq'],
+ tmax=tmax, preload=False, add_eeg_ref=False)
+ idx_offsets = np.arange(decim) + this_offset
+ for offset, idx_offset in zip(np.arange(decim), idx_offsets):
+ expected_times = epochs.times[idx_offset::decim]
+ expected_data = epochs.get_data()[:, :, idx_offset::decim]
+ must_have = offset / float(epochs.info['sfreq'])
+ assert_true(np.isclose(must_have, expected_times).any())
+ ep_decim = epochs.copy().decimate(decim, offset)
+ assert_true(np.isclose(must_have, ep_decim.times).any())
+ assert_allclose(ep_decim.times, expected_times)
+ assert_allclose(ep_decim.get_data(), expected_data)
+ assert_equal(ep_decim.info['sfreq'], sfreq_new)
+
+ # More complex cases
+ epochs = Epochs(raw, events, event_id, tmin, tmax, preload=False,
+ add_eeg_ref=False)
expected_data = epochs.get_data()[:, :, ::decim]
expected_times = epochs.times[::decim]
for preload in (True, False):
# at init
epochs = Epochs(raw, events, event_id, tmin, tmax, decim=decim,
- preload=preload)
+ preload=preload, add_eeg_ref=False)
assert_allclose(epochs.get_data(), expected_data)
assert_allclose(epochs.get_data(), expected_data)
assert_equal(epochs.info['sfreq'], sfreq_new)
@@ -318,13 +361,13 @@ def test_decim():
# split between init and afterward
epochs = Epochs(raw, events, event_id, tmin, tmax, decim=dec_1,
- preload=preload).decimate(dec_2)
+ preload=preload, add_eeg_ref=False).decimate(dec_2)
assert_allclose(epochs.get_data(), expected_data)
assert_allclose(epochs.get_data(), expected_data)
assert_equal(epochs.info['sfreq'], sfreq_new)
assert_array_equal(epochs.times, expected_times)
epochs = Epochs(raw, events, event_id, tmin, tmax, decim=dec_2,
- preload=preload).decimate(dec_1)
+ preload=preload, add_eeg_ref=False).decimate(dec_1)
assert_allclose(epochs.get_data(), expected_data)
assert_allclose(epochs.get_data(), expected_data)
assert_equal(epochs.info['sfreq'], sfreq_new)
@@ -332,7 +375,7 @@ def test_decim():
# split between init and afterward, with preload in between
epochs = Epochs(raw, events, event_id, tmin, tmax, decim=dec_1,
- preload=preload)
+ preload=preload, add_eeg_ref=False)
epochs.load_data()
epochs = epochs.decimate(dec_2)
assert_allclose(epochs.get_data(), expected_data)
@@ -340,7 +383,7 @@ def test_decim():
assert_equal(epochs.info['sfreq'], sfreq_new)
assert_array_equal(epochs.times, expected_times)
epochs = Epochs(raw, events, event_id, tmin, tmax, decim=dec_2,
- preload=preload)
+ preload=preload, add_eeg_ref=False)
epochs.load_data()
epochs = epochs.decimate(dec_1)
assert_allclose(epochs.get_data(), expected_data)
@@ -350,7 +393,7 @@ def test_decim():
# decimate afterward
epochs = Epochs(raw, events, event_id, tmin, tmax,
- preload=preload).decimate(decim)
+ preload=preload, add_eeg_ref=False).decimate(decim)
assert_allclose(epochs.get_data(), expected_data)
assert_allclose(epochs.get_data(), expected_data)
assert_equal(epochs.info['sfreq'], sfreq_new)
@@ -358,7 +401,7 @@ def test_decim():
# decimate afterward, with preload in between
epochs = Epochs(raw, events, event_id, tmin, tmax,
- preload=preload)
+ preload=preload, add_eeg_ref=False)
epochs.load_data()
epochs.decimate(decim)
assert_allclose(epochs.get_data(), expected_data)
@@ -368,8 +411,7 @@ def test_decim():
def test_base_epochs():
- """Test base epochs class
- """
+ """Test base epochs class."""
raw = _get_data()[0]
epochs = _BaseEpochs(raw.info, None, np.ones((1, 3), int),
event_id, tmin, tmax)
@@ -383,13 +425,13 @@ def test_base_epochs():
@requires_version('scipy', '0.14')
def test_savgol_filter():
- """Test savgol filtering
- """
+ """Test savgol filtering."""
h_freq = 10.
raw, events = _get_data()[:2]
- epochs = Epochs(raw, events, event_id, tmin, tmax)
+ epochs = Epochs(raw, events, event_id, tmin, tmax, add_eeg_ref=False)
assert_raises(RuntimeError, epochs.savgol_filter, 10.)
- epochs = Epochs(raw, events, event_id, tmin, tmax, preload=True)
+ epochs = Epochs(raw, events, event_id, tmin, tmax, preload=True,
+ add_eeg_ref=False)
freqs = fftpack.fftfreq(len(epochs.times), 1. / epochs.info['sfreq'])
data = np.abs(fftpack.fft(epochs.get_data()))
match_mask = np.logical_and(freqs >= 0, freqs <= h_freq / 2.)
@@ -406,14 +448,16 @@ def test_savgol_filter():
def test_epochs_hash():
- """Test epoch hashing
- """
+ """Test epoch hashing."""
raw, events = _get_data()[:2]
- epochs = Epochs(raw, events, event_id, tmin, tmax)
+ epochs = Epochs(raw, events, event_id, tmin, tmax,
+ add_eeg_ref=False)
assert_raises(RuntimeError, epochs.__hash__)
- epochs = Epochs(raw, events, event_id, tmin, tmax, preload=True)
+ epochs = Epochs(raw, events, event_id, tmin, tmax, preload=True,
+ add_eeg_ref=False)
assert_equal(hash(epochs), hash(epochs))
- epochs_2 = Epochs(raw, events, event_id, tmin, tmax, preload=True)
+ epochs_2 = Epochs(raw, events, event_id, tmin, tmax, preload=True,
+ add_eeg_ref=False)
assert_equal(hash(epochs), hash(epochs_2))
# do NOT use assert_equal here, failing output is terrible
assert_true(pickle.dumps(epochs) == pickle.dumps(epochs_2))
@@ -431,27 +475,46 @@ def test_event_ordering():
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter('always')
Epochs(raw, eve, event_id, tmin, tmax,
- baseline=(None, 0), reject=reject, flat=flat)
+ baseline=(None, 0), reject=reject, flat=flat,
+ add_eeg_ref=False)
assert_equal(len(w), ii)
if ii > 0:
assert_true('chronologically' in '%s' % w[-1].message)
def test_epochs_bad_baseline():
- """Test Epochs initialization with bad baseline parameters
- """
+ """Test Epochs initialization with bad baseline parameters."""
raw, events = _get_data()[:2]
- assert_raises(ValueError, Epochs, raw, events, None, -0.1, 0.3, (-0.2, 0))
- assert_raises(ValueError, Epochs, raw, events, None, -0.1, 0.3, (0, 0.4))
+ assert_raises(ValueError, Epochs, raw, events, None, -0.1, 0.3, (-0.2, 0),
+ add_eeg_ref=False)
+ assert_raises(ValueError, Epochs, raw, events, None, -0.1, 0.3, (0, 0.4),
+ add_eeg_ref=False)
+ assert_raises(ValueError, Epochs, raw, events, None, -0.1, 0.3, (0.1, 0),
+ add_eeg_ref=False)
+ assert_raises(ValueError, Epochs, raw, events, None, 0.1, 0.3, (None, 0),
+ add_eeg_ref=False)
+ assert_raises(ValueError, Epochs, raw, events, None, -0.3, -0.1, (0, None),
+ add_eeg_ref=False)
+ epochs = Epochs(raw, events, None, 0.1, 0.3, baseline=None,
+ add_eeg_ref=False)
+ assert_raises(RuntimeError, epochs.apply_baseline, (0.1, 0.2))
+ epochs.load_data()
+ assert_raises(ValueError, epochs.apply_baseline, (None, 0))
+ assert_raises(ValueError, epochs.apply_baseline, (0, None))
+ # put some rescale options here, too
+ data = np.arange(100, dtype=float)
+ assert_raises(ValueError, rescale, data, times=data, baseline=(-2, -1))
+ rescale(data.copy(), times=data, baseline=(2, 2)) # ok
+ assert_raises(ValueError, rescale, data, times=data, baseline=(2, 1))
+ assert_raises(ValueError, rescale, data, times=data, baseline=(100, 101))
def test_epoch_combine_ids():
- """Test combining event ids in epochs compared to events
- """
+ """Test combining event ids in epochs compared to events."""
raw, events, picks = _get_data()
epochs = Epochs(raw, events, {'a': 1, 'b': 2, 'c': 3,
'd': 4, 'e': 5, 'f': 32},
- tmin, tmax, picks=picks, preload=False)
+ tmin, tmax, picks=picks, preload=False, add_eeg_ref=False)
events_new = merge_events(events, [1, 2], 12)
epochs_new = combine_event_ids(epochs, ['a', 'b'], {'ab': 12})
assert_equal(epochs_new['ab'].name, 'ab')
@@ -460,29 +523,33 @@ def test_epoch_combine_ids():
def test_epoch_multi_ids():
- """Test epoch selection via multiple/partial keys
- """
+ """Test epoch selection via multiple/partial keys."""
raw, events, picks = _get_data()
epochs = Epochs(raw, events, {'a/b/a': 1, 'a/b/b': 2, 'a/c': 3,
'b/d': 4, 'a_b': 5},
- tmin, tmax, picks=picks, preload=False)
- epochs_regular = epochs[['a', 'b']]
+ tmin, tmax, picks=picks, preload=False, add_eeg_ref=False)
+ epochs_regular = epochs['a/b']
+ epochs_reverse = epochs['b/a']
epochs_multi = epochs[['a/b/a', 'a/b/b']]
- assert_array_equal(epochs_regular.events, epochs_multi.events)
+ assert_array_equal(epochs_multi.events, epochs_regular.events)
+ assert_array_equal(epochs_reverse.events, epochs_regular.events)
+ assert_allclose(epochs_multi.get_data(), epochs_regular.get_data())
+ assert_allclose(epochs_reverse.get_data(), epochs_regular.get_data())
def test_read_epochs_bad_events():
- """Test epochs when events are at the beginning or the end of the file
- """
+ """Test epochs when events are at the beginning or the end of the file."""
raw, events, picks = _get_data()
# Event at the beginning
epochs = Epochs(raw, np.array([[raw.first_samp, 0, event_id]]),
- event_id, tmin, tmax, picks=picks, baseline=(None, 0))
+ event_id, tmin, tmax, picks=picks, baseline=(None, 0),
+ add_eeg_ref=False)
with warnings.catch_warnings(record=True):
evoked = epochs.average()
epochs = Epochs(raw, np.array([[raw.first_samp, 0, event_id]]),
- event_id, tmin, tmax, picks=picks, baseline=(None, 0))
+ event_id, tmin, tmax, picks=picks, baseline=(None, 0),
+ add_eeg_ref=False)
assert_true(repr(epochs)) # test repr
epochs.drop_bad()
assert_true(repr(epochs))
@@ -491,7 +558,8 @@ def test_read_epochs_bad_events():
# Event at the end
epochs = Epochs(raw, np.array([[raw.last_samp, 0, event_id]]),
- event_id, tmin, tmax, picks=picks, baseline=(None, 0))
+ event_id, tmin, tmax, picks=picks, baseline=(None, 0),
+ add_eeg_ref=False)
with warnings.catch_warnings(record=True):
evoked = epochs.average()
@@ -501,29 +569,28 @@ def test_read_epochs_bad_events():
@slow_test
def test_read_write_epochs():
- """Test epochs from raw files with IO as fif file
- """
+ """Test epochs from raw files with IO as fif file."""
raw, events, picks = _get_data(preload=True)
tempdir = _TempDir()
temp_fname = op.join(tempdir, 'test-epo.fif')
temp_fname_no_bl = op.join(tempdir, 'test_no_bl-epo.fif')
baseline = (None, 0)
epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks,
- baseline=baseline, preload=True)
+ baseline=baseline, preload=True, add_eeg_ref=False)
epochs_orig = epochs.copy()
epochs_no_bl = Epochs(raw, events, event_id, tmin, tmax, picks=picks,
- baseline=None, preload=True)
+ baseline=None, preload=True, add_eeg_ref=False)
assert_true(epochs_no_bl.baseline is None)
evoked = epochs.average()
data = epochs.get_data()
# Bad tmin/tmax parameters
assert_raises(ValueError, Epochs, raw, events, event_id, tmax, tmin,
- baseline=None)
+ baseline=None, add_eeg_ref=False)
epochs_no_id = Epochs(raw, pick_events(events, include=event_id),
None, tmin, tmax, picks=picks,
- baseline=(None, 0))
+ baseline=(None, 0), add_eeg_ref=False)
assert_array_equal(data, epochs_no_id.get_data())
eog_picks = pick_types(raw.info, meg=False, eeg=False, stim=False,
@@ -540,7 +607,7 @@ def test_read_write_epochs():
# decim with lowpass
warnings.simplefilter('always')
epochs_dec = Epochs(raw, events, event_id, tmin, tmax, picks=picks,
- baseline=(None, 0), decim=2)
+ baseline=(None, 0), decim=2, add_eeg_ref=False)
assert_equal(len(w), 1)
# decim without lowpass
@@ -575,7 +642,7 @@ def test_read_write_epochs():
for proj in (True, 'delayed', False):
epochs = Epochs(raw, events, event_ids, tmin, tmax, picks=picks,
baseline=(None, 0), proj=proj, reject=reject,
- add_eeg_ref=True)
+ add_eeg_ref=False)
assert_equal(epochs.proj, proj if proj != 'delayed' else False)
data1 = epochs.get_data()
epochs2 = epochs.copy().apply_proj()
@@ -643,7 +710,8 @@ def test_read_write_epochs():
# add reject here so some of the epochs get dropped
epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks,
- baseline=(None, 0), reject=reject)
+ baseline=(None, 0), reject=reject,
+ add_eeg_ref=False)
epochs.save(temp_fname)
# ensure bad events are not saved
epochs_read3 = read_epochs(temp_fname, preload=preload)
@@ -678,7 +746,8 @@ def test_read_write_epochs():
# test loading epochs with missing events
epochs = Epochs(raw, events, dict(foo=1, bar=999), tmin, tmax,
- picks=picks, on_missing='ignore')
+ picks=picks, on_missing='ignore',
+ add_eeg_ref=False)
epochs.save(temp_fname)
epochs_read = read_epochs(temp_fname, preload=preload)
assert_allclose(epochs.get_data(), epochs_read.get_data(), **tols)
@@ -705,23 +774,23 @@ def test_read_write_epochs():
def test_epochs_proj():
- """Test handling projection (apply proj in Raw or in Epochs)
- """
+ """Test handling projection (apply proj in Raw or in Epochs)."""
tempdir = _TempDir()
raw, events, picks = _get_data()
exclude = raw.info['bads'] + ['MEG 2443', 'EEG 053'] # bads + 2 more
this_picks = pick_types(raw.info, meg=True, eeg=False, stim=True,
eog=True, exclude=exclude)
epochs = Epochs(raw, events[:4], event_id, tmin, tmax, picks=this_picks,
- baseline=(None, 0), proj=True)
+ baseline=(None, 0), proj=True, add_eeg_ref=False)
assert_true(all(p['active'] is True for p in epochs.info['projs']))
evoked = epochs.average()
assert_true(all(p['active'] is True for p in evoked.info['projs']))
data = epochs.get_data()
- raw_proj = Raw(raw_fname, proj=True)
+ raw_proj = read_raw_fif(raw_fname, add_eeg_ref=False).apply_proj()
epochs_no_proj = Epochs(raw_proj, events[:4], event_id, tmin, tmax,
- picks=this_picks, baseline=(None, 0), proj=False)
+ picks=this_picks, baseline=(None, 0), proj=False,
+ add_eeg_ref=False)
data_no_proj = epochs_no_proj.get_data()
assert_true(all(p['active'] is True for p in epochs_no_proj.info['projs']))
@@ -735,7 +804,8 @@ def test_epochs_proj():
this_picks = pick_types(raw.info, meg=True, eeg=True, stim=True,
eog=True, exclude=exclude)
epochs = Epochs(raw, events[:4], event_id, tmin, tmax, picks=this_picks,
- baseline=(None, 0), proj=True, add_eeg_ref=True)
+ baseline=(None, 0), proj=True, add_eeg_ref=False)
+ epochs.set_eeg_reference().apply_proj()
assert_true(_has_eeg_average_ref_proj(epochs.info['projs']))
epochs = Epochs(raw, events[:4], event_id, tmin, tmax, picks=this_picks,
baseline=(None, 0), proj=True, add_eeg_ref=False)
@@ -744,14 +814,14 @@ def test_epochs_proj():
# make sure we don't add avg ref when a custom ref has been applied
raw.info['custom_ref_applied'] = True
epochs = Epochs(raw, events[:4], event_id, tmin, tmax, picks=this_picks,
- baseline=(None, 0), proj=True)
+ baseline=(None, 0), proj=True, add_eeg_ref=False)
assert_true(not _has_eeg_average_ref_proj(epochs.info['projs']))
# From GH#2200:
# This has no problem
proj = raw.info['projs']
epochs = Epochs(raw, events[:4], event_id, tmin, tmax, picks=this_picks,
- baseline=(None, 0), proj=False)
+ baseline=(None, 0), proj=False, add_eeg_ref=False)
epochs.info['projs'] = []
data = epochs.copy().add_proj(proj).apply_proj().get_data()
# save and reload data
@@ -764,7 +834,7 @@ def test_epochs_proj():
assert_allclose(data, data_2, atol=1e-15, rtol=1e-3)
# adding EEG ref (GH #2727)
- raw = Raw(raw_fname)
+ raw = read_raw_fif(raw_fname, add_eeg_ref=False)
raw.add_proj([], remove_existing=True)
raw.info['bads'] = ['MEG 2443', 'EEG 053']
picks = pick_types(raw.info, meg=False, eeg=True, stim=True, eog=False,
@@ -776,14 +846,11 @@ def test_epochs_proj():
temp_fname = op.join(tempdir, 'test-epo.fif')
epochs.save(temp_fname)
for preload in (True, False):
- epochs = read_epochs(temp_fname, add_eeg_ref=True, proj=True,
- preload=preload)
+ epochs = read_epochs(temp_fname, proj=False, preload=preload)
+ epochs.set_eeg_reference().apply_proj()
assert_allclose(epochs.get_data().mean(axis=1), 0, atol=1e-15)
- epochs = read_epochs(temp_fname, add_eeg_ref=True, proj=False,
- preload=preload)
- assert_raises(AssertionError, assert_allclose,
- epochs.get_data().mean(axis=1), 0., atol=1e-15)
- epochs.add_eeg_average_proj()
+ epochs = read_epochs(temp_fname, proj=False, preload=preload)
+ epochs.set_eeg_reference()
assert_raises(AssertionError, assert_allclose,
epochs.get_data().mean(axis=1), 0., atol=1e-15)
epochs.apply_proj()
@@ -791,36 +858,35 @@ def test_epochs_proj():
def test_evoked_arithmetic():
- """Test arithmetic of evoked data
- """
+ """Test arithmetic of evoked data."""
raw, events, picks = _get_data()
epochs1 = Epochs(raw, events[:4], event_id, tmin, tmax, picks=picks,
- baseline=(None, 0))
+ baseline=(None, 0), add_eeg_ref=False)
evoked1 = epochs1.average()
epochs2 = Epochs(raw, events[4:8], event_id, tmin, tmax, picks=picks,
- baseline=(None, 0))
+ baseline=(None, 0), add_eeg_ref=False)
evoked2 = epochs2.average()
epochs = Epochs(raw, events[:8], event_id, tmin, tmax, picks=picks,
- baseline=(None, 0))
+ baseline=(None, 0), add_eeg_ref=False)
evoked = epochs.average()
- evoked_sum = evoked1 + evoked2
+ evoked_sum = combine_evoked([evoked1, evoked2], weights='nave')
assert_array_equal(evoked.data, evoked_sum.data)
assert_array_equal(evoked.times, evoked_sum.times)
- assert_true(evoked_sum.nave == (evoked1.nave + evoked2.nave))
- evoked_diff = evoked1 - evoked1
+ assert_equal(evoked_sum.nave, evoked1.nave + evoked2.nave)
+ evoked_diff = combine_evoked([evoked1, evoked1], weights=[1, -1])
assert_array_equal(np.zeros_like(evoked.data), evoked_diff.data)
def test_evoked_io_from_epochs():
- """Test IO of evoked data made from epochs
- """
+ """Test IO of evoked data made from epochs."""
tempdir = _TempDir()
raw, events, picks = _get_data()
# offset our tmin so we don't get exactly a zero value when decimating
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter('always')
epochs = Epochs(raw, events[:4], event_id, tmin + 0.011, tmax,
- picks=picks, baseline=(None, 0), decim=5)
+ picks=picks, baseline=(None, 0), decim=5,
+ add_eeg_ref=False)
assert_true(len(w) == 1)
evoked = epochs.average()
evoked.info['proj_name'] = '' # Test that empty string shortcuts to None.
@@ -835,7 +901,8 @@ def test_evoked_io_from_epochs():
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter('always')
epochs = Epochs(raw, events[:4], event_id, 0.1, tmax,
- picks=picks, baseline=(0.1, 0.2), decim=5)
+ picks=picks, baseline=(0.1, 0.2), decim=5,
+ add_eeg_ref=False)
evoked = epochs.average()
evoked.save(op.join(tempdir, 'evoked-ave.fif'))
evoked2 = read_evokeds(op.join(tempdir, 'evoked-ave.fif'))[0]
@@ -846,7 +913,8 @@ def test_evoked_io_from_epochs():
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter('always')
epochs = Epochs(raw, events[:4], event_id, -0.2, tmax,
- picks=picks, baseline=(0.1, 0.2), decim=5)
+ picks=picks, baseline=(0.1, 0.2), decim=5,
+ add_eeg_ref=False)
evoked = epochs.average()
evoked.crop(0.099, None)
assert_allclose(evoked.data, evoked2.data, rtol=1e-4, atol=1e-20)
@@ -854,12 +922,11 @@ def test_evoked_io_from_epochs():
def test_evoked_standard_error():
- """Test calculation and read/write of standard error
- """
+ """Test calculation and read/write of standard error."""
raw, events, picks = _get_data()
tempdir = _TempDir()
epochs = Epochs(raw, events[:4], event_id, tmin, tmax, picks=picks,
- baseline=(None, 0))
+ baseline=(None, 0), add_eeg_ref=False)
evoked = [epochs.average(), epochs.standard_error()]
write_evokeds(op.join(tempdir, 'evoked-ave.fif'), evoked)
evoked2 = read_evokeds(op.join(tempdir, 'evoked-ave.fif'), [0, 1])
@@ -884,13 +951,11 @@ def test_evoked_standard_error():
def test_reject_epochs():
- """Test of epochs rejection
- """
+ """Test of epochs rejection."""
raw, events, picks = _get_data()
events1 = events[events[:, 2] == event_id]
- epochs = Epochs(raw, events1,
- event_id, tmin, tmax, baseline=(None, 0),
- reject=reject, flat=flat)
+ epochs = Epochs(raw, events1, event_id, tmin, tmax, baseline=(None, 0),
+ reject=reject, flat=flat, add_eeg_ref=False)
assert_raises(RuntimeError, len, epochs)
n_events = len(epochs.events)
data = epochs.get_data()
@@ -908,7 +973,7 @@ def test_reject_epochs():
raw_2.info['bads'] = ['MEG 2443']
reject_crazy = dict(grad=1000e-15, mag=4e-15, eeg=80e-9, eog=150e-9)
epochs = Epochs(raw_2, events1, event_id, tmin, tmax, baseline=(None, 0),
- reject=reject_crazy, flat=flat)
+ reject=reject_crazy, flat=flat, add_eeg_ref=False)
epochs.drop_bad()
assert_true(all('MEG 2442' in e for e in epochs.drop_log))
@@ -916,15 +981,15 @@ def test_reject_epochs():
# Invalid reject_tmin/reject_tmax/detrend
assert_raises(ValueError, Epochs, raw, events1, event_id, tmin, tmax,
- reject_tmin=1., reject_tmax=0)
+ reject_tmin=1., reject_tmax=0, add_eeg_ref=False)
assert_raises(ValueError, Epochs, raw, events1, event_id, tmin, tmax,
- reject_tmin=tmin - 1, reject_tmax=1.)
+ reject_tmin=tmin - 1, reject_tmax=1., add_eeg_ref=False)
assert_raises(ValueError, Epochs, raw, events1, event_id, tmin, tmax,
- reject_tmin=0., reject_tmax=tmax + 1)
+ reject_tmin=0., reject_tmax=tmax + 1, add_eeg_ref=False)
epochs = Epochs(raw, events1, event_id, tmin, tmax, picks=picks,
baseline=(None, 0), reject=reject, flat=flat,
- reject_tmin=0., reject_tmax=.1)
+ reject_tmin=0., reject_tmax=.1, add_eeg_ref=False)
data = epochs.get_data()
n_clean_epochs = len(data)
assert_true(n_clean_epochs == 7)
@@ -933,7 +998,8 @@ def test_reject_epochs():
assert_true(epochs.times[epochs._reject_time][-1] <= 0.1)
# Invalid data for _is_good_epoch function
- epochs = Epochs(raw, events1, event_id, tmin, tmax, reject=None, flat=None)
+ epochs = Epochs(raw, events1, event_id, tmin, tmax, reject=None, flat=None,
+ add_eeg_ref=False)
assert_equal(epochs._is_good_epoch(None), (False, ['NO_DATA']))
assert_equal(epochs._is_good_epoch(np.zeros((1, 1))),
(False, ['TOO_SHORT']))
@@ -942,17 +1008,16 @@ def test_reject_epochs():
def test_preload_epochs():
- """Test preload of epochs
- """
+ """Test preload of epochs."""
raw, events, picks = _get_data()
epochs_preload = Epochs(raw, events[:16], event_id, tmin, tmax,
picks=picks, baseline=(None, 0), preload=True,
- reject=reject, flat=flat)
+ reject=reject, flat=flat, add_eeg_ref=False)
data_preload = epochs_preload.get_data()
epochs = Epochs(raw, events[:16], event_id, tmin, tmax, picks=picks,
baseline=(None, 0), preload=False,
- reject=reject, flat=flat)
+ reject=reject, flat=flat, add_eeg_ref=False)
data = epochs.get_data()
assert_array_equal(data_preload, data)
assert_array_almost_equal(epochs_preload.average().data,
@@ -960,12 +1025,11 @@ def test_preload_epochs():
def test_indexing_slicing():
- """Test of indexing and slicing operations
- """
+ """Test of indexing and slicing operations."""
raw, events, picks = _get_data()
epochs = Epochs(raw, events[:20], event_id, tmin, tmax, picks=picks,
baseline=(None, 0), preload=False,
- reject=reject, flat=flat)
+ reject=reject, flat=flat, add_eeg_ref=False)
data_normal = epochs.get_data()
@@ -980,7 +1044,7 @@ def test_indexing_slicing():
for preload in [True, False]:
epochs2 = Epochs(raw, events[:20], event_id, tmin, tmax,
picks=picks, baseline=(None, 0), preload=preload,
- reject=reject, flat=flat)
+ reject=reject, flat=flat, add_eeg_ref=False)
if not preload:
epochs2.drop_bad()
@@ -1018,14 +1082,13 @@ def test_indexing_slicing():
def test_comparision_with_c():
- """Test of average obtained vs C code
- """
+ """Test of average obtained vs C code."""
raw, events = _get_data()[:2]
c_evoked = read_evokeds(evoked_nf_name, condition=0)
- epochs = Epochs(raw, events, event_id, tmin, tmax,
- baseline=None, preload=True,
- reject=None, flat=None)
- evoked = epochs.average()
+ epochs = Epochs(raw, events, event_id, tmin, tmax, baseline=None,
+ preload=True, reject=None, flat=None, add_eeg_ref=False,
+ proj=False)
+ evoked = epochs.set_eeg_reference().apply_proj().average()
sel = pick_channels(c_evoked.ch_names, evoked.ch_names)
evoked_data = evoked.data
c_evoked_data = c_evoked.data[sel]
@@ -1036,18 +1099,17 @@ def test_comparision_with_c():
def test_crop():
- """Test of crop of epochs
- """
+ """Test of crop of epochs."""
raw, events, picks = _get_data()
epochs = Epochs(raw, events[:5], event_id, tmin, tmax, picks=picks,
baseline=(None, 0), preload=False,
- reject=reject, flat=flat)
+ reject=reject, flat=flat, add_eeg_ref=False)
assert_raises(RuntimeError, epochs.crop, None, 0.2) # not preloaded
data_normal = epochs.get_data()
epochs2 = Epochs(raw, events[:5], event_id, tmin, tmax,
picks=picks, baseline=(None, 0), preload=True,
- reject=reject, flat=flat)
+ reject=reject, flat=flat, add_eeg_ref=False)
with warnings.catch_warnings(record=True) as w:
epochs2.crop(-20, 200)
assert_true(len(w) == 2)
@@ -1081,7 +1143,7 @@ def test_crop():
epochs = Epochs(raw, events[:5], event_id, -1, 1,
picks=picks, baseline=(None, 0), preload=True,
- reject=reject, flat=flat)
+ reject=reject, flat=flat, add_eeg_ref=False)
# We include nearest sample, so actually a bit beyound our bounds here
assert_allclose(epochs.tmin, -1.0006410259015925, rtol=1e-12)
assert_allclose(epochs.tmax, 1.0006410259015925, rtol=1e-12)
@@ -1099,12 +1161,12 @@ def test_resample():
raw, events, picks = _get_data()
epochs = Epochs(raw, events[:10], event_id, tmin, tmax, picks=picks,
baseline=(None, 0), preload=False,
- reject=reject, flat=flat)
+ reject=reject, flat=flat, add_eeg_ref=False)
assert_raises(RuntimeError, epochs.resample, 100)
epochs_o = Epochs(raw, events[:10], event_id, tmin, tmax, picks=picks,
baseline=(None, 0), preload=True,
- reject=reject, flat=flat)
+ reject=reject, flat=flat, add_eeg_ref=False)
epochs = epochs_o.copy()
data_normal = cp.deepcopy(epochs.get_data())
@@ -1142,7 +1204,7 @@ def test_resample():
assert_true(epochs_resampled is epochs)
# test proper setting of times (#2645)
- n_trial, n_chan, n_time, sfreq = 1, 1, 10, 1000
+ n_trial, n_chan, n_time, sfreq = 1, 1, 10, 1000.
data = np.zeros((n_trial, n_chan, n_time))
events = np.zeros((n_trial, 3), int)
info = create_info(n_chan, sfreq, 'eeg')
@@ -1155,18 +1217,31 @@ def test_resample():
for e in epochs1, epochs2, epochs:
assert_equal(e.times[0], epochs.tmin)
assert_equal(e.times[-1], epochs.tmax)
+ # test that cropping after resampling works (#3296)
+ this_tmin = -0.002
+ epochs = EpochsArray(data, deepcopy(info), events, tmin=this_tmin)
+ for times in (epochs.times, epochs._raw_times):
+ assert_allclose(times, np.arange(n_time) / sfreq + this_tmin)
+ epochs.resample(info['sfreq'] * 2.)
+ for times in (epochs.times, epochs._raw_times):
+ assert_allclose(times, np.arange(2 * n_time) / (sfreq * 2) + this_tmin)
+ epochs.crop(0, None)
+ for times in (epochs.times, epochs._raw_times):
+ assert_allclose(times, np.arange((n_time - 2) * 2) / (sfreq * 2))
+ epochs.resample(sfreq)
+ for times in (epochs.times, epochs._raw_times):
+ assert_allclose(times, np.arange(n_time - 2) / sfreq)
def test_detrend():
- """Test detrending of epochs
- """
+ """Test detrending of epochs."""
raw, events, picks = _get_data()
# test first-order
epochs_1 = Epochs(raw, events[:4], event_id, tmin, tmax, picks=picks,
- baseline=None, detrend=1)
+ baseline=None, detrend=1, add_eeg_ref=False)
epochs_2 = Epochs(raw, events[:4], event_id, tmin, tmax, picks=picks,
- baseline=None, detrend=None)
+ baseline=None, detrend=None, add_eeg_ref=False)
data_picks = pick_types(epochs_1.info, meg=True, eeg=True,
exclude='bads')
evoked_1 = epochs_1.average()
@@ -1179,9 +1254,11 @@ def test_detrend():
# test zeroth-order case
for preload in [True, False]:
epochs_1 = Epochs(raw, events[:4], event_id, tmin, tmax, picks=picks,
- baseline=(None, None), preload=preload)
+ baseline=(None, None), preload=preload,
+ add_eeg_ref=False)
epochs_2 = Epochs(raw, events[:4], event_id, tmin, tmax, picks=picks,
- baseline=None, preload=preload, detrend=0)
+ baseline=None, preload=preload, detrend=0,
+ add_eeg_ref=False)
a = epochs_1.get_data()
b = epochs_2.get_data()
# All data channels should be almost equal
@@ -1190,35 +1267,34 @@ def test_detrend():
# There are non-M/EEG channels that should not be equal:
assert_true(not np.allclose(a, b))
- assert_raises(ValueError, Epochs, raw, events[:4], event_id, tmin, tmax,
- detrend=2)
+ for value in ['foo', 2, False, True]:
+ assert_raises(ValueError, Epochs, raw, events[:4], event_id,
+ tmin, tmax, detrend=value, add_eeg_ref=False)
def test_bootstrap():
- """Test of bootstrapping of epochs
- """
+ """Test of bootstrapping of epochs."""
raw, events, picks = _get_data()
epochs = Epochs(raw, events[:5], event_id, tmin, tmax, picks=picks,
baseline=(None, 0), preload=True,
- reject=reject, flat=flat)
+ reject=reject, flat=flat, add_eeg_ref=False)
epochs2 = bootstrap(epochs, random_state=0)
assert_true(len(epochs2.events) == len(epochs.events))
assert_true(epochs._data.shape == epochs2._data.shape)
def test_epochs_copy():
- """Test copy epochs
- """
+ """Test copy epochs."""
raw, events, picks = _get_data()
epochs = Epochs(raw, events[:5], event_id, tmin, tmax, picks=picks,
baseline=(None, 0), preload=True,
- reject=reject, flat=flat)
+ reject=reject, flat=flat, add_eeg_ref=False)
copied = epochs.copy()
assert_array_equal(epochs._data, copied._data)
epochs = Epochs(raw, events[:5], event_id, tmin, tmax, picks=picks,
baseline=(None, 0), preload=False,
- reject=reject, flat=flat)
+ reject=reject, flat=flat, add_eeg_ref=False)
copied = epochs.copy()
data = epochs.get_data()
copied_data = copied.get_data()
@@ -1226,11 +1302,10 @@ def test_epochs_copy():
def test_iter_evoked():
- """Test the iterator for epochs -> evoked
- """
+ """Test the iterator for epochs -> evoked."""
raw, events, picks = _get_data()
epochs = Epochs(raw, events[:5], event_id, tmin, tmax, picks=picks,
- baseline=(None, 0))
+ baseline=(None, 0), add_eeg_ref=False)
for ii, ev in enumerate(epochs.iter_evoked()):
x = ev.data
@@ -1239,11 +1314,10 @@ def test_iter_evoked():
def test_subtract_evoked():
- """Test subtraction of Evoked from Epochs
- """
+ """Test subtraction of Evoked from Epochs."""
raw, events, picks = _get_data()
epochs = Epochs(raw, events[:10], event_id, tmin, tmax, picks=picks,
- baseline=(None, 0))
+ baseline=(None, 0), add_eeg_ref=False)
# make sure subraction fails if data channels are missing
assert_raises(ValueError, epochs.subtract_evoked,
@@ -1257,7 +1331,8 @@ def test_subtract_evoked():
# use preloading and SSP from the start
epochs2 = Epochs(raw, events[:10], event_id, tmin, tmax, picks=picks,
- baseline=(None, 0), preload=True, proj=True)
+ baseline=(None, 0), preload=True, proj=True,
+ add_eeg_ref=False)
evoked = epochs2.average()
epochs2.subtract_evoked(evoked)
@@ -1272,12 +1347,13 @@ def test_subtract_evoked():
def test_epoch_eq():
- """Test epoch count equalization and condition combining
- """
+ """Test epoch count equalization and condition combining."""
raw, events, picks = _get_data()
# equalizing epochs objects
- epochs_1 = Epochs(raw, events, event_id, tmin, tmax, picks=picks)
- epochs_2 = Epochs(raw, events, event_id_2, tmin, tmax, picks=picks)
+ epochs_1 = Epochs(raw, events, event_id, tmin, tmax, picks=picks,
+ add_eeg_ref=False)
+ epochs_2 = Epochs(raw, events, event_id_2, tmin, tmax, picks=picks,
+ add_eeg_ref=False)
epochs_1.drop_bad() # make sure drops are logged
assert_true(len([l for l in epochs_1.drop_log if not l]) ==
len(epochs_1.events))
@@ -1290,15 +1366,17 @@ def test_epoch_eq():
assert_true(epochs_1.events.shape[0] != epochs_2.events.shape[0])
equalize_epoch_counts([epochs_1, epochs_2], method='mintime')
assert_true(epochs_1.events.shape[0] == epochs_2.events.shape[0])
- epochs_3 = Epochs(raw, events, event_id, tmin, tmax, picks=picks)
- epochs_4 = Epochs(raw, events, event_id_2, tmin, tmax, picks=picks)
+ epochs_3 = Epochs(raw, events, event_id, tmin, tmax, picks=picks,
+ add_eeg_ref=False)
+ epochs_4 = Epochs(raw, events, event_id_2, tmin, tmax, picks=picks,
+ add_eeg_ref=False)
equalize_epoch_counts([epochs_3, epochs_4], method='truncate')
assert_true(epochs_1.events.shape[0] == epochs_3.events.shape[0])
assert_true(epochs_3.events.shape[0] == epochs_4.events.shape[0])
# equalizing conditions
epochs = Epochs(raw, events, {'a': 1, 'b': 2, 'c': 3, 'd': 4},
- tmin, tmax, picks=picks, reject=reject)
+ tmin, tmax, picks=picks, reject=reject, add_eeg_ref=False)
epochs.drop_bad() # make sure drops are logged
assert_true(len([l for l in epochs.drop_log if not l]) ==
len(epochs.events))
@@ -1351,7 +1429,7 @@ def test_epoch_eq():
# equalizing with hierarchical tags
epochs = Epochs(raw, events, {'a/x': 1, 'b/x': 2, 'a/y': 3, 'b/y': 4},
- tmin, tmax, picks=picks, reject=reject)
+ tmin, tmax, picks=picks, reject=reject, add_eeg_ref=False)
cond1, cond2 = ['a', ['b/x', 'b/y']], [['a/x', 'a/y'], 'b']
es = [epochs.copy().equalize_event_counts(c, copy=False)[0]
for c in (cond1, cond2)]
@@ -1371,44 +1449,47 @@ def test_epoch_eq():
def test_access_by_name():
- """Test accessing epochs by event name and on_missing for rare events
- """
+ """Test accessing epochs by event name and on_missing for rare events."""
tempdir = _TempDir()
raw, events, picks = _get_data()
# Test various invalid inputs
assert_raises(ValueError, Epochs, raw, events, {1: 42, 2: 42}, tmin,
- tmax, picks=picks)
+ tmax, picks=picks, add_eeg_ref=False)
assert_raises(ValueError, Epochs, raw, events, {'a': 'spam', 2: 'eggs'},
- tmin, tmax, picks=picks)
+ tmin, tmax, picks=picks, add_eeg_ref=False)
assert_raises(ValueError, Epochs, raw, events, {'a': 'spam', 2: 'eggs'},
- tmin, tmax, picks=picks)
+ tmin, tmax, picks=picks, add_eeg_ref=False)
assert_raises(ValueError, Epochs, raw, events, 'foo', tmin, tmax,
- picks=picks)
+ picks=picks, add_eeg_ref=False)
assert_raises(ValueError, Epochs, raw, events, ['foo'], tmin, tmax,
- picks=picks)
+ picks=picks, add_eeg_ref=False)
# Test accessing non-existent events (assumes 12345678 does not exist)
event_id_illegal = dict(aud_l=1, does_not_exist=12345678)
assert_raises(ValueError, Epochs, raw, events, event_id_illegal,
- tmin, tmax)
+ tmin, tmax, add_eeg_ref=False)
# Test on_missing
assert_raises(ValueError, Epochs, raw, events, 1, tmin, tmax,
- on_missing='foo')
+ on_missing='foo', add_eeg_ref=False)
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter('always')
- Epochs(raw, events, event_id_illegal, tmin, tmax, on_missing='warning')
+ Epochs(raw, events, event_id_illegal, tmin, tmax, on_missing='warning',
+ add_eeg_ref=False)
nw = len(w)
assert_true(1 <= nw <= 2)
- Epochs(raw, events, event_id_illegal, tmin, tmax, on_missing='ignore')
+ Epochs(raw, events, event_id_illegal, tmin, tmax, on_missing='ignore',
+ add_eeg_ref=False)
assert_equal(len(w), nw)
# Test constructing epochs with a list of ints as events
- epochs = Epochs(raw, events, [1, 2], tmin, tmax, picks=picks)
+ epochs = Epochs(raw, events, [1, 2], tmin, tmax, picks=picks,
+ add_eeg_ref=False)
for k, v in epochs.event_id.items():
assert_equal(int(k), v)
- epochs = Epochs(raw, events, {'a': 1, 'b': 2}, tmin, tmax, picks=picks)
+ epochs = Epochs(raw, events, {'a': 1, 'b': 2}, tmin, tmax, picks=picks,
+ add_eeg_ref=False)
assert_raises(KeyError, epochs.__getitem__, 'bar')
data = epochs['a'].get_data()
@@ -1416,7 +1497,7 @@ def test_access_by_name():
assert_true(len(data) == len(event_a))
epochs = Epochs(raw, events, {'a': 1, 'b': 2}, tmin, tmax, picks=picks,
- preload=True)
+ preload=True, add_eeg_ref=False)
assert_raises(KeyError, epochs.__getitem__, 'bar')
temp_fname = op.join(tempdir, 'test-epo.fif')
epochs.save(temp_fname)
@@ -1430,7 +1511,7 @@ def test_access_by_name():
assert_array_equal(epochs2['a'].events, epochs['a'].events)
epochs3 = Epochs(raw, events, {'a': 1, 'b': 2, 'c': 3, 'd': 4},
- tmin, tmax, picks=picks, preload=True)
+ tmin, tmax, picks=picks, preload=True, add_eeg_ref=False)
assert_equal(list(sorted(epochs3[('a', 'b')].event_id.values())),
[1, 2])
epochs4 = epochs['a']
@@ -1451,9 +1532,10 @@ def test_access_by_name():
@requires_pandas
def test_to_data_frame():
- """Test epochs Pandas exporter"""
+ """Test epochs Pandas exporter."""
raw, events, picks = _get_data()
- epochs = Epochs(raw, events, {'a': 1, 'b': 2}, tmin, tmax, picks=picks)
+ epochs = Epochs(raw, events, {'a': 1, 'b': 2}, tmin, tmax, picks=picks,
+ add_eeg_ref=False)
assert_raises(ValueError, epochs.to_data_frame, index=['foo', 'bar'])
assert_raises(ValueError, epochs.to_data_frame, index='qux')
assert_raises(ValueError, epochs.to_data_frame, np.arange(400))
@@ -1479,12 +1561,11 @@ def test_to_data_frame():
def test_epochs_proj_mixin():
- """Test SSP proj methods from ProjMixin class
- """
+ """Test SSP proj methods from ProjMixin class."""
raw, events, picks = _get_data()
for proj in [True, False]:
epochs = Epochs(raw, events[:4], event_id, tmin, tmax, picks=picks,
- baseline=(None, 0), proj=proj)
+ baseline=(None, 0), proj=proj, add_eeg_ref=False)
assert_true(all(p['active'] == proj for p in epochs.info['projs']))
@@ -1510,21 +1591,22 @@ def test_epochs_proj_mixin():
# catch no-gos.
# wrong proj argument
assert_raises(ValueError, Epochs, raw, events[:4], event_id, tmin, tmax,
- picks=picks, baseline=(None, 0), proj='crazy')
+ picks=picks, baseline=(None, 0), proj='crazy',
+ add_eeg_ref=False)
for preload in [True, False]:
epochs = Epochs(raw, events[:4], event_id, tmin, tmax, picks=picks,
baseline=(None, 0), proj='delayed', preload=preload,
- add_eeg_ref=True, reject=reject)
+ reject=reject, add_eeg_ref=False).set_eeg_reference()
epochs_proj = Epochs(
raw, events[:4], event_id, tmin, tmax, picks=picks,
- baseline=(None, 0), proj=True, preload=preload, add_eeg_ref=True,
- reject=reject)
+ baseline=(None, 0), proj=True, preload=preload, add_eeg_ref=False,
+ reject=reject).set_eeg_reference().apply_proj()
epochs_noproj = Epochs(
raw, events[:4], event_id, tmin, tmax, picks=picks,
- baseline=(None, 0), proj=False, preload=preload, add_eeg_ref=True,
- reject=reject)
+ baseline=(None, 0), proj=False, preload=preload, add_eeg_ref=False,
+ reject=reject).set_eeg_reference()
assert_allclose(epochs.copy().apply_proj().get_data(),
epochs_proj.get_data(), rtol=1e-10, atol=1e-25)
@@ -1549,14 +1631,15 @@ def test_epochs_proj_mixin():
# test mixin against manual application
epochs = Epochs(raw, events[:4], event_id, tmin, tmax, picks=picks,
- baseline=None, proj=False, add_eeg_ref=True)
+ baseline=None, proj=False,
+ add_eeg_ref=False).set_eeg_reference()
data = epochs.get_data().copy()
epochs.apply_proj()
assert_allclose(np.dot(epochs._projector, data[0]), epochs._data[0])
def test_delayed_epochs():
- """Test delayed projection on Epochs"""
+ """Test delayed projection on Epochs."""
raw, events, picks = _get_data()
events = events[:10]
picks = np.concatenate([pick_types(raw.info, meg=True, eeg=True)[::22],
@@ -1570,11 +1653,12 @@ def test_delayed_epochs():
raw.info['lowpass'] = 40. # fake the LP info so no warnings
for decim in (1, 3):
proj_data = Epochs(raw, events, event_id, tmin, tmax, proj=True,
- reject=reject, decim=decim)
+ reject=reject, decim=decim, add_eeg_ref=False)
use_tmin = proj_data.tmin
proj_data = proj_data.get_data()
noproj_data = Epochs(raw, events, event_id, tmin, tmax, proj=False,
- reject=reject, decim=decim).get_data()
+ reject=reject, decim=decim,
+ add_eeg_ref=False).get_data()
assert_equal(proj_data.shape, noproj_data.shape)
assert_equal(proj_data.shape[0], n_epochs)
for preload in (True, False):
@@ -1585,7 +1669,8 @@ def test_delayed_epochs():
if ii in (0, 1):
epochs = Epochs(raw, events, event_id, tmin, tmax,
proj=proj, reject=reject,
- preload=preload, decim=decim)
+ preload=preload, decim=decim,
+ add_eeg_ref=False)
else:
fake_events = np.zeros((len(comp), 3), int)
fake_events[:, 0] = np.arange(len(comp))
@@ -1618,11 +1703,10 @@ def test_delayed_epochs():
def test_drop_epochs():
- """Test dropping of epochs.
- """
+ """Test dropping of epochs."""
raw, events, picks = _get_data()
epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks,
- baseline=(None, 0))
+ baseline=(None, 0), add_eeg_ref=False)
events1 = events[events[:, 2] == event_id]
# Bound checks
@@ -1650,15 +1734,15 @@ def test_drop_epochs():
def test_drop_epochs_mult():
- """Test that subselecting epochs or making less epochs is equivalent"""
+ """Test that subselecting epochs or making less epochs is equivalent."""
raw, events, picks = _get_data()
for preload in [True, False]:
epochs1 = Epochs(raw, events, {'a': 1, 'b': 2},
tmin, tmax, picks=picks, reject=reject,
- preload=preload)['a']
+ preload=preload, add_eeg_ref=False)['a']
epochs2 = Epochs(raw, events, {'a': 1},
tmin, tmax, picks=picks, reject=reject,
- preload=preload)
+ preload=preload, add_eeg_ref=False)
if preload:
# In the preload case you cannot know the bads if already ignored
@@ -1680,7 +1764,7 @@ def test_drop_epochs_mult():
def test_contains():
- """Test membership API"""
+ """Test membership API."""
raw, events = _get_data(True)[:2]
# Add seeg channel
seeg = RawArray(np.zeros((1, len(raw.times))),
@@ -1699,7 +1783,7 @@ def test_contains():
picks_contains = pick_types(raw.info, meg=meg, eeg=eeg, seeg=seeg)
epochs = Epochs(raw, events, {'a': 1, 'b': 2}, tmin, tmax,
picks=picks_contains, reject=None,
- preload=False)
+ preload=False, add_eeg_ref=False)
if eeg:
test = 'eeg'
elif seeg:
@@ -1714,12 +1798,11 @@ def test_contains():
def test_drop_channels_mixin():
- """Test channels-dropping functionality
- """
+ """Test channels-dropping functionality."""
raw, events = _get_data()[:2]
# here without picks to get additional coverage
epochs = Epochs(raw, events, event_id, tmin, tmax, picks=None,
- baseline=(None, 0), preload=True)
+ baseline=(None, 0), preload=True, add_eeg_ref=False)
drop_ch = epochs.ch_names[:3]
ch_names = epochs.ch_names[3:]
@@ -1735,14 +1818,13 @@ def test_drop_channels_mixin():
def test_pick_channels_mixin():
- """Test channel-picking functionality
- """
+ """Test channel-picking functionality."""
raw, events, picks = _get_data()
epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks,
- baseline=(None, 0), preload=True)
+ baseline=(None, 0), preload=True, add_eeg_ref=False)
ch_names = epochs.ch_names[:3]
epochs.preload = False
- assert_raises(RuntimeError, epochs.drop_channels, ['foo'])
+ assert_raises(RuntimeError, epochs.drop_channels, [ch_names[0]])
epochs.preload = True
ch_names_orig = epochs.ch_names
dummy = epochs.copy().pick_channels(ch_names)
@@ -1756,15 +1838,15 @@ def test_pick_channels_mixin():
# Invalid picks
assert_raises(ValueError, Epochs, raw, events, event_id, tmin, tmax,
- picks=[])
+ picks=[], add_eeg_ref=False)
def test_equalize_channels():
- """Test equalization of channels
- """
+ """Test equalization of channels."""
raw, events, picks = _get_data()
epochs1 = Epochs(raw, events, event_id, tmin, tmax, picks=picks,
- baseline=(None, 0), proj=False, preload=True)
+ baseline=(None, 0), proj=False, preload=True,
+ add_eeg_ref=False)
epochs2 = epochs1.copy()
ch_names = epochs1.ch_names[2:]
epochs1.drop_channels(epochs1.ch_names[:1])
@@ -1776,12 +1858,13 @@ def test_equalize_channels():
def test_illegal_event_id():
- """Test handling of invalid events ids"""
+ """Test handling of invalid events ids."""
raw, events, picks = _get_data()
event_id_illegal = dict(aud_l=1, does_not_exist=12345678)
assert_raises(ValueError, Epochs, raw, events, event_id_illegal, tmin,
- tmax, picks=picks, baseline=(None, 0), proj=False)
+ tmax, picks=picks, baseline=(None, 0), proj=False,
+ add_eeg_ref=False)
def test_add_channels_epochs():
@@ -1790,7 +1873,8 @@ def test_add_channels_epochs():
def make_epochs(picks, proj):
return Epochs(raw, events, event_id, tmin, tmax, baseline=(None, 0),
- reject=None, preload=True, proj=proj, picks=picks)
+ reject=None, preload=True, proj=proj, picks=picks,
+ add_eeg_ref=False)
picks = pick_types(raw.info, meg=True, eeg=True, exclude='bads')
picks_meg = pick_types(raw.info, meg=True, eeg=False, exclude='bads')
@@ -1804,7 +1888,8 @@ def test_add_channels_epochs():
epochs_meg.info._check_consistency()
epochs_eeg.info._check_consistency()
- epochs2 = add_channels_epochs([epochs_meg, epochs_eeg])
+ epochs2 = add_channels_epochs([epochs_meg, epochs_eeg],
+ add_eeg_ref=False)
assert_equal(len(epochs.info['projs']), len(epochs2.info['projs']))
assert_equal(len(epochs.info.keys()), len(epochs_meg.info.keys()))
@@ -1821,85 +1906,85 @@ def test_add_channels_epochs():
epochs_meg2 = epochs_meg.copy()
epochs_meg2.info['meas_date'] += 10
- add_channels_epochs([epochs_meg2, epochs_eeg])
+ add_channels_epochs([epochs_meg2, epochs_eeg], add_eeg_ref=False)
epochs_meg2 = epochs_meg.copy()
epochs2.info['filename'] = epochs2.info['filename'].upper()
- epochs2 = add_channels_epochs([epochs_meg, epochs_eeg])
+ epochs2 = add_channels_epochs([epochs_meg, epochs_eeg],
+ add_eeg_ref=False)
epochs_meg2 = epochs_meg.copy()
epochs_meg2.events[3, 2] -= 1
assert_raises(ValueError, add_channels_epochs,
- [epochs_meg2, epochs_eeg])
+ [epochs_meg2, epochs_eeg], add_eeg_ref=False)
assert_raises(ValueError, add_channels_epochs,
- [epochs_meg, epochs_eeg[:2]])
+ [epochs_meg, epochs_eeg[:2]], add_eeg_ref=False)
epochs_meg.info['chs'].pop(0)
epochs_meg.info._update_redundant()
assert_raises(RuntimeError, add_channels_epochs,
- [epochs_meg, epochs_eeg])
+ [epochs_meg, epochs_eeg], add_eeg_ref=False)
epochs_meg2 = epochs_meg.copy()
epochs_meg2.info['sfreq'] = None
assert_raises(RuntimeError, add_channels_epochs,
- [epochs_meg2, epochs_eeg])
+ [epochs_meg2, epochs_eeg], add_eeg_ref=False)
epochs_meg2 = epochs_meg.copy()
epochs_meg2.info['sfreq'] += 10
assert_raises(RuntimeError, add_channels_epochs,
- [epochs_meg2, epochs_eeg])
+ [epochs_meg2, epochs_eeg], add_eeg_ref=False)
epochs_meg2 = epochs_meg.copy()
epochs_meg2.info['chs'][1]['ch_name'] = epochs_meg2.info['ch_names'][0]
epochs_meg2.info._update_redundant()
assert_raises(RuntimeError, add_channels_epochs,
- [epochs_meg2, epochs_eeg])
+ [epochs_meg2, epochs_eeg], add_eeg_ref=False)
epochs_meg2 = epochs_meg.copy()
epochs_meg2.info['dev_head_t']['to'] += 1
assert_raises(ValueError, add_channels_epochs,
- [epochs_meg2, epochs_eeg])
+ [epochs_meg2, epochs_eeg], add_eeg_ref=False)
epochs_meg2 = epochs_meg.copy()
epochs_meg2.info['dev_head_t']['to'] += 1
assert_raises(ValueError, add_channels_epochs,
- [epochs_meg2, epochs_eeg])
+ [epochs_meg2, epochs_eeg], add_eeg_ref=False)
epochs_meg2 = epochs_meg.copy()
epochs_meg2.info['expimenter'] = 'foo'
assert_raises(RuntimeError, add_channels_epochs,
- [epochs_meg2, epochs_eeg])
+ [epochs_meg2, epochs_eeg], add_eeg_ref=False)
epochs_meg2 = epochs_meg.copy()
epochs_meg2.preload = False
assert_raises(ValueError, add_channels_epochs,
- [epochs_meg2, epochs_eeg])
+ [epochs_meg2, epochs_eeg], add_eeg_ref=False)
epochs_meg2 = epochs_meg.copy()
epochs_meg2.times += 0.4
assert_raises(NotImplementedError, add_channels_epochs,
- [epochs_meg2, epochs_eeg])
+ [epochs_meg2, epochs_eeg], add_eeg_ref=False)
epochs_meg2 = epochs_meg.copy()
epochs_meg2.times += 0.5
assert_raises(NotImplementedError, add_channels_epochs,
- [epochs_meg2, epochs_eeg])
+ [epochs_meg2, epochs_eeg], add_eeg_ref=False)
epochs_meg2 = epochs_meg.copy()
epochs_meg2.baseline = None
assert_raises(NotImplementedError, add_channels_epochs,
- [epochs_meg2, epochs_eeg])
+ [epochs_meg2, epochs_eeg], add_eeg_ref=False)
epochs_meg2 = epochs_meg.copy()
epochs_meg2.event_id['b'] = 2
assert_raises(NotImplementedError, add_channels_epochs,
- [epochs_meg2, epochs_eeg])
+ [epochs_meg2, epochs_eeg], add_eeg_ref=False)
def test_array_epochs():
- """Test creating epochs from array
- """
+ """Test creating epochs from array."""
import matplotlib.pyplot as plt
tempdir = _TempDir()
@@ -1982,11 +2067,10 @@ def test_array_epochs():
def test_concatenate_epochs():
- """Test concatenate epochs"""
+ """Test concatenate epochs."""
raw, events, picks = _get_data()
- epochs = Epochs(
- raw=raw, events=events, event_id=event_id, tmin=tmin, tmax=tmax,
- picks=picks)
+ epochs = Epochs(raw=raw, events=events, event_id=event_id, tmin=tmin,
+ tmax=tmax, picks=picks, add_eeg_ref=False)
epochs2 = epochs.copy()
epochs_list = [epochs, epochs2]
epochs_conc = concatenate_epochs(epochs_list)
@@ -2018,17 +2102,28 @@ def test_concatenate_epochs():
epochs2.baseline = (-0.1, None)
assert_raises(ValueError, concatenate_epochs, [epochs, epochs2])
+ # check if dev_head_t is same
+ epochs2 = epochs.copy()
+ concatenate_epochs([epochs, epochs2]) # should work
+ epochs2.info['dev_head_t']['trans'][:3, 3] += 0.0001
+ assert_raises(ValueError, concatenate_epochs, [epochs, epochs2])
+ assert_raises(TypeError, concatenate_epochs, 'foo')
+ assert_raises(TypeError, concatenate_epochs, [epochs, 'foo'])
+ epochs2.info['dev_head_t'] = None
+ assert_raises(ValueError, concatenate_epochs, [epochs, epochs2])
+ epochs.info['dev_head_t'] = None
+ concatenate_epochs([epochs, epochs2]) # should work
+
def test_add_channels():
- """Test epoch splitting / re-appending channel types
- """
+ """Test epoch splitting / re-appending channel types."""
raw, events, picks = _get_data()
epoch_nopre = Epochs(
raw=raw, events=events, event_id=event_id, tmin=tmin, tmax=tmax,
- picks=picks)
+ picks=picks, add_eeg_ref=False)
epoch = Epochs(
raw=raw, events=events, event_id=event_id, tmin=tmin, tmax=tmax,
- picks=picks, preload=True)
+ picks=picks, preload=True, add_eeg_ref=False)
epoch_eeg = epoch.copy().pick_types(meg=False, eeg=True)
epoch_meg = epoch.copy().pick_types(meg=True)
epoch_stim = epoch.copy().pick_types(meg=False, stim=True)
@@ -2073,9 +2168,9 @@ def test_seeg_ecog():
def test_default_values():
"""Test default event_id, tmax tmin values are working correctly"""
raw, events = _get_data()[:2]
- epoch_1 = Epochs(raw, events[:1], preload=True)
+ epoch_1 = Epochs(raw, events[:1], preload=True, add_eeg_ref=False)
epoch_2 = Epochs(raw, events[:1], event_id=None, tmin=-0.2, tmax=0.5,
- preload=True)
+ preload=True, add_eeg_ref=False)
assert_equal(hash(epoch_1), hash(epoch_2))
diff --git a/mne/tests/test_event.py b/mne/tests/test_event.py
index 5df9092..e1dd4d6 100644
--- a/mne/tests/test_event.py
+++ b/mne/tests/test_event.py
@@ -4,14 +4,17 @@ import os
from nose.tools import assert_true, assert_raises
import numpy as np
from numpy.testing import (assert_array_almost_equal, assert_array_equal,
- assert_equal)
+ assert_equal, assert_allclose)
import warnings
from mne import (read_events, write_events, make_fixed_length_events,
- find_events, pick_events, find_stim_steps, io, pick_channels)
+ find_events, pick_events, find_stim_steps, pick_channels,
+ read_evokeds, Epochs)
+from mne.io import read_raw_fif
from mne.tests.common import assert_naming
from mne.utils import _TempDir, run_tests_if_main
-from mne.event import define_target_events, merge_events
+from mne.event import define_target_events, merge_events, AcqParserFIF
+from mne.datasets import testing
warnings.simplefilter('always')
@@ -22,6 +25,11 @@ fname_1 = op.join(base_dir, 'test-1-eve.fif')
fname_txt = op.join(base_dir, 'test-eve.eve')
fname_txt_1 = op.join(base_dir, 'test-eve-1.eve')
+# for testing Elekta averager
+elekta_base_dir = op.join(testing.data_path(download=False), 'misc')
+fname_raw_elekta = op.join(elekta_base_dir, 'test_elekta_3ch_raw.fif')
+fname_ave_elekta = op.join(elekta_base_dir, 'test_elekta-ave.fif')
+
# using mne_process_raw --raw test_raw.fif --eventsout test-mpr-eve.eve:
fname_txt_mpr = op.join(base_dir, 'test-mpr-eve.eve')
fname_old_txt = op.join(base_dir, 'test-eve-old-style.eve')
@@ -29,8 +37,8 @@ raw_fname = op.join(base_dir, 'test_raw.fif')
def test_fix_stim():
- """Test fixing stim STI016 for Neuromag"""
- raw = io.read_raw_fif(raw_fname, preload=True)
+ """Test fixing stim STI016 for Neuromag."""
+ raw = read_raw_fif(raw_fname, preload=True, add_eeg_ref=False)
# 32768 (016) + 3 (002+001) bits gets incorrectly coded during acquisition
raw._data[raw.ch_names.index('STI 014'), :3] = [0, -32765, 0]
with warnings.catch_warnings(record=True) as w:
@@ -43,12 +51,12 @@ def test_fix_stim():
def test_add_events():
- """Test adding events to a Raw file"""
+ """Test adding events to a Raw file."""
# need preload
- raw = io.read_raw_fif(raw_fname, preload=False)
+ raw = read_raw_fif(raw_fname, preload=False, add_eeg_ref=False)
events = np.array([[raw.first_samp, 0, 1]])
assert_raises(RuntimeError, raw.add_events, events, 'STI 014')
- raw = io.read_raw_fif(raw_fname, preload=True)
+ raw = read_raw_fif(raw_fname, preload=True, add_eeg_ref=False)
orig_events = find_events(raw, 'STI 014')
# add some events
events = np.array([raw.first_samp, 0, 1])
@@ -66,7 +74,7 @@ def test_add_events():
def test_merge_events():
- """Test event merging"""
+ """Test event merging."""
events_orig = [[1, 0, 1], [3, 0, 2], [10, 0, 3], [20, 0, 4]]
events_replacement = \
@@ -98,7 +106,7 @@ def test_merge_events():
def test_io_events():
- """Test IO for events"""
+ """Test IO for events."""
tempdir = _TempDir()
# Test binary fif IO
events = read_events(fname) # Use as the gold standard
@@ -119,7 +127,7 @@ def test_io_events():
assert_array_almost_equal(events, events2)
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter('always')
- events2 = read_events(fname_txt_mpr)
+ events2 = read_events(fname_txt_mpr, mask=0, mask_type='not_and')
assert_true(sum('first row of' in str(ww.message) for ww in w) == 1)
assert_array_almost_equal(events, events2)
@@ -169,9 +177,9 @@ def test_io_events():
def test_find_events():
- """Test find events in raw file"""
+ """Test find events in raw file."""
events = read_events(fname)
- raw = io.read_raw_fif(raw_fname, preload=True)
+ raw = read_raw_fif(raw_fname, preload=True, add_eeg_ref=False)
# let's test the defaulting behavior while we're at it
extra_ends = ['', '_1']
orig_envs = [os.getenv('MNE_STIM_CHANNEL%s' % s) for s in extra_ends]
@@ -181,7 +189,7 @@ def test_find_events():
events2 = find_events(raw)
assert_array_almost_equal(events, events2)
# now test with mask
- events11 = find_events(raw, mask=3)
+ events11 = find_events(raw, mask=3, mask_type='not_and')
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter('always')
events22 = read_events(fname, mask=3)
@@ -203,14 +211,33 @@ def test_find_events():
# 1 == '0b1', 2 == '0b10', 3 == '0b11', 4 == '0b100'
assert_raises(TypeError, find_events, raw, mask="0")
- assert_array_equal(find_events(raw, shortest_event=1, mask=1),
+ assert_raises(ValueError, find_events, raw, mask=0, mask_type='blah')
+ # testing mask_type. default = 'not_and'
+ assert_array_equal(find_events(raw, shortest_event=1, mask=1,
+ mask_type='not_and'),
[[2, 0, 2], [4, 2, 4]])
- assert_array_equal(find_events(raw, shortest_event=1, mask=2),
+ assert_array_equal(find_events(raw, shortest_event=1, mask=2,
+ mask_type='not_and'),
[[1, 0, 1], [3, 0, 1], [4, 1, 4]])
- assert_array_equal(find_events(raw, shortest_event=1, mask=3),
+ assert_array_equal(find_events(raw, shortest_event=1, mask=3,
+ mask_type='not_and'),
[[4, 0, 4]])
- assert_array_equal(find_events(raw, shortest_event=1, mask=4),
+ assert_array_equal(find_events(raw, shortest_event=1, mask=4,
+ mask_type='not_and'),
+ [[1, 0, 1], [2, 1, 2], [3, 2, 3]])
+ # testing with mask_type = 'and'
+ assert_array_equal(find_events(raw, shortest_event=1, mask=1,
+ mask_type='and'),
+ [[1, 0, 1], [3, 0, 1]])
+ assert_array_equal(find_events(raw, shortest_event=1, mask=2,
+ mask_type='and'),
+ [[2, 0, 2]])
+ assert_array_equal(find_events(raw, shortest_event=1, mask=3,
+ mask_type='and'),
[[1, 0, 1], [2, 1, 2], [3, 2, 3]])
+ assert_array_equal(find_events(raw, shortest_event=1, mask=4,
+ mask_type='and'),
+ [[4, 0, 4]])
# test empty events channel
raw._data[stim_channel_idx, :] = 0
@@ -318,7 +345,7 @@ def test_find_events():
def test_pick_events():
- """Test pick events in a events ndarray"""
+ """Test pick events in a events ndarray."""
events = np.array([[1, 0, 1],
[2, 1, 0],
[3, 0, 4],
@@ -338,8 +365,8 @@ def test_pick_events():
def test_make_fixed_length_events():
- """Test making events of a fixed length"""
- raw = io.read_raw_fif(raw_fname)
+ """Test making events of a fixed length."""
+ raw = read_raw_fif(raw_fname, add_eeg_ref=False)
events = make_fixed_length_events(raw, id=1)
assert_true(events.shape[1], 3)
events_zero = make_fixed_length_events(raw, 1, first_samp=False)
@@ -353,12 +380,17 @@ def test_make_fixed_length_events():
# With bad limits (no resulting events)
assert_raises(ValueError, make_fixed_length_events, raw, 1,
tmin, tmax - 1e-3, duration)
+ # not raw, bad id or duration
+ assert_raises(ValueError, make_fixed_length_events, raw, 2.3)
+ assert_raises(ValueError, make_fixed_length_events, 'not raw', 2)
+ assert_raises(ValueError, make_fixed_length_events, raw, 23, tmin, tmax,
+ 'abc')
def test_define_events():
- """Test defining response events"""
+ """Test defining response events."""
events = read_events(fname)
- raw = io.read_raw_fif(raw_fname)
+ raw = read_raw_fif(raw_fname, add_eeg_ref=False)
events_, _ = define_target_events(events, 5, 32, raw.info['sfreq'],
.2, 0.7, 42, 99)
n_target = events[events[:, 2] == 5].shape[0]
@@ -385,4 +417,75 @@ def test_define_events():
assert_array_equal(true_lag_fill, lag_fill)
assert_array_equal(true_lag_nofill, lag_nofill)
+
+ at testing.requires_testing_data
+def test_acqparser():
+ """ Test AcqParserFIF """
+ # no acquisition parameters
+ assert_raises(ValueError, AcqParserFIF, {'acq_pars': ''})
+ # invalid acquisition parameters
+ assert_raises(ValueError, AcqParserFIF, {'acq_pars': 'baaa'})
+ assert_raises(ValueError, AcqParserFIF, {'acq_pars': 'ERFVersion\n1'})
+ # test oldish file
+ raw = read_raw_fif(raw_fname, preload=False)
+ acqp = AcqParserFIF(raw.info)
+ # test __repr__()
+ assert_true(repr(acqp))
+ # old file should trigger compat mode
+ assert_true(acqp.compat)
+ # count events and categories
+ assert_equal(len(acqp.categories), 6)
+ assert_equal(len(acqp._categories), 17)
+ assert_equal(len(acqp.events), 6)
+ assert_equal(len(acqp._events), 17)
+ # get category
+ assert_true(acqp['Surprise visual'])
+ # test TRIUX file
+ raw = read_raw_fif(fname_raw_elekta, preload=False)
+ acqp = AcqParserFIF(raw.info)
+ # test __repr__()
+ assert_true(repr(acqp))
+ # this file should not be in compatibility mode
+ assert_true(not acqp.compat)
+ # nonexisting category
+ assert_raises(KeyError, acqp.__getitem__, 'does not exist')
+ assert_raises(KeyError, acqp.get_condition, raw, 'foo')
+ # category not a string
+ assert_raises(ValueError, acqp.__getitem__, 0)
+ # number of events / categories
+ assert_equal(len(acqp), 7)
+ assert_equal(len(acqp.categories), 7)
+ assert_equal(len(acqp._categories), 32)
+ assert_equal(len(acqp.events), 6)
+ assert_equal(len(acqp._events), 32)
+ # get category
+ assert_true(acqp['Test event 5'])
+
+
+ at testing.requires_testing_data
+def test_acqparser_averaging():
+ """ Test averaging with AcqParserFIF vs. Elekta software """
+ raw = read_raw_fif(fname_raw_elekta, preload=True)
+ acqp = AcqParserFIF(raw.info)
+ for cat in acqp.categories:
+ # XXX datasets match only when baseline is applied to both,
+ # not sure where relative dc shift comes from
+ cond = acqp.get_condition(raw, cat)
+ eps = Epochs(raw, baseline=(-.05, 0), **cond)
+ ev = eps.average()
+ ev_ref = read_evokeds(fname_ave_elekta, cat['comment'],
+ baseline=(-.05, 0), proj=False)
+ ev_mag = ev.copy()
+ ev_mag.pick_channels(['MEG0111'])
+ ev_grad = ev.copy()
+ ev_grad.pick_channels(['MEG2643', 'MEG1622'])
+ ev_ref_mag = ev_ref.copy()
+ ev_ref_mag.pick_channels(['MEG0111'])
+ ev_ref_grad = ev_ref.copy()
+ ev_ref_grad.pick_channels(['MEG2643', 'MEG1622'])
+ assert_allclose(ev_mag.data, ev_ref_mag.data,
+ rtol=0, atol=1e-15) # tol = 1 fT
+ assert_allclose(ev_grad.data, ev_ref_grad.data,
+ rtol=0, atol=1e-13) # tol = 1 fT/cm
+
run_tests_if_main()
diff --git a/mne/tests/test_evoked.py b/mne/tests/test_evoked.py
index f3b680e..45794f9 100644
--- a/mne/tests/test_evoked.py
+++ b/mne/tests/test_evoked.py
@@ -16,9 +16,10 @@ from numpy.testing import (assert_array_almost_equal, assert_equal,
from nose.tools import assert_true, assert_raises, assert_not_equal
from mne import (equalize_channels, pick_types, read_evokeds, write_evokeds,
- grand_average, combine_evoked, create_info)
+ grand_average, combine_evoked, create_info, read_events,
+ Epochs, EpochsArray)
from mne.evoked import _get_peak, Evoked, EvokedArray
-from mne.epochs import EpochsArray
+from mne.io import read_raw_fif
from mne.tests.common import assert_naming
from mne.utils import (_TempDir, requires_pandas, slow_test, requires_version,
run_tests_if_main)
@@ -26,16 +27,57 @@ from mne.externals.six.moves import cPickle as pickle
warnings.simplefilter('always')
-fname = op.join(op.dirname(__file__), '..', 'io', 'tests', 'data',
- 'test-ave.fif')
-fname_gz = op.join(op.dirname(__file__), '..', 'io', 'tests', 'data',
- 'test-ave.fif.gz')
+base_dir = op.join(op.dirname(__file__), '..', 'io', 'tests', 'data')
+fname = op.join(base_dir, 'test-ave.fif')
+fname_gz = op.join(base_dir, 'test-ave.fif.gz')
+raw_fname = op.join(base_dir, 'test_raw.fif')
+event_name = op.join(base_dir, 'test-eve.fif')
+
+
+def test_decim():
+ """Test evoked decimation."""
+ rng = np.random.RandomState(0)
+ n_epochs, n_channels, n_times = 5, 10, 20
+ dec_1, dec_2 = 2, 3
+ decim = dec_1 * dec_2
+ sfreq = 1000.
+ sfreq_new = sfreq / decim
+ data = rng.randn(n_epochs, n_channels, n_times)
+ events = np.array([np.arange(n_epochs), [0] * n_epochs, [1] * n_epochs]).T
+ info = create_info(n_channels, sfreq, 'eeg')
+ info['lowpass'] = sfreq_new / float(decim)
+ epochs = EpochsArray(data, info, events)
+ data_epochs = epochs.copy().decimate(decim).get_data()
+ data_epochs_2 = epochs.copy().decimate(decim, offset=1).get_data()
+ data_epochs_3 = epochs.decimate(dec_1).decimate(dec_2).get_data()
+ assert_array_equal(data_epochs, data[:, :, ::decim])
+ assert_array_equal(data_epochs_2, data[:, :, 1::decim])
+ assert_array_equal(data_epochs, data_epochs_3)
+
+ # Now let's do it with some real data
+ raw = read_raw_fif(raw_fname, add_eeg_ref=False)
+ events = read_events(event_name)
+ sfreq_new = raw.info['sfreq'] / decim
+ raw.info['lowpass'] = sfreq_new / 4. # suppress aliasing warnings
+ picks = pick_types(raw.info, meg=True, eeg=True, exclude=())
+ epochs = Epochs(raw, events, 1, -0.2, 0.5, picks=picks, preload=True,
+ add_eeg_ref=False)
+ for offset in (0, 1):
+ ev_ep_decim = epochs.copy().decimate(decim, offset).average()
+ ev_decim = epochs.average().decimate(decim, offset)
+ expected_times = epochs.times[offset::decim]
+ assert_allclose(ev_decim.times, expected_times)
+ assert_allclose(ev_ep_decim.times, expected_times)
+ expected_data = epochs.get_data()[:, :, offset::decim].mean(axis=0)
+ assert_allclose(ev_decim.data, expected_data)
+ assert_allclose(ev_ep_decim.data, expected_data)
+ assert_equal(ev_decim.info['sfreq'], sfreq_new)
+ assert_array_equal(ev_decim.times, expected_times)
@requires_version('scipy', '0.14')
def test_savgol_filter():
- """Test savgol filtering
- """
+ """Test savgol filtering."""
h_freq = 10.
evoked = read_evokeds(fname, 0)
freqs = fftpack.fftfreq(len(evoked.times), 1. / evoked.info['sfreq'])
@@ -57,8 +99,7 @@ def test_savgol_filter():
def test_hash_evoked():
- """Test evoked hashing
- """
+ """Test evoked hashing."""
ave = read_evokeds(fname, 0)
ave_2 = read_evokeds(fname, 0)
assert_equal(hash(ave), hash(ave_2))
@@ -71,8 +112,7 @@ def test_hash_evoked():
@slow_test
def test_io_evoked():
- """Test IO for evoked data (fif + gz) with integer and str args
- """
+ """Test IO for evoked data (fif + gz) with integer and str args."""
tempdir = _TempDir()
ave = read_evokeds(fname, 0)
@@ -119,9 +159,9 @@ def test_io_evoked():
assert_equal(av1.comment, av2.comment)
# test warnings on bad filenames
+ fname2 = op.join(tempdir, 'test-bad-name.fif')
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter('always')
- fname2 = op.join(tempdir, 'test-bad-name.fif')
write_evokeds(fname2, ave)
read_evokeds(fname2)
assert_naming(w, 'test_evoked.py', 2)
@@ -129,10 +169,24 @@ def test_io_evoked():
# constructor
assert_raises(TypeError, Evoked, fname)
+ # MaxShield
+ fname_ms = op.join(tempdir, 'test-ave.fif')
+ assert_true(ave.info['maxshield'] is False)
+ ave.info['maxshield'] = True
+ ave.save(fname_ms)
+ assert_raises(ValueError, read_evokeds, fname_ms)
+ with warnings.catch_warnings(record=True) as w:
+ aves = read_evokeds(fname_ms, allow_maxshield=True)
+ assert_true(all('Elekta' in str(ww.message) for ww in w))
+ assert_true(all(ave.info['maxshield'] is True for ave in aves))
+ with warnings.catch_warnings(record=True) as w:
+ aves = read_evokeds(fname_ms, allow_maxshield='yes')
+ assert_equal(len(w), 0)
+ assert_true(all(ave.info['maxshield'] is True for ave in aves))
+
def test_shift_time_evoked():
- """ Test for shifting of time scale
- """
+ """ Test for shifting of time scale."""
tempdir = _TempDir()
# Shift backward
ave = read_evokeds(fname, 0)
@@ -172,8 +226,7 @@ def test_shift_time_evoked():
def test_evoked_resample():
- """Test for resampling of evoked data
- """
+ """Test for resampling of evoked data."""
tempdir = _TempDir()
# upsample, write it out, read it in
ave = read_evokeds(fname, 0)
@@ -204,8 +257,7 @@ def test_evoked_resample():
def test_evoked_detrend():
- """Test for detrending evoked data
- """
+ """Test for detrending evoked data."""
ave = read_evokeds(fname, 0)
ave_normal = read_evokeds(fname, 0)
ave.detrend(0)
@@ -217,7 +269,7 @@ def test_evoked_detrend():
@requires_pandas
def test_to_data_frame():
- """Test evoked Pandas exporter"""
+ """Test evoked Pandas exporter."""
ave = read_evokeds(fname, 0)
assert_raises(ValueError, ave.to_data_frame, picks=np.arange(400))
df = ave.to_data_frame()
@@ -229,8 +281,7 @@ def test_to_data_frame():
def test_evoked_proj():
- """Test SSP proj operations
- """
+ """Test SSP proj operations."""
for proj in [True, False]:
ave = read_evokeds(fname, condition=0, proj=proj)
assert_true(all(p['active'] == proj for p in ave.info['projs']))
@@ -258,9 +309,7 @@ def test_evoked_proj():
def test_get_peak():
- """Test peak getter
- """
-
+ """Test peak getter."""
evoked = read_evokeds(fname, condition=0, proj=True)
assert_raises(ValueError, evoked.get_peak, ch_type='mag', tmin=1)
assert_raises(ValueError, evoked.get_peak, ch_type='mag', tmax=0.9)
@@ -301,8 +350,7 @@ def test_get_peak():
def test_drop_channels_mixin():
- """Test channels-dropping functionality
- """
+ """Test channels-dropping functionality."""
evoked = read_evokeds(fname, condition=0, proj=True)
drop_ch = evoked.ch_names[:3]
ch_names = evoked.ch_names[3:]
@@ -312,15 +360,19 @@ def test_drop_channels_mixin():
assert_equal(ch_names, dummy.ch_names)
assert_equal(ch_names_orig, evoked.ch_names)
assert_equal(len(ch_names_orig), len(evoked.data))
+ dummy2 = evoked.copy().drop_channels([drop_ch[0]])
+ assert_equal(dummy2.ch_names, ch_names_orig[1:])
evoked.drop_channels(drop_ch)
assert_equal(ch_names, evoked.ch_names)
assert_equal(len(ch_names), len(evoked.data))
+ for ch_names in ([1, 2], "fake", ["fake"]):
+ assert_raises(ValueError, evoked.drop_channels, ch_names)
+
def test_pick_channels_mixin():
- """Test channel-picking functionality
- """
+ """Test channel-picking functionality."""
evoked = read_evokeds(fname, condition=0, proj=True)
ch_names = evoked.ch_names[:3]
@@ -344,8 +396,7 @@ def test_pick_channels_mixin():
def test_equalize_channels():
- """Test equalization of channels
- """
+ """Test equalization of channels."""
evoked1 = read_evokeds(fname, condition=0, proj=True)
evoked2 = evoked1.copy()
ch_names = evoked1.ch_names[2:]
@@ -357,9 +408,8 @@ def test_equalize_channels():
assert_equal(ch_names, e.ch_names)
-def test_evoked_arithmetic():
- """Test evoked arithmetic
- """
+def test_arithmetic():
+ """Test evoked arithmetic."""
ev = read_evokeds(fname, condition=0)
ev1 = EvokedArray(np.ones_like(ev.data), ev.info, ev.times[0], nave=20)
ev2 = EvokedArray(-np.ones_like(ev.data), ev.info, ev.times[0], nave=10)
@@ -367,22 +417,35 @@ def test_evoked_arithmetic():
# combine_evoked([ev1, ev2]) should be the same as ev1 + ev2:
# data should be added according to their `nave` weights
# nave = ev1.nave + ev2.nave
- ev = ev1 + ev2
+ with warnings.catch_warnings(record=True): # deprecation no weights
+ ev = combine_evoked([ev1, ev2])
assert_equal(ev.nave, ev1.nave + ev2.nave)
assert_allclose(ev.data, 1. / 3. * np.ones_like(ev.data))
- ev = ev1 - ev2
- assert_equal(ev.nave, ev1.nave + ev2.nave)
- assert_equal(ev.comment, ev1.comment + ' - ' + ev2.comment)
- assert_allclose(ev.data, np.ones_like(ev1.data))
+
+ # with same trial counts, a bunch of things should be equivalent
+ for weights in ('nave', 'equal', [0.5, 0.5]):
+ ev = combine_evoked([ev1, ev1], weights=weights)
+ assert_allclose(ev.data, ev1.data)
+ assert_equal(ev.nave, 2 * ev1.nave)
+ ev = combine_evoked([ev1, -ev1], weights=weights)
+ assert_allclose(ev.data, 0., atol=1e-20)
+ assert_equal(ev.nave, 2 * ev1.nave)
+ ev = combine_evoked([ev1, -ev1], weights='equal')
+ assert_allclose(ev.data, 0., atol=1e-20)
+ assert_equal(ev.nave, 2 * ev1.nave)
+ ev = combine_evoked([ev1, -ev2], weights='equal')
+ expected = int(round(1. / (0.25 / ev1.nave + 0.25 / ev2.nave)))
+ assert_equal(expected, 27) # this is reasonable
+ assert_equal(ev.nave, expected)
# default comment behavior if evoked.comment is None
old_comment1 = ev1.comment
old_comment2 = ev2.comment
ev1.comment = None
- with warnings.catch_warnings(record=True):
- warnings.simplefilter('always')
- ev = ev1 - ev2
- assert_equal(ev.comment, 'unknown')
+ ev = combine_evoked([ev1, -ev2], weights=[1, -1])
+ assert_equal(ev.comment.count('unknown'), 2)
+ assert_true('-unknown' in ev.comment)
+ assert_true(' + ' in ev.comment)
ev1.comment = old_comment1
ev2.comment = old_comment2
@@ -412,11 +475,11 @@ def test_evoked_arithmetic():
assert_equal(gave.data.shape, [len(ch_names), evoked1.data.shape[1]])
assert_equal(ch_names, gave.ch_names)
assert_equal(gave.nave, 2)
+ assert_raises(ValueError, grand_average, [1, evoked1])
def test_array_epochs():
- """Test creating evoked from array
- """
+ """Test creating evoked from array."""
tempdir = _TempDir()
# creating
@@ -463,14 +526,14 @@ def test_array_epochs():
def test_time_as_index():
- """Test time as index"""
+ """Test time as index."""
evoked = read_evokeds(fname, condition=0).crop(-.1, .1)
assert_array_equal(evoked.time_as_index([-.1, .1], use_rounding=True),
[0, len(evoked.times) - 1])
def test_add_channels():
- """Test evoked splitting / re-appending channel types"""
+ """Test evoked splitting / re-appending channel types."""
evoked = read_evokeds(fname, condition=0)
evoked.info['buffer_size_sec'] = None
hpi_coils = [{'event_bits': []},
@@ -502,4 +565,18 @@ def test_add_channels():
assert_raises(AssertionError, evoked_meg.add_channels, evoked_badsf)
+def test_evoked_baseline():
+ """Test evoked baseline."""
+ evoked = read_evokeds(fname, condition=0, baseline=None)
+
+ # Here we create a data_set with constant data.
+ evoked = EvokedArray(np.ones_like(evoked.data), evoked.info,
+ evoked.times[0])
+
+ # Mean baseline correction is applied, since the data is equal to its mean
+ # the resulting data should be a matrix of zeroes.
+ evoked.apply_baseline((None, None))
+
+ assert_allclose(evoked.data, np.zeros_like(evoked.data))
+
run_tests_if_main()
diff --git a/mne/tests/test_filter.py b/mne/tests/test_filter.py
index fc89752..88ee708 100644
--- a/mne/tests/test_filter.py
+++ b/mne/tests/test_filter.py
@@ -1,84 +1,172 @@
+import os.path as op
+import warnings
+
import numpy as np
from numpy.testing import (assert_array_almost_equal, assert_almost_equal,
assert_array_equal, assert_allclose)
from nose.tools import assert_equal, assert_true, assert_raises
-import warnings
-from scipy.signal import resample as sp_resample
+from scipy.signal import resample as sp_resample, butter
+from mne import create_info
+from mne.io import RawArray, read_raw_fif
from mne.filter import (band_pass_filter, high_pass_filter, low_pass_filter,
band_stop_filter, resample, _resample_stim_channels,
construct_iir_filter, notch_filter, detrend,
- _overlap_add_filter, _smart_pad)
+ _overlap_add_filter, _smart_pad, design_mne_c_filter,
+ estimate_ringing_samples, filter_data)
-from mne.utils import sum_squared, run_tests_if_main, slow_test, catch_logging
+from mne.utils import (sum_squared, run_tests_if_main, slow_test,
+ catch_logging, requires_version, _TempDir,
+ requires_mne, run_subprocess)
warnings.simplefilter('always') # enable b/c these tests throw warnings
rng = np.random.RandomState(0)
+ at requires_mne
+def test_mne_c_design():
+ """Test MNE-C filter design"""
+ tempdir = _TempDir()
+ temp_fname = op.join(tempdir, 'test_raw.fif')
+ out_fname = op.join(tempdir, 'test_c_raw.fif')
+ x = np.zeros((1, 10001))
+ x[0, 5000] = 1.
+ time_sl = slice(5000 - 4096, 5000 + 4097)
+ sfreq = 1000.
+ RawArray(x, create_info(1, sfreq, 'eeg')).save(temp_fname)
+
+ tols = dict(rtol=1e-4, atol=1e-4)
+ cmd = ('mne_process_raw', '--projoff', '--raw', temp_fname,
+ '--save', out_fname)
+ run_subprocess(cmd)
+ h = design_mne_c_filter(sfreq, None, 40)
+ h_c = read_raw_fif(out_fname, add_eeg_ref=False)[0][0][0][time_sl]
+ assert_allclose(h, h_c, **tols)
+
+ run_subprocess(cmd + ('--highpass', '5', '--highpassw', '2.5'))
+ h = design_mne_c_filter(sfreq, 5, 40, 2.5)
+ h_c = read_raw_fif(out_fname, add_eeg_ref=False)[0][0][0][time_sl]
+ assert_allclose(h, h_c, **tols)
+
+ run_subprocess(cmd + ('--lowpass', '1000', '--highpass', '10'))
+ h = design_mne_c_filter(sfreq, 10, None, verbose=True)
+ h_c = read_raw_fif(out_fname, add_eeg_ref=False)[0][0][0][time_sl]
+ assert_allclose(h, h_c, **tols)
+
+
+ at requires_version('scipy', '0.16')
+def test_estimate_ringing():
+ """Test our ringing estimation function"""
+ # Actual values might differ based on system, so let's be approximate
+ for kind in ('ba', 'sos'):
+ for thresh, lims in ((0.1, (30, 60)), # 47
+ (0.01, (300, 600)), # 475
+ (0.001, (3000, 6000)), # 4758
+ (0.0001, (30000, 60000))): # 37993
+ n_ring = estimate_ringing_samples(butter(3, thresh, output=kind))
+ assert_true(lims[0] <= n_ring <= lims[1],
+ msg='%s %s: %s <= %s <= %s'
+ % (kind, thresh, lims[0], n_ring, lims[1]))
+ with warnings.catch_warnings(record=True) as w:
+ assert_equal(estimate_ringing_samples(butter(4, 0.00001)), 100000)
+ assert_true(any('properly estimate' in str(ww.message) for ww in w))
+
+
def test_1d_filter():
"""Test our private overlap-add filtering function"""
# make some random signals and filters
- for n_signal in (1, 2, 5, 10, 20, 40, 100, 200, 400, 1000, 2000):
+ for n_signal in (1, 2, 3, 5, 10, 20, 40):
x = rng.randn(n_signal)
- for n_filter in (2, 5, 10, 20, 40, 100, 200, 400, 1000, 2000):
- # Don't test n_filter == 1 because scipy can't handle it.
- if n_filter > n_signal:
- continue # only equal or lesser lengths supported
+ for n_filter in (1, 2, 3, 5, 10, 11, 20, 21, 40, 41, 100, 101):
for filter_type in ('identity', 'random'):
if filter_type == 'random':
h = rng.randn(n_filter)
else: # filter_type == 'identity'
h = np.concatenate([[1.], np.zeros(n_filter - 1)])
# ensure we pad the signal the same way for both filters
- n_pad = max(min(n_filter, n_signal - 1), 0)
+ n_pad = n_filter - 1
x_pad = _smart_pad(x, np.array([n_pad, n_pad]))
- for zero_phase in (True, False):
+ for phase in ('zero', 'linear', 'zero-double'):
# compute our expected result the slow way
- if zero_phase:
- x_expected = np.convolve(x_pad, h)[::-1]
- x_expected = np.convolve(x_expected, h)[::-1]
- x_expected = x_expected[len(h) - 1:-(len(h) - 1)]
+ if phase == 'zero':
+ # only allow zero-phase for odd-length filters
+ if n_filter % 2 == 0:
+ assert_raises(RuntimeError, _overlap_add_filter,
+ x[np.newaxis], h, phase=phase)
+ continue
+ shift = (len(h) - 1) // 2
+ x_expected = np.convolve(x_pad, h)
+ x_expected = x_expected[shift:len(x_expected) - shift]
+ elif phase == 'zero-double':
+ shift = len(h) - 1
+ x_expected = np.convolve(x_pad, h)
+ x_expected = np.convolve(x_expected[::-1], h)[::-1]
+ x_expected = x_expected[shift:len(x_expected) - shift]
+ shift = 0
else:
+ shift = 0
x_expected = np.convolve(x_pad, h)
- x_expected = x_expected[:-(len(h) - 1)]
+ x_expected = x_expected[:len(x_expected) - len(h) + 1]
# remove padding
if n_pad > 0:
- x_expected = x_expected[n_pad:-n_pad]
+ x_expected = x_expected[n_pad:len(x_expected) - n_pad]
+ assert_equal(len(x_expected), len(x))
# make sure we actually set things up reasonably
if filter_type == 'identity':
- assert_allclose(x_expected, x)
+ out = x_pad.copy()
+ out = out[shift + n_pad:]
+ out = out[:len(x)]
+ out = np.concatenate((out, np.zeros(max(len(x) -
+ len(out), 0))))
+ assert_equal(len(out), len(x))
+ assert_allclose(out, x_expected)
+ assert_equal(len(x_expected), len(x))
+
# compute our version
for n_fft in (None, 32, 128, 129, 1023, 1024, 1025, 2048):
# need to use .copy() b/c signal gets modified inplace
x_copy = x[np.newaxis, :].copy()
- if (n_fft is not None and n_fft < 2 * n_filter - 1 and
- zero_phase):
- assert_raises(ValueError, _overlap_add_filter,
- x_copy, h, n_fft, zero_phase)
- elif (n_fft is not None and n_fft < n_filter and not
- zero_phase):
+ min_fft = 2 * n_filter - 1
+ if phase == 'zero-double':
+ min_fft = 2 * min_fft - 1
+ if n_fft is not None and n_fft < min_fft:
assert_raises(ValueError, _overlap_add_filter,
- x_copy, h, n_fft, zero_phase)
+ x_copy, h, n_fft, phase=phase)
else:
- # bad len warning
- with warnings.catch_warnings(record=True):
- x_filtered = _overlap_add_filter(
- x_copy, h, n_fft, zero_phase)[0]
- assert_allclose(x_expected, x_filtered)
+ x_filtered = _overlap_add_filter(
+ x_copy, h, n_fft, phase=phase)[0]
+ assert_allclose(x_filtered, x_expected, atol=1e-13)
+ at requires_version('scipy', '0.16')
def test_iir_stability():
- """Test IIR filter stability check
- """
+ """Test IIR filter stability check"""
sig = np.empty(1000)
sfreq = 1000
# This will make an unstable filter, should throw RuntimeError
assert_raises(RuntimeError, high_pass_filter, sig, sfreq, 0.6,
- method='iir', iir_params=dict(ftype='butter', order=8))
- # can't pass iir_params if method='fir'
+ method='iir', iir_params=dict(ftype='butter', order=8,
+ output='ba'))
+ # This one should work just fine
+ high_pass_filter(sig, sfreq, 0.6, method='iir',
+ iir_params=dict(ftype='butter', order=8, output='sos'))
+ # bad system type
+ assert_raises(ValueError, high_pass_filter, sig, sfreq, 0.6, method='iir',
+ iir_params=dict(ftype='butter', order=8, output='foo'))
+ # missing ftype
+ assert_raises(RuntimeError, high_pass_filter, sig, sfreq, 0.6,
+ method='iir', iir_params=dict(order=8, output='sos'))
+ # bad ftype
+ assert_raises(RuntimeError, high_pass_filter, sig, sfreq, 0.6,
+ method='iir',
+ iir_params=dict(order=8, ftype='foo', output='sos'))
+ # missing gstop
+ assert_raises(RuntimeError, high_pass_filter, sig, sfreq, 0.6,
+ method='iir', iir_params=dict(gpass=0.5, output='sos'))
+ # can't pass iir_params if method='fft'
assert_raises(ValueError, high_pass_filter, sig, sfreq, 0.1,
- method='fir', iir_params=dict(ftype='butter', order=2))
+ method='fft', iir_params=dict(ftype='butter', order=2,
+ output='sos'))
# method must be string
assert_raises(TypeError, high_pass_filter, sig, sfreq, 0.1,
method=1)
@@ -86,17 +174,30 @@ def test_iir_stability():
assert_raises(ValueError, high_pass_filter, sig, sfreq, 0.1,
method='blah')
# bad iir_params
+ assert_raises(TypeError, high_pass_filter, sig, sfreq, 0.1,
+ method='iir', iir_params='blah')
assert_raises(ValueError, high_pass_filter, sig, sfreq, 0.1,
- method='fir', iir_params='blah')
+ method='fft', iir_params=dict())
# should pass because dafault trans_bandwidth is not relevant
- high_pass_filter(sig, 250, 0.5, method='iir',
- iir_params=dict(ftype='butter', order=6))
+ iir_params = dict(ftype='butter', order=2, output='sos')
+ x_sos = high_pass_filter(sig, 250, 0.5, method='iir',
+ iir_params=iir_params)
+ iir_params_sos = construct_iir_filter(iir_params, f_pass=0.5, sfreq=250,
+ btype='highpass')
+ x_sos_2 = high_pass_filter(sig, 250, 0.5, method='iir',
+ iir_params=iir_params_sos)
+ assert_allclose(x_sos[100:-100], x_sos_2[100:-100])
+ x_ba = high_pass_filter(sig, 250, 0.5, method='iir',
+ iir_params=dict(ftype='butter', order=2,
+ output='ba'))
+ # Note that this will fail for higher orders (e.g., 6) showing the
+ # hopefully decreased numerical error of SOS
+ assert_allclose(x_sos[100:-100], x_ba[100:-100])
def test_notch_filters():
- """Test notch filters
- """
+ """Test notch filters"""
# let's use an ugly, prime sfreq for fun
sfreq = 487.0
sig_len_secs = 20
@@ -118,8 +219,9 @@ def test_notch_filters():
tols = [2, 1, 1, 1]
for meth, lf, fl, tol in zip(methods, line_freqs, filter_lengths, tols):
with catch_logging() as log_file:
- b = notch_filter(a, sfreq, lf, filter_length=fl, method=meth,
- verbose='INFO')
+ with warnings.catch_warnings(record=True): # filter_length=None
+ b = notch_filter(a, sfreq, lf, filter_length=fl, method=meth,
+ phase='zero', verbose=True)
if lf is None:
out = log_file.getvalue().split('\n')[:-1]
@@ -178,78 +280,67 @@ def test_resample_stim_channel():
assert_equal(new_data.shape[1], new_data_len)
+ at requires_version('scipy', '0.16')
@slow_test
def test_filters():
- """Test low-, band-, high-pass, and band-stop filters plus resampling
- """
- sfreq = 500
- sig_len_secs = 30
+ """Test low-, band-, high-pass, and band-stop filters plus resampling"""
+ sfreq = 100
+ sig_len_secs = 15
a = rng.randn(2, sig_len_secs * sfreq)
# let's test our catchers
for fl in ['blah', [0, 1], 1000.5, '10ss', '10']:
- assert_raises(ValueError, band_pass_filter, a, sfreq, 4, 8,
- filter_length=fl)
+ assert_raises(ValueError, band_pass_filter, a, sfreq, 4, 8, fl,
+ 1.0, 1.0, phase='zero')
for nj in ['blah', 0.5]:
- assert_raises(ValueError, band_pass_filter, a, sfreq, 4, 8, n_jobs=nj)
+ assert_raises(ValueError, band_pass_filter, a, sfreq, 4, 8, 100,
+ 1.0, 1.0, n_jobs=nj, phase='zero', fir_window='hann')
+ assert_raises(ValueError, band_pass_filter, a, sfreq, 4, 8, 100,
+ 1.0, 1.0, phase='zero', fir_window='foo')
# > Nyq/2
- assert_raises(ValueError, band_pass_filter, a, sfreq, 4, sfreq / 2.)
- assert_raises(ValueError, low_pass_filter, a, sfreq, sfreq / 2.)
+ assert_raises(ValueError, band_pass_filter, a, sfreq, 4, sfreq / 2.,
+ 100, 1.0, 1.0, phase='zero', fir_window='hann')
+ assert_raises(ValueError, low_pass_filter, a, sfreq, sfreq / 2.,
+ 100, 1.0, phase='zero', fir_window='hann')
# check our short-filter warning:
with warnings.catch_warnings(record=True) as w:
# Warning for low attenuation
- band_pass_filter(a, sfreq, 1, 8, filter_length=1024)
+ band_pass_filter(a, sfreq, 1, 8, filter_length=256, phase='zero')
+ assert_true(any('attenuation' in str(ww.message) for ww in w))
+ with warnings.catch_warnings(record=True) as w:
# Warning for too short a filter
- band_pass_filter(a, sfreq, 1, 8, filter_length='0.5s')
- assert_true(len(w) >= 2)
+ band_pass_filter(a, sfreq, 1, 8, filter_length='0.5s', phase='zero')
+ assert_true(any('Increase filter_length' in str(ww.message) for ww in w))
# try new default and old default
- for fl in ['10s', '5000ms', None]:
- bp = band_pass_filter(a, sfreq, 4, 8, filter_length=fl)
- bs = band_stop_filter(a, sfreq, 4 - 0.5, 8 + 0.5, filter_length=fl)
- lp = low_pass_filter(a, sfreq, 8, filter_length=fl, n_jobs=2)
- hp = high_pass_filter(lp, sfreq, 4, filter_length=fl)
- assert_array_almost_equal(hp, bp, 2)
- assert_array_almost_equal(bp + bs, a, 1)
-
- # Overlap-add filtering with a fixed filter length
- filter_length = 8192
- bp_oa = band_pass_filter(a, sfreq, 4, 8, filter_length)
- bs_oa = band_stop_filter(a, sfreq, 4 - 0.5, 8 + 0.5, filter_length)
- lp_oa = low_pass_filter(a, sfreq, 8, filter_length)
- hp_oa = high_pass_filter(lp_oa, sfreq, 4, filter_length)
- assert_array_almost_equal(hp_oa, bp_oa, 2)
- # Our filters are no longer quite complementary with linear rolloffs :(
- # this is the tradeoff for stability of the filtering
- # obtained by directly using the result of firwin2 instead of
- # modifying it...
- assert_array_almost_equal(bp_oa + bs_oa, a, 1)
-
- # The two methods should give the same result
- # As filtering for short signals uses a circular convolution (FFT) and
- # the overlap-add filter implements a linear convolution, the signal
- # boundary will be slightly different and we ignore it
- n_edge_ignore = 0
- assert_array_almost_equal(hp[n_edge_ignore:-n_edge_ignore],
- hp_oa[n_edge_ignore:-n_edge_ignore], 2)
+ for fl in ['auto', '10s', '5000ms', 1024]:
+ bp = band_pass_filter(a, sfreq, 4, 8, fl, 1.0, 1.0, phase='zero',
+ fir_window='hamming')
+ bs = band_stop_filter(a, sfreq, 4 - 1.0, 8 + 1.0, fl, 1.0, 1.0,
+ phase='zero', fir_window='hamming')
+ lp = low_pass_filter(a, sfreq, 8, fl, 1.0, n_jobs=2, phase='zero',
+ fir_window='hamming')
+ hp = high_pass_filter(lp, sfreq, 4, fl, 1.0, phase='zero',
+ fir_window='hamming')
+ assert_array_almost_equal(hp, bp, 4)
+ assert_array_almost_equal(bp + bs, a, 4)
# and since these are low-passed, downsampling/upsampling should be close
n_resamp_ignore = 10
- bp_up_dn = resample(resample(bp_oa, 2, 1, n_jobs=2), 1, 2, n_jobs=2)
- assert_array_almost_equal(bp_oa[n_resamp_ignore:-n_resamp_ignore],
+ bp_up_dn = resample(resample(bp, 2, 1, n_jobs=2), 1, 2, n_jobs=2)
+ assert_array_almost_equal(bp[n_resamp_ignore:-n_resamp_ignore],
bp_up_dn[n_resamp_ignore:-n_resamp_ignore], 2)
# note that on systems without CUDA, this line serves as a test for a
# graceful fallback to n_jobs=1
- bp_up_dn = resample(resample(bp_oa, 2, 1, n_jobs='cuda'), 1, 2,
- n_jobs='cuda')
- assert_array_almost_equal(bp_oa[n_resamp_ignore:-n_resamp_ignore],
+ bp_up_dn = resample(resample(bp, 2, 1, n_jobs='cuda'), 1, 2, n_jobs='cuda')
+ assert_array_almost_equal(bp[n_resamp_ignore:-n_resamp_ignore],
bp_up_dn[n_resamp_ignore:-n_resamp_ignore], 2)
# test to make sure our resamling matches scipy's
- bp_up_dn = sp_resample(sp_resample(bp_oa, 2 * bp_oa.shape[-1], axis=-1,
+ bp_up_dn = sp_resample(sp_resample(bp, 2 * bp.shape[-1], axis=-1,
window='boxcar'),
- bp_oa.shape[-1], window='boxcar', axis=-1)
- assert_array_almost_equal(bp_oa[n_resamp_ignore:-n_resamp_ignore],
+ bp.shape[-1], window='boxcar', axis=-1)
+ assert_array_almost_equal(bp[n_resamp_ignore:-n_resamp_ignore],
bp_up_dn[n_resamp_ignore:-n_resamp_ignore], 2)
# make sure we don't alias
@@ -261,31 +352,43 @@ def test_filters():
assert_array_almost_equal(np.zeros_like(sig_gone), sig_gone, 2)
# let's construct some filters
- iir_params = dict(ftype='cheby1', gpass=1, gstop=20)
+ iir_params = dict(ftype='cheby1', gpass=1, gstop=20, output='ba')
iir_params = construct_iir_filter(iir_params, 40, 80, 1000, 'low')
# this should be a third order filter
- assert_true(iir_params['a'].size - 1 == 3)
- assert_true(iir_params['b'].size - 1 == 3)
- iir_params = dict(ftype='butter', order=4)
+ assert_equal(iir_params['a'].size - 1, 3)
+ assert_equal(iir_params['b'].size - 1, 3)
+ iir_params = dict(ftype='butter', order=4, output='ba')
+ iir_params = construct_iir_filter(iir_params, 40, None, 1000, 'low')
+ assert_equal(iir_params['a'].size - 1, 4)
+ assert_equal(iir_params['b'].size - 1, 4)
+ iir_params = dict(ftype='cheby1', gpass=1, gstop=20, output='sos')
+ iir_params = construct_iir_filter(iir_params, 40, 80, 1000, 'low')
+ # this should be a third order filter, which requires 2 SOS ((2, 6))
+ assert_equal(iir_params['sos'].shape, (2, 6))
+ iir_params = dict(ftype='butter', order=4, output='sos')
iir_params = construct_iir_filter(iir_params, 40, None, 1000, 'low')
- assert_true(iir_params['a'].size - 1 == 4)
- assert_true(iir_params['b'].size - 1 == 4)
+ assert_equal(iir_params['sos'].shape, (2, 6))
# check that picks work for 3d array with one channel and picks=[0]
a = rng.randn(5 * sfreq, 5 * sfreq)
b = a[:, None, :]
- with warnings.catch_warnings(record=True) as w:
- a_filt = band_pass_filter(a, sfreq, 4, 8)
- b_filt = band_pass_filter(b, sfreq, 4, 8, picks=[0])
+ a_filt = band_pass_filter(a, sfreq, 4, 8, 400, 2.0, 2.0, phase='zero',
+ fir_window='hamming')
+ b_filt = band_pass_filter(b, sfreq, 4, 8, 400, 2.0, 2.0, picks=[0],
+ phase='zero', fir_window='hamming')
assert_array_equal(a_filt[:, None, :], b_filt)
# check for n-dimensional case
a = rng.randn(2, 2, 2, 2)
- assert_raises(ValueError, band_pass_filter, a, sfreq, Fp1=4, Fp2=8,
- picks=np.array([0, 1]))
+ with warnings.catch_warnings(record=True): # filter too long
+ assert_raises(ValueError, band_pass_filter, a, sfreq, 4, 8, 100,
+ 1.0, 1.0, picks=np.array([0, 1]), phase='zero')
+
+def test_filter_auto():
+ """Test filter auto parameters"""
# test that our overlap-add filtering doesn't introduce strange
# artifacts (from mne_analyze mailing list 2015/06/25)
N = 300
@@ -293,16 +396,35 @@ def test_filters():
lp = 10.
sine_freq = 1.
x = np.ones(N)
- x += np.sin(2 * np.pi * sine_freq * np.arange(N) / sfreq)
- with warnings.catch_warnings(record=True): # filter attenuation
- x_filt = low_pass_filter(x, sfreq, lp, '1s')
+ t = np.arange(N) / sfreq
+ x += np.sin(2 * np.pi * sine_freq * t)
+ x_orig = x.copy()
+ x_filt = low_pass_filter(x, sfreq, lp, 'auto', 'auto', phase='zero',
+ fir_window='hamming')
+ assert_array_equal(x, x_orig)
# the firwin2 function gets us this close
- assert_allclose(x, x_filt, rtol=1e-3, atol=1e-3)
+ assert_allclose(x, x_filt, rtol=1e-4, atol=1e-4)
+ assert_array_equal(x_filt, low_pass_filter(
+ x, sfreq, lp, 'auto', 'auto', phase='zero', fir_window='hamming'))
+ assert_array_equal(x, x_orig)
+ assert_array_equal(x_filt, filter_data(
+ x, sfreq, None, lp, h_trans_bandwidth='auto', phase='zero',
+ fir_window='hamming', filter_length='auto'))
+ assert_array_equal(x, x_orig)
+ assert_array_equal(x_filt, filter_data(
+ x, sfreq, None, lp, h_trans_bandwidth='auto', phase='zero',
+ fir_window='hamming', filter_length='auto', copy=False))
+ assert_array_equal(x, x_filt)
+
+ # degenerate conditions
+ assert_raises(ValueError, filter_data, x, -sfreq, 1, 10)
+ assert_raises(ValueError, filter_data, x, sfreq, 1, sfreq * 0.75)
+ assert_raises(TypeError, filter_data, x.astype(np.float32), sfreq, None,
+ 10, filter_length='auto', h_trans_bandwidth='auto')
def test_cuda():
- """Test CUDA-based filtering
- """
+ """Test CUDA-based filtering"""
# NOTE: don't make test_cuda() the last test, or pycuda might spew
# some warnings about clean-up failing
# Also, using `n_jobs='cuda'` on a non-CUDA system should be fine,
@@ -312,21 +434,28 @@ def test_cuda():
a = rng.randn(sig_len_secs * sfreq)
with catch_logging() as log_file:
- for fl in ['10s', None, 2048]:
- bp = band_pass_filter(a, sfreq, 4, 8, n_jobs=1, filter_length=fl)
- bs = band_stop_filter(a, sfreq, 4 - 0.5, 8 + 0.5, n_jobs=1,
- filter_length=fl)
- lp = low_pass_filter(a, sfreq, 8, n_jobs=1, filter_length=fl)
- hp = high_pass_filter(lp, sfreq, 4, n_jobs=1, filter_length=fl)
-
- bp_c = band_pass_filter(a, sfreq, 4, 8, n_jobs='cuda',
- filter_length=fl, verbose='INFO')
- bs_c = band_stop_filter(a, sfreq, 4 - 0.5, 8 + 0.5, n_jobs='cuda',
- filter_length=fl, verbose='INFO')
- lp_c = low_pass_filter(a, sfreq, 8, n_jobs='cuda',
- filter_length=fl, verbose='INFO')
- hp_c = high_pass_filter(lp, sfreq, 4, n_jobs='cuda',
- filter_length=fl, verbose='INFO')
+ for fl in ['auto', '10s', 2048]:
+ bp = band_pass_filter(a, sfreq, 4, 8, fl, 1.0, 1.0, n_jobs=1,
+ phase='zero', fir_window='hann')
+ bs = band_stop_filter(a, sfreq, 4 - 1.0, 8 + 1.0, fl, 1.0, 1.0,
+ n_jobs=1, phase='zero', fir_window='hann')
+ lp = low_pass_filter(a, sfreq, 8, fl, 1.0, n_jobs=1, phase='zero',
+ fir_window='hann')
+ hp = high_pass_filter(lp, sfreq, 4, fl, 1.0, n_jobs=1,
+ phase='zero', fir_window='hann')
+
+ bp_c = band_pass_filter(a, sfreq, 4, 8, fl, 1.0, 1.0,
+ n_jobs='cuda', verbose='INFO',
+ phase='zero', fir_window='hann')
+ bs_c = band_stop_filter(a, sfreq, 4 - 1.0, 8 + 1.0, fl, 1.0, 1.0,
+ n_jobs='cuda', verbose='INFO',
+ phase='zero', fir_window='hann')
+ lp_c = low_pass_filter(a, sfreq, 8, fl, 1.0,
+ n_jobs='cuda', verbose='INFO',
+ phase='zero', fir_window='hann')
+ hp_c = high_pass_filter(lp, sfreq, 4, fl, 1.0,
+ n_jobs='cuda', verbose='INFO',
+ phase='zero', fir_window='hann')
assert_array_almost_equal(bp, bp_c, 12)
assert_array_almost_equal(bs, bs_c, 12)
@@ -358,8 +487,7 @@ def test_cuda():
def test_detrend():
- """Test zeroth and first order detrending
- """
+ """Test zeroth and first order detrending"""
x = np.arange(10)
assert_array_almost_equal(detrend(x, 1), np.zeros_like(x))
x = np.ones(10)
diff --git a/mne/tests/test_fixes.py b/mne/tests/test_fixes.py
index bf647e5..65e971b 100644
--- a/mne/tests/test_fixes.py
+++ b/mne/tests/test_fixes.py
@@ -4,193 +4,21 @@
# License: BSD
import numpy as np
+from scipy.signal import filtfilt
-from nose.tools import assert_equal, assert_raises, assert_true
from numpy.testing import assert_array_equal
-from distutils.version import LooseVersion
-from scipy import signal, sparse
from mne.utils import run_tests_if_main
-from mne.fixes import (_in1d, _tril_indices, _copysign, _unravel_index,
- _Counter, _unique, _bincount, _digitize,
- _sparse_block_diag, _matrix_rank, _meshgrid,
- _isclose)
-from mne.fixes import _firwin2 as mne_firwin2
-from mne.fixes import _filtfilt as mne_filtfilt
-
-rng = np.random.RandomState(0)
-
-
-def test_counter():
- """Test Counter replacement"""
- import collections
- try:
- Counter = collections.Counter
- except Exception:
- pass
- else:
- a = Counter([1, 2, 1, 3])
- b = _Counter([1, 2, 1, 3])
- c = _Counter()
- c.update(b)
- for key, count in zip([1, 2, 3], [2, 1, 1]):
- assert_equal(a[key], b[key])
- assert_equal(a[key], c[key])
-
-
-def test_unique():
- """Test unique() replacement
- """
- # skip test for np version < 1.5
- if LooseVersion(np.__version__) < LooseVersion('1.5'):
- return
- for arr in [np.array([]), rng.rand(10), np.ones(10)]:
- # basic
- assert_array_equal(np.unique(arr), _unique(arr))
- # with return_index=True
- x1, x2 = np.unique(arr, return_index=True, return_inverse=False)
- y1, y2 = _unique(arr, return_index=True, return_inverse=False)
- assert_array_equal(x1, y1)
- assert_array_equal(x2, y2)
- # with return_inverse=True
- x1, x2 = np.unique(arr, return_index=False, return_inverse=True)
- y1, y2 = _unique(arr, return_index=False, return_inverse=True)
- assert_array_equal(x1, y1)
- assert_array_equal(x2, y2)
- # with both:
- x1, x2, x3 = np.unique(arr, return_index=True, return_inverse=True)
- y1, y2, y3 = _unique(arr, return_index=True, return_inverse=True)
- assert_array_equal(x1, y1)
- assert_array_equal(x2, y2)
- assert_array_equal(x3, y3)
-
-
-def test_bincount():
- """Test bincount() replacement
- """
- # skip test for np version < 1.6
- if LooseVersion(np.__version__) < LooseVersion('1.6'):
- return
- for minlength in [None, 100]:
- x = _bincount(np.ones(10, int), None, minlength)
- y = np.bincount(np.ones(10, int), None, minlength)
- assert_array_equal(x, y)
-
-
-def test_in1d():
- """Test numpy.in1d() replacement"""
- a = np.arange(10)
- b = a[a % 2 == 0]
- assert_equal(_in1d(a, b).sum(), 5)
-
-
-def test_digitize():
- """Test numpy.digitize() replacement"""
- data = np.arange(9)
- bins = [0, 5, 10]
- left = np.array([1, 1, 1, 1, 1, 2, 2, 2, 2])
- right = np.array([0, 1, 1, 1, 1, 1, 2, 2, 2])
-
- assert_array_equal(_digitize(data, bins), left)
- assert_array_equal(_digitize(data, bins, True), right)
- assert_raises(NotImplementedError, _digitize, data + 0.1, bins, True)
- assert_raises(NotImplementedError, _digitize, data, [0., 5, 10], True)
-
-
-def test_tril_indices():
- """Test numpy.tril_indices() replacement"""
- il1 = _tril_indices(4)
- il2 = _tril_indices(4, -1)
-
- a = np.array([[1, 2, 3, 4],
- [5, 6, 7, 8],
- [9, 10, 11, 12],
- [13, 14, 15, 16]])
-
- assert_array_equal(a[il1],
- np.array([1, 5, 6, 9, 10, 11, 13, 14, 15, 16]))
-
- assert_array_equal(a[il2], np.array([5, 9, 10, 13, 14, 15]))
-
-
-def test_unravel_index():
- """Test numpy.unravel_index() replacement"""
- assert_equal(_unravel_index(2, (2, 3)), (0, 2))
- assert_equal(_unravel_index(2, (2, 2)), (1, 0))
- assert_equal(_unravel_index(254, (17, 94)), (2, 66))
- assert_equal(_unravel_index((2 * 3 + 1) * 6 + 4, (4, 3, 6)), (2, 1, 4))
- assert_array_equal(_unravel_index(np.array([22, 41, 37]), (7, 6)),
- [[3, 6, 6], [4, 5, 1]])
- assert_array_equal(_unravel_index(1621, (6, 7, 8, 9)), (3, 1, 4, 1))
-
-
-def test_copysign():
- """Test numpy.copysign() replacement"""
- a = np.array([-1, 1, -1])
- b = np.array([1, -1, 1])
-
- assert_array_equal(_copysign(a, b), b)
- assert_array_equal(_copysign(b, a), a)
-
-
-def test_firwin2():
- """Test firwin2 backport
- """
- taps1 = mne_firwin2(150, [0.0, 0.5, 1.0], [1.0, 1.0, 0.0])
- taps2 = signal.firwin2(150, [0.0, 0.5, 1.0], [1.0, 1.0, 0.0])
- assert_array_equal(taps1, taps2)
+from mne.fixes import _sosfiltfilt as mne_sosfiltfilt
def test_filtfilt():
- """Test IIR filtfilt replacement
- """
+ """Test SOS filtfilt replacement"""
x = np.r_[1, np.zeros(100)]
# Filter with an impulse
- y = mne_filtfilt([1, 0], [1, 0], x, padlen=0)
+ y = filtfilt([1, 0], [1, 0], x, padlen=0)
+ assert_array_equal(x, y)
+ y = mne_sosfiltfilt(np.array([[1., 0., 0., 1, 0., 0.]]), x, padlen=0)
assert_array_equal(x, y)
-
-
-def test_sparse_block_diag():
- """Test sparse block diag replacement"""
- x = _sparse_block_diag([sparse.eye(2, 2), sparse.eye(2, 2)])
- x = x - sparse.eye(4, 4)
- x.eliminate_zeros()
- assert_equal(len(x.data), 0)
-
-
-def test_rank():
- """Test rank replacement"""
- assert_equal(_matrix_rank(np.ones(10)), 1)
- assert_equal(_matrix_rank(np.eye(10)), 10)
- assert_equal(_matrix_rank(np.ones((10, 10))), 1)
- assert_raises(TypeError, _matrix_rank, np.ones((10, 10, 10)))
-
-
-def test_meshgrid():
- """Test meshgrid replacement
- """
- a = np.arange(10)
- b = np.linspace(0, 1, 5)
- a_grid, b_grid = _meshgrid(a, b, indexing='ij')
- for grid in (a_grid, b_grid):
- assert_equal(grid.shape, (a.size, b.size))
- a_grid, b_grid = _meshgrid(a, b, indexing='xy', copy=True)
- for grid in (a_grid, b_grid):
- assert_equal(grid.shape, (b.size, a.size))
- assert_raises(TypeError, _meshgrid, a, b, foo='a')
- assert_raises(ValueError, _meshgrid, a, b, indexing='foo')
-
-
-def test_isclose():
- """Test isclose replacement
- """
- a = rng.randn(10)
- b = a.copy()
- assert_true(_isclose(a, b).all())
- a[0] = np.inf
- b[0] = np.inf
- a[-1] = np.nan
- b[-1] = np.nan
- assert_true(_isclose(a, b, equal_nan=True).all())
run_tests_if_main()
diff --git a/mne/tests/test_import_nesting.py b/mne/tests/test_import_nesting.py
index 36d0a20..cc4c264 100644
--- a/mne/tests/test_import_nesting.py
+++ b/mne/tests/test_import_nesting.py
@@ -24,7 +24,7 @@ if len(bad) > 0:
out.append('Found un-nested scipy submodules: %s' % list(bad))
# check sklearn and others
-_sklearn = _pandas = _nose = False
+_sklearn = _pandas = _nose = _mayavi = False
for x in sys.modules.keys():
if x.startswith('sklearn') and not _sklearn:
out.append('Found un-nested sklearn import')
@@ -35,6 +35,9 @@ for x in sys.modules.keys():
if x.startswith('nose') and not _nose:
out.append('Found un-nested nose import')
_nose = True
+ if x.startswith('mayavi') and not _mayavi:
+ out.append('Found un-nested mayavi import')
+ _mayavi = True
if len(out) > 0:
print('\\n' + '\\n'.join(out), end='')
exit(1)
diff --git a/mne/tests/test_label.py b/mne/tests/test_label.py
index 26907fe..c41acff 100644
--- a/mne/tests/test_label.py
+++ b/mne/tests/test_label.py
@@ -3,7 +3,6 @@ import os.path as op
import shutil
import glob
import warnings
-import sys
import numpy as np
from scipy import sparse
@@ -19,7 +18,7 @@ from mne import (read_label, stc_to_label, read_source_estimate,
from mne.label import Label, _blend_colors, label_sign_flip
from mne.utils import (_TempDir, requires_sklearn, get_subjects_dir,
run_tests_if_main, slow_test)
-from mne.fixes import digitize, in1d, assert_is, assert_is_not
+from mne.fixes import assert_is, assert_is_not
from mne.label import _n_colors
from mne.source_space import SourceSpaces
from mne.source_estimate import mesh_edges
@@ -264,21 +263,21 @@ def test_label_in_src():
# construct label from source space vertices
vert_in_src = np.intersect1d(label.vertices, src[0]['vertno'], True)
- where = in1d(label.vertices, vert_in_src)
+ where = np.in1d(label.vertices, vert_in_src)
pos_in_src = label.pos[where]
values_in_src = label.values[where]
label_src = Label(vert_in_src, pos_in_src, values_in_src,
hemi='lh').fill(src)
# check label vertices
- vertices_status = in1d(src[0]['nearest'], label.vertices)
+ vertices_status = np.in1d(src[0]['nearest'], label.vertices)
vertices_in = np.nonzero(vertices_status)[0]
vertices_out = np.nonzero(np.logical_not(vertices_status))[0]
assert_array_equal(label_src.vertices, vertices_in)
- assert_array_equal(in1d(vertices_out, label_src.vertices), False)
+ assert_array_equal(np.in1d(vertices_out, label_src.vertices), False)
# check values
- value_idx = digitize(src[0]['nearest'][vertices_in], vert_in_src, True)
+ value_idx = np.digitize(src[0]['nearest'][vertices_in], vert_in_src, True)
assert_array_equal(label_src.values, values_in_src[value_idx])
# test exception
@@ -397,9 +396,7 @@ def test_read_labels_from_annot():
for label in labels_lh:
assert_true(label.name.endswith('-lh'))
assert_true(label.hemi == 'lh')
- # XXX fails on 2.6 for some reason...
- if sys.version_info[:2] > (2, 6):
- assert_is_not(label.color, None)
+ assert_is_not(label.color, None)
# read labels using annot_fname
annot_fname = op.join(subjects_dir, 'sample', 'label', 'rh.aparc.annot')
@@ -571,7 +568,7 @@ def test_write_labels_to_annot():
label0 = labels_lh[0]
label1 = labels_reloaded[-1]
assert_equal(label1.name, "unknown-lh")
- assert_true(np.all(in1d(label0.vertices, label1.vertices)))
+ assert_true(np.all(np.in1d(label0.vertices, label1.vertices)))
# unnamed labels
labels4 = labels[:]
@@ -713,7 +710,7 @@ def test_morph():
label.values.fill(1)
label = label.morph(None, 'fsaverage', 5, grade, subjects_dir, 1)
label = label.morph('fsaverage', 'sample', 5, None, subjects_dir, 2)
- assert_true(np.mean(in1d(label_orig.vertices, label.vertices)) == 1.0)
+ assert_true(np.in1d(label_orig.vertices, label.vertices).all())
assert_true(len(label.vertices) < 3 * len(label_orig.vertices))
vals.append(label.vertices)
assert_array_equal(vals[0], vals[1])
@@ -746,7 +743,7 @@ def test_grow_labels():
for label, seed, hemi, sh, name in zip(labels, seeds, tgt_hemis,
should_be_in, tgt_names):
assert_true(np.any(label.vertices == seed))
- assert_true(np.all(in1d(sh, label.vertices)))
+ assert_true(np.all(np.in1d(sh, label.vertices)))
assert_equal(label.hemi, hemi)
assert_equal(label.name, name)
@@ -775,6 +772,7 @@ def test_grow_labels():
@testing.requires_testing_data
def test_label_sign_flip():
+ """Test label sign flip computation"""
src = read_source_spaces(src_fname)
label = Label(vertices=src[0]['vertno'][:5], hemi='lh')
src[0]['nn'][label.vertices] = np.array(
@@ -791,4 +789,55 @@ def test_label_sign_flip():
len(idx))
+ at testing.requires_testing_data
+def test_label_center_of_mass():
+ """Test computing the center of mass of a label"""
+ stc = read_source_estimate(stc_fname)
+ stc.lh_data[:] = 0
+ vertex_stc = stc.center_of_mass('sample', subjects_dir=subjects_dir)[0]
+ assert_equal(vertex_stc, 124791)
+ label = Label(stc.vertices[1], pos=None, values=stc.rh_data.mean(axis=1),
+ hemi='rh', subject='sample')
+ vertex_label = label.center_of_mass(subjects_dir=subjects_dir)
+ assert_equal(vertex_label, vertex_stc)
+
+ labels = read_labels_from_annot('sample', parc='aparc.a2009s',
+ subjects_dir=subjects_dir)
+ src = read_source_spaces(src_fname)
+ # Try a couple of random ones, one from left and one from right
+ # Visually verified in about the right place using mne_analyze
+ for label, expected in zip([labels[2], labels[3], labels[-5]],
+ [141162, 145221, 55979]):
+ label.values[:] = -1
+ assert_raises(ValueError, label.center_of_mass,
+ subjects_dir=subjects_dir)
+ label.values[:] = 1
+ assert_equal(label.center_of_mass(subjects_dir=subjects_dir), expected)
+ assert_equal(label.center_of_mass(subjects_dir=subjects_dir,
+ restrict_vertices=label.vertices),
+ expected)
+ # restrict to source space
+ idx = 0 if label.hemi == 'lh' else 1
+ # this simple nearest version is not equivalent, but is probably
+ # close enough for many labels (including the test ones):
+ pos = label.pos[np.where(label.vertices == expected)[0][0]]
+ pos = (src[idx]['rr'][src[idx]['vertno']] - pos)
+ pos = np.argmin(np.sum(pos * pos, axis=1))
+ src_expected = src[idx]['vertno'][pos]
+ # see if we actually get the same one
+ src_restrict = np.intersect1d(label.vertices, src[idx]['vertno'])
+ assert_equal(label.center_of_mass(subjects_dir=subjects_dir,
+ restrict_vertices=src_restrict),
+ src_expected)
+ assert_equal(label.center_of_mass(subjects_dir=subjects_dir,
+ restrict_vertices=src),
+ src_expected)
+ # degenerate cases
+ assert_raises(ValueError, label.center_of_mass, subjects_dir=subjects_dir,
+ restrict_vertices='foo')
+ assert_raises(TypeError, label.center_of_mass, subjects_dir=subjects_dir,
+ surf=1)
+ assert_raises(IOError, label.center_of_mass, subjects_dir=subjects_dir,
+ surf='foo')
+
run_tests_if_main()
diff --git a/mne/tests/test_line_endings.py b/mne/tests/test_line_endings.py
index a327a01..41d604b 100644
--- a/mne/tests/test_line_endings.py
+++ b/mne/tests/test_line_endings.py
@@ -17,6 +17,11 @@ skip_files = (
'FreeSurferColorLUT.txt',
'test_edf_stim_channel.txt',
'FieldTrip.py',
+ 'license.txt',
+ # part of testing compatibility with older BV formats is testing
+ # the line endings and coding schemes used there
+ 'test_old_layout_latin1_software_filter.vhdr',
+ 'test_old_layout_latin1_software_filter.vmrk'
)
@@ -50,8 +55,7 @@ def _assert_line_endings(dir_):
def test_line_endings():
- """Test line endings of mne-python
- """
+ """Test line endings of mne-python"""
tempdir = _TempDir()
with open(op.join(tempdir, 'foo'), 'wb') as fid:
fid.write('bad\r\ngood\n'.encode('ascii'))
diff --git a/mne/tests/test_proj.py b/mne/tests/test_proj.py
index 959e0cc..a91cf25 100644
--- a/mne/tests/test_proj.py
+++ b/mne/tests/test_proj.py
@@ -11,7 +11,7 @@ import copy as cp
import mne
from mne.datasets import testing
from mne import pick_types
-from mne.io import Raw
+from mne.io import read_raw_fif
from mne import compute_proj_epochs, compute_proj_evoked, compute_proj_raw
from mne.io.proj import (make_projector, activate_proj,
_needs_eeg_average_ref_proj)
@@ -19,8 +19,7 @@ from mne.proj import (read_proj, write_proj, make_eeg_average_ref_proj,
_has_eeg_average_ref_proj)
from mne import read_events, Epochs, sensitivity_map, read_source_estimate
from mne.tests.common import assert_naming
-from mne.utils import (_TempDir, run_tests_if_main, clean_warning_registry,
- slow_test)
+from mne.utils import _TempDir, run_tests_if_main, slow_test
warnings.simplefilter('always') # enable b/c these tests throw warnings
@@ -41,9 +40,8 @@ ecg_fname = op.join(sample_path, 'sample_audvis_ecg-proj.fif')
def test_bad_proj():
- """Test dealing with bad projection application
- """
- raw = Raw(raw_fname, preload=True)
+ """Test dealing with bad projection application."""
+ raw = read_raw_fif(raw_fname, preload=True, add_eeg_ref=False)
events = read_events(event_fname)
picks = pick_types(raw.info, meg=True, stim=False, ecg=False,
eog=False, exclude='bads')
@@ -58,11 +56,12 @@ def test_bad_proj():
def _check_warnings(raw, events, picks, count=3):
- """Helper to count warnings"""
+ """Helper to count warnings."""
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter('always')
Epochs(raw, events, dict(aud_l=1, vis_l=3),
- -0.2, 0.5, picks=picks, preload=True, proj=True)
+ -0.2, 0.5, picks=picks, preload=True, proj=True,
+ add_eeg_ref=False)
assert_equal(len(w), count)
for ww in w:
assert_true('dangerous' in str(ww.message))
@@ -70,7 +69,7 @@ def _check_warnings(raw, events, picks, count=3):
@testing.requires_testing_data
def test_sensitivity_maps():
- """Test sensitivity map computation"""
+ """Test sensitivity map computation."""
fwd = mne.read_forward_solution(fwd_fname, surf_ori=True)
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter('always')
@@ -119,17 +118,17 @@ def test_sensitivity_maps():
def test_compute_proj_epochs():
- """Test SSP computation on epochs"""
+ """Test SSP computation on epochs."""
tempdir = _TempDir()
event_id, tmin, tmax = 1, -0.2, 0.3
- raw = Raw(raw_fname, preload=True)
+ raw = read_raw_fif(raw_fname, preload=True, add_eeg_ref=False)
events = read_events(event_fname)
bad_ch = 'MEG 2443'
picks = pick_types(raw.info, meg=True, eeg=False, stim=False, eog=False,
exclude=[])
epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks,
- baseline=None, proj=False)
+ baseline=None, proj=False, add_eeg_ref=False)
evoked = epochs.average()
projs = compute_proj_epochs(epochs, n_grad=1, n_mag=1, n_eeg=0, n_jobs=1)
@@ -186,7 +185,6 @@ def test_compute_proj_epochs():
assert_allclose(proj, proj_par, rtol=1e-8, atol=1e-16)
# test warnings on bad filenames
- clean_warning_registry()
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter('always')
proj_badname = op.join(tempdir, 'test-bad-name.fif.gz')
@@ -201,7 +199,7 @@ def test_compute_proj_raw():
tempdir = _TempDir()
# Test that the raw projectors work
raw_time = 2.5 # Do shorter amount for speed
- raw = Raw(raw_fname).crop(0, raw_time)
+ raw = read_raw_fif(raw_fname, add_eeg_ref=False).crop(0, raw_time)
raw.load_data()
for ii in (0.25, 0.5, 1, 2):
with warnings.catch_warnings(record=True) as w:
@@ -264,8 +262,8 @@ def test_compute_proj_raw():
def test_make_eeg_average_ref_proj():
- """Test EEG average reference projection"""
- raw = Raw(raw_fname, add_eeg_ref=False, preload=True)
+ """Test EEG average reference projection."""
+ raw = read_raw_fif(raw_fname, add_eeg_ref=False, preload=True)
eeg = mne.pick_types(raw.info, meg=False, eeg=True)
# No average EEG reference
@@ -287,26 +285,27 @@ def test_has_eeg_average_ref_proj():
"""Test checking whether an EEG average reference exists"""
assert_true(not _has_eeg_average_ref_proj([]))
- raw = Raw(raw_fname, add_eeg_ref=True, preload=False)
+ raw = read_raw_fif(raw_fname, add_eeg_ref=False, preload=False)
+ raw.set_eeg_reference()
assert_true(_has_eeg_average_ref_proj(raw.info['projs']))
def test_needs_eeg_average_ref_proj():
"""Test checking whether a recording needs an EEG average reference"""
- raw = Raw(raw_fname, add_eeg_ref=False, preload=False)
+ raw = read_raw_fif(raw_fname, add_eeg_ref=False, preload=False)
assert_true(_needs_eeg_average_ref_proj(raw.info))
- raw = Raw(raw_fname, add_eeg_ref=True, preload=False)
+ raw.set_eeg_reference()
assert_true(not _needs_eeg_average_ref_proj(raw.info))
# No EEG channels
- raw = Raw(raw_fname, add_eeg_ref=False, preload=True)
+ raw = read_raw_fif(raw_fname, add_eeg_ref=False, preload=True)
eeg = [raw.ch_names[c] for c in pick_types(raw.info, meg=False, eeg=True)]
raw.drop_channels(eeg)
assert_true(not _needs_eeg_average_ref_proj(raw.info))
# Custom ref flag set
- raw = Raw(raw_fname, add_eeg_ref=False, preload=False)
+ raw = read_raw_fif(raw_fname, add_eeg_ref=False, preload=False)
raw.info['custom_ref_applied'] = True
assert_true(not _needs_eeg_average_ref_proj(raw.info))
diff --git a/mne/tests/test_report.py b/mne/tests/test_report.py
index f10167b..d13be10 100644
--- a/mne/tests/test_report.py
+++ b/mne/tests/test_report.py
@@ -2,18 +2,19 @@
# Teon Brooks <teon.brooks at gmail.com>
#
# License: BSD (3-clause)
-import sys
+
+import glob
import os
import os.path as op
-import glob
-import warnings
import shutil
+import sys
+import warnings
from nose.tools import assert_true, assert_equal, assert_raises
from nose.plugins.skip import SkipTest
from mne import Epochs, read_events, pick_types, read_evokeds
-from mne.io import Raw
+from mne.io import read_raw_fif
from mne.datasets import testing
from mne.report import Report
from mne.utils import (_TempDir, requires_mayavi, requires_nibabel,
@@ -48,8 +49,7 @@ warnings.simplefilter('always') # enable b/c these tests throw warnings
@testing.requires_testing_data
@requires_PIL
def test_render_report():
- """Test rendering -*.fif files for mne report.
- """
+ """Test rendering -*.fif files for mne report."""
tempdir = _TempDir()
raw_fname_new = op.join(tempdir, 'temp_raw.fif')
event_fname_new = op.join(tempdir, 'temp_raw-eve.fif')
@@ -66,9 +66,10 @@ def test_render_report():
# create and add -epo.fif and -ave.fif files
epochs_fname = op.join(tempdir, 'temp-epo.fif')
evoked_fname = op.join(tempdir, 'temp-ave.fif')
- raw = Raw(raw_fname_new)
+ raw = read_raw_fif(raw_fname_new, add_eeg_ref=False)
picks = pick_types(raw.info, meg='mag', eeg=False) # faster with one type
- epochs = Epochs(raw, read_events(event_fname), 1, -0.2, 0.2, picks=picks)
+ epochs = Epochs(raw, read_events(event_fname), 1, -0.2, 0.2, picks=picks,
+ add_eeg_ref=False)
epochs.save(epochs_fname)
epochs.average().save(evoked_fname)
@@ -79,6 +80,7 @@ def test_render_report():
warnings.simplefilter('always')
report.parse_folder(data_path=tempdir, on_error='raise')
assert_true(len(w) >= 1)
+ assert_true(repr(report))
# Check correct paths and filenames
fnames = glob.glob(op.join(tempdir, '*.fif'))
@@ -89,6 +91,7 @@ def test_render_report():
assert_equal(len(report.fnames), len(fnames))
assert_equal(len(report.html), len(report.fnames))
+ assert_equal(len(report.fnames), len(report))
# Check saving functionality
report.data_path = tempdir
@@ -113,6 +116,7 @@ def test_render_report():
warnings.simplefilter('always')
report.parse_folder(data_path=tempdir, pattern=pattern)
assert_true(len(w) >= 1)
+ assert_true(repr(report))
fnames = glob.glob(op.join(tempdir, '*.raw')) + \
glob.glob(op.join(tempdir, '*.raw'))
@@ -126,8 +130,7 @@ def test_render_report():
@requires_mayavi
@requires_PIL
def test_render_add_sections():
- """Test adding figures/images to section.
- """
+ """Test adding figures/images to section."""
from PIL import Image
tempdir = _TempDir()
import matplotlib.pyplot as plt
@@ -171,6 +174,7 @@ def test_render_add_sections():
report.add_figs_to_section(figs=fig, # test non-list input
captions='random image', scale=1.2)
+ assert_true(repr(report))
@slow_test
@@ -178,8 +182,7 @@ def test_render_add_sections():
@requires_mayavi
@requires_nibabel()
def test_render_mri():
- """Test rendering MRI for mne report.
- """
+ """Test rendering MRI for mne report."""
tempdir = _TempDir()
trans_fname_new = op.join(tempdir, 'temp-trans.fif')
for a, b in [[trans_fname, trans_fname_new]]:
@@ -191,13 +194,13 @@ def test_render_mri():
report.parse_folder(data_path=tempdir, mri_decim=30, pattern='*',
n_jobs=2)
report.save(op.join(tempdir, 'report.html'), open_browser=False)
+ assert_true(repr(report))
@testing.requires_testing_data
@requires_nibabel()
def test_render_mri_without_bem():
- """Test rendering MRI without BEM for mne report.
- """
+ """Test rendering MRI without BEM for mne report."""
tempdir = _TempDir()
os.mkdir(op.join(tempdir, 'sample'))
os.mkdir(op.join(tempdir, 'sample', 'mri'))
@@ -214,8 +217,7 @@ def test_render_mri_without_bem():
@testing.requires_testing_data
@requires_nibabel()
def test_add_htmls_to_section():
- """Test adding html str to mne report.
- """
+ """Test adding html str to mne report."""
report = Report(info_fname=raw_fname,
subject='sample', subjects_dir=subjects_dir)
html = '<b>MNE-Python is AWESOME</b>'
@@ -224,11 +226,11 @@ def test_add_htmls_to_section():
idx = report._sectionlabels.index('report_' + section)
html_compare = report.html[idx]
assert_true(html in html_compare)
+ assert_true(repr(report))
def test_add_slider_to_section():
- """Test adding a slider with a series of images to mne report.
- """
+ """Test adding a slider with a series of images to mne report."""
tempdir = _TempDir()
from matplotlib import pyplot as plt
report = Report(info_fname=raw_fname,
@@ -251,6 +253,7 @@ def test_add_slider_to_section():
def test_validate_input():
+ """Test Report input validation."""
report = Report()
items = ['a', 'b', 'c']
captions = ['Letter A', 'Letter B', 'Letter C']
diff --git a/mne/tests/test_selection.py b/mne/tests/test_selection.py
index ba6a8aa..f9b75e3 100644
--- a/mne/tests/test_selection.py
+++ b/mne/tests/test_selection.py
@@ -20,7 +20,7 @@ def test_read_selection():
'Right-parietal', 'Left-occipital', 'Right-occipital',
'Left-frontal', 'Right-frontal']
- raw = read_raw_fif(raw_fname)
+ raw = read_raw_fif(raw_fname, add_eeg_ref=False)
for i, name in enumerate(sel_names):
sel = read_selection(name)
assert_true(ch_names[i] in sel)
@@ -40,7 +40,7 @@ def test_read_selection():
assert_true(len(set(frontal).intersection(set(occipital))) == 0)
ch_names_new = [ch.replace(' ', '') for ch in ch_names]
- raw_new = read_raw_fif(raw_new_fname)
+ raw_new = read_raw_fif(raw_new_fname, add_eeg_ref=False)
for i, name in enumerate(sel_names):
sel = read_selection(name, info=raw_new.info)
assert_true(ch_names_new[i] in sel)
diff --git a/mne/tests/test_source_estimate.py b/mne/tests/test_source_estimate.py
index f5d2743..d6a907c 100644
--- a/mne/tests/test_source_estimate.py
+++ b/mne/tests/test_source_estimate.py
@@ -49,7 +49,7 @@ rng = np.random.RandomState(0)
@testing.requires_testing_data
-def test_aaspatial_inter_hemi_connectivity():
+def test_spatial_inter_hemi_connectivity():
"""Test spatial connectivity between hemispheres"""
# trivial cases
conn = spatial_inter_hemi_connectivity(fname_src_3, 5e-6)
@@ -140,14 +140,12 @@ def test_volume_stc():
with warnings.catch_warnings(record=True): # nib<->numpy
img = nib.load(vol_fname)
assert_true(img.shape == t1_img.shape + (len(stc.times),))
- assert_array_almost_equal(img.get_affine(), t1_img.get_affine(),
- decimal=5)
+ assert_array_almost_equal(img.affine, t1_img.affine, decimal=5)
# export without saving
img = stc.as_volume(src, dest='mri', mri_resolution=True)
assert_true(img.shape == t1_img.shape + (len(stc.times),))
- assert_array_almost_equal(img.get_affine(), t1_img.get_affine(),
- decimal=5)
+ assert_array_almost_equal(img.affine, t1_img.affine, decimal=5)
except ImportError:
print('Save as nifti test skipped, needs NiBabel')
@@ -271,8 +269,7 @@ def test_stc_arithmetic():
@slow_test
@testing.requires_testing_data
def test_stc_methods():
- """Test stc methods lh_data, rh_data, bin(), center_of_mass(), resample()
- """
+ """Test stc methods lh_data, rh_data, bin, center_of_mass, resample"""
stc = read_source_estimate(fname_stc)
# lh_data / rh_data
@@ -286,6 +283,8 @@ def test_stc_methods():
assert a[0] == bin.data[0, 0]
assert_raises(ValueError, stc.center_of_mass, 'sample')
+ assert_raises(TypeError, stc.center_of_mass, 'sample',
+ subjects_dir=subjects_dir, surf=1)
stc.lh_data[:] = 0
vertex, hemi, t = stc.center_of_mass('sample', subjects_dir=subjects_dir)
assert_true(hemi == 1)
@@ -441,6 +440,11 @@ def test_morph_data():
stc_to1 = stc_from.morph(subject_to, grade=3, smooth=12, buffer_size=1000,
subjects_dir=subjects_dir)
stc_to1.save(op.join(tempdir, '%s_audvis-meg' % subject_to))
+ # Morphing to a density that is too high should raise an informative error
+ # (here we need to push to grade=6, but for some subjects even grade=5
+ # will break)
+ assert_raises(ValueError, stc_to1.morph, subject_from, grade=6,
+ subjects_dir=subjects_dir)
# make sure we can specify vertices
vertices_to = grade_to_vertices(subject_to, grade=3,
subjects_dir=subjects_dir)
diff --git a/mne/tests/test_source_space.py b/mne/tests/test_source_space.py
index 5fe5b8d..f6b184a 100644
--- a/mne/tests/test_source_space.py
+++ b/mne/tests/test_source_space.py
@@ -205,6 +205,8 @@ def test_discrete_source_space():
# now do MRI
assert_raises(ValueError, setup_volume_source_space, 'sample',
pos=pos_dict, mri=fname_mri)
+ assert_equal(repr(src_new), repr(src_c))
+ assert_equal(src_new.kind, 'discrete')
finally:
if op.isfile(temp_name):
os.remove(temp_name)
@@ -233,6 +235,8 @@ def test_volume_source_space():
assert_raises(IOError, setup_volume_source_space, 'sample', temp_name,
pos=7.0, bem=None, surface='foo', # bad surf
mri=fname_mri, subjects_dir=subjects_dir)
+ assert_equal(repr(src), repr(src_new))
+ assert_equal(src.kind, 'volume')
@testing.requires_testing_data
@@ -254,6 +258,7 @@ def test_other_volume_source_spaces():
mri=fname_mri,
subjects_dir=subjects_dir)
_compare_source_spaces(src, src_new, mode='approx')
+ assert_true('volume, shape' in repr(src))
del src
del src_new
assert_raises(ValueError, setup_volume_source_space, 'sample', temp_name,
@@ -341,6 +346,8 @@ def test_setup_source_space():
subjects_dir=subjects_dir, add_dist=False,
overwrite=True)
_compare_source_spaces(src, src_new, mode='approx')
+ assert_equal(repr(src), repr(src_new))
+ assert_equal(repr(src).count('surface ('), 2)
assert_array_equal(src[0]['vertno'], np.arange(10242))
assert_array_equal(src[1]['vertno'], np.arange(10242))
@@ -531,6 +538,8 @@ def test_combine_source_spaces():
src.save(src_out_name)
src_from_file = read_source_spaces(src_out_name)
_compare_source_spaces(src, src_from_file, mode='approx')
+ assert_equal(repr(src), repr(src_from_file))
+ assert_equal(src.kind, 'combined')
# test that all source spaces are in MRI coordinates
coord_frames = np.array([s['coord_frame'] for s in src])
@@ -551,9 +560,9 @@ def test_combine_source_spaces():
# unrecognized file type
bad_image_fname = op.join(tempdir, 'temp-image.png')
- with warnings.catch_warnings(record=True): # vertices outside vol space
- assert_raises(ValueError, src.export_volume, bad_image_fname,
- verbose='error')
+ # vertices outside vol space warning
+ assert_raises(ValueError, src.export_volume, bad_image_fname,
+ verbose='error')
# mixed coordinate frames
disc3 = disc.copy()
diff --git a/mne/tests/test_surface.py b/mne/tests/test_surface.py
index 7394882..23c7255 100644
--- a/mne/tests/test_surface.py
+++ b/mne/tests/test_surface.py
@@ -13,9 +13,11 @@ from mne import read_surface, write_surface, decimate_surface
from mne.surface import (read_morph_map, _compute_nearest,
fast_cross_3d, get_head_surf, read_curvature,
get_meg_helmet_surf)
-from mne.utils import _TempDir, requires_mayavi, run_tests_if_main, slow_test
+from mne.utils import (_TempDir, requires_mayavi, requires_tvtk,
+ run_tests_if_main, slow_test)
from mne.io import read_info
from mne.transforms import _get_trans
+from mne.io.meas_info import _is_equal_dict
data_path = testing.data_path(download=False)
subjects_dir = op.join(data_path, 'subjects')
@@ -27,8 +29,7 @@ rng = np.random.RandomState(0)
def test_helmet():
- """Test loading helmet surfaces
- """
+ """Test loading helmet surfaces."""
base_dir = op.join(op.dirname(__file__), '..', 'io')
fname_raw = op.join(base_dir, 'tests', 'data', 'test_raw.fif')
fname_kit_raw = op.join(base_dir, 'kit', 'tests', 'data',
@@ -47,8 +48,7 @@ def test_helmet():
@testing.requires_testing_data
def test_head():
- """Test loading the head surface
- """
+ """Test loading the head surface."""
surf_1 = get_head_surf('sample', subjects_dir=subjects_dir)
surf_2 = get_head_surf('sample', 'head', subjects_dir=subjects_dir)
assert_true(len(surf_1['rr']) < len(surf_2['rr'])) # BEM vs dense head
@@ -57,8 +57,7 @@ def test_head():
def test_huge_cross():
- """Test cross product with lots of elements
- """
+ """Test cross product with lots of elements."""
x = rng.rand(100000, 3)
y = rng.rand(1, 3)
z = np.cross(x, y)
@@ -67,7 +66,7 @@ def test_huge_cross():
def test_compute_nearest():
- """Test nearest neighbor searches"""
+ """Test nearest neighbor searches."""
x = rng.randn(500, 3)
x /= np.sqrt(np.sum(x ** 2, axis=1))[:, None]
nn_true = rng.permutation(np.arange(500, dtype=np.int))[:20]
@@ -91,8 +90,7 @@ def test_compute_nearest():
@slow_test
@testing.requires_testing_data
def test_make_morph_maps():
- """Test reading and creating morph maps
- """
+ """Test reading and creating morph maps."""
# make a new fake subjects_dir
tempdir = _TempDir()
for subject in ('sample', 'sample_ds', 'fsaverage_ds'):
@@ -122,25 +120,28 @@ def test_make_morph_maps():
@testing.requires_testing_data
def test_io_surface():
- """Test reading and writing of Freesurfer surface mesh files
- """
+ """Test reading and writing of Freesurfer surface mesh files."""
tempdir = _TempDir()
fname_quad = op.join(data_path, 'subjects', 'bert', 'surf',
'lh.inflated.nofix')
fname_tri = op.join(data_path, 'subjects', 'fsaverage', 'surf',
'lh.inflated')
for fname in (fname_quad, fname_tri):
- pts, tri = read_surface(fname)
- write_surface(op.join(tempdir, 'tmp'), pts, tri)
- c_pts, c_tri = read_surface(op.join(tempdir, 'tmp'))
+ with warnings.catch_warnings(record=True) as w:
+ pts, tri, vol_info = read_surface(fname, read_metadata=True)
+ assert_true(all('No volume info' in str(ww.message) for ww in w))
+ write_surface(op.join(tempdir, 'tmp'), pts, tri, volume_info=vol_info)
+ with warnings.catch_warnings(record=True) as w: # No vol info
+ c_pts, c_tri, c_vol_info = read_surface(op.join(tempdir, 'tmp'),
+ read_metadata=True)
assert_array_equal(pts, c_pts)
assert_array_equal(tri, c_tri)
+ assert_true(_is_equal_dict([vol_info, c_vol_info]))
@testing.requires_testing_data
def test_read_curv():
- """Test reading curvature data
- """
+ """Test reading curvature data."""
fname_curv = op.join(data_path, 'subjects', 'fsaverage', 'surf', 'lh.curv')
fname_surf = op.join(data_path, 'subjects', 'fsaverage', 'surf',
'lh.inflated')
@@ -150,10 +151,10 @@ def test_read_curv():
assert_true(np.logical_or(bin_curv == 0, bin_curv == 1).all())
+ at requires_tvtk
@requires_mayavi
def test_decimate_surface():
- """Test triangular surface decimation
- """
+ """Test triangular surface decimation."""
points = np.array([[-0.00686118, -0.10369860, 0.02615170],
[-0.00713948, -0.10370162, 0.02614874],
[-0.00686208, -0.10368247, 0.02588313],
diff --git a/mne/tests/test_transforms.py b/mne/tests/test_transforms.py
index 9f81adf..841165e 100644
--- a/mne/tests/test_transforms.py
+++ b/mne/tests/test_transforms.py
@@ -54,7 +54,7 @@ def test_io_trans():
assert_raises(RuntimeError, _find_trans, 'sample', subjects_dir=tempdir)
trans0 = read_trans(fname)
fname1 = op.join(tempdir, 'sample', 'test-trans.fif')
- write_trans(fname1, trans0)
+ trans0.save(fname1)
assert_true(fname1 == _find_trans('sample', subjects_dir=tempdir))
trans1 = read_trans(fname1)
diff --git a/mne/tests/test_utils.py b/mne/tests/test_utils.py
index 33f079e..4669bec 100644
--- a/mne/tests/test_utils.py
+++ b/mne/tests/test_utils.py
@@ -8,8 +8,11 @@ import os
import warnings
from mne import read_evokeds
+from mne.datasets import testing
from mne.externals.six.moves import StringIO
-from mne.io import show_fiff
+from mne.io import show_fiff, read_raw_fif
+from mne.epochs import _segment_raw
+from mne.time_frequency import tfr_morlet
from mne.utils import (set_log_level, set_log_file, _TempDir,
get_config, set_config, deprecated, _fetch_file,
sum_squared, estimate_rank,
@@ -21,7 +24,9 @@ from mne.utils import (set_log_level, set_log_file, _TempDir,
set_memmap_min_size, _get_stim_channel, _check_fname,
create_slices, _time_mask, random_permutation,
_get_call_line, compute_corr, sys_info, verbose,
- check_fname, requires_ftp)
+ check_fname, requires_ftp, get_config_path,
+ object_size, buggy_mkl_svd, _get_inst_data,
+ copy_doc, copy_function_doc_to_method_doc)
warnings.simplefilter('always') # enable b/c these tests throw warnings
@@ -32,12 +37,33 @@ fname_raw = op.join(base_dir, 'test_raw.fif')
fname_log = op.join(base_dir, 'test-ave.log')
fname_log_2 = op.join(base_dir, 'test-ave-2.log')
+data_path = testing.data_path(download=False)
+fname_fsaverage_trans = op.join(data_path, 'subjects', 'fsaverage', 'bem',
+ 'fsaverage-trans.fif')
+
def clean_lines(lines=[]):
# Function to scrub filenames for checking logging output (in test_logging)
return [l if 'Reading ' not in l else 'Reading test file' for l in lines]
+def test_buggy_mkl():
+ """Test decorator for buggy MKL issues"""
+ from nose.plugins.skip import SkipTest
+
+ @buggy_mkl_svd
+ def foo(a, b):
+ raise np.linalg.LinAlgError('SVD did not converge')
+ with warnings.catch_warnings(record=True) as w:
+ assert_raises(SkipTest, foo, 1, 2)
+ assert_true(all('convergence error' in str(ww.message) for ww in w))
+
+ @buggy_mkl_svd
+ def bar(c, d, e):
+ raise RuntimeError('SVD did not converge')
+ assert_raises(RuntimeError, bar, 1, 2, 3)
+
+
def test_sys_info():
"""Test info-showing utility
"""
@@ -65,6 +91,49 @@ def test_get_call_line():
assert_equal(my_line, 'my_line = bar() # testing more')
+def test_object_size():
+ """Test object size estimation"""
+ assert_true(object_size(np.ones(10, np.float32)) <
+ object_size(np.ones(10, np.float64)))
+ for lower, upper, obj in ((0, 60, ''),
+ (0, 30, 1),
+ (0, 30, 1.),
+ (0, 60, 'foo'),
+ (0, 150, np.ones(0)),
+ (0, 150, np.int32(1)),
+ (150, 500, np.ones(20)),
+ (100, 400, dict()),
+ (400, 1000, dict(a=np.ones(50))),
+ (200, 900, sparse.eye(20, format='csc')),
+ (200, 900, sparse.eye(20, format='csr'))):
+ size = object_size(obj)
+ assert_true(lower < size < upper,
+ msg='%s < %s < %s:\n%s' % (lower, size, upper, obj))
+
+
+def test_get_inst_data():
+ """Test _get_inst_data"""
+ raw = read_raw_fif(fname_raw, add_eeg_ref=False)
+ raw.crop(tmax=1.)
+ assert_equal(_get_inst_data(raw), raw._data)
+ raw.pick_channels(raw.ch_names[:2])
+
+ epochs = _segment_raw(raw, 0.5)
+ assert_equal(_get_inst_data(epochs), epochs._data)
+
+ evoked = epochs.average()
+ assert_equal(_get_inst_data(evoked), evoked.data)
+
+ evoked.crop(tmax=0.1)
+ picks = list(range(2))
+ freqs = np.array([50., 55.])
+ n_cycles = 3
+ tfr = tfr_morlet(evoked, freqs, n_cycles, return_itc=False, picks=picks)
+ assert_equal(_get_inst_data(tfr), tfr.data)
+
+ assert_raises(TypeError, _get_inst_data, 'foo')
+
+
def test_misc():
"""Test misc utilities"""
assert_equal(_memory_usage(-1)[0], -1)
@@ -250,13 +319,6 @@ def test_logging():
old_lines = clean_lines(old_log_file.readlines())
with open(fname_log_2, 'r') as old_log_file_2:
old_lines_2 = clean_lines(old_log_file_2.readlines())
- # we changed our logging a little bit
- old_lines = [o.replace('No baseline correction applied...',
- 'No baseline correction applied')
- for o in old_lines]
- old_lines_2 = [o.replace('No baseline correction applied...',
- 'No baseline correction applied')
- for o in old_lines_2]
if op.isfile(test_name):
os.remove(test_name)
@@ -298,8 +360,7 @@ def test_logging():
evoked = read_evokeds(fname_evoked, condition=1)
with open(test_name, 'r') as new_log_file:
new_lines = clean_lines(new_log_file.readlines())
- with open(fname_log, 'r') as old_log_file:
- assert_equal(new_lines, old_lines)
+ assert_equal(new_lines, old_lines)
# check to make sure appending works (and as default, raises a warning)
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter('always')
@@ -336,15 +397,20 @@ def test_config():
assert_true(len(set_config(None, None)) > 10) # tuple of valid keys
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter('always')
- set_config(key, None, home_dir=tempdir)
+ set_config(key, None, home_dir=tempdir, set_env=False)
assert_true(len(w) == 1)
assert_true(get_config(key, home_dir=tempdir) is None)
assert_raises(KeyError, get_config, key, raise_error=True)
with warnings.catch_warnings(record=True):
warnings.simplefilter('always')
- set_config(key, value, home_dir=tempdir)
+ assert_true(key not in os.environ)
+ set_config(key, value, home_dir=tempdir, set_env=True)
+ assert_true(key in os.environ)
assert_true(get_config(key, home_dir=tempdir) == value)
- set_config(key, None, home_dir=tempdir)
+ set_config(key, None, home_dir=tempdir, set_env=True)
+ assert_true(key not in os.environ)
+ set_config(key, None, home_dir=tempdir, set_env=True)
+ assert_true(key not in os.environ)
if old_val is not None:
os.environ[key] = old_val
# Check if get_config with no input returns all config
@@ -354,8 +420,18 @@ def test_config():
warnings.simplefilter('always')
set_config(key, value, home_dir=tempdir)
assert_equal(get_config(home_dir=tempdir), config)
+ # Check what happens when we use a corrupted file
+ json_fname = get_config_path(home_dir=tempdir)
+ with open(json_fname, 'w') as fid:
+ fid.write('foo{}')
+ with warnings.catch_warnings(record=True) as w:
+ assert_equal(get_config(home_dir=tempdir), dict())
+ assert_true(any('not a valid JSON' in str(ww.message) for ww in w))
+ with warnings.catch_warnings(record=True) as w: # non-standard key
+ assert_raises(RuntimeError, set_config, key, 'true', home_dir=tempdir)
+ at testing.requires_testing_data
def test_show_fiff():
"""Test show_fiff
"""
@@ -366,6 +442,7 @@ def test_show_fiff():
'FIFF_EPOCH']
assert_true(all(key in info for key in keys))
info = show_fiff(fname_raw, read_limit=1024)
+ assert_true('COORD_TRANS' in show_fiff(fname_fsaverage_trans))
@deprecated('message')
@@ -562,4 +639,122 @@ def test_random_permutation():
assert_array_equal(python_randperm, matlab_randperm - 1)
+def test_copy_doc():
+ '''Test decorator for copying docstrings'''
+ class A:
+ def m1():
+ """Docstring for m1"""
+ pass
+
+ class B:
+ def m1():
+ pass
+
+ class C (A):
+ @copy_doc(A.m1)
+ def m1():
+ pass
+
+ assert_equal(C.m1.__doc__, 'Docstring for m1')
+ assert_raises(ValueError, copy_doc(B.m1), C.m1)
+
+
+def test_copy_function_doc_to_method_doc():
+ '''Test decorator for re-using function docstring as method docstrings'''
+ def f1(object, a, b, c):
+ """Docstring for f1
+
+ Parameters
+ ----------
+ object : object
+ Some object. This description also has
+
+ blank lines in it.
+ a : int
+ Parameter a
+ b : int
+ Parameter b
+ """
+ pass
+
+ def f2(object):
+ """Docstring for f2
+
+ Parameters
+ ----------
+ object : object
+ Only one parameter
+
+ Returns
+ -------
+ nothing.
+ """
+ pass
+
+ def f3(object):
+ """Docstring for f3
+
+ Parameters
+ ----------
+ object : object
+ Only one parameter
+ """
+ pass
+
+ def f4(object):
+ """Docstring for f4"""
+ pass
+
+ def f5(object):
+ """Docstring for f5
+
+ Parameters
+ ----------
+ Returns
+ -------
+ nothing.
+ """
+ pass
+
+ class A:
+ @copy_function_doc_to_method_doc(f1)
+ def method_f1(self, a, b, c):
+ pass
+
+ @copy_function_doc_to_method_doc(f2)
+ def method_f2(self):
+ "method_f3 own docstring"
+ pass
+
+ @copy_function_doc_to_method_doc(f3)
+ def method_f3(self):
+ pass
+
+ assert_equal(
+ A.method_f1.__doc__,
+ """Docstring for f1
+
+ Parameters
+ ----------
+ a : int
+ Parameter a
+ b : int
+ Parameter b
+ """
+ )
+
+ assert_equal(
+ A.method_f2.__doc__,
+ """Docstring for f2
+
+ Returns
+ -------
+ nothing.
+ method_f3 own docstring"""
+ )
+
+ assert_equal(A.method_f3.__doc__, 'Docstring for f3\n\n ')
+ assert_raises(ValueError, copy_function_doc_to_method_doc(f4), A.method_f1)
+ assert_raises(ValueError, copy_function_doc_to_method_doc(f5), A.method_f1)
+
run_tests_if_main()
diff --git a/mne/time_frequency/__init__.py b/mne/time_frequency/__init__.py
index 327a5c8..72ab5b2 100644
--- a/mne/time_frequency/__init__.py
+++ b/mne/time_frequency/__init__.py
@@ -2,10 +2,12 @@
"""
from .tfr import (single_trial_power, morlet, tfr_morlet, cwt_morlet,
- AverageTFR, tfr_multitaper, read_tfrs, write_tfrs)
-from .psd import compute_raw_psd, compute_epochs_psd, psd_welch, psd_multitaper
-from .csd import CrossSpectralDensity, compute_epochs_csd
+ AverageTFR, tfr_multitaper, read_tfrs, write_tfrs,
+ EpochsTFR)
+from .psd import psd_welch, psd_multitaper
+from .csd import (CrossSpectralDensity, compute_epochs_csd, csd_epochs,
+ csd_array)
from .ar import fit_iir_model_raw
-from .multitaper import dpss_windows, multitaper_psd
+from .multitaper import dpss_windows
from .stft import stft, istft, stftfreq
from ._stockwell import tfr_stockwell
diff --git a/mne/time_frequency/_stockwell.py b/mne/time_frequency/_stockwell.py
index 4a4b867..0af3180 100644
--- a/mne/time_frequency/_stockwell.py
+++ b/mne/time_frequency/_stockwell.py
@@ -29,14 +29,15 @@ def _check_input_st(x_in, n_fft):
elif n_fft < n_times:
raise ValueError("n_fft cannot be smaller than signal size. "
"Got %s < %s." % (n_fft, n_times))
- zero_pad = None
if n_times < n_fft:
warn('The input signal is shorter ({0}) than "n_fft" ({1}). '
'Applying zero padding.'.format(x_in.shape[-1], n_fft))
zero_pad = n_fft - n_times
pad_array = np.zeros(x_in.shape[:-1] + (zero_pad,), x_in.dtype)
x_in = np.concatenate((x_in, pad_array), axis=-1)
- return x_in, n_fft, zero_pad
+ else:
+ zero_pad = 0
+ return x_in, n_fft, zero_pad
def _precompute_st_windows(n_samp, start_f, stop_f, sfreq, width):
@@ -83,7 +84,10 @@ def _st_power_itc(x, start_f, compute_itc, zero_pad, decim, W):
for i_f, window in enumerate(W):
f = start_f + i_f
ST = fftpack.ifft(XX[:, f:f + n_samp] * window)
- TFR = ST[:, :-zero_pad:decim]
+ if zero_pad > 0:
+ TFR = ST[:, :-zero_pad:decim]
+ else:
+ TFR = ST[:, ::decim]
TFR_abs = np.abs(TFR)
if compute_itc:
TFR /= TFR_abs
diff --git a/mne/time_frequency/csd.py b/mne/time_frequency/csd.py
index 4d5d25f..7f618ec 100644
--- a/mne/time_frequency/csd.py
+++ b/mne/time_frequency/csd.py
@@ -12,6 +12,9 @@ from ..utils import logger, verbose, warn
from ..time_frequency.multitaper import (dpss_windows, _mt_spectra,
_csd_from_mt, _psd_from_mt_adaptive)
+from ..utils import deprecated
+from ..externals.six.moves import xrange as range
+
class CrossSpectralDensity(object):
"""Cross-spectral density
@@ -48,11 +51,24 @@ class CrossSpectralDensity(object):
return '<CrossSpectralDensity | %s>' % s
+ at deprecated(("compute_epochs_csd has been deprecated and will be removed in "
+ "0.14, use csd_epochs instead."))
@verbose
def compute_epochs_csd(epochs, mode='multitaper', fmin=0, fmax=np.inf,
fsum=True, tmin=None, tmax=None, n_fft=None,
mt_bandwidth=None, mt_adaptive=False, mt_low_bias=True,
projs=None, verbose=None):
+ return csd_epochs(epochs, mode=mode, fmin=fmin, fmax=fmax,
+ fsum=fsum, tmin=tmin, tmax=tmax, n_fft=n_fft,
+ mt_bandwidth=mt_bandwidth, mt_adaptive=mt_adaptive,
+ mt_low_bias=mt_low_bias, projs=projs, verbose=verbose)
+
+
+ at verbose
+def csd_epochs(epochs, mode='multitaper', fmin=0, fmax=np.inf,
+ fsum=True, tmin=None, tmax=None, n_fft=None,
+ mt_bandwidth=None, mt_adaptive=False, mt_low_bias=True,
+ projs=None, verbose=None):
"""Estimate cross-spectral density from epochs
Note: Baseline correction should be used when creating the Epochs.
@@ -155,33 +171,8 @@ def compute_epochs_csd(epochs, mode='multitaper', fmin=0, fmax=np.inf,
# Preparing for computing CSD
logger.info('Computing cross-spectral density from epochs...')
- if mode == 'multitaper':
- # Compute standardized half-bandwidth
- if mt_bandwidth is not None:
- half_nbw = float(mt_bandwidth) * n_times / (2 * sfreq)
- else:
- half_nbw = 2
-
- # Compute DPSS windows
- n_tapers_max = int(2 * half_nbw)
- window_fun, eigvals = dpss_windows(n_times, half_nbw, n_tapers_max,
- low_bias=mt_low_bias)
- n_tapers = len(eigvals)
- logger.info(' using multitaper spectrum estimation with %d DPSS '
- 'windows' % n_tapers)
-
- if mt_adaptive and len(eigvals) < 3:
- warn('Not adaptively combining the spectral estimators due to a '
- 'low number of tapers.')
- mt_adaptive = False
- elif mode == 'fourier':
- logger.info(' using FFT with a Hanning window to estimate spectra')
- window_fun = np.hanning(n_times)
- mt_adaptive = False
- eigvals = 1.
- n_tapers = None
- else:
- raise ValueError('Mode has an invalid value.')
+ window_fun, eigvals, n_tapers, mt_adaptive = _compute_csd_params(
+ n_times, sfreq, mode, mt_bandwidth, mt_low_bias, mt_adaptive)
csds_mean = np.zeros((len(ch_names), len(ch_names), n_freqs),
dtype=complex)
@@ -195,33 +186,8 @@ def compute_epochs_csd(epochs, mode='multitaper', fmin=0, fmax=np.inf,
epoch = epoch[picks_meeg][:, tslice]
# Calculating Fourier transform using multitaper module
- x_mt, _ = _mt_spectra(epoch, window_fun, sfreq, n_fft)
-
- if mt_adaptive:
- # Compute adaptive weights
- _, weights = _psd_from_mt_adaptive(x_mt, eigvals, freq_mask,
- return_weights=True)
- # Tiling weights so that we can easily use _csd_from_mt()
- weights = weights[:, np.newaxis, :, :]
- weights = np.tile(weights, [1, x_mt.shape[0], 1, 1])
- else:
- # Do not use adaptive weights
- if mode == 'multitaper':
- weights = np.sqrt(eigvals)[np.newaxis, np.newaxis, :,
- np.newaxis]
- else:
- # Hack so we can sum over axis=-2
- weights = np.array([1.])[:, None, None, None]
-
- x_mt = x_mt[:, :, freq_mask_mt]
-
- # Calculating CSD
- # Tiling x_mt so that we can easily use _csd_from_mt()
- x_mt = x_mt[:, np.newaxis, :, :]
- x_mt = np.tile(x_mt, [1, x_mt.shape[0], 1, 1])
- y_mt = np.transpose(x_mt, axes=[1, 0, 2, 3])
- weights_y = np.transpose(weights, axes=[1, 0, 2, 3])
- csds_epoch = _csd_from_mt(x_mt, y_mt, weights, weights_y)
+ csds_epoch = _csd_array(epoch, sfreq, window_fun, eigvals, freq_mask,
+ freq_mask_mt, n_fft, mode, mt_adaptive)
# Scaling by number of samples and compensating for loss of power due
# to windowing (see section 11.5.2 in Bendat & Piersol).
@@ -255,3 +221,221 @@ def compute_epochs_csd(epochs, mode='multitaper', fmin=0, fmax=np.inf,
frequencies=frequencies[i],
n_fft=n_fft))
return csds
+
+
+ at verbose
+def csd_array(X, sfreq, mode='multitaper', fmin=0, fmax=np.inf,
+ fsum=True, n_fft=None, mt_bandwidth=None,
+ mt_adaptive=False, mt_low_bias=True, verbose=None):
+ """Estimate cross-spectral density from an array.
+
+ .. note:: Results are scaled by sampling frequency for compatibility with
+ Matlab.
+
+ Parameters
+ ----------
+ X : array-like, shape (n_replicates, n_series, n_times)
+ The time series data consisting of n_replicated separate observations
+ of signals with n_series components and of length n_times. For example,
+ n_replicates could be the number of epochs, and n_series the number of
+ vertices in a source-space.
+ sfreq : float
+ Sampling frequency of observations.
+ mode : str
+ Spectrum estimation mode can be either: 'multitaper' or 'fourier'.
+ fmin : float
+ Minimum frequency of interest.
+ fmax : float
+ Maximum frequency of interest.
+ fsum : bool
+ Sum CSD values for the frequencies of interest. Summing is performed
+ instead of averaging so that accumulated power is comparable to power
+ in the time domain. If True, a single CSD matrix will be returned. If
+ False, the output will be an array of CSD matrices.
+ n_fft : int | None
+ Length of the FFT. If None the exact number of samples between tmin and
+ tmax will be used.
+ mt_bandwidth : float | None
+ The bandwidth of the multitaper windowing function in Hz.
+ Only used in 'multitaper' mode.
+ mt_adaptive : bool
+ Use adaptive weights to combine the tapered spectra into PSD.
+ Only used in 'multitaper' mode.
+ mt_low_bias : bool
+ Only use tapers with more than 90% spectral concentration within
+ bandwidth. Only used in 'multitaper' mode.
+ verbose : bool, str, int, or None
+ If not None, override default verbose level (see mne.verbose).
+
+ Returns
+ -------
+ csd : array, shape (n_freqs, n_series, n_series) if fsum is True, otherwise (n_series, n_series).
+ The computed cross spectral-density (either summed or not).
+ freqs : array
+ Frequencies the cross spectral-density is evaluated at.
+ """ # noqa
+
+ # Check correctness of input data and parameters
+ if fmax < fmin:
+ raise ValueError('fmax must be larger than fmin')
+
+ X = np.asarray(X, dtype=float)
+ if X.ndim != 3:
+ raise ValueError("X must be n_replicates x n_series x n_times.")
+ n_replicates, n_series, n_times = X.shape
+
+ # Preparing frequencies of interest
+ n_fft = n_times if n_fft is None else n_fft
+ orig_frequencies = fftfreq(n_fft, 1. / sfreq)
+ freq_mask = (orig_frequencies > fmin) & (orig_frequencies < fmax)
+ frequencies = orig_frequencies[freq_mask]
+ n_freqs = len(frequencies)
+
+ if n_freqs == 0:
+ raise ValueError('No discrete fourier transform results within '
+ 'the given frequency window. Please widen either '
+ 'the frequency window or the time window')
+
+ # Preparing for computing CSD
+ logger.info('Computing cross-spectral density from array...')
+ window_fun, eigvals, n_tapers, mt_adaptive = _compute_csd_params(
+ n_times, sfreq, mode, mt_bandwidth, mt_low_bias, mt_adaptive)
+
+ csds_mean = np.zeros((n_series, n_series, n_freqs), dtype=complex)
+
+ # Picking frequencies of interest
+ freq_mask_mt = freq_mask[orig_frequencies >= 0]
+
+ # Compute CSD for each trial
+ for xi in X:
+
+ csds_trial = _csd_array(xi, sfreq, window_fun, eigvals, freq_mask,
+ freq_mask_mt, n_fft, mode, mt_adaptive)
+
+ # Scaling by number of trials and compensating for loss of power due
+ # to windowing (see section 11.5.2 in Bendat & Piersol).
+ if mode == 'fourier':
+ csds_trial /= n_times
+ csds_trial *= 8 / 3.
+
+ # Scaling by sampling frequency for compatibility with Matlab
+ csds_trial /= sfreq
+
+ csds_mean += csds_trial
+
+ csds_mean /= n_replicates
+
+ logger.info('[done]')
+
+ # Summing over frequencies of interest or returning a list of separate CSD
+ # matrices for each frequency
+ if fsum is True:
+ csds_mean = np.sum(csds_mean, 2)
+
+ return csds_mean, frequencies
+
+
+def _compute_csd_params(n_times, sfreq, mode, mt_bandwidth, mt_low_bias,
+ mt_adaptive):
+ """ Auxliary function to compute windowing and multitaper parameters.
+
+ Parameters
+ ----------
+ n_times : int
+ Number of time points.
+ s_freq : int
+ Sampling frequency of signal.
+ mode : str
+ Spectrum estimation mode can be either: 'multitaper' or 'fourier'.
+ mt_bandwidth : float | None
+ The bandwidth of the multitaper windowing function in Hz.
+ Only used in 'multitaper' mode.
+ mt_low_bias : bool
+ Only use tapers with more than 90% spectral concentration within
+ bandwidth. Only used in 'multitaper' mode.
+ mt_adaptive : bool
+ Use adaptive weights to combine the tapered spectra into PSD.
+ Only used in 'multitaper' mode.
+
+ Returns
+ -------
+ window_fun : array
+ Window function(s) of length n_times. When 'multitaper' mode is used
+ will correspond to first output of `dpss_windows` and when 'fourier'
+ mode is used will be a Hanning window of length `n_times`.
+ eigvals : array | float
+ Eigenvalues associated with wondow functions. Only needed when mode is
+ 'multitaper'. When the mode 'fourier' is used this is set to 1.
+ n_tapers : int | None
+ Number of tapers to use. Only used when mode is 'multitaper'.
+ ret_mt_adaptive : bool
+ Updated value of `mt_adaptive` argument as certain parameter values
+ will not allow adaptive spectral estimators.
+ """
+ ret_mt_adaptive = mt_adaptive
+ if mode == 'multitaper':
+ # Compute standardized half-bandwidth
+ if mt_bandwidth is not None:
+ half_nbw = float(mt_bandwidth) * n_times / (2. * sfreq)
+ else:
+ half_nbw = 2.
+
+ # Compute DPSS windows
+ n_tapers_max = int(2 * half_nbw)
+ window_fun, eigvals = dpss_windows(n_times, half_nbw, n_tapers_max,
+ low_bias=mt_low_bias)
+ n_tapers = len(eigvals)
+ logger.info(' using multitaper spectrum estimation with %d DPSS '
+ 'windows' % n_tapers)
+
+ if mt_adaptive and len(eigvals) < 3:
+ warn('Not adaptively combining the spectral estimators due to a '
+ 'low number of tapers.')
+ ret_mt_adaptive = False
+ elif mode == 'fourier':
+ logger.info(' using FFT with a Hanning window to estimate spectra')
+ window_fun = np.hanning(n_times)
+ ret_mt_adaptive = False
+ eigvals = 1.
+ n_tapers = None
+ else:
+ raise ValueError('Mode has an invalid value.')
+
+ return window_fun, eigvals, n_tapers, ret_mt_adaptive
+
+
+def _csd_array(x, sfreq, window_fun, eigvals, freq_mask, freq_mask_mt, n_fft,
+ mode, mt_adaptive):
+ """ Calculating Fourier transform using multitaper module.
+
+ The arguments correspond to the values in `compute_csd_epochs` and
+ `csd_array`.
+ """
+ x_mt, _ = _mt_spectra(x, window_fun, sfreq, n_fft)
+
+ if mt_adaptive:
+ # Compute adaptive weights
+ _, weights = _psd_from_mt_adaptive(x_mt, eigvals, freq_mask,
+ return_weights=True)
+ # Tiling weights so that we can easily use _csd_from_mt()
+ weights = weights[:, np.newaxis, :, :]
+ weights = np.tile(weights, [1, x_mt.shape[0], 1, 1])
+ else:
+ # Do not use adaptive weights
+ if mode == 'multitaper':
+ weights = np.sqrt(eigvals)[np.newaxis, np.newaxis, :, np.newaxis]
+ else:
+ # Hack so we can sum over axis=-2
+ weights = np.array([1.])[:, np.newaxis, np.newaxis, np.newaxis]
+
+ x_mt = x_mt[:, :, freq_mask_mt]
+
+ # Calculating CSD
+ # Tiling x_mt so that we can easily use _csd_from_mt()
+ x_mt = x_mt[:, np.newaxis, :, :]
+ x_mt = np.tile(x_mt, [1, x_mt.shape[0], 1, 1])
+ y_mt = np.transpose(x_mt, axes=[1, 0, 2, 3])
+ weights_y = np.transpose(weights, axes=[1, 0, 2, 3])
+ csds = _csd_from_mt(x_mt, y_mt, weights, weights_y)
+
+ return csds
diff --git a/mne/time_frequency/multitaper.py b/mne/time_frequency/multitaper.py
index a5b616e..945ff8b 100644
--- a/mne/time_frequency/multitaper.py
+++ b/mne/time_frequency/multitaper.py
@@ -7,7 +7,7 @@ import numpy as np
from scipy import fftpack, linalg
from ..parallel import parallel_func
-from ..utils import verbose, sum_squared, deprecated, warn
+from ..utils import sum_squared, warn
def tridisolve(d, e, b, overwrite_b=True):
@@ -494,7 +494,7 @@ def _psd_multitaper(x, sfreq, fmin=0, fmax=np.inf, bandwidth=None,
See Also
--------
- mne.io.Raw.plot_psd, mne.Epochs.plot_psd
+ mne.io.Raw.plot_psd, mne.Epochs.plot_psd, csd_epochs
Notes
-----
@@ -559,59 +559,3 @@ def _psd_multitaper(x, sfreq, fmin=0, fmax=np.inf, bandwidth=None,
if ndim_in == 1:
psd = psd[0]
return psd, freqs
-
-
- at deprecated('This will be deprecated in release v0.12, see psd_multitaper.')
- at verbose
-def multitaper_psd(x, sfreq=2 * np.pi, fmin=0, fmax=np.inf, bandwidth=None,
- adaptive=False, low_bias=True, n_jobs=1,
- normalization='length', verbose=None):
- """Compute power spectrum density (PSD) using a multi-taper method
-
- Parameters
- ----------
- x : array, shape=(n_signals, n_times) or (n_times,)
- The data to compute PSD from.
- sfreq : float
- The sampling frequency.
- fmin : float
- The lower frequency of interest.
- fmax : float
- The upper frequency of interest.
- bandwidth : float
- The bandwidth of the multi taper windowing function in Hz.
- adaptive : bool
- Use adaptive weights to combine the tapered spectra into PSD
- (slow, use n_jobs >> 1 to speed up computation).
- low_bias : bool
- Only use tapers with more than 90% spectral concentration within
- bandwidth.
- n_jobs : int
- Number of parallel jobs to use (only used if adaptive=True).
- normalization : str
- Either "full" or "length" (default). If "full", the PSD will
- be normalized by the sampling rate as well as the length of
- the signal (as in nitime).
- verbose : bool, str, int, or None
- If not None, override default verbose level (see mne.verbose).
-
- Returns
- -------
- psd : array, shape=(n_signals, len(freqs)) or (len(freqs),)
- The computed PSD.
- freqs : array
- The frequency points in Hz of the PSD.
-
- See Also
- --------
- mne.io.Raw.plot_psd
- mne.Epochs.plot_psd
-
- Notes
- -----
- .. versionadded:: 0.9.0
- """
- return _psd_multitaper(x=x, sfreq=sfreq, fmin=fmin, fmax=fmax,
- bandwidth=bandwidth, adaptive=adaptive,
- low_bias=low_bias,
- normalization=normalization, n_jobs=n_jobs)
diff --git a/mne/time_frequency/psd.py b/mne/time_frequency/psd.py
index 1de44d1..950efc3 100644
--- a/mne/time_frequency/psd.py
+++ b/mne/time_frequency/psd.py
@@ -6,91 +6,10 @@ import numpy as np
from ..parallel import parallel_func
from ..io.pick import _pick_data_channels
-from ..utils import logger, verbose, deprecated, _time_mask
+from ..utils import logger, verbose, _time_mask
from .multitaper import _psd_multitaper
- at deprecated('This will be deprecated in release v0.12, see psd_welch.')
- at verbose
-def compute_raw_psd(raw, tmin=0., tmax=None, picks=None, fmin=0,
- fmax=np.inf, n_fft=2048, n_overlap=0,
- proj=False, n_jobs=1, verbose=None):
- """Compute power spectral density with average periodograms.
-
- Parameters
- ----------
- raw : instance of Raw
- The raw data.
- tmin : float
- Minimum time instant to consider (in seconds).
- tmax : float | None
- Maximum time instant to consider (in seconds). None will use the
- end of the file.
- picks : array-like of int | None
- The selection of channels to include in the computation.
- If None, take all channels.
- fmin : float
- Min frequency of interest
- fmax : float
- Max frequency of interest
- n_fft : int
- The length of the tapers ie. the windows. The smaller
- it is the smoother are the PSDs.
- n_overlap : int
- The number of points of overlap between blocks. The default value
- is 0 (no overlap).
- proj : bool
- Apply SSP projection vectors.
- n_jobs : int
- Number of CPUs to use in the computation.
- verbose : bool, str, int, or None
- If not None, override default verbose level (see mne.verbose).
-
- Returns
- -------
- psd : array of float
- The PSD for all channels
- freqs: array of float
- The frequencies
-
- See Also
- --------
- psd_welch, psd_multitaper
- """
- from scipy.signal import welch
- from ..io.base import _BaseRaw
- if not isinstance(raw, _BaseRaw):
- raise ValueError('Input must be an instance of Raw')
- tmax = raw.times[-1] if tmax is None else tmax
- start, stop = raw.time_as_index([tmin, tmax])
- picks = slice(None) if picks is None else picks
-
- if proj:
- # Copy first so it's not modified
- raw = raw.copy().apply_proj()
- data, times = raw[picks, start:(stop + 1)]
- n_fft, n_overlap = _check_nfft(len(times), n_fft, n_overlap)
-
- n_fft = int(n_fft)
- Fs = raw.info['sfreq']
-
- logger.info("Effective window size : %0.3f (s)" % (n_fft / float(Fs)))
-
- parallel, my_pwelch, n_jobs = parallel_func(_pwelch, n_jobs=n_jobs,
- verbose=verbose)
-
- freqs = np.arange(n_fft // 2 + 1) * (Fs / n_fft)
- freq_mask = (freqs >= fmin) & (freqs <= fmax)
- freqs = freqs[freq_mask]
-
- psds = np.array(parallel(my_pwelch([channel],
- noverlap=n_overlap, nfft=n_fft, fs=Fs,
- freq_mask=freq_mask, welch_fun=welch)
- for channel in data))[:, 0, :]
-
- return psds, freqs
-
-
def _pwelch(epoch, noverlap, nfft, fs, freq_mask, welch_fun):
"""Aux function"""
return welch_fun(epoch, nperseg=nfft, noverlap=noverlap,
@@ -250,7 +169,8 @@ def psd_welch(inst, fmin=0, fmax=np.inf, tmin=None, tmax=None, n_fft=256,
See Also
--------
- mne.io.Raw.plot_psd, mne.Epochs.plot_psd, psd_multitaper
+ mne.io.Raw.plot_psd, mne.Epochs.plot_psd, psd_multitaper,
+ csd_epochs
Notes
-----
@@ -329,7 +249,7 @@ def psd_multitaper(inst, fmin=0, fmax=np.inf, tmin=None, tmax=None,
See Also
--------
- mne.io.Raw.plot_psd, mne.Epochs.plot_psd, psd_welch
+ mne.io.Raw.plot_psd, mne.Epochs.plot_psd, psd_welch, csd_epochs
Notes
-----
@@ -341,94 +261,3 @@ def psd_multitaper(inst, fmin=0, fmax=np.inf, tmin=None, tmax=None,
bandwidth=bandwidth, adaptive=adaptive,
low_bias=low_bias,
normalization=normalization, n_jobs=n_jobs)
-
-
- at deprecated('This will be deprecated in release v0.12, see psd_welch.')
- at verbose
-def compute_epochs_psd(epochs, picks=None, fmin=0, fmax=np.inf, tmin=None,
- tmax=None, n_fft=256, n_overlap=0, proj=False,
- n_jobs=1, verbose=None):
- """Compute power spectral density with average periodograms.
-
- Parameters
- ----------
- epochs : instance of Epochs
- The epochs.
- picks : array-like of int | None
- The selection of channels to include in the computation.
- If None, take all channels.
- fmin : float
- Min frequency of interest
- fmax : float
- Max frequency of interest
- tmin : float | None
- Min time of interest
- tmax : float | None
- Max time of interest
- n_fft : int
- The length of the tapers ie. the windows. The smaller
- it is the smoother are the PSDs. The default value is 256.
- If ``n_fft > len(epochs.times)``, it will be adjusted down to
- ``len(epochs.times)``.
- n_overlap : int
- The number of points of overlap between blocks. Will be adjusted
- to be <= n_fft.
- proj : bool
- Apply SSP projection vectors.
- n_jobs : int
- Number of CPUs to use in the computation.
- verbose : bool, str, int, or None
- If not None, override default verbose level (see mne.verbose).
-
- Returns
- -------
- psds : ndarray (n_epochs, n_channels, n_freqs)
- The power spectral densities.
- freqs : ndarray, shape (n_freqs,)
- The frequencies.
-
- See Also
- --------
- psd_welch, psd_multitaper
- """
- from scipy.signal import welch
- from ..epochs import _BaseEpochs
- if not isinstance(epochs, _BaseEpochs):
- raise ValueError("Input must be an instance of Epochs")
- n_fft = int(n_fft)
- Fs = epochs.info['sfreq']
- if picks is None:
- picks = _pick_data_channels(epochs.info, with_ref_meg=False)
- n_fft, n_overlap = _check_nfft(len(epochs.times), n_fft, n_overlap)
-
- if tmin is not None or tmax is not None:
- time_mask = _time_mask(epochs.times, tmin, tmax,
- sfreq=epochs.info['sfreq'])
- else:
- time_mask = slice(None)
- if proj:
- # Copy first so it's not modified
- epochs = epochs.copy().apply_proj()
- data = epochs.get_data()[:, picks][:, :, time_mask]
-
- logger.info("Effective window size : %0.3f (s)" % (n_fft / float(Fs)))
-
- freqs = np.arange(n_fft // 2 + 1, dtype=float) * (Fs / n_fft)
- freq_mask = (freqs >= fmin) & (freqs <= fmax)
- freqs = freqs[freq_mask]
- psds = np.empty(data.shape[:-1] + (freqs.size,))
-
- parallel, my_pwelch, n_jobs = parallel_func(_pwelch, n_jobs=n_jobs,
- verbose=verbose)
-
- for idx, fepochs in zip(np.array_split(np.arange(len(data)), n_jobs),
- parallel(my_pwelch(epoch, noverlap=n_overlap,
- nfft=n_fft, fs=Fs,
- freq_mask=freq_mask,
- welch_fun=welch)
- for epoch in np.array_split(data,
- n_jobs))):
- for i_epoch, f_epoch in zip(idx, fepochs):
- psds[i_epoch, :, :] = f_epoch
-
- return psds, freqs
diff --git a/mne/time_frequency/tests/test_ar.py b/mne/time_frequency/tests/test_ar.py
index 146e2fd..15d128d 100644
--- a/mne/time_frequency/tests/test_ar.py
+++ b/mne/time_frequency/tests/test_ar.py
@@ -15,8 +15,7 @@ raw_fname = op.join(op.dirname(__file__), '..', '..', 'io', 'tests', 'data',
@requires_patsy
@requires_statsmodels
def test_yule_walker():
- """Test Yule-Walker against statsmodels
- """
+ """Test Yule-Walker against statsmodels."""
from statsmodels.regression.linear_model import yule_walker as sm_yw
d = np.random.randn(100)
sm_rho, sm_sigma = sm_yw(d, order=2)
@@ -26,9 +25,8 @@ def test_yule_walker():
def test_ar_raw():
- """Test fitting AR model on raw data
- """
- raw = io.read_raw_fif(raw_fname)
+ """Test fitting AR model on raw data."""
+ raw = io.read_raw_fif(raw_fname, add_eeg_ref=False)
# pick MEG gradiometers
picks = pick_types(raw.info, meg='grad', exclude='bads')
picks = picks[:2]
diff --git a/mne/time_frequency/tests/test_csd.py b/mne/time_frequency/tests/test_csd.py
index 753b191..61dfae2 100644
--- a/mne/time_frequency/tests/test_csd.py
+++ b/mne/time_frequency/tests/test_csd.py
@@ -6,9 +6,9 @@ import warnings
import mne
-from mne.io import Raw
+from mne.io import read_raw_fif
from mne.utils import sum_squared
-from mne.time_frequency import compute_epochs_csd, tfr_morlet
+from mne.time_frequency import csd_epochs, csd_array, tfr_morlet
warnings.simplefilter('always')
base_dir = op.join(op.dirname(__file__), '..', '..', 'io', 'tests', 'data')
@@ -16,49 +16,53 @@ raw_fname = op.join(base_dir, 'test_raw.fif')
event_fname = op.join(base_dir, 'test-eve.fif')
-def _get_data():
- # Read raw data
- raw = Raw(raw_fname)
- raw.info['bads'] = ['MEG 2443', 'EEG 053'] # 2 bads channels
-
- # Set picks
- picks = mne.pick_types(raw.info, meg=True, eeg=False, eog=False,
- stim=False, exclude='bads')
-
- # Read several epochs
- event_id, tmin, tmax = 1, -0.2, 0.5
+def _get_data(mode='real'):
+ """Get data."""
+ raw = read_raw_fif(raw_fname, add_eeg_ref=False)
events = mne.read_events(event_fname)[0:100]
- epochs = mne.Epochs(raw, events, event_id, tmin, tmax, proj=True,
- picks=picks, baseline=(None, 0), preload=True,
- reject=dict(grad=4000e-13, mag=4e-12))
+ if mode == 'real':
+ # Read raw data
+ raw.info['bads'] = ['MEG 2443', 'EEG 053'] # 2 bads channels
+
+ # Set picks
+ picks = mne.pick_types(raw.info, meg=True, eeg=False, eog=False,
+ stim=False, exclude='bads')
- # Create an epochs object with one epoch and one channel of artificial data
- event_id, tmin, tmax = 1, 0.0, 1.0
- epochs_sin = mne.Epochs(raw, events[0:5], event_id, tmin, tmax, proj=True,
+ # Read several epochs
+ event_id, tmin, tmax = 1, -0.2, 0.5
+ epochs = mne.Epochs(raw, events, event_id, tmin, tmax, proj=True,
+ picks=picks, baseline=(None, 0), preload=True,
+ reject=dict(grad=4000e-13, mag=4e-12),
+ add_eeg_ref=False)
+ elif mode == 'sin':
+ # Create an epochs object with one epoch and one channel of artificial
+ # data
+ event_id, tmin, tmax = 1, 0.0, 1.0
+ epochs = mne.Epochs(raw, events[0:5], event_id, tmin, tmax, proj=True,
picks=[0], baseline=(None, 0), preload=True,
- reject=dict(grad=4000e-13))
- freq = 10
- epochs_sin._data = np.sin(2 * np.pi * freq *
- epochs_sin.times)[None, None, :]
- return epochs, epochs_sin
+ reject=dict(grad=4000e-13), add_eeg_ref=False)
+ freq = 10
+ epochs._data = np.sin(2 * np.pi * freq *
+ epochs.times)[None, None, :]
+ return epochs
-def test_compute_epochs_csd():
- """Test computing cross-spectral density from epochs
- """
- epochs, epochs_sin = _get_data()
+
+def test_csd_epochs():
+ """Test computing cross-spectral density from epochs."""
+ epochs = _get_data(mode='real')
# Check that wrong parameters are recognized
- assert_raises(ValueError, compute_epochs_csd, epochs, mode='notamode')
- assert_raises(ValueError, compute_epochs_csd, epochs, fmin=20, fmax=10)
- assert_raises(ValueError, compute_epochs_csd, epochs, fmin=20, fmax=20.1)
- assert_raises(ValueError, compute_epochs_csd, epochs, tmin=0.15, tmax=0.1)
- assert_raises(ValueError, compute_epochs_csd, epochs, tmin=0, tmax=10)
- assert_raises(ValueError, compute_epochs_csd, epochs, tmin=10, tmax=11)
-
- data_csd_mt = compute_epochs_csd(epochs, mode='multitaper', fmin=8,
- fmax=12, tmin=0.04, tmax=0.15)
- data_csd_fourier = compute_epochs_csd(epochs, mode='fourier', fmin=8,
- fmax=12, tmin=0.04, tmax=0.15)
+ assert_raises(ValueError, csd_epochs, epochs, mode='notamode')
+ assert_raises(ValueError, csd_epochs, epochs, fmin=20, fmax=10)
+ assert_raises(ValueError, csd_epochs, epochs, fmin=20, fmax=20.1)
+ assert_raises(ValueError, csd_epochs, epochs, tmin=0.15, tmax=0.1)
+ assert_raises(ValueError, csd_epochs, epochs, tmin=0, tmax=10)
+ assert_raises(ValueError, csd_epochs, epochs, tmin=10, tmax=11)
+
+ data_csd_mt = csd_epochs(epochs, mode='multitaper', fmin=8, fmax=12,
+ tmin=0.04, tmax=0.15)
+ data_csd_fourier = csd_epochs(epochs, mode='fourier', fmin=8, fmax=12,
+ tmin=0.04, tmax=0.15)
# Check shape of the CSD matrix
n_chan = len(data_csd_mt.ch_names)
@@ -84,73 +88,70 @@ def test_compute_epochs_csd():
assert_equal(max_ch_fourier, max_ch_power)
# Maximum CSD should occur for specific channel
- ch_csd_mt = [np.abs(data_csd_mt.data[max_ch_power][i])
- if i != max_ch_power else 0 for i in range(n_chan)]
+ ch_csd_mt = np.abs(data_csd_mt.data[max_ch_power])
+ ch_csd_mt[max_ch_power] = 0.
max_ch_csd_mt = np.argmax(ch_csd_mt)
- ch_csd_fourier = [np.abs(data_csd_fourier.data[max_ch_power][i])
- if i != max_ch_power else 0 for i in range(n_chan)]
+ ch_csd_fourier = np.abs(data_csd_fourier.data[max_ch_power])
+ ch_csd_fourier[max_ch_power] = 0.
max_ch_csd_fourier = np.argmax(ch_csd_fourier)
assert_equal(max_ch_csd_mt, max_ch_csd_fourier)
# Check a list of CSD matrices is returned for multiple frequencies within
# a given range when fsum=False
- csd_fsum = compute_epochs_csd(epochs, mode='fourier', fmin=8, fmax=20,
- fsum=True)
- csds = compute_epochs_csd(epochs, mode='fourier', fmin=8, fmax=20,
- fsum=False)
+ csd_fsum = csd_epochs(epochs, mode='fourier', fmin=8, fmax=20, fsum=True)
+ csds = csd_epochs(epochs, mode='fourier', fmin=8, fmax=20, fsum=False)
freqs = [csd.frequencies[0] for csd in csds]
csd_sum = np.zeros_like(csd_fsum.data)
for csd in csds:
csd_sum += csd.data
- assert(len(csds) == 2)
- assert(len(csd_fsum.frequencies) == 2)
+ assert_equal(len(csds), 2)
+ assert_equal(len(csd_fsum.frequencies), 2)
assert_array_equal(csd_fsum.frequencies, freqs)
assert_array_equal(csd_fsum.data, csd_sum)
-def test_compute_epochs_csd_on_artificial_data():
- """Test computing CSD on artificial data
- """
- epochs, epochs_sin = _get_data()
- sfreq = epochs_sin.info['sfreq']
+def test_csd_epochs_on_artificial_data():
+ """Test computing CSD on artificial data."""
+ epochs = _get_data(mode='sin')
+ sfreq = epochs.info['sfreq']
# Computing signal power in the time domain
- signal_power = sum_squared(epochs_sin._data)
- signal_power_per_sample = signal_power / len(epochs_sin.times)
+ signal_power = sum_squared(epochs._data)
+ signal_power_per_sample = signal_power / len(epochs.times)
# Computing signal power in the frequency domain
- data_csd_fourier = compute_epochs_csd(epochs_sin, mode='fourier')
- data_csd_mt = compute_epochs_csd(epochs_sin, mode='multitaper')
+ data_csd_fourier = csd_epochs(epochs, mode='fourier')
+ data_csd_mt = csd_epochs(epochs, mode='multitaper')
fourier_power = np.abs(data_csd_fourier.data[0, 0]) * sfreq
mt_power = np.abs(data_csd_mt.data[0, 0]) * sfreq
assert_true(abs(fourier_power - signal_power) <= 0.5)
assert_true(abs(mt_power - signal_power) <= 1)
# Power per sample should not depend on time window length
- for tmax in [0.2, 0.4, 0.6, 0.8]:
- for add_n_fft in [30, 0, 30]:
- t_mask = (epochs_sin.times >= 0) & (epochs_sin.times <= tmax)
+ for tmax in [0.2, 0.8]:
+ for add_n_fft in [0, 30]:
+ t_mask = (epochs.times >= 0) & (epochs.times <= tmax)
n_samples = sum(t_mask)
n_fft = n_samples + add_n_fft
- data_csd_fourier = compute_epochs_csd(epochs_sin, mode='fourier',
- tmin=None, tmax=tmax, fmin=0,
- fmax=np.inf, n_fft=n_fft)
- fourier_power_per_sample = np.abs(data_csd_fourier.data[0, 0]) *\
- sfreq / data_csd_fourier.n_fft
+ data_csd_fourier = csd_epochs(epochs, mode='fourier',
+ tmin=None, tmax=tmax, fmin=0,
+ fmax=np.inf, n_fft=n_fft)
+ first_samp = data_csd_fourier.data[0, 0]
+ fourier_power_per_sample = np.abs(first_samp) * sfreq / n_fft
assert_true(abs(signal_power_per_sample -
fourier_power_per_sample) < 0.003)
# Power per sample should not depend on number of tapers
- for n_tapers in [1, 2, 3, 5]:
- for add_n_fft in [30, 0, 30]:
+ for n_tapers in [1, 2, 5]:
+ for add_n_fft in [0, 30]:
mt_bandwidth = sfreq / float(n_samples) * (n_tapers + 1)
- data_csd_mt = compute_epochs_csd(epochs_sin, mode='multitaper',
- tmin=None, tmax=tmax, fmin=0,
- fmax=np.inf,
- mt_bandwidth=mt_bandwidth,
- n_fft=n_fft)
+ data_csd_mt = csd_epochs(epochs, mode='multitaper',
+ tmin=None, tmax=tmax, fmin=0,
+ fmax=np.inf,
+ mt_bandwidth=mt_bandwidth,
+ n_fft=n_fft)
mt_power_per_sample = np.abs(data_csd_mt.data[0, 0]) *\
sfreq / data_csd_mt.n_fft
# The estimate of power gets worse for small time windows when
@@ -161,3 +162,157 @@ def test_compute_epochs_csd_on_artificial_data():
delta = 0.004
assert_true(abs(signal_power_per_sample -
mt_power_per_sample) < delta)
+
+
+def test_compute_csd():
+ """Test computing cross-spectral density from ndarray."""
+ epochs = _get_data(mode='real')
+
+ tmin = 0.04
+ tmax = 0.15
+ tmp = np.where(np.logical_and(epochs.times >= tmin,
+ epochs.times <= tmax))[0]
+
+ picks_meeg = mne.pick_types(epochs[0].info, meg=True, eeg=True, eog=False,
+ ref_meg=False, exclude='bads')
+
+ epochs_data = [e[picks_meeg][:, tmp].copy() for e in epochs]
+ n_trials = len(epochs)
+ n_series = len(picks_meeg)
+ X = np.concatenate(epochs_data, axis=0)
+ X = np.reshape(X, (n_trials, n_series, -1))
+ X_list = epochs_data
+
+ sfreq = epochs.info['sfreq']
+
+ # Check data types and sizes are checked
+ diff_types = [np.random.randn(3, 5), "error"]
+ err_data = [np.random.randn(3, 5), np.random.randn(2, 4)]
+ assert_raises(ValueError, csd_array, err_data, sfreq)
+ assert_raises(ValueError, csd_array, diff_types, sfreq)
+ assert_raises(ValueError, csd_array, np.random.randn(3), sfreq)
+
+ # Check that wrong parameters are recognized
+ assert_raises(ValueError, csd_array, X, sfreq, mode='notamode')
+ assert_raises(ValueError, csd_array, X, sfreq, fmin=20, fmax=10)
+ assert_raises(ValueError, csd_array, X, sfreq, fmin=20, fmax=20.1)
+
+ data_csd_mt, freqs_mt = csd_array(X, sfreq, mode='multitaper',
+ fmin=8, fmax=12)
+ data_csd_fourier, freqs_fft = csd_array(X, sfreq, mode='fourier',
+ fmin=8, fmax=12)
+
+ # Test as list too
+ data_csd_mt_list, freqs_mt_list = csd_array(X_list, sfreq,
+ mode='multitaper',
+ fmin=8, fmax=12)
+ data_csd_fourier_list, freqs_fft_list = csd_array(X_list, sfreq,
+ mode='fourier',
+ fmin=8, fmax=12)
+
+ assert_array_equal(data_csd_mt, data_csd_mt_list)
+ assert_array_equal(data_csd_fourier, data_csd_fourier_list)
+ assert_array_equal(freqs_mt, freqs_mt_list)
+ assert_array_equal(freqs_fft, freqs_fft_list)
+
+ # Check shape of the CSD matrix
+ n_chan = len(epochs.ch_names)
+ assert_equal(data_csd_mt.shape, (n_chan, n_chan))
+ assert_equal(data_csd_fourier.shape, (n_chan, n_chan))
+
+ # Check if the CSD matrix is hermitian
+ assert_array_equal(np.tril(data_csd_mt).T.conj(),
+ np.triu(data_csd_mt))
+ assert_array_equal(np.tril(data_csd_fourier).T.conj(),
+ np.triu(data_csd_fourier))
+
+ # Computing induced power for comparison
+ epochs.crop(tmin=0.04, tmax=0.15)
+ tfr = tfr_morlet(epochs, freqs=[10], n_cycles=0.6, return_itc=False)
+ power = np.mean(tfr.data, 2)
+
+ # Maximum PSD should occur for specific channel
+ max_ch_power = power.argmax()
+ max_ch_mt = data_csd_mt.diagonal().argmax()
+ max_ch_fourier = data_csd_fourier.diagonal().argmax()
+ assert_equal(max_ch_mt, max_ch_power)
+ assert_equal(max_ch_fourier, max_ch_power)
+
+ # Maximum CSD should occur for specific channel
+ ch_csd_mt = np.abs(data_csd_mt[max_ch_power])
+ ch_csd_mt[max_ch_power] = 0.
+ max_ch_csd_mt = np.argmax(ch_csd_mt)
+ ch_csd_fourier = np.abs(data_csd_fourier[max_ch_power])
+ ch_csd_fourier[max_ch_power] = 0.
+ max_ch_csd_fourier = np.argmax(ch_csd_fourier)
+ assert_equal(max_ch_csd_mt, max_ch_csd_fourier)
+
+ # Check a list of CSD matrices is returned for multiple frequencies within
+ # a given range when fsum=False
+ csd_fsum, freqs_fsum = csd_array(X, sfreq, mode='fourier', fmin=8,
+ fmax=20, fsum=True)
+ csds, freqs = csd_array(X, sfreq, mode='fourier', fmin=8, fmax=20,
+ fsum=False)
+
+ csd_sum = np.sum(csds, axis=2)
+
+ assert_equal(csds.shape[2], 2)
+ assert_equal(len(freqs), 2)
+ assert_array_equal(freqs_fsum, freqs)
+ assert_array_equal(csd_fsum, csd_sum)
+
+
+def test_csd_on_artificial_data():
+ """Test computing CSD on artificial data. """
+ epochs = _get_data(mode='sin')
+ sfreq = epochs.info['sfreq']
+
+ # Computing signal power in the time domain
+ signal_power = sum_squared(epochs._data)
+ signal_power_per_sample = signal_power / len(epochs.times)
+
+ # Computing signal power in the frequency domain
+ data_csd_mt, freqs_mt = csd_array(epochs._data, sfreq,
+ mode='multitaper')
+ data_csd_fourier, freqs_fft = csd_array(epochs._data, sfreq,
+ mode='fourier')
+
+ fourier_power = np.abs(data_csd_fourier[0, 0]) * sfreq
+ mt_power = np.abs(data_csd_mt[0, 0]) * sfreq
+ assert_true(abs(fourier_power - signal_power) <= 0.5)
+ assert_true(abs(mt_power - signal_power) <= 1)
+
+ # Power per sample should not depend on time window length
+ for tmax in [0.2, 0.8]:
+ tslice = np.where(epochs.times <= tmax)[0]
+
+ for add_n_fft in [0, 30]:
+ t_mask = (epochs.times >= 0) & (epochs.times <= tmax)
+ n_samples = sum(t_mask)
+ n_fft = n_samples + add_n_fft
+
+ data_csd_fourier, _ = csd_array(epochs._data[:, :, tslice],
+ sfreq, mode='fourier',
+ fmin=0, fmax=np.inf, n_fft=n_fft)
+
+ first_samp = data_csd_fourier[0, 0]
+ fourier_power_per_sample = np.abs(first_samp) * sfreq / n_fft
+ assert_true(abs(signal_power_per_sample -
+ fourier_power_per_sample) < 0.003)
+ # Power per sample should not depend on number of tapers
+ for n_tapers in [1, 2, 5]:
+ for add_n_fft in [0, 30]:
+ mt_bandwidth = sfreq / float(n_samples) * (n_tapers + 1)
+ data_csd_mt, _ = csd_array(epochs._data[:, :, tslice],
+ sfreq, mt_bandwidth=mt_bandwidth,
+ n_fft=n_fft)
+ mt_power_per_sample = np.abs(data_csd_mt[0, 0]) *\
+ sfreq / n_fft
+ # The estimate of power gets worse for small time windows when
+ # more tapers are used
+ if n_tapers == 5 and tmax == 0.2:
+ delta = 0.05
+ else:
+ delta = 0.004
+ assert_true(abs(signal_power_per_sample -
+ mt_power_per_sample) < delta)
diff --git a/mne/time_frequency/tests/test_psd.py b/mne/time_frequency/tests/test_psd.py
index d950d25..25b2356 100644
--- a/mne/time_frequency/tests/test_psd.py
+++ b/mne/time_frequency/tests/test_psd.py
@@ -1,14 +1,12 @@
import numpy as np
-import warnings
import os.path as op
from numpy.testing import assert_array_almost_equal, assert_raises
from nose.tools import assert_true
-from mne import io, pick_types, Epochs, read_events
-from mne.io import RawArray
+from mne import pick_types, Epochs, read_events
+from mne.io import RawArray, read_raw_fif
from mne.utils import requires_version, slow_test
-from mne.time_frequency import (compute_raw_psd, compute_epochs_psd,
- psd_welch, psd_multitaper)
+from mne.time_frequency import psd_welch, psd_multitaper
base_dir = op.join(op.dirname(__file__), '..', '..', 'io', 'tests', 'data')
raw_fname = op.join(base_dir, 'test_raw.fif')
@@ -17,9 +15,8 @@ event_fname = op.join(base_dir, 'test-eve.fif')
@requires_version('scipy', '0.12')
def test_psd():
- """Tests the welch and multitaper PSD
- """
- raw = io.read_raw_fif(raw_fname)
+ """Tests the welch and multitaper PSD."""
+ raw = read_raw_fif(raw_fname, add_eeg_ref=False)
picks_psd = [0, 1]
# Populate raw with sinusoids
@@ -41,109 +38,99 @@ def test_psd():
kws_welch = dict(n_fft=n_fft)
kws_mt = dict(low_bias=True)
funcs = [(psd_welch, kws_welch),
- (psd_multitaper, kws_mt),
- (compute_raw_psd, kws_welch)]
-
- with warnings.catch_warnings(record=True) as w:
- warnings.simplefilter('always')
- for func, kws in funcs:
- kws = kws.copy()
- kws.update(kws_psd)
- psds, freqs = func(raw, proj=False, **kws)
- psds_proj, freqs_proj = func(raw, proj=True, **kws)
-
- assert_true(psds.shape == (len(kws['picks']), len(freqs)))
- assert_true(np.sum(freqs < 0) == 0)
- assert_true(np.sum(psds < 0) == 0)
-
- # Is power found where it should be
- ixs_max = np.argmax(psds, axis=1)
- for ixmax, ifreq in zip(ixs_max, freqs_sig):
- # Find nearest frequency to the "true" freq
- ixtrue = np.argmin(np.abs(ifreq - freqs))
- assert_true(np.abs(ixmax - ixtrue) < 2)
-
- # Make sure the projection doesn't change channels it shouldn't
- assert_array_almost_equal(psds, psds_proj)
- # Array input shouldn't work
- assert_raises(ValueError, func, raw[:3, :20][0])
- assert_true(len(w), 3)
+ (psd_multitaper, kws_mt)]
+
+ for func, kws in funcs:
+ kws = kws.copy()
+ kws.update(kws_psd)
+ psds, freqs = func(raw, proj=False, **kws)
+ psds_proj, freqs_proj = func(raw, proj=True, **kws)
+
+ assert_true(psds.shape == (len(kws['picks']), len(freqs)))
+ assert_true(np.sum(freqs < 0) == 0)
+ assert_true(np.sum(psds < 0) == 0)
+
+ # Is power found where it should be
+ ixs_max = np.argmax(psds, axis=1)
+ for ixmax, ifreq in zip(ixs_max, freqs_sig):
+ # Find nearest frequency to the "true" freq
+ ixtrue = np.argmin(np.abs(ifreq - freqs))
+ assert_true(np.abs(ixmax - ixtrue) < 2)
+
+ # Make sure the projection doesn't change channels it shouldn't
+ assert_array_almost_equal(psds, psds_proj)
+ # Array input shouldn't work
+ assert_raises(ValueError, func, raw[:3, :20][0])
# -- Epochs/Evoked --
events = read_events(event_fname)
events[:, 0] -= first_samp
tmin, tmax, event_id = -0.5, 0.5, 1
epochs = Epochs(raw, events[:10], event_id, tmin, tmax, picks=picks_psd,
- proj=False, preload=True, baseline=None)
+ proj=False, preload=True, baseline=None, add_eeg_ref=False)
evoked = epochs.average()
tmin_full, tmax_full = -1, 1
epochs_full = Epochs(raw, events[:10], event_id, tmin_full, tmax_full,
picks=picks_psd, proj=False, preload=True,
- baseline=None)
+ baseline=None, add_eeg_ref=False)
kws_psd = dict(tmin=tmin, tmax=tmax, fmin=fmin, fmax=fmax,
picks=picks_psd) # Common to all
funcs = [(psd_welch, kws_welch),
- (psd_multitaper, kws_mt),
- (compute_epochs_psd, kws_welch)]
-
- with warnings.catch_warnings(record=True) as w:
- warnings.simplefilter('always')
- for func, kws in funcs:
- kws = kws.copy()
- kws.update(kws_psd)
-
- psds, freqs = func(
- epochs[:1], proj=False, **kws)
- psds_proj, freqs_proj = func(
- epochs[:1], proj=True, **kws)
- psds_f, freqs_f = func(
- epochs_full[:1], proj=False, **kws)
-
- # this one will fail if you add for example 0.1 to tmin
- assert_array_almost_equal(psds, psds_f, 27)
- # Make sure the projection doesn't change channels it shouldn't
- assert_array_almost_equal(psds, psds_proj, 27)
-
- # Is power found where it should be
- ixs_max = np.argmax(psds.mean(0), axis=1)
- for ixmax, ifreq in zip(ixs_max, freqs_sig):
- # Find nearest frequency to the "true" freq
- ixtrue = np.argmin(np.abs(ifreq - freqs))
- assert_true(np.abs(ixmax - ixtrue) < 2)
- assert_true(psds.shape == (1, len(kws['picks']), len(freqs)))
- assert_true(np.sum(freqs < 0) == 0)
- assert_true(np.sum(psds < 0) == 0)
-
- # Array input shouldn't work
- assert_raises(ValueError, func, epochs.get_data())
-
- if func is not compute_epochs_psd:
- # Testing evoked (doesn't work w/ compute_epochs_psd)
- psds_ev, freqs_ev = func(
- evoked, proj=False, **kws)
- psds_ev_proj, freqs_ev_proj = func(
- evoked, proj=True, **kws)
-
- # Is power found where it should be
- ixs_max = np.argmax(psds_ev, axis=1)
- for ixmax, ifreq in zip(ixs_max, freqs_sig):
- # Find nearest frequency to the "true" freq
- ixtrue = np.argmin(np.abs(ifreq - freqs_ev))
- assert_true(np.abs(ixmax - ixtrue) < 2)
-
- # Make sure the projection doesn't change channels it shouldn't
- assert_array_almost_equal(psds_ev, psds_ev_proj, 27)
- assert_true(psds_ev.shape == (len(kws['picks']), len(freqs)))
- assert_true(len(w), 3)
+ (psd_multitaper, kws_mt)]
+
+ for func, kws in funcs:
+ kws = kws.copy()
+ kws.update(kws_psd)
+
+ psds, freqs = func(
+ epochs[:1], proj=False, **kws)
+ psds_proj, freqs_proj = func(
+ epochs[:1], proj=True, **kws)
+ psds_f, freqs_f = func(
+ epochs_full[:1], proj=False, **kws)
+
+ # this one will fail if you add for example 0.1 to tmin
+ assert_array_almost_equal(psds, psds_f, 27)
+ # Make sure the projection doesn't change channels it shouldn't
+ assert_array_almost_equal(psds, psds_proj, 27)
+
+ # Is power found where it should be
+ ixs_max = np.argmax(psds.mean(0), axis=1)
+ for ixmax, ifreq in zip(ixs_max, freqs_sig):
+ # Find nearest frequency to the "true" freq
+ ixtrue = np.argmin(np.abs(ifreq - freqs))
+ assert_true(np.abs(ixmax - ixtrue) < 2)
+ assert_true(psds.shape == (1, len(kws['picks']), len(freqs)))
+ assert_true(np.sum(freqs < 0) == 0)
+ assert_true(np.sum(psds < 0) == 0)
+
+ # Array input shouldn't work
+ assert_raises(ValueError, func, epochs.get_data())
+
+ # Testing evoked (doesn't work w/ compute_epochs_psd)
+ psds_ev, freqs_ev = func(
+ evoked, proj=False, **kws)
+ psds_ev_proj, freqs_ev_proj = func(
+ evoked, proj=True, **kws)
+
+ # Is power found where it should be
+ ixs_max = np.argmax(psds_ev, axis=1)
+ for ixmax, ifreq in zip(ixs_max, freqs_sig):
+ # Find nearest frequency to the "true" freq
+ ixtrue = np.argmin(np.abs(ifreq - freqs_ev))
+ assert_true(np.abs(ixmax - ixtrue) < 2)
+
+ # Make sure the projection doesn't change channels it shouldn't
+ assert_array_almost_equal(psds_ev, psds_ev_proj, 27)
+ assert_true(psds_ev.shape == (len(kws['picks']), len(freqs)))
@slow_test
@requires_version('scipy', '0.12')
def test_compares_psd():
- """Test PSD estimation on raw for plt.psd and scipy.signal.welch
- """
- raw = io.read_raw_fif(raw_fname)
+ """Test PSD estimation on raw for plt.psd and scipy.signal.welch."""
+ raw = read_raw_fif(raw_fname, add_eeg_ref=False)
exclude = raw.info['bads'] + ['MEG 2443', 'EEG 053'] # bads + 2 more
diff --git a/mne/time_frequency/tests/test_stockwell.py b/mne/time_frequency/tests/test_stockwell.py
index 65ffa92..9eb1c50 100644
--- a/mne/time_frequency/tests/test_stockwell.py
+++ b/mne/time_frequency/tests/test_stockwell.py
@@ -12,29 +12,45 @@ from numpy.testing import assert_array_almost_equal, assert_allclose
from scipy import fftpack
-from mne import io, read_events, Epochs, pick_types
+from mne import read_events, Epochs
+from mne.io import read_raw_fif
from mne.time_frequency._stockwell import (tfr_stockwell, _st,
- _precompute_st_windows)
+ _precompute_st_windows,
+ _check_input_st,
+ _st_power_itc)
+
from mne.time_frequency.tfr import AverageTFR
base_dir = op.join(op.dirname(__file__), '..', '..', 'io', 'tests', 'data')
raw_fname = op.join(base_dir, 'test_raw.fif')
-event_id, tmin, tmax = 1, -0.2, 0.5
-event_id_2 = 2
-raw = io.read_raw_fif(raw_fname, add_eeg_ref=False)
-event_name = op.join(base_dir, 'test-eve.fif')
-events = read_events(event_name)
-picks = pick_types(raw.info, meg=True, eeg=True, stim=True,
- ecg=True, eog=True, include=['STI 014'],
- exclude='bads')
-reject = dict(grad=1000e-12, mag=4e-12, eeg=80e-6, eog=150e-6)
-flat = dict(grad=1e-15, mag=1e-15)
+def test_stockwell_check_input():
+ """Test input checker for stockwell"""
+ # check for data size equal and unequal to a power of 2
+
+ for last_dim in (127, 128):
+ data = np.zeros((2, 10, last_dim))
+ x_in, n_fft, zero_pad = _check_input_st(data, None)
+
+ assert_equal(x_in.shape, (2, 10, 128))
+ assert_equal(n_fft, 128)
+ assert_equal(zero_pad, 128 - last_dim)
+
+
+def test_stockwell_st_no_zero_pad():
+ """Test stockwell power itc"""
+ data = np.zeros((20, 128))
+ start_f = 1
+ stop_f = 10
+ sfreq = 30
+ width = 2
+ W = _precompute_st_windows(data.shape[-1], start_f, stop_f, sfreq, width)
+ _st_power_itc(data, 10, True, 0, 1, W)
def test_stockwell_core():
- """Test stockwell transform"""
+ """Test stockwell transform."""
# adapted from
# http://vcs.ynic.york.ac.uk/docs/naf/intro/concepts/timefreq.html
sfreq = 1000.0 # make things easy to understand
@@ -75,9 +91,14 @@ def test_stockwell_core():
def test_stockwell_api():
- """Test stockwell functions"""
+ """Test stockwell functions."""
+ raw = read_raw_fif(raw_fname, add_eeg_ref=False)
+ event_id, tmin, tmax = 1, -0.2, 0.5
+ event_name = op.join(base_dir, 'test-eve.fif')
+ events = read_events(event_name)
epochs = Epochs(raw, events, # XXX pick 2 has epochs of zeros.
- event_id, tmin, tmax, picks=[0, 1, 3], baseline=(None, 0))
+ event_id, tmin, tmax, picks=[0, 1, 3], baseline=(None, 0),
+ add_eeg_ref=False)
for fmin, fmax in [(None, 50), (5, 50), (5, None)]:
with warnings.catch_warnings(record=True): # zero papdding
power, itc = tfr_stockwell(epochs, fmin=fmin, fmax=fmax,
diff --git a/mne/time_frequency/tests/test_tfr.py b/mne/time_frequency/tests/test_tfr.py
index 532ebb7..c39f672 100644
--- a/mne/time_frequency/tests/test_tfr.py
+++ b/mne/time_frequency/tests/test_tfr.py
@@ -4,15 +4,17 @@ from numpy.testing import assert_array_almost_equal, assert_array_equal
from nose.tools import assert_true, assert_false, assert_equal, assert_raises
import mne
-from mne import io, Epochs, read_events, pick_types, create_info, EpochsArray
+from mne import Epochs, read_events, pick_types, create_info, EpochsArray
+from mne.io import read_raw_fif
from mne.utils import (_TempDir, run_tests_if_main, slow_test, requires_h5py,
grand_average)
from mne.time_frequency import single_trial_power
from mne.time_frequency.tfr import (cwt_morlet, morlet, tfr_morlet,
- _dpss_wavelet, tfr_multitaper,
+ _make_dpss, tfr_multitaper, rescale,
AverageTFR, read_tfrs, write_tfrs,
- combine_tfr, cwt)
-
+ combine_tfr, cwt, _compute_tfr)
+from mne.viz.utils import _fake_click
+from itertools import product
import matplotlib
matplotlib.use('Agg') # for testing don't use X server
@@ -23,7 +25,7 @@ event_fname = op.join(op.dirname(__file__), '..', '..', 'io', 'tests',
def test_morlet():
- """Test morlet with and without zero mean"""
+ """Test morlet with and without zero mean."""
Wz = morlet(1000, [10], 2., zero_mean=True)
W = morlet(1000, [10], 2., zero_mean=False)
@@ -32,15 +34,14 @@ def test_morlet():
def test_time_frequency():
- """Test time frequency transform (PSD and phase lock)
- """
+ """Test the to-be-deprecated time-frequency transform (PSD and ITC)."""
# Set parameters
event_id = 1
tmin = -0.2
tmax = 0.498 # Allows exhaustive decimation testing
# Setup for reading the raw data
- raw = io.read_raw_fif(raw_fname)
+ raw = read_raw_fif(raw_fname, add_eeg_ref=False)
events = read_events(event_fname)
include = []
@@ -52,13 +53,13 @@ def test_time_frequency():
picks = picks[:2]
epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks,
- baseline=(None, 0))
+ baseline=(None, 0), add_eeg_ref=False)
data = epochs.get_data()
times = epochs.times
nave = len(data)
epochs_nopicks = Epochs(raw, events, event_id, tmin, tmax,
- baseline=(None, 0))
+ baseline=(None, 0), add_eeg_ref=False)
freqs = np.arange(6, 20, 5) # define frequencies of interest
n_cycles = freqs / 4.
@@ -75,21 +76,32 @@ def test_time_frequency():
use_fft=True, return_itc=True)
power_, itc_ = tfr_morlet(epochs, freqs=freqs, n_cycles=n_cycles,
use_fft=True, return_itc=True, decim=slice(0, 2))
- # Test picks argument
- power_picks, itc_picks = tfr_morlet(epochs_nopicks, freqs=freqs,
- n_cycles=n_cycles, use_fft=True,
- return_itc=True, picks=picks)
+ # Test picks argument and average parameter
+ assert_raises(ValueError, tfr_morlet, epochs, freqs=freqs,
+ n_cycles=n_cycles, return_itc=True, average=False)
+
+ power_picks, itc_picks = \
+ tfr_morlet(epochs_nopicks,
+ freqs=freqs, n_cycles=n_cycles, use_fft=True,
+ return_itc=True, picks=picks, average=True)
+
+ epochs_power_picks = \
+ tfr_morlet(epochs_nopicks,
+ freqs=freqs, n_cycles=n_cycles, use_fft=True,
+ return_itc=False, picks=picks, average=False)
+ power_picks_avg = epochs_power_picks.average()
# the actual data arrays here are equivalent, too...
assert_array_almost_equal(power.data, power_picks.data)
+ assert_array_almost_equal(power.data, power_picks_avg.data)
assert_array_almost_equal(itc.data, itc_picks.data)
assert_array_almost_equal(power.data, power_evoked.data)
print(itc) # test repr
print(itc.ch_names) # test property
itc += power # test add
- itc -= power # test add
+ itc -= power # test sub
- power.apply_baseline(baseline=(-0.1, 0), mode='logratio')
+ power = power.apply_baseline(baseline=(-0.1, 0), mode='logratio')
assert_true('meg' in power)
assert_true('grad' in power)
@@ -209,10 +221,10 @@ def test_time_frequency():
def test_dpsswavelet():
- """Test DPSS wavelet"""
+ """Test DPSS tapers."""
freqs = np.arange(5, 25, 3)
- Ws = _dpss_wavelet(1000, freqs=freqs, n_cycles=freqs / 2.,
- time_bandwidth=4.0, zero_mean=True)
+ Ws = _make_dpss(1000, freqs=freqs, n_cycles=freqs / 2., time_bandwidth=4.0,
+ zero_mean=True)
assert_true(len(Ws) == 3) # 3 tapers expected
@@ -224,10 +236,10 @@ def test_dpsswavelet():
@slow_test
def test_tfr_multitaper():
- """Test tfr_multitaper"""
+ """Test tfr_multitaper."""
sfreq = 200.0
- ch_names = ['SIM0001', 'SIM0002', 'SIM0003']
- ch_types = ['grad', 'grad', 'grad']
+ ch_names = ['SIM0001', 'SIM0002']
+ ch_types = ['grad', 'grad']
info = create_info(ch_names=ch_names, sfreq=sfreq, ch_types=ch_types)
n_times = int(sfreq) # Second long epochs
@@ -252,7 +264,8 @@ def test_tfr_multitaper():
epochs = EpochsArray(data=dat, info=info, events=events, event_id=event_id,
reject=reject)
- freqs = np.arange(5, 100, 3, dtype=np.float)
+ freqs = np.arange(35, 70, 5, dtype=np.float)
+
power, itc = tfr_multitaper(epochs, freqs=freqs, n_cycles=freqs / 2.,
time_bandwidth=4.0)
power2, itc2 = tfr_multitaper(epochs, freqs=freqs, n_cycles=freqs / 2.,
@@ -261,11 +274,27 @@ def test_tfr_multitaper():
power_picks, itc_picks = tfr_multitaper(epochs, freqs=freqs,
n_cycles=freqs / 2.,
time_bandwidth=4.0, picks=picks)
+ power_epochs = tfr_multitaper(epochs, freqs=freqs,
+ n_cycles=freqs / 2., time_bandwidth=4.0,
+ return_itc=False, average=False)
+ power_averaged = power_epochs.average()
power_evoked = tfr_multitaper(epochs.average(), freqs=freqs,
n_cycles=freqs / 2., time_bandwidth=4.0,
- return_itc=False)
+ return_itc=False, average=False).average()
+
+ print(power_evoked) # test repr for EpochsTFR
+
+ assert_raises(ValueError, tfr_multitaper, epochs,
+ freqs=freqs, n_cycles=freqs / 2.,
+ return_itc=True, average=False)
+
# test picks argument
assert_array_almost_equal(power.data, power_picks.data)
+ assert_array_almost_equal(power.data, power_averaged.data)
+ assert_array_almost_equal(power.times, power_epochs.times)
+ assert_array_almost_equal(power.times, power_averaged.times)
+ assert_equal(power.nave, power_averaged.nave)
+ assert_equal(power_epochs.data.shape, (3, 2, 7, 200))
assert_array_almost_equal(itc.data, itc_picks.data)
# one is squared magnitude of the average (evoked) and
# the other is average of the squared magnitudes (epochs PSD)
@@ -291,7 +320,7 @@ def test_tfr_multitaper():
def test_crop():
- """Test TFR cropping"""
+ """Test TFR cropping."""
data = np.zeros((3, 2, 3))
times = np.array([.1, .2, .3])
freqs = np.array([.10, .20])
@@ -306,7 +335,7 @@ def test_crop():
@requires_h5py
def test_io():
- """Test TFR IO capacities"""
+ """Test TFR IO capacities."""
tempdir = _TempDir()
fname = op.join(tempdir, 'test-tfr.h5')
@@ -375,10 +404,29 @@ def test_plot():
tfr.plot_topo(picks=[1, 2])
plt.close('all')
+ fig = tfr.plot(picks=[1], cmap='RdBu_r') # interactive mode on by default
+ fig.canvas.key_press_event('up')
+ fig.canvas.key_press_event(' ')
+ fig.canvas.key_press_event('down')
+
+ cbar = fig.get_axes()[0].CB # Fake dragging with mouse.
+ ax = cbar.cbar.ax
+ _fake_click(fig, ax, (0.1, 0.1))
+ _fake_click(fig, ax, (0.1, 0.2), kind='motion')
+ _fake_click(fig, ax, (0.1, 0.3), kind='release')
+
+ _fake_click(fig, ax, (0.1, 0.1), button=3)
+ _fake_click(fig, ax, (0.1, 0.2), button=3, kind='motion')
+ _fake_click(fig, ax, (0.1, 0.3), kind='release')
+
+ fig.canvas.scroll_event(0.5, 0.5, -0.5) # scroll down
+ fig.canvas.scroll_event(0.5, 0.5, 0.5) # scroll up
+
+ plt.close('all')
+
def test_add_channels():
- """Test tfr splitting / re-appending channel types
- """
+ """Test tfr splitting / re-appending channel types."""
data = np.zeros((6, 2, 3))
times = np.array([.1, .2, .3])
freqs = np.array([.10, .20])
@@ -412,4 +460,119 @@ def test_add_channels():
assert_raises(AssertionError, tfr_meg.add_channels, tfr_badsf)
+def test_compute_tfr():
+ """Test _compute_tfr function."""
+ # Set parameters
+ event_id = 1
+ tmin = -0.2
+ tmax = 0.498 # Allows exhaustive decimation testing
+
+ # Setup for reading the raw data
+ raw = read_raw_fif(raw_fname, add_eeg_ref=False)
+ events = read_events(event_fname)
+
+ exclude = raw.info['bads'] + ['MEG 2443', 'EEG 053'] # bads + 2 more
+
+ # picks MEG gradiometers
+ picks = pick_types(raw.info, meg='grad', eeg=False,
+ stim=False, include=[], exclude=exclude)
+
+ picks = picks[:2]
+ epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks,
+ baseline=(None, 0), add_eeg_ref=False)
+ data = epochs.get_data()
+ sfreq = epochs.info['sfreq']
+ freqs = np.arange(10, 20, 3).astype(float)
+
+ # Check all combination of options
+ for method, use_fft, zero_mean, output in product(
+ ('multitaper', 'morlet'), (False, True), (False, True),
+ ('complex', 'power', 'phase',
+ 'avg_power_itc', 'avg_power', 'itc')):
+ # Check exception
+ if (method == 'multitaper') and (output == 'phase'):
+ assert_raises(NotImplementedError, _compute_tfr, data, freqs,
+ sfreq, method=method, output=output)
+ continue
+
+ # Check runs
+ out = _compute_tfr(data, freqs, sfreq, method=method,
+ use_fft=use_fft, zero_mean=zero_mean,
+ n_cycles=2., output=output)
+ # Check shapes
+ shape = np.r_[data.shape[:2], len(freqs), data.shape[2]]
+ if ('avg' in output) or ('itc' in output):
+ assert_array_equal(shape[1:], out.shape)
+ else:
+ assert_array_equal(shape, out.shape)
+
+ # Check types
+ if output in ('complex', 'avg_power_itc'):
+ assert_equal(np.complex, out.dtype)
+ else:
+ assert_equal(np.float, out.dtype)
+ assert_true(np.all(np.isfinite(out)))
+
+ # Check that functions are equivalent to
+ # i) single_trial_power: X, shape (n_signals, n_chans, n_times)
+ old_power = single_trial_power(data, sfreq, freqs, n_cycles=2.)
+ new_power = _compute_tfr(data, freqs, sfreq, n_cycles=2.,
+ method='morlet', output='power')
+ assert_array_almost_equal(old_power, new_power)
+ old_power = single_trial_power(data, sfreq, freqs, n_cycles=2.,
+ times=epochs.times, baseline=(-.100, 0),
+ baseline_mode='ratio')
+ new_power = rescale(new_power, epochs.times, (-.100, 0), 'ratio')
+
+ # ii) cwt_morlet: X, shape (n_signals, n_times)
+ old_complex = cwt_morlet(data[0], sfreq, freqs, n_cycles=2.)
+ new_complex = _compute_tfr(data[[0]], freqs, sfreq, n_cycles=2.,
+ method='morlet', output='complex')
+ assert_array_almost_equal(old_complex, new_complex[0])
+
+ # Check errors params
+ for _data in (None, 'foo', data[0]):
+ assert_raises(ValueError, _compute_tfr, _data, freqs, sfreq)
+ for _freqs in (None, 'foo', [[0]]):
+ assert_raises(ValueError, _compute_tfr, data, _freqs, sfreq)
+ for _sfreq in (None, 'foo'):
+ assert_raises(ValueError, _compute_tfr, data, freqs, _sfreq)
+ for key in ('output', 'method', 'use_fft', 'decim', 'n_jobs'):
+ for value in (None, 'foo'):
+ kwargs = {key: value} # FIXME pep8
+ assert_raises(ValueError, _compute_tfr, data, freqs, sfreq,
+ **kwargs)
+
+ # No time_bandwidth param in morlet
+ assert_raises(ValueError, _compute_tfr, data, freqs, sfreq,
+ method='morlet', time_bandwidth=1)
+ # No phase in multitaper XXX Check ?
+ assert_raises(NotImplementedError, _compute_tfr, data, freqs, sfreq,
+ method='multitaper', output='phase')
+
+ # Inter-trial coherence tests
+ out = _compute_tfr(data, freqs, sfreq, output='itc', n_cycles=2.)
+ assert_true(np.sum(out >= 1) == 0)
+ assert_true(np.sum(out <= 0) == 0)
+
+ # Check decim shapes
+ # 2: multiple of len(times) even
+ # 3: multiple odd
+ # 8: not multiple, even
+ # 9: not multiple, odd
+ for decim in (2, 3, 8, 9, slice(0, 2), slice(1, 3), slice(2, 4)):
+ _decim = slice(None, None, decim) if isinstance(decim, int) else decim
+ n_time = len(np.arange(data.shape[2])[_decim])
+ shape = np.r_[data.shape[:2], len(freqs), n_time]
+ for method in ('multitaper', 'morlet'):
+ # Single trials
+ out = _compute_tfr(data, freqs, sfreq, method=method,
+ decim=decim, n_cycles=2.)
+ assert_array_equal(shape, out.shape)
+ # Averages
+ out = _compute_tfr(data, freqs, sfreq, method=method,
+ decim=decim, output='avg_power',
+ n_cycles=2.)
+ assert_array_equal(shape[1:], out.shape)
+
run_tests_if_main()
diff --git a/mne/time_frequency/tfr.py b/mne/time_frequency/tfr.py
index d40d136..a3b5876 100644
--- a/mne/time_frequency/tfr.py
+++ b/mne/time_frequency/tfr.py
@@ -1,61 +1,51 @@
-"""A module which implements the time frequency estimation.
+"""A module which implements the time-frequency estimation.
Morlet code inspired by Matlab code from Sheraz Khan & Brainstorm & SPM
"""
# Authors : Alexandre Gramfort <alexandre.gramfort at telecom-paristech.fr>
# Hari Bharadwaj <hari at nmr.mgh.harvard.edu>
# Clement Moutard <clement.moutard at polytechnique.org>
+# Jean-Remi King <jeanremi.king at gmail.com>
#
# License : BSD (3-clause)
from copy import deepcopy
+from functools import partial
from math import sqrt
import numpy as np
from scipy import linalg
-from scipy.fftpack import fftn, ifftn
+from scipy.fftpack import fft, ifft
-from ..fixes import partial
from ..baseline import rescale
from ..parallel import parallel_func
-from ..utils import (logger, verbose, _time_mask, warn, check_fname,
- _check_copy_dep)
+from ..utils import (logger, verbose, _time_mask, check_fname, deprecated,
+ sizeof_fmt)
from ..channels.channels import ContainsMixin, UpdateChannelsMixin
+from ..channels.layout import _pair_grad_sensors
from ..io.pick import pick_info, pick_types
from ..io.meas_info import Info
+from ..utils import SizeMixin
from .multitaper import dpss_windows
from ..viz.utils import figure_nobar, plt_show
from ..externals.h5io import write_hdf5, read_hdf5
from ..externals.six import string_types
-def _get_data(inst, return_itc):
- """Get data from Epochs or Evoked instance as epochs x ch x time"""
- from ..epochs import _BaseEpochs
- from ..evoked import Evoked
- if not isinstance(inst, (_BaseEpochs, Evoked)):
- raise TypeError('inst must be Epochs or Evoked')
- if isinstance(inst, _BaseEpochs):
- data = inst.get_data()
- else:
- if return_itc:
- raise ValueError('return_itc must be False for evoked data')
- data = inst.data[np.newaxis, ...].copy()
- return data
-
+# Make wavelet
-def morlet(sfreq, freqs, n_cycles=7, sigma=None, zero_mean=False):
- """Compute Wavelets for the given frequency range
+def morlet(sfreq, freqs, n_cycles=7.0, sigma=None, zero_mean=False):
+ """Compute Morlet wavelets for the given frequency range.
Parameters
----------
sfreq : float
- Sampling Frequency
+ The sampling Frequency.
freqs : array
frequency range of interest (1 x Frequencies)
- n_cycles: float | array of float
+ n_cycles: float | array of float, defaults to 7.0
Number of cycles. Fixed number or one per frequency.
- sigma : float, (optional)
+ sigma : float, defaults to None
It controls the width of the wavelet ie its temporal
resolution. If sigma is None the temporal resolution
is adapted with the frequency like for all wavelet transform.
@@ -63,18 +53,14 @@ def morlet(sfreq, freqs, n_cycles=7, sigma=None, zero_mean=False):
If sigma is fixed the temporal resolution is fixed
like for the short time Fourier transform and the number
of oscillations increases with the frequency.
- zero_mean : bool
- Make sure the wavelet is zero mean
+ zero_mean : bool, defaults to False
+ Make sure the wavelet has a mean of zero.
Returns
-------
Ws : list of array
- Wavelets time series
+ The wavelets time series.
- See Also
- --------
- mne.time_frequency.cwt_morlet : Compute time-frequency decomposition
- with Morlet wavelets
"""
Ws = list()
n_cycles = np.atleast_1d(n_cycles)
@@ -107,29 +93,31 @@ def morlet(sfreq, freqs, n_cycles=7, sigma=None, zero_mean=False):
return Ws
-def _dpss_wavelet(sfreq, freqs, n_cycles=7, time_bandwidth=4.0,
- zero_mean=False):
- """Compute Wavelets for the given frequency range
+def _make_dpss(sfreq, freqs, n_cycles=7., time_bandwidth=4.0, zero_mean=False):
+ """Compute discrete prolate spheroidal sequences (DPSS) tapers for the
+ given frequency range.
Parameters
----------
sfreq : float
- Sampling Frequency.
+ The sampling frequency.
freqs : ndarray, shape (n_freqs,)
The frequencies in Hz.
- n_cycles : float | ndarray, shape (n_freqs,)
+ n_cycles : float | ndarray, shape (n_freqs,), defaults to 7.
The number of cycles globally or for each frequency.
- Defaults to 7.
- time_bandwidth : float, (optional)
+ time_bandwidth : float, defaults to 4.0
Time x Bandwidth product.
The number of good tapers (low-bias) is chosen automatically based on
this to equal floor(time_bandwidth - 1).
Default is 4.0, giving 3 good tapers.
+ zero_mean : bool | None, , defaults to False
+ Make sure the wavelet has a mean of zero.
+
Returns
-------
Ws : list of array
- Wavelets time series
+ The wavelets time series.
"""
Ws = list()
if time_bandwidth < 2.0:
@@ -171,20 +159,35 @@ def _dpss_wavelet(sfreq, freqs, n_cycles=7, time_bandwidth=4.0,
return Ws
-def _centered(arr, newsize):
- """Aux Function to center data"""
- # Return the center newsize portion of the array.
- newsize = np.asarray(newsize)
- currsize = np.array(arr.shape)
- startind = (currsize - newsize) // 2
- endind = startind + newsize
- myslice = [slice(startind[k], endind[k]) for k in range(len(endind))]
- return arr[tuple(myslice)]
-
+# Low level convolution
def _cwt(X, Ws, mode="same", decim=1, use_fft=True):
"""Compute cwt with fft based convolutions or temporal convolutions.
Return a generator over signals.
+
+ Parameters
+ ----------
+ X : array of shape (n_signals, n_times)
+ The data.
+ Ws : list of array
+ Wavelets time series.
+ mode : {'full', 'valid', 'same'}
+ See numpy.convolve.
+ decim : int | slice, defaults to 1
+ To reduce memory usage, decimation factor after time-frequency
+ decomposition.
+ If `int`, returns tfr[..., ::decim].
+ If `slice`, returns tfr[..., decim].
+
+ .. note:: Decimation may create aliasing artifacts.
+
+ use_fft : bool, defaults to True
+ Use the FFT for convolutions or not.
+
+ Returns
+ -------
+ out : array, shape (n_signals, n_freqs, n_time_decim)
+ The time-frequency transform of the signals.
"""
if mode not in ['same', 'valid', 'full']:
raise ValueError("`mode` must be 'same', 'valid' or 'full', "
@@ -211,21 +214,22 @@ def _cwt(X, Ws, mode="same", decim=1, use_fft=True):
fft_Ws = np.empty((n_freqs, fsize), dtype=np.complex128)
for i, W in enumerate(Ws):
if len(W) > n_times:
- raise ValueError('Wavelet is too long for such a short signal. '
- 'Reduce the number of cycles.')
+ raise ValueError('At least one of the wavelets is longer than the '
+ 'signal. Use a longer signal or shorter '
+ 'wavelets.')
if use_fft:
- fft_Ws[i] = fftn(W, [fsize])
+ fft_Ws[i] = fft(W, fsize)
# Make generator looping across signals
tfr = np.zeros((n_freqs, n_times_out), dtype=np.complex128)
for x in X:
if use_fft:
- fft_x = fftn(x, [fsize])
+ fft_x = fft(x, fsize)
# Loop across wavelets
for ii, W in enumerate(Ws):
if use_fft:
- ret = ifftn(fft_x * fft_Ws[ii])[:n_times + W.size - 1]
+ ret = ifft(fft_x * fft_Ws[ii])[:n_times + W.size - 1]
else:
ret = np.convolve(x, W, mode=mode)
@@ -245,6 +249,297 @@ def _cwt(X, Ws, mode="same", decim=1, use_fft=True):
yield tfr
+# Loop of convolution: single trial
+
+
+def _compute_tfr(epoch_data, frequencies, sfreq=1.0, method='morlet',
+ n_cycles=7.0, zero_mean=None, time_bandwidth=None,
+ use_fft=True, decim=1, output='complex', n_jobs=1,
+ verbose=None):
+ """Computes time-frequency transforms.
+
+ Parameters
+ ----------
+ epoch_data : array of shape (n_epochs, n_channels, n_times)
+ The epochs.
+ frequencies : array-like of floats, shape (n_freqs)
+ The frequencies.
+ sfreq : float | int, defaults to 1.0
+ Sampling frequency of the data.
+ method : 'multitaper' | 'morlet', defaults to 'morlet'
+ The time-frequency method. 'morlet' convolves a Morlet wavelet.
+ 'multitaper' uses Morlet wavelets windowed with multiple DPSS
+ multitapers.
+ n_cycles : float | array of float, defaults to 7.0
+ Number of cycles in the Morlet wavelet. Fixed number
+ or one per frequency.
+ zero_mean : bool | None, defaults to None
+ None means True for method='multitaper' and False for method='morlet'.
+ If True, make sure the wavelets have a mean of zero.
+ time_bandwidth : float, defaults to None
+ If None and method=multitaper, will be set to 4.0 (3 tapers).
+ Time x (Full) Bandwidth product. Only applies if
+ method == 'multitaper'. The number of good tapers (low-bias) is
+ chosen automatically based on this to equal floor(time_bandwidth - 1).
+ use_fft : bool, defaults to True
+ Use the FFT for convolutions or not.
+ decim : int | slice, defaults to 1
+ To reduce memory usage, decimation factor after time-frequency
+ decomposition.
+ If `int`, returns tfr[..., ::decim].
+ If `slice`, returns tfr[..., decim].
+
+ .. note::
+ Decimation may create aliasing artifacts, yet decimation
+ is done after the convolutions.
+
+ output : str, defaults to 'complex'
+
+ * 'complex' : single trial complex.
+ * 'power' : single trial power.
+ * 'phase' : single trial phase.
+ * 'avg_power' : average of single trial power.
+ * 'itc' : inter-trial coherence.
+ * 'avg_power_itc' : average of single trial power and inter-trial
+ coherence across trials.
+
+ n_jobs : int, defaults to 1
+ The number of epochs to process at the same time. The parallelization
+ is implemented across channels.
+ verbose : bool, str, int, or None, defaults to None
+ If not None, override default verbose level (see mne.verbose).
+
+ Returns
+ -------
+ out : array
+ Time frequency transform of epoch_data. If output is in ['complex',
+ 'phase', 'power'], then shape of out is (n_epochs, n_chans, n_freqs,
+ n_times), else it is (n_chans, n_freqs, n_times). If output is
+ 'avg_power_itc', the real values code for 'avg_power' and the
+ imaginary values code for the 'itc': out = avg_power + i * itc
+ """
+ # Check data
+ epoch_data = np.asarray(epoch_data)
+ if epoch_data.ndim != 3:
+ raise ValueError('epoch_data must be of shape '
+ '(n_epochs, n_chans, n_times)')
+
+ # Check params
+ frequencies, sfreq, zero_mean, n_cycles, time_bandwidth, decim = \
+ _check_tfr_param(frequencies, sfreq, method, zero_mean, n_cycles,
+ time_bandwidth, use_fft, decim, output)
+
+ # Setup wavelet
+ if method == 'morlet':
+ W = morlet(sfreq, frequencies, n_cycles=n_cycles, zero_mean=zero_mean)
+ Ws = [W] # to have same dimensionality as the 'multitaper' case
+
+ elif method == 'multitaper':
+ Ws = _make_dpss(sfreq, frequencies, n_cycles=n_cycles,
+ time_bandwidth=time_bandwidth, zero_mean=zero_mean)
+
+ # Check wavelets
+ if len(Ws[0][0]) > epoch_data.shape[2]:
+ raise ValueError('At least one of the wavelets is longer than the '
+ 'signal. Use a longer signal or shorter wavelets.')
+
+ # Initialize output
+ decim = _check_decim(decim)
+ n_freqs = len(frequencies)
+ n_epochs, n_chans, n_times = epoch_data[:, :, decim].shape
+ if output in ('power', 'phase', 'avg_power', 'itc'):
+ dtype = np.float
+ elif output in ('complex', 'avg_power_itc'):
+ # avg_power_itc is stored as power + 1i * itc to keep a
+ # simple dimensionality
+ dtype = np.complex
+
+ if ('avg_' in output) or ('itc' in output):
+ out = np.empty((n_chans, n_freqs, n_times), dtype)
+ else:
+ out = np.empty((n_chans, n_epochs, n_freqs, n_times), dtype)
+
+ # Parallel computation
+ parallel, my_cwt, _ = parallel_func(_time_frequency_loop, n_jobs)
+
+ # Parallelization is applied across channels.
+ tfrs = parallel(
+ my_cwt(channel, Ws, output, use_fft, 'same', decim)
+ for channel in epoch_data.transpose(1, 0, 2))
+
+ # FIXME: to avoid overheads we should use np.array_split()
+ for channel_idx, tfr in enumerate(tfrs):
+ out[channel_idx] = tfr
+
+ if ('avg_' not in output) and ('itc' not in output):
+ # This is to enforce that the first dimension is for epochs
+ out = out.transpose(1, 0, 2, 3)
+ return out
+
+
+def _check_tfr_param(frequencies, sfreq, method, zero_mean, n_cycles,
+ time_bandwidth, use_fft, decim, output):
+ """Aux. function to _compute_tfr to check the params validity."""
+ # Check frequencies
+ if not isinstance(frequencies, (list, np.ndarray)):
+ raise ValueError('frequencies must be an array-like, got %s '
+ 'instead.' % type(frequencies))
+ frequencies = np.asarray(frequencies, dtype=float)
+ if frequencies.ndim != 1:
+ raise ValueError('frequencies must be of shape (n_freqs,), got %s '
+ 'instead.' % np.array(frequencies.shape))
+
+ # Check sfreq
+ if not isinstance(sfreq, (float, int)):
+ raise ValueError('sfreq must be a float or an int, got %s '
+ 'instead.' % type(sfreq))
+ sfreq = float(sfreq)
+
+ # Default zero_mean = True if multitaper else False
+ zero_mean = method == 'multitaper' if zero_mean is None else zero_mean
+ if not isinstance(zero_mean, bool):
+ raise ValueError('zero_mean should be of type bool, got %s. instead'
+ % type(zero_mean))
+ frequencies = np.asarray(frequencies)
+
+ if (method == 'multitaper') and (output == 'phase'):
+ raise NotImplementedError(
+ 'This function is not optimized to compute the phase using the '
+ 'multitaper method. Use np.angle of the complex output instead.')
+
+ # Check n_cycles
+ if isinstance(n_cycles, (int, float)):
+ n_cycles = float(n_cycles)
+ elif isinstance(n_cycles, (list, np.ndarray)):
+ n_cycles = np.array(n_cycles)
+ if len(n_cycles) != len(frequencies):
+ raise ValueError('n_cycles must be a float or an array of length '
+ '%i frequencies, got %i cycles instead.' %
+ (len(frequencies), len(n_cycles)))
+ else:
+ raise ValueError('n_cycles must be a float or an array, got %s '
+ 'instead.' % type(n_cycles))
+
+ # Check time_bandwidth
+ if (method == 'morlet') and (time_bandwidth is not None):
+ raise ValueError('time_bandwidth only applies to "multitaper" method.')
+ elif method == 'multitaper':
+ time_bandwidth = (4.0 if time_bandwidth is None
+ else float(time_bandwidth))
+
+ # Check use_fft
+ if not isinstance(use_fft, bool):
+ raise ValueError('use_fft must be a boolean, got %s '
+ 'instead.' % type(use_fft))
+ # Check decim
+ if isinstance(decim, int):
+ decim = slice(None, None, decim)
+ if not isinstance(decim, slice):
+ raise ValueError('decim must be an integer or a slice, '
+ 'got %s instead.' % type(decim))
+
+ # Check output
+ allowed_ouput = ('complex', 'power', 'phase',
+ 'avg_power_itc', 'avg_power', 'itc')
+ if output not in allowed_ouput:
+ raise ValueError("Unknown output type. Allowed are %s but "
+ "got %s." % (allowed_ouput, output))
+
+ if method not in ('multitaper', 'morlet'):
+ raise ValueError('method must be "morlet" or "multitaper", got %s '
+ 'instead.' % type(method))
+
+ return frequencies, sfreq, zero_mean, n_cycles, time_bandwidth, decim
+
+
+def _time_frequency_loop(X, Ws, output, use_fft, mode, decim):
+ """Aux. function to _compute_tfr.
+
+ Loops time-frequency transform across wavelets and epochs.
+
+ Parameters
+ ----------
+ X : array, shape (n_epochs, n_times)
+ The epochs data of a single channel.
+ Ws : list, shape (n_tapers, n_wavelets, n_times)
+ The wavelets.
+ output : str
+
+ * 'complex' : single trial complex.
+ * 'power' : single trial power.
+ * 'phase' : single trial phase.
+ * 'avg_power' : average of single trial power.
+ * 'itc' : inter-trial coherence.
+ * 'avg_power_itc' : average of single trial power and inter-trial
+ coherence across trials.
+
+ use_fft : bool
+ Use the FFT for convolutions or not.
+ mode : {'full', 'valid', 'same'}
+ See numpy.convolve.
+ decim : slice
+ The decimation slice: e.g. power[:, decim]
+ """
+ # Set output type
+ dtype = np.float
+ if output in ['complex', 'avg_power_itc']:
+ dtype = np.complex
+
+ # Init outputs
+ decim = _check_decim(decim)
+ n_epochs, n_times = X[:, decim].shape
+ n_freqs = len(Ws[0])
+ if ('avg_' in output) or ('itc' in output):
+ tfrs = np.zeros((n_freqs, n_times), dtype=dtype)
+ else:
+ tfrs = np.zeros((n_epochs, n_freqs, n_times), dtype=dtype)
+
+ # Loops across tapers.
+ for W in Ws:
+ coefs = _cwt(X, W, mode, decim=decim, use_fft=use_fft)
+
+ # Inter-trial phase locking is apparently computed per taper...
+ if 'itc' in output:
+ plf = np.zeros((n_freqs, n_times), dtype=np.complex)
+
+ # Loop across epochs
+ for epoch_idx, tfr in enumerate(coefs):
+ # Transform complex values
+ if output in ['power', 'avg_power']:
+ tfr = (tfr * tfr.conj()).real # power
+ elif output == 'phase':
+ tfr = np.angle(tfr)
+ elif output == 'avg_power_itc':
+ tfr_abs = np.abs(tfr)
+ plf += tfr / tfr_abs # phase
+ tfr = tfr_abs ** 2 # power
+ elif output == 'itc':
+ plf += tfr / np.abs(tfr) # phase
+ continue # not need to stack anything else than plf
+
+ # Stack or add
+ if ('avg_' in output) or ('itc' in output):
+ tfrs += tfr
+ else:
+ tfrs[epoch_idx] += tfr
+
+ # Compute inter trial coherence
+ if output == 'avg_power_itc':
+ tfrs += 1j * np.abs(plf)
+ elif output == 'itc':
+ tfrs += np.abs(plf)
+
+ # Normalization of average metrics
+ if ('avg_' in output) or ('itc' in output):
+ tfrs /= n_epochs
+
+ # Normalization by number of taper
+ tfrs /= len(Ws)
+ return tfrs
+
+
+ at deprecated("This function will be removed in mne 0.14; use mne.time_frequency"
+ ".tfr_morlet() with average=False instead.")
def cwt_morlet(X, sfreq, freqs, use_fft=True, n_cycles=7.0, zero_mean=False,
decim=1):
"""Compute time freq decomposition with Morlet wavelets
@@ -265,13 +560,15 @@ def cwt_morlet(X, sfreq, freqs, use_fft=True, n_cycles=7.0, zero_mean=False,
n_cycles: float | array of float
Number of cycles. Fixed number or one per frequency.
zero_mean : bool
- Make sure the wavelets are zero mean.
+ Make sure the wavelets have a mean of zero.
decim : int | slice
To reduce memory usage, decimation factor after time-frequency
decomposition.
If `int`, returns tfr[..., ::decim].
- If `slice` returns tfr[..., decim].
- Note that decimation may create aliasing artifacts.
+ If `slice`, returns tfr[..., decim].
+
+ .. note: Decimation may create aliasing artifacts.
+
Defaults to 1.
Returns
@@ -318,14 +615,16 @@ def cwt(X, Ws, use_fft=True, mode='same', decim=1):
To reduce memory usage, decimation factor after time-frequency
decomposition.
If `int`, returns tfr[..., ::decim].
- If `slice` returns tfr[..., decim].
- Note that decimation may create aliasing artifacts.
+ If `slice`, returns tfr[..., decim].
+
+ .. note:: Decimation may create aliasing artifacts.
+
Defaults to 1.
Returns
-------
tfr : array, shape (n_signals, n_frequencies, n_times)
- The time frequency decompositions.
+ The time-frequency decompositions.
See Also
--------
@@ -344,27 +643,8 @@ def cwt(X, Ws, use_fft=True, mode='same', decim=1):
return tfrs
-def _time_frequency(X, Ws, use_fft, decim):
- """Aux of time_frequency for parallel computing over channels
- """
- decim = _check_decim(decim)
- n_epochs, n_times = X[:, decim].shape
- n_frequencies = len(Ws)
- psd = np.zeros((n_frequencies, n_times)) # PSD
- plf = np.zeros((n_frequencies, n_times), np.complex) # phase lock
-
- mode = 'same'
- tfrs = _cwt(X, Ws, mode, decim=decim, use_fft=use_fft)
-
- for tfr in tfrs:
- tfr_abs = np.abs(tfr)
- psd += tfr_abs ** 2
- plf += tfr / tfr_abs
- psd /= n_epochs
- plf = np.abs(plf) / n_epochs
- return psd, plf
-
-
+ at deprecated("This function will be removed in mne 0.14; use mne.time_frequency"
+ ".tfr_morlet() with average=False instead.")
@verbose
def single_trial_power(data, sfreq, frequencies, use_fft=True, n_cycles=7,
baseline=None, baseline_mode='ratio', times=None,
@@ -373,7 +653,7 @@ def single_trial_power(data, sfreq, frequencies, use_fft=True, n_cycles=7,
Parameters
----------
- data : array of shape [n_epochs, n_channels, n_times]
+ data : array, shape (n_epochs, n_channels, n_times)
The epochs
sfreq : float
Sampling rate
@@ -392,24 +672,31 @@ def single_trial_power(data, sfreq, frequencies, use_fft=True, n_cycles=7,
and if b is None then b is set to the end of the interval.
If baseline is equal ot (None, None) all the time
interval is used.
- baseline_mode : None | 'ratio' | 'zscore'
+ baseline_mode : None | 'ratio' | 'zscore' | 'mean' | 'percent' | 'logratio' | 'zlogratio'
Do baseline correction with ratio (power is divided by mean
power during baseline) or zscore (power is divided by standard
deviation of power during baseline after subtracting the mean,
- power = [power - mean(power_baseline)] / std(power_baseline))
+ power = [power - mean(power_baseline)] / std(power_baseline)),
+ mean simply subtracts the mean power, percent is the same as
+ applying ratio then mean, logratio is the same as mean but then
+ rendered in log-scale, zlogratio is the same as zscore but data
+ is rendered in log-scale first.
+ If None no baseline correction is applied.
times : array
Required to define baseline
decim : int | slice
To reduce memory usage, decimation factor after time-frequency
decomposition.
If `int`, returns tfr[..., ::decim].
- If `slice` returns tfr[..., decim].
- Note that decimation may create aliasing artifacts.
+ If `slice`, returns tfr[..., decim].
+
+ .. note:: Decimation may create aliasing artifacts.
+
Defaults to 1.
n_jobs : int
The number of epochs to process at the same time
zero_mean : bool
- Make sure the wavelets are zero mean.
+ Make sure the wavelets have a mean of zero.
verbose : bool, str, int, or None
If not None, override default verbose level (see mne.verbose).
@@ -417,7 +704,7 @@ def single_trial_power(data, sfreq, frequencies, use_fft=True, n_cycles=7,
-------
power : 4D array
Power estimate (Epochs x Channels x Frequencies x Timepoints).
- """
+ """ # noqa
decim = _check_decim(decim)
mode = 'same'
n_frequencies = len(frequencies)
@@ -454,113 +741,258 @@ def single_trial_power(data, sfreq, frequencies, use_fft=True, n_cycles=7,
return power
-def _induced_power_cwt(data, sfreq, frequencies, use_fft=True, n_cycles=7,
- decim=1, n_jobs=1, zero_mean=False):
- """Compute time induced power and inter-trial phase-locking factor
+# Aux function to reduce redundancy between tfr_morlet and tfr_multitaper
+
+def _tfr_aux(method, inst, freqs, decim, return_itc, picks, average,
+ **tfr_params):
+ decim = _check_decim(decim)
+ data = _get_data(inst, return_itc)
+ info = inst.info
+
+ info, data, picks = _prepare_picks(info, data, picks)
+ data = data[:, picks, :]
+
+ if average:
+ if return_itc:
+ output = 'avg_power_itc'
+ else:
+ output = 'avg_power'
+ else:
+ output = 'power'
+ if return_itc:
+ raise ValueError('Inter-trial coherence is not supported'
+ ' with average=False')
+
+ out = _compute_tfr(data, freqs, info['sfreq'], method=method,
+ output=output, decim=decim, **tfr_params)
+ times = inst.times[decim].copy()
+
+ if average:
+ if return_itc:
+ power, itc = out.real, out.imag
+ else:
+ power = out
+ nave = len(data)
+ out = AverageTFR(info, power, times, freqs, nave,
+ method='%s-power' % method)
+ if return_itc:
+ out = (out, AverageTFR(info, itc, times, freqs, nave,
+ method='%s-itc' % method))
+ else:
+ power = out
+ out = EpochsTFR(info, power, times, freqs, method='%s-power' % method)
+
+ return out
+
- The time frequency decomposition is done with Morlet wavelets
+ at verbose
+def tfr_morlet(inst, freqs, n_cycles, use_fft=False, return_itc=True, decim=1,
+ n_jobs=1, picks=None, zero_mean=True, average=True,
+ verbose=None):
+ """Compute Time-Frequency Representation (TFR) using Morlet wavelets
Parameters
----------
- data : array
- 3D array of shape [n_epochs, n_channels, n_times]
- sfreq : float
- Sampling frequency.
- frequencies : array
- Array of frequencies of interest
- use_fft : bool
- Compute transform with fft based convolutions or temporal
- convolutions.
- n_cycles : float | array of float
- Number of cycles. Fixed number or one per frequency.
- decim : int | slice
+ inst : Epochs | Evoked
+ The epochs or evoked object.
+ freqs : ndarray, shape (n_freqs,)
+ The frequencies in Hz.
+ n_cycles : float | ndarray, shape (n_freqs,)
+ The number of cycles globally or for each frequency.
+ use_fft : bool, defaults to False
+ The fft based convolution or not.
+ return_itc : bool, defaults to True
+ Return inter-trial coherence (ITC) as well as averaged power.
+ Must be ``False`` for evoked data.
+ decim : int | slice, defaults to 1
To reduce memory usage, decimation factor after time-frequency
decomposition.
If `int`, returns tfr[..., ::decim].
- If `slice` returns tfr[..., decim].
- Note that decimation may create aliasing artifacts.
- Defaults to 1.
- n_jobs : int
- The number of CPUs used in parallel. All CPUs are used in -1.
- Requires joblib package.
- zero_mean : bool
- Make sure the wavelets are zero mean.
+ If `slice`, returns tfr[..., decim].
+
+ .. note:: Decimation may create aliasing artifacts.
+
+ n_jobs : int, defaults to 1
+ The number of jobs to run in parallel.
+ picks : array-like of int | None, defaults to None
+ The indices of the channels to plot. If None, all available
+ channels are displayed.
+ zero_mean : bool, defaults to True
+ Make sure the wavelet has a mean of zero.
+
+ .. versionadded:: 0.13.0
+ average : bool, defaults to True
+ If True average across Epochs.
+
+ .. versionadded:: 0.13.0
+ verbose : bool, str, int, or None, defaults to None
+ If not None, override default verbose level (see mne.verbose).
Returns
-------
- power : 2D array
- Induced power (Channels x Frequencies x Timepoints).
- Squared amplitude of time-frequency coefficients.
- phase_lock : 2D array
- Phase locking factor in [0, 1] (Channels x Frequencies x Timepoints)
- """
- decim = _check_decim(decim)
- n_frequencies = len(frequencies)
- n_epochs, n_channels, n_times = data[:, :, decim].shape
+ power : AverageTFR | EpochsTFR
+ The averaged power.
+ itc : AverageTFR | EpochsTFR
+ The inter-trial coherence (ITC). Only returned if return_itc
+ is True.
- # Precompute wavelets for given frequency range to save time
- Ws = morlet(sfreq, frequencies, n_cycles=n_cycles, zero_mean=zero_mean)
+ See Also
+ --------
+ tfr_multitaper, tfr_stockwell
+ """
+ tfr_params = dict(n_cycles=n_cycles, n_jobs=n_jobs, use_fft=use_fft,
+ zero_mean=zero_mean)
+ return _tfr_aux('morlet', inst, freqs, decim, return_itc, picks,
+ average, **tfr_params)
- psd = np.empty((n_channels, n_frequencies, n_times))
- plf = np.empty((n_channels, n_frequencies, n_times))
- # Separate to save memory for n_jobs=1
- parallel, my_time_frequency, _ = parallel_func(_time_frequency, n_jobs)
- psd_plf = parallel(my_time_frequency(data[:, c, :], Ws, use_fft, decim)
- for c in range(n_channels))
- for c, (psd_c, plf_c) in enumerate(psd_plf):
- psd[c, :, :], plf[c, :, :] = psd_c, plf_c
- return psd, plf
+ at verbose
+def tfr_multitaper(inst, freqs, n_cycles, time_bandwidth=4.0,
+ use_fft=True, return_itc=True, decim=1,
+ n_jobs=1, picks=None, average=True, verbose=None):
+ """Compute Time-Frequency Representation (TFR) using DPSS tapers.
-def _preproc_tfr(data, times, freqs, tmin, tmax, fmin, fmax, mode,
- baseline, vmin, vmax, dB, sfreq):
- """Aux Function to prepare tfr computation"""
- from ..viz.utils import _setup_vmin_vmax
+ Parameters
+ ----------
+ inst : Epochs | Evoked
+ The epochs or evoked object.
+ freqs : ndarray, shape (n_freqs,)
+ The frequencies in Hz.
+ n_cycles : float | ndarray, shape (n_freqs,)
+ The number of cycles globally or for each frequency.
+ The time-window length is thus T = n_cycles / freq.
+ time_bandwidth : float, (optional), defaults to 4.0 (3 good tapers).
+ Time x (Full) Bandwidth product. Should be >= 2.0.
+ Choose this along with n_cycles to get desired frequency resolution.
+ The number of good tapers (least leakage from far away frequencies)
+ is chosen automatically based on this to floor(time_bandwidth - 1).
+ E.g., With freq = 20 Hz and n_cycles = 10, we get time = 0.5 s.
+ If time_bandwidth = 4., then frequency smoothing is (4 / time) = 8 Hz.
+ use_fft : bool, defaults to True
+ The fft based convolution or not.
+ return_itc : bool, defaults to True
+ Return inter-trial coherence (ITC) as well as averaged power.
+ decim : int | slice, defaults to 1
+ To reduce memory usage, decimation factor after time-frequency
+ decomposition.
+ If `int`, returns tfr[..., ::decim].
+ If `slice`, returns tfr[..., decim].
- copy = baseline is not None
- data = rescale(data, times, baseline, mode, copy=copy)
+ .. note:: Decimation may create aliasing artifacts.
- # crop time
- itmin, itmax = None, None
- idx = np.where(_time_mask(times, tmin, tmax, sfreq=sfreq))[0]
- if tmin is not None:
- itmin = idx[0]
- if tmax is not None:
- itmax = idx[-1] + 1
+ n_jobs : int, defaults to 1
+ The number of jobs to run in parallel.
+ picks : array-like of int | None, defaults to None
+ The indices of the channels to plot. If None, all available
+ channels are displayed.
+ average : bool, defaults to True
+ If True average across Epochs.
- times = times[itmin:itmax]
+ .. versionadded:: 0.13.0
+ verbose : bool, str, int, or None, defaults to None
+ If not None, override default verbose level (see mne.verbose).
- # crop freqs
- ifmin, ifmax = None, None
- idx = np.where(_time_mask(freqs, fmin, fmax, sfreq=sfreq))[0]
- if fmin is not None:
- ifmin = idx[0]
- if fmax is not None:
- ifmax = idx[-1] + 1
+ Returns
+ -------
+ power : AverageTFR | EpochsTFR
+ The averaged power.
+ itc : AverageTFR | EpochsTFR
+ The inter-trial coherence (ITC). Only returned if return_itc
+ is True.
- freqs = freqs[ifmin:ifmax]
+ See Also
+ --------
+ tfr_multitaper, tfr_stockwell
- # crop data
- data = data[:, ifmin:ifmax, itmin:itmax]
+ Notes
+ -----
+ .. versionadded:: 0.9.0
+ """
+ tfr_params = dict(n_cycles=n_cycles, n_jobs=n_jobs, use_fft=use_fft,
+ zero_mean=True, time_bandwidth=time_bandwidth)
+ return _tfr_aux('multitaper', inst, freqs, decim, return_itc, picks,
+ average, **tfr_params)
- times *= 1e3
- if dB:
- data = 10 * np.log10((data * data.conj()).real)
- vmin, vmax = _setup_vmin_vmax(data, vmin, vmax)
- return data, times, freqs, vmin, vmax
+# TFR(s) class
+class _BaseTFR(ContainsMixin, UpdateChannelsMixin, SizeMixin):
+ @property
+ def ch_names(self):
+ return self.info['ch_names']
-class AverageTFR(ContainsMixin, UpdateChannelsMixin):
- """Container for Time-Frequency data
+ def crop(self, tmin=None, tmax=None):
+ """Crop data to a given time interval in place
- Can for example store induced power at sensor level or intertrial
- coherence.
+ Parameters
+ ----------
+ tmin : float | None
+ Start time of selection in seconds.
+ tmax : float | None
+ End time of selection in seconds.
- Parameters
- ----------
- info : Info
- The measurement info.
+ Returns
+ -------
+ inst : instance of AverageTFR
+ The modified instance.
+ """
+ mask = _time_mask(self.times, tmin, tmax, sfreq=self.info['sfreq'])
+ self.times = self.times[mask]
+ self.data = self.data[..., mask]
+ return self
+
+ def copy(self):
+ """Return a copy of the instance."""
+ return deepcopy(self)
+
+ @verbose
+ def apply_baseline(self, baseline, mode='mean', verbose=None):
+ """Baseline correct the data
+
+ Parameters
+ ----------
+ baseline : tuple or list of length 2
+ The time interval to apply rescaling / baseline correction.
+ If None do not apply it. If baseline is (a, b)
+ the interval is between "a (s)" and "b (s)".
+ If a is None the beginning of the data is used
+ and if b is None then b is set to the end of the interval.
+ If baseline is equal to (None, None) all the time
+ interval is used.
+ mode : None | 'ratio' | 'zscore' | 'mean' | 'percent' | 'logratio' | 'zlogratio'
+ Do baseline correction with ratio (power is divided by mean
+ power during baseline) or zscore (power is divided by standard
+ deviation of power during baseline after subtracting the mean,
+ power = [power - mean(power_baseline)] / std(power_baseline)),
+ mean simply subtracts the mean power, percent is the same as
+ applying ratio then mean, logratio is the same as mean but then
+ rendered in log-scale, zlogratio is the same as zscore but data
+ is rendered in log-scale first.
+ If None no baseline correction is applied.
+ verbose : bool, str, int, or None
+ If not None, override default verbose level (see mne.verbose).
+
+ Returns
+ -------
+ inst : instance of AverageTFR
+ The modified instance.
+
+ """ # noqa
+ self.data = rescale(self.data, self.times, baseline, mode,
+ copy=False)
+ return self
+
+
+class AverageTFR(_BaseTFR):
+ """Container for Time-Frequency data
+
+ Can for example store induced power at sensor level or inter-trial
+ coherence.
+
+ Parameters
+ ----------
+ info : Info
+ The measurement info.
data : ndarray, shape (n_channels, n_freqs, n_times)
The data.
times : ndarray, shape (n_times,)
@@ -569,12 +1001,10 @@ class AverageTFR(ContainsMixin, UpdateChannelsMixin):
The frequencies in Hz.
nave : int
The number of averaged TFRs.
- comment : str | None
+ comment : str | None, defaults to None
Comment on the data, e.g., the experimental condition.
- Defaults to None.
- method : str | None
+ method : str | None, defaults to None
Comment on the method used to compute the data, e.g., morlet wavelet.
- Defaults to None.
verbose : bool, str, int, or None
If not None, override default verbose level (see mne.verbose).
@@ -600,43 +1030,14 @@ class AverageTFR(ContainsMixin, UpdateChannelsMixin):
raise ValueError("Number of times and data size don't match"
" (%d != %d)." % (n_times, len(times)))
self.data = data
- self.times = np.asarray(times)
- self.freqs = np.asarray(freqs)
+ self.times = np.array(times, dtype=float)
+ self.freqs = np.array(freqs, dtype=float)
self.nave = nave
self.comment = comment
self.method = method
- @property
- def ch_names(self):
- return self.info['ch_names']
-
- def crop(self, tmin=None, tmax=None, copy=None):
- """Crop data to a given time interval
-
- Parameters
- ----------
- tmin : float | None
- Start time of selection in seconds.
- tmax : float | None
- End time of selection in seconds.
- copy : bool
- This parameter has been deprecated and will be removed in 0.13.
- Use inst.copy() instead.
- Whether to return a new instance or modify in place.
-
- Returns
- -------
- inst : instance of AverageTFR
- The modified instance.
- """
- inst = _check_copy_dep(self, copy)
- mask = _time_mask(inst.times, tmin, tmax, sfreq=self.info['sfreq'])
- inst.times = inst.times[mask]
- inst.data = inst.data[:, :, mask]
- return inst
-
@verbose
- def plot(self, picks=None, baseline=None, mode='mean', tmin=None,
+ def plot(self, picks, baseline=None, mode='mean', tmin=None,
tmax=None, fmin=None, fmax=None, vmin=None, vmax=None,
cmap='RdBu_r', dB=False, colorbar=True, show=True,
title=None, axes=None, layout=None, verbose=None):
@@ -644,8 +1045,8 @@ class AverageTFR(ContainsMixin, UpdateChannelsMixin):
Parameters
----------
- picks : array-like of int | None
- The indices of the channels to plot.
+ picks : array-like of int
+ The indices of the channels to plot, one figure per channel.
baseline : None (default) or tuple of length 2
The time interval to apply baseline correction.
If None do not apply it. If baseline is (a, b)
@@ -654,11 +1055,15 @@ class AverageTFR(ContainsMixin, UpdateChannelsMixin):
and if b is None then b is set to the end of the interval.
If baseline is equal ot (None, None) all the time
interval is used.
- mode : None | 'logratio' | 'ratio' | 'zscore' | 'mean' | 'percent'
+ mode : None | 'ratio' | 'zscore' | 'mean' | 'percent' | 'logratio' | 'zlogratio'
Do baseline correction with ratio (power is divided by mean
power during baseline) or zscore (power is divided by standard
deviation of power during baseline after subtracting the mean,
- power = [power - mean(power_baseline)] / std(power_baseline)).
+ power = [power - mean(power_baseline)] / std(power_baseline)),
+ mean simply subtracts the mean power, percent is the same as
+ applying ratio then mean, logratio is the same as mean but then
+ rendered in log-scale, zlogratio is the same as zscore but data
+ is rendered in log-scale first.
If None no baseline correction is applied.
tmin : None | float
The first time instant to display. If None the first time point
@@ -678,8 +1083,20 @@ class AverageTFR(ContainsMixin, UpdateChannelsMixin):
vmax : float | None
The maxinum value an the color scale. If vmax is None, the data
maximum value is used.
- cmap : matplotlib colormap | str
- The colormap to use. Defaults to 'RdBu_r'.
+ cmap : matplotlib colormap | 'interactive' | (colormap, bool)
+ The colormap to use. If tuple, the first value indicates the
+ colormap to use and the second value is a boolean defining
+ interactivity. In interactive mode the colors are adjustable by
+ clicking and dragging the colorbar with left and right mouse
+ button. Left mouse button moves the scale up and down and right
+ mouse button adjusts the range. Hitting space bar resets the range.
+ Up and down arrows can be used to change the colormap. If
+ 'interactive', translates to ('RdBu_r', True). Defaults to
+ 'RdBu_r'.
+
+ .. warning:: Interactive mode works smoothly only for a small
+ amount of images.
+
dB : bool
If True, 20*log10 is applied to the data to get dB.
colorbar : bool
@@ -704,7 +1121,7 @@ class AverageTFR(ContainsMixin, UpdateChannelsMixin):
-------
fig : matplotlib.figure.Figure
The figure containing the topography.
- """
+ """ # noqa
from ..viz.topo import _imshow_tfr
import matplotlib.pyplot as plt
times, freqs = self.times.copy(), self.freqs.copy()
@@ -727,6 +1144,10 @@ class AverageTFR(ContainsMixin, UpdateChannelsMixin):
raise RuntimeError('There must be an axes for each picked '
'channel.')
+ if cmap == 'interactive':
+ cmap = ('RdBu_r', True)
+ elif not isinstance(cmap, tuple):
+ cmap = (cmap, True)
for idx in range(len(data)):
if axes is None:
fig = plt.figure()
@@ -742,7 +1163,9 @@ class AverageTFR(ContainsMixin, UpdateChannelsMixin):
colorbar=colorbar, picker=False, cmap=cmap)
if title:
fig.suptitle(title)
- colorbar = False # only one colorbar for multiple axes
+ # Only draw 1 cbar. For interactive mode we pass the ref to cbar.
+ colorbar = ax.CB if cmap[1] else False
+
plt_show(show)
return fig
@@ -772,13 +1195,14 @@ class AverageTFR(ContainsMixin, UpdateChannelsMixin):
if 'mag' in self:
types.append('mag')
if 'grad' in self:
- types.append('grad')
+ if len(_pair_grad_sensors(self.info, topomap_coords=False,
+ raise_error=False)) >= 2:
+ types.append('grad')
+ elif len(types) == 0:
+ return # Don't draw a figure for nothing.
fig = figure_nobar()
- fig.suptitle('{:.2f} s - {:.2f} s, {:.2f} Hz - {:.2f} Hz'.format(tmin,
- tmax,
- fmin,
- fmax),
- y=0.04)
+ fig.suptitle('{0:.2f} s - {1:.2f} s, {2:.2f} Hz - {3:.2f} Hz'.format(
+ tmin, tmax, fmin, fmax), y=0.04)
for idx, ch_type in enumerate(types):
ax = plt.subplot(1, len(types), idx + 1)
plot_tfr_topomap(self, ch_type=ch_type, tmin=tmin, tmax=tmax,
@@ -791,13 +1215,14 @@ class AverageTFR(ContainsMixin, UpdateChannelsMixin):
tmax=None, fmin=None, fmax=None, vmin=None, vmax=None,
layout=None, cmap='RdBu_r', title=None, dB=False,
colorbar=True, layout_scale=0.945, show=True,
- border='none', fig_facecolor='k', font_color='w'):
+ border='none', fig_facecolor='k', fig_background=None,
+ font_color='w'):
"""Plot TFRs in a topography with images
Parameters
----------
picks : array-like of int | None
- The indices of the channels to plot. If None all available
+ The indices of the channels to plot. If None, all available
channels are displayed.
baseline : None (default) or tuple of length 2
The time interval to apply baseline correction.
@@ -807,11 +1232,15 @@ class AverageTFR(ContainsMixin, UpdateChannelsMixin):
and if b is None then b is set to the end of the interval.
If baseline is equal ot (None, None) all the time
interval is used.
- mode : None | 'logratio' | 'ratio' | 'zscore' | 'mean' | 'percent'
+ mode : None | 'ratio' | 'zscore' | 'mean' | 'percent' | 'logratio' | 'zlogratio'
Do baseline correction with ratio (power is divided by mean
power during baseline) or zscore (power is divided by standard
deviation of power during baseline after subtracting the mean,
- power = [power - mean(power_baseline)] / std(power_baseline)).
+ power = [power - mean(power_baseline)] / std(power_baseline)),
+ mean simply subtracts the mean power, percent is the same as
+ applying ratio then mean, logratio is the same as mean but then
+ rendered in log-scale, zlogratio is the same as zscore but data
+ is rendered in log-scale first.
If None no baseline correction is applied.
tmin : None | float
The first time instant to display. If None the first time point
@@ -851,6 +1280,9 @@ class AverageTFR(ContainsMixin, UpdateChannelsMixin):
matplotlib borders style to be used for each sensor plot.
fig_facecolor : str | obj
The figure face color. Defaults to black.
+ fig_background : None | array
+ A background image for the figure. This must be a valid input to
+ `matplotlib.pyplot.imshow`. Defaults to None.
font_color: str | obj
The color of tick labels in the colorbar. Defaults to white.
@@ -858,8 +1290,9 @@ class AverageTFR(ContainsMixin, UpdateChannelsMixin):
-------
fig : matplotlib.figure.Figure
The figure containing the topography.
- """
+ """ # noqa
from ..viz.topo import _imshow_tfr, _plot_topo, _imshow_tfr_unified
+ from ..viz import add_background_image
times = self.times.copy()
freqs = self.freqs
data = self.data
@@ -878,8 +1311,8 @@ class AverageTFR(ContainsMixin, UpdateChannelsMixin):
onselect_callback = partial(self._onselect, baseline=baseline,
mode=mode, layout=layout)
- click_fun = partial(_imshow_tfr, tfr=data, freq=freqs, cmap=cmap,
- onselect=onselect_callback)
+ click_fun = partial(_imshow_tfr, tfr=data, freq=freqs,
+ cmap=(cmap, True), onselect=onselect_callback)
imshow = partial(_imshow_tfr_unified, tfr=data, freq=freqs, cmap=cmap,
onselect=onselect_callback)
@@ -890,73 +1323,11 @@ class AverageTFR(ContainsMixin, UpdateChannelsMixin):
x_label='Time (ms)', y_label='Frequency (Hz)',
fig_facecolor=fig_facecolor, font_color=font_color,
unified=True, img=True)
+
+ add_background_image(fig, fig_background)
plt_show(show)
return fig
- def _check_compat(self, tfr):
- """checks that self and tfr have the same time-frequency ranges"""
- assert np.all(tfr.times == self.times)
- assert np.all(tfr.freqs == self.freqs)
-
- def __add__(self, tfr):
- self._check_compat(tfr)
- out = self.copy()
- out.data += tfr.data
- return out
-
- def __iadd__(self, tfr):
- self._check_compat(tfr)
- self.data += tfr.data
- return self
-
- def __sub__(self, tfr):
- self._check_compat(tfr)
- out = self.copy()
- out.data -= tfr.data
- return out
-
- def __isub__(self, tfr):
- self._check_compat(tfr)
- self.data -= tfr.data
- return self
-
- def copy(self):
- """Return a copy of the instance."""
- return deepcopy(self)
-
- def __repr__(self):
- s = "time : [%f, %f]" % (self.times[0], self.times[-1])
- s += ", freq : [%f, %f]" % (self.freqs[0], self.freqs[-1])
- s += ", nave : %d" % self.nave
- s += ', channels : %d' % self.data.shape[0]
- return "<AverageTFR | %s>" % s
-
- @verbose
- def apply_baseline(self, baseline, mode='mean', verbose=None):
- """Baseline correct the data
-
- Parameters
- ----------
- baseline : tuple or list of length 2
- The time interval to apply rescaling / baseline correction.
- If None do not apply it. If baseline is (a, b)
- the interval is between "a (s)" and "b (s)".
- If a is None the beginning of the data is used
- and if b is None then b is set to the end of the interval.
- If baseline is equal to (None, None) all the time
- interval is used.
- mode : 'logratio' | 'ratio' | 'zscore' | 'mean' | 'percent'
- Do baseline correction with ratio (power is divided by mean
- power during baseline) or z-score (power is divided by standard
- deviation of power during baseline after subtracting the mean,
- power = [power - mean(power_baseline)] / std(power_baseline))
- If None, baseline no correction will be performed.
- verbose : bool, str, int, or None
- If not None, override default verbose level (see mne.verbose).
- """
- self.data = rescale(self.data, self.times, baseline, mode,
- copy=False)
-
def plot_topomap(self, tmin=None, tmax=None, fmin=None, fmax=None,
ch_type=None, baseline=None, mode='mean',
layout=None, vmin=None, vmax=None, cmap=None,
@@ -992,12 +1363,16 @@ class AverageTFR(ContainsMixin, UpdateChannelsMixin):
and if b is None then b is set to the end of the interval.
If baseline is equal to (None, None) all the time
interval is used.
- mode : 'logratio' | 'ratio' | 'zscore' | 'mean' | 'percent'
+ mode : None | 'ratio' | 'zscore' | 'mean' | 'percent' | 'logratio' | 'zlogratio'
Do baseline correction with ratio (power is divided by mean
- power during baseline) or z-score (power is divided by standard
+ power during baseline) or zscore (power is divided by standard
deviation of power during baseline after subtracting the mean,
- power = [power - mean(power_baseline)] / std(power_baseline))
- If None, baseline no correction will be performed.
+ power = [power - mean(power_baseline)] / std(power_baseline)),
+ mean simply subtracts the mean power, percent is the same as
+ applying ratio then mean, logratio is the same as mean but then
+ rendered in log-scale, zlogratio is the same as zscore but data
+ is rendered in log-scale first.
+ If None no baseline correction is applied.
layout : None | Layout
Layout instance specifying sensor positions (does not need to
be specified for Neuromag data). If possible, the correct layout
@@ -1013,10 +1388,16 @@ class AverageTFR(ContainsMixin, UpdateChannelsMixin):
The value specifying the upper bound of the color range. If None,
the maximum value is used. If callable, the output equals
vmax(data). Defaults to None.
- cmap : matplotlib colormap | None
- Colormap. If None and the plotted data is all positive, defaults to
- 'Reds'. If None and data contains also negative values, defaults to
- 'RdBu_r'. Defaults to None.
+ cmap : matplotlib colormap | (colormap, bool) | 'interactive' | None
+ Colormap to use. If tuple, the first value indicates the colormap
+ to use and the second value is a boolean defining interactivity. In
+ interactive mode the colors are adjustable by clicking and dragging
+ the colorbar with left and right mouse button. Left mouse button
+ moves the scale up and down and right mouse button adjusts the
+ range. Hitting space bar resets the range. Up and down arrows can
+ be used to change the colormap. If None (default), 'Reds' is used
+ for all positive data, otherwise defaults to 'RdBu_r'. If
+ 'interactive', translates to (None, True).
sensors : bool | str
Add markers for sensor locations to the plot. Accepts matplotlib
plot format string (e.g., 'r+' for red plusses). If True, a circle
@@ -1065,7 +1446,7 @@ class AverageTFR(ContainsMixin, UpdateChannelsMixin):
-------
fig : matplotlib.figure.Figure
The figure containing the topography.
- """
+ """ # noqa
from ..viz import plot_tfr_topomap
return plot_tfr_topomap(self, tmin=tmin, tmax=tmax, fmin=fmin,
fmax=fmax, ch_type=ch_type, baseline=baseline,
@@ -1076,6 +1457,41 @@ class AverageTFR(ContainsMixin, UpdateChannelsMixin):
title=title, axes=axes, show=show,
outlines=outlines, head_pos=head_pos)
+ def _check_compat(self, tfr):
+ """checks that self and tfr have the same time-frequency ranges"""
+ assert np.all(tfr.times == self.times)
+ assert np.all(tfr.freqs == self.freqs)
+
+ def __add__(self, tfr):
+ self._check_compat(tfr)
+ out = self.copy()
+ out.data += tfr.data
+ return out
+
+ def __iadd__(self, tfr):
+ self._check_compat(tfr)
+ self.data += tfr.data
+ return self
+
+ def __sub__(self, tfr):
+ self._check_compat(tfr)
+ out = self.copy()
+ out.data -= tfr.data
+ return out
+
+ def __isub__(self, tfr):
+ self._check_compat(tfr)
+ self.data -= tfr.data
+ return self
+
+ def __repr__(self):
+ s = "time : [%f, %f]" % (self.times[0], self.times[-1])
+ s += ", freq : [%f, %f]" % (self.freqs[0], self.freqs[-1])
+ s += ", nave : %d" % self.nave
+ s += ', channels : %d' % self.data.shape[0]
+ s += ', ~%s' % (sizeof_fmt(self._size),)
+ return "<AverageTFR | %s>" % s
+
def save(self, fname, overwrite=False):
"""Save TFR object to hdf5 file
@@ -1089,334 +1505,77 @@ class AverageTFR(ContainsMixin, UpdateChannelsMixin):
write_tfrs(fname, self, overwrite=overwrite)
-def _prepare_write_tfr(tfr, condition):
- """Aux function"""
- return (condition, dict(times=tfr.times, freqs=tfr.freqs,
- data=tfr.data, info=tfr.info,
- nave=tfr.nave, comment=tfr.comment,
- method=tfr.method))
-
+class EpochsTFR(_BaseTFR):
+ """Container for Time-Frequency data on epochs
-def write_tfrs(fname, tfr, overwrite=False):
- """Write a TFR dataset to hdf5.
+ Can for example store induced power at sensor level.
Parameters
----------
- fname : string
- The file name, which should end with -tfr.h5
- tfr : AverageTFR instance, or list of AverageTFR instances
- The TFR dataset, or list of TFR datasets, to save in one file.
- Note. If .comment is not None, a name will be generated on the fly,
- based on the order in which the TFR objects are passed
- overwrite : bool
- If True, overwrite file (if it exists). Defaults to False.
+ info : Info
+ The measurement info.
+ data : ndarray, shape (n_epochs, n_channels, n_freqs, n_times)
+ The data.
+ times : ndarray, shape (n_times,)
+ The time values in seconds.
+ freqs : ndarray, shape (n_freqs,)
+ The frequencies in Hz.
+ comment : str | None, defaults to None
+ Comment on the data, e.g., the experimental condition.
+ method : str | None, defaults to None
+ Comment on the method used to compute the data, e.g., morlet wavelet.
+ verbose : bool, str, int, or None
+ If not None, override default verbose level (see mne.verbose).
- See Also
- --------
- read_tfrs
+ Attributes
+ ----------
+ ch_names : list
+ The names of the channels.
Notes
-----
- .. versionadded:: 0.9.0
+ .. versionadded:: 0.13.0
"""
- out = []
- if not isinstance(tfr, (list, tuple)):
- tfr = [tfr]
- for ii, tfr_ in enumerate(tfr):
- comment = ii if tfr_.comment is None else tfr_.comment
- out.append(_prepare_write_tfr(tfr_, condition=comment))
- write_hdf5(fname, out, overwrite=overwrite, title='mnepython')
-
+ @verbose
+ def __init__(self, info, data, times, freqs, comment=None,
+ method=None, verbose=None):
+ self.info = info
+ if data.ndim != 4:
+ raise ValueError('data should be 4d. Got %d.' % data.ndim)
+ n_epochs, n_channels, n_freqs, n_times = data.shape
+ if n_channels != len(info['chs']):
+ raise ValueError("Number of channels and data size don't match"
+ " (%d != %d)." % (n_channels, len(info['chs'])))
+ if n_freqs != len(freqs):
+ raise ValueError("Number of frequencies and data size don't match"
+ " (%d != %d)." % (n_freqs, len(freqs)))
+ if n_times != len(times):
+ raise ValueError("Number of times and data size don't match"
+ " (%d != %d)." % (n_times, len(times)))
+ self.data = data
+ self.times = np.array(times, dtype=float)
+ self.freqs = np.array(freqs, dtype=float)
+ self.comment = comment
+ self.method = method
-def read_tfrs(fname, condition=None):
- """
- Read TFR datasets from hdf5 file.
+ def __repr__(self):
+ s = "time : [%f, %f]" % (self.times[0], self.times[-1])
+ s += ", freq : [%f, %f]" % (self.freqs[0], self.freqs[-1])
+ s += ", epochs : %d" % self.data.shape[0]
+ s += ', channels : %d' % self.data.shape[1]
+ s += ', ~%s' % (sizeof_fmt(self._size),)
+ return "<EpochsTFR | %s>" % s
- Parameters
- ----------
- fname : string
- The file name, which should end with -tfr.h5 .
- condition : int or str | list of int or str | None
- The condition to load. If None, all conditions will be returned.
- Defaults to None.
+ def average(self):
+ data = np.mean(self.data, axis=0)
+ return AverageTFR(info=self.info.copy(), data=data,
+ times=self.times.copy(), freqs=self.freqs.copy(),
+ nave=self.data.shape[0],
+ method=self.method)
- See Also
- --------
- write_tfrs
- Returns
- -------
- tfrs : list of instances of AverageTFR | instance of AverageTFR
- Depending on `condition` either the TFR object or a list of multiple
- TFR objects.
-
- Notes
- -----
- .. versionadded:: 0.9.0
- """
-
- check_fname(fname, 'tfr', ('-tfr.h5',))
-
- logger.info('Reading %s ...' % fname)
- tfr_data = read_hdf5(fname, title='mnepython')
- for k, tfr in tfr_data:
- tfr['info'] = Info(tfr['info'])
-
- if condition is not None:
- tfr_dict = dict(tfr_data)
- if condition not in tfr_dict:
- keys = ['%s' % k for k in tfr_dict]
- raise ValueError('Cannot find condition ("{0}") in this file. '
- 'I can give you "{1}""'
- .format(condition, " or ".join(keys)))
- out = AverageTFR(**tfr_dict[condition])
- else:
- out = [AverageTFR(**d) for d in list(zip(*tfr_data))[1]]
- return out
-
-
- at verbose
-def tfr_morlet(inst, freqs, n_cycles, use_fft=False, return_itc=True, decim=1,
- n_jobs=1, picks=None, verbose=None):
- """Compute Time-Frequency Representation (TFR) using Morlet wavelets
-
- Parameters
- ----------
- inst : Epochs | Evoked
- The epochs or evoked object.
- freqs : ndarray, shape (n_freqs,)
- The frequencies in Hz.
- n_cycles : float | ndarray, shape (n_freqs,)
- The number of cycles globally or for each frequency.
- use_fft : bool
- The fft based convolution or not.
- return_itc : bool
- Return intertrial coherence (ITC) as well as averaged power.
- Must be ``False`` for evoked data.
- decim : int | slice
- To reduce memory usage, decimation factor after time-frequency
- decomposition.
- If `int`, returns tfr[..., ::decim].
- If `slice` returns tfr[..., decim].
- Note that decimation may create aliasing artifacts.
- Defaults to 1.
- n_jobs : int
- The number of jobs to run in parallel.
- picks : array-like of int | None
- The indices of the channels to plot. If None all available
- channels are displayed.
- verbose : bool, str, int, or None
- If not None, override default verbose level (see mne.verbose).
-
- Returns
- -------
- power : instance of AverageTFR
- The averaged power.
- itc : instance of AverageTFR
- The intertrial coherence (ITC). Only returned if return_itc
- is True.
-
- See Also
- --------
- tfr_multitaper, tfr_stockwell
- """
- decim = _check_decim(decim)
- data = _get_data(inst, return_itc)
- info = inst.info
-
- info, data, picks = _prepare_picks(info, data, picks)
- data = data[:, picks, :]
-
- power, itc = _induced_power_cwt(data, sfreq=info['sfreq'],
- frequencies=freqs,
- n_cycles=n_cycles, n_jobs=n_jobs,
- use_fft=use_fft, decim=decim,
- zero_mean=True)
- times = inst.times[decim].copy()
- nave = len(data)
- out = AverageTFR(info, power, times, freqs, nave, method='morlet-power')
- if return_itc:
- out = (out, AverageTFR(info, itc, times, freqs, nave,
- method='morlet-itc'))
- return out
-
-
-def _prepare_picks(info, data, picks):
- if picks is None:
- picks = pick_types(info, meg=True, eeg=True, ref_meg=False,
- exclude='bads')
- if np.array_equal(picks, np.arange(len(data))):
- picks = slice(None)
- else:
- info = pick_info(info, picks)
-
- return info, data, picks
-
-
- at verbose
-def _induced_power_mtm(data, sfreq, frequencies, time_bandwidth=4.0,
- use_fft=True, n_cycles=7, decim=1, n_jobs=1,
- zero_mean=True, verbose=None):
- """Compute time induced power and inter-trial phase-locking factor
-
- The time frequency decomposition is done with DPSS wavelets
-
- Parameters
- ----------
- data : np.ndarray, shape (n_epochs, n_channels, n_times)
- The input data.
- sfreq : float
- Sampling frequency.
- frequencies : np.ndarray, shape (n_frequencies,)
- Array of frequencies of interest
- time_bandwidth : float
- Time x (Full) Bandwidth product.
- The number of good tapers (low-bias) is chosen automatically based on
- this to equal floor(time_bandwidth - 1). Default is 4.0 (3 tapers).
- use_fft : bool
- Compute transform with fft based convolutions or temporal
- convolutions. Defaults to True.
- n_cycles : float | np.ndarray shape (n_frequencies,)
- Number of cycles. Fixed number or one per frequency. Defaults to 7.
- decim : int | slice
- To reduce memory usage, decimation factor after time-frequency
- decomposition.
- If `int`, returns tfr[..., ::decim].
- If `slice` returns tfr[..., decim].
- Note that decimation may create aliasing artifacts.
- Defaults to 1.
- n_jobs : int
- The number of CPUs used in parallel. All CPUs are used in -1.
- Requires joblib package. Defaults to 1.
- zero_mean : bool
- Make sure the wavelets are zero mean. Defaults to True.
- verbose : bool, str, int, or None
- If not None, override default verbose level (see mne.verbose).
-
- Returns
- -------
- power : np.ndarray, shape (n_channels, n_frequencies, n_times)
- Induced power. Squared amplitude of time-frequency coefficients.
- itc : np.ndarray, shape (n_channels, n_frequencies, n_times)
- Phase locking value.
- """
- decim = _check_decim(decim)
- n_epochs, n_channels, n_times = data[:, :, decim].shape
- logger.info('Data is %d trials and %d channels', n_epochs, n_channels)
- n_frequencies = len(frequencies)
- logger.info('Multitaper time-frequency analysis for %d frequencies',
- n_frequencies)
-
- # Precompute wavelets for given frequency range to save time
- Ws = _dpss_wavelet(sfreq, frequencies, n_cycles=n_cycles,
- time_bandwidth=time_bandwidth, zero_mean=zero_mean)
- n_taps = len(Ws)
- logger.info('Using %d tapers', n_taps)
- n_times_wavelets = Ws[0][0].shape[0]
- if data.shape[2] <= n_times_wavelets:
- warn('Time windows are as long or longer than the epoch. Consider '
- 'reducing n_cycles.')
- psd = np.zeros((n_channels, n_frequencies, n_times))
- itc = np.zeros((n_channels, n_frequencies, n_times))
- parallel, my_time_frequency, _ = parallel_func(_time_frequency,
- n_jobs)
- for m in range(n_taps):
- psd_itc = parallel(my_time_frequency(data[:, c, :], Ws[m], use_fft,
- decim)
- for c in range(n_channels))
- for c, (psd_c, itc_c) in enumerate(psd_itc):
- psd[c, :, :] += psd_c
- itc[c, :, :] += itc_c
- psd /= n_taps
- itc /= n_taps
- return psd, itc
-
-
- at verbose
-def tfr_multitaper(inst, freqs, n_cycles, time_bandwidth=4.0,
- use_fft=True, return_itc=True, decim=1,
- n_jobs=1, picks=None, verbose=None):
- """Compute Time-Frequency Representation (TFR) using DPSS wavelets
-
- Parameters
- ----------
- inst : Epochs | Evoked
- The epochs or evoked object.
- freqs : ndarray, shape (n_freqs,)
- The frequencies in Hz.
- n_cycles : float | ndarray, shape (n_freqs,)
- The number of cycles globally or for each frequency.
- The time-window length is thus T = n_cycles / freq.
- time_bandwidth : float, (optional)
- Time x (Full) Bandwidth product. Should be >= 2.0.
- Choose this along with n_cycles to get desired frequency resolution.
- The number of good tapers (least leakage from far away frequencies)
- is chosen automatically based on this to floor(time_bandwidth - 1).
- Default is 4.0 (3 good tapers).
- E.g., With freq = 20 Hz and n_cycles = 10, we get time = 0.5 s.
- If time_bandwidth = 4., then frequency smoothing is (4 / time) = 8 Hz.
- use_fft : bool
- The fft based convolution or not.
- Defaults to True.
- return_itc : bool
- Return intertrial coherence (ITC) as well as averaged power.
- Defaults to True.
- decim : int | slice
- To reduce memory usage, decimation factor after time-frequency
- decomposition.
- If `int`, returns tfr[..., ::decim].
- If `slice` returns tfr[..., decim].
- Note that decimation may create aliasing artifacts.
- Defaults to 1.
- n_jobs : int
- The number of jobs to run in parallel. Defaults to 1.
- picks : array-like of int | None
- The indices of the channels to plot. If None all available
- channels are displayed.
- verbose : bool, str, int, or None
- If not None, override default verbose level (see mne.verbose).
-
- Returns
- -------
- power : AverageTFR
- The averaged power.
- itc : AverageTFR
- The intertrial coherence (ITC). Only returned if return_itc
- is True.
-
- See Also
- --------
- tfr_multitaper, tfr_stockwell
-
- Notes
- -----
- .. versionadded:: 0.9.0
- """
- decim = _check_decim(decim)
- data = _get_data(inst, return_itc)
- info = inst.info
-
- info, data, picks = _prepare_picks(info, data, picks)
- data = data = data[:, picks, :]
-
- power, itc = _induced_power_mtm(data, sfreq=info['sfreq'],
- frequencies=freqs, n_cycles=n_cycles,
- time_bandwidth=time_bandwidth,
- use_fft=use_fft, decim=decim,
- n_jobs=n_jobs, zero_mean=True,
- verbose='INFO')
- times = inst.times[decim].copy()
- nave = len(data)
- out = AverageTFR(info, power, times, freqs, nave,
- method='mutlitaper-power')
- if return_itc:
- out = (out, AverageTFR(info, itc, times, freqs, nave,
- method='mutlitaper-itc'))
- return out
-
-
-def combine_tfr(all_tfr, weights='nave'):
- """Merge AverageTFR data by weighted addition
+def combine_tfr(all_tfr, weights='nave'):
+ """Merge AverageTFR data by weighted addition.
Create a new AverageTFR instance, using a combination of the supplied
instances as its data. By default, the mean (weighted by trials) is used.
@@ -1469,12 +1628,93 @@ def combine_tfr(all_tfr, weights='nave'):
for t_ in all_tfr[1:])))
tfr.info['bads'] = bads
+ # XXX : should be refactored with combined_evoked function
tfr.data = sum(w * t_.data for w, t_ in zip(weights, all_tfr))
tfr.nave = max(int(1. / sum(w ** 2 / e.nave
for w, e in zip(weights, all_tfr))), 1)
return tfr
+# Utils
+
+
+def _get_data(inst, return_itc):
+ """Get data from Epochs or Evoked instance as epochs x ch x time"""
+ from ..epochs import _BaseEpochs
+ from ..evoked import Evoked
+ if not isinstance(inst, (_BaseEpochs, Evoked)):
+ raise TypeError('inst must be Epochs or Evoked')
+ if isinstance(inst, _BaseEpochs):
+ data = inst.get_data()
+ else:
+ if return_itc:
+ raise ValueError('return_itc must be False for evoked data')
+ data = inst.data[np.newaxis, ...].copy()
+ return data
+
+
+def _prepare_picks(info, data, picks):
+ if picks is None:
+ picks = pick_types(info, meg=True, eeg=True, ref_meg=False,
+ exclude='bads')
+ if np.array_equal(picks, np.arange(len(data))):
+ picks = slice(None)
+ else:
+ info = pick_info(info, picks)
+
+ return info, data, picks
+
+
+def _centered(arr, newsize):
+ """Aux Function to center data"""
+ # Return the center newsize portion of the array.
+ newsize = np.asarray(newsize)
+ currsize = np.array(arr.shape)
+ startind = (currsize - newsize) // 2
+ endind = startind + newsize
+ myslice = [slice(startind[k], endind[k]) for k in range(len(endind))]
+ return arr[tuple(myslice)]
+
+
+def _preproc_tfr(data, times, freqs, tmin, tmax, fmin, fmax, mode,
+ baseline, vmin, vmax, dB, sfreq):
+ """Aux Function to prepare tfr computation"""
+ from ..viz.utils import _setup_vmin_vmax
+
+ copy = baseline is not None
+ data = rescale(data, times, baseline, mode, copy=copy)
+
+ # crop time
+ itmin, itmax = None, None
+ idx = np.where(_time_mask(times, tmin, tmax, sfreq=sfreq))[0]
+ if tmin is not None:
+ itmin = idx[0]
+ if tmax is not None:
+ itmax = idx[-1] + 1
+
+ times = times[itmin:itmax]
+
+ # crop freqs
+ ifmin, ifmax = None, None
+ idx = np.where(_time_mask(freqs, fmin, fmax, sfreq=sfreq))[0]
+ if fmin is not None:
+ ifmin = idx[0]
+ if fmax is not None:
+ ifmax = idx[-1] + 1
+
+ freqs = freqs[ifmin:ifmax]
+
+ # crop data
+ data = data[:, ifmin:ifmax, itmin:itmax]
+
+ times *= 1e3
+ if dB:
+ data = 10 * np.log10((data * data.conj()).real)
+
+ vmin, vmax = _setup_vmin_vmax(data, vmin, vmax)
+ return data, times, freqs, vmin, vmax
+
+
def _check_decim(decim):
""" aux function checking the decim parameter """
if isinstance(decim, int):
@@ -1483,3 +1723,92 @@ def _check_decim(decim):
raise(TypeError, '`decim` must be int or slice, got %s instead'
% type(decim))
return decim
+
+
+# i/o
+
+
+def write_tfrs(fname, tfr, overwrite=False):
+ """Write a TFR dataset to hdf5.
+
+ Parameters
+ ----------
+ fname : string
+ The file name, which should end with -tfr.h5
+ tfr : AverageTFR instance, or list of AverageTFR instances
+ The TFR dataset, or list of TFR datasets, to save in one file.
+ Note. If .comment is not None, a name will be generated on the fly,
+ based on the order in which the TFR objects are passed
+ overwrite : bool
+ If True, overwrite file (if it exists). Defaults to False.
+
+ See Also
+ --------
+ read_tfrs
+
+ Notes
+ -----
+ .. versionadded:: 0.9.0
+ """
+ out = []
+ if not isinstance(tfr, (list, tuple)):
+ tfr = [tfr]
+ for ii, tfr_ in enumerate(tfr):
+ comment = ii if tfr_.comment is None else tfr_.comment
+ out.append(_prepare_write_tfr(tfr_, condition=comment))
+ write_hdf5(fname, out, overwrite=overwrite, title='mnepython')
+
+
+def _prepare_write_tfr(tfr, condition):
+ """Aux function"""
+ return (condition, dict(times=tfr.times, freqs=tfr.freqs,
+ data=tfr.data, info=tfr.info,
+ nave=tfr.nave, comment=tfr.comment,
+ method=tfr.method))
+
+
+def read_tfrs(fname, condition=None):
+ """
+ Read TFR datasets from hdf5 file.
+
+ Parameters
+ ----------
+ fname : string
+ The file name, which should end with -tfr.h5 .
+ condition : int or str | list of int or str | None
+ The condition to load. If None, all conditions will be returned.
+ Defaults to None.
+
+ See Also
+ --------
+ write_tfrs
+
+ Returns
+ -------
+ tfrs : list of instances of AverageTFR | instance of AverageTFR
+ Depending on `condition` either the TFR object or a list of multiple
+ TFR objects.
+
+ Notes
+ -----
+ .. versionadded:: 0.9.0
+ """
+
+ check_fname(fname, 'tfr', ('-tfr.h5',))
+
+ logger.info('Reading %s ...' % fname)
+ tfr_data = read_hdf5(fname, title='mnepython')
+ for k, tfr in tfr_data:
+ tfr['info'] = Info(tfr['info'])
+
+ if condition is not None:
+ tfr_dict = dict(tfr_data)
+ if condition not in tfr_dict:
+ keys = ['%s' % k for k in tfr_dict]
+ raise ValueError('Cannot find condition ("{0}") in this file. '
+ 'The file contains "{1}""'
+ .format(condition, " or ".join(keys)))
+ out = AverageTFR(**tfr_dict[condition])
+ else:
+ out = [AverageTFR(**d) for d in list(zip(*tfr_data))[1]]
+ return out
diff --git a/mne/transforms.py b/mne/transforms.py
index 9437cf5..30c5529 100644
--- a/mne/transforms.py
+++ b/mne/transforms.py
@@ -22,8 +22,6 @@ from .externals.six import string_types
# right/anterior/superior:
als_ras_trans = np.array([[0, -1, 0, 0], [1, 0, 0, 0], [0, 0, 1, 0],
[0, 0, 0, 1]])
-# simultaneously convert [m] to [mm]:
-als_ras_trans_mm = als_ras_trans * [0.001, 0.001, 0.001, 1]
_str_to_frame = dict(meg=FIFF.FIFFV_COORD_DEVICE,
@@ -104,6 +102,16 @@ class Transform(dict):
def to_str(self):
return _coord_frame_name(self['to'])
+ def save(self, fname):
+ """Save the transform as -trans.fif file
+
+ Parameters
+ ----------
+ fname : str
+ The name of the file, which should end in '-trans.fif'.
+ """
+ write_trans(fname, self)
+
def _coord_frame_name(cframe):
"""Map integers to human-readable (verbose) names"""
diff --git a/mne/utils.py b/mne/utils.py
index 380cbc8..d8773a4 100644
--- a/mne/utils.py
+++ b/mne/utils.py
@@ -9,6 +9,8 @@ from __future__ import print_function
import atexit
from distutils.version import LooseVersion
from functools import wraps
+import ftplib
+from functools import partial
import hashlib
import inspect
import json
@@ -24,8 +26,8 @@ import subprocess
import sys
import tempfile
import time
+import traceback
import warnings
-import ftplib
import numpy as np
from scipy import linalg, sparse
@@ -34,7 +36,7 @@ from .externals.six.moves import urllib
from .externals.six import string_types, StringIO, BytesIO
from .externals.decorator import decorator
-from .fixes import _get_args, partial
+from .fixes import _get_args
logger = logging.getLogger('mne') # one selection here used across mne-python
logger.propagate = False # don't propagate (in case of multiple imports)
@@ -59,24 +61,39 @@ def nottest(f):
return f
+# # # WARNING # # #
+# This list must also be updated in doc/_templates/class.rst if it is
+# changed here!
+_doc_special_members = ('__contains__', '__getitem__', '__iter__', '__len__',
+ '__call__', '__add__', '__sub__', '__mul__', '__div__',
+ '__neg__', '__hash__')
+
###############################################################################
# RANDOM UTILITIES
-def _check_copy_dep(inst, copy, kind='inst', default=False):
- """Check for copy deprecation for 0.13 and 0.14"""
+def _explain_exception(start=-1, stop=None, prefix='> '):
+ """Explain an exception."""
+ # start=-1 means "only the most recent caller"
+ etype, value, tb = sys.exc_info()
+ string = traceback.format_list(traceback.extract_tb(tb)[start:stop])
+ string = (''.join(string).split('\n') +
+ traceback.format_exception_only(etype, value))
+ string = ':\n' + prefix + ('\n' + prefix).join(string)
+ return string
+
+
+def _check_copy_dep(inst, copy, kind='inst'):
+ """Check for copy deprecation for 0.14"""
# For methods with copy=False default, we only need one release cycle
# for deprecation (0.13). For copy=True, we first need to go to copy=False
# (one cycle; 0.13) then remove copy altogether (one cycle; 0.14).
- if copy or (copy is None and default is True):
- remove_version = '0.14' if default is True else '0.13'
- warn('The copy parameter is deprecated and will be removed in %s. '
- 'In 0.13 the behavior will be to operate in-place '
- '(like copy=False). In 0.12 the default is copy=%s. '
- 'Use %s.copy() if necessary.'
- % (remove_version, default, kind), DeprecationWarning)
- if copy is None:
- copy = default
+ if copy:
+ warn('The copy parameter is deprecated and will be removed in 0.14. '
+ 'In 0.13 the default is copy=False. Use %s.copy() if necessary.'
+ % (kind,), DeprecationWarning)
+ elif copy is None:
+ copy = False
return inst.copy() if copy else inst
@@ -146,6 +163,46 @@ def object_hash(x, h=None):
return int(h.hexdigest(), 16)
+def object_size(x):
+ """Estimate the size of a reasonable python object
+
+ Parameters
+ ----------
+ x : object
+ Object to approximate the size of.
+ Can be anything comprised of nested versions of:
+ {dict, list, tuple, ndarray, str, bytes, float, int, None}.
+
+ Returns
+ -------
+ size : int
+ The estimated size in bytes of the object.
+ """
+ # Note: this will not process object arrays properly (since those only)
+ # hold references
+ if isinstance(x, (bytes, string_types, int, float, type(None))):
+ size = sys.getsizeof(x)
+ elif isinstance(x, np.ndarray):
+ # On newer versions of NumPy, just doing sys.getsizeof(x) works,
+ # but on older ones you always get something small :(
+ size = sys.getsizeof(np.array([])) + x.nbytes
+ elif isinstance(x, np.generic):
+ size = x.nbytes
+ elif isinstance(x, dict):
+ size = sys.getsizeof(x)
+ for key, value in x.items():
+ size += object_size(key)
+ size += object_size(value)
+ elif isinstance(x, (list, tuple)):
+ size = sys.getsizeof(x) + sum(object_size(xx) for xx in x)
+ elif sparse.isspmatrix_csc(x) or sparse.isspmatrix_csr(x):
+ size = sum(sys.getsizeof(xx)
+ for xx in [x, x.data, x.indices, x.indptr])
+ else:
+ raise RuntimeError('unsupported type: %s (%s)' % (type(x), x))
+ return size
+
+
def object_diff(a, b, pre=''):
"""Compute all differences between two python variables
@@ -287,24 +344,27 @@ def warn(message, category=RuntimeWarning):
"""
import mne
root_dir = op.dirname(mne.__file__)
- stacklevel = 1
frame = None
stack = inspect.stack()
last_fname = ''
for fi, frame in enumerate(stack):
- fname = frame[1]
- del frame
+ fname, lineno = frame[1:3]
if fname == '<string>' and last_fname == 'utils.py': # in verbose dec
last_fname = fname
continue
# treat tests as scripts
- if not fname.startswith(root_dir) or \
+ # and don't capture unittest/case.py (assert_raises)
+ if not (fname.startswith(root_dir) or
+ ('unittest' in fname and 'case' in fname)) or \
op.basename(op.dirname(fname)) == 'tests':
- stacklevel = fi + 1
break
last_fname = op.basename(fname)
- del stack
- warnings.warn(message, category, stacklevel=stacklevel)
+ if logger.level <= logging.WARN:
+ # We need to use this instead of warn(message, category, stacklevel)
+ # because we move out of the MNE stack, so warnings won't properly
+ # recognize the module name (and our warnings.simplefilter will fail)
+ warnings.warn_explicit(message, category, fname, lineno,
+ 'mne', globals().get('__warningregistry__', {}))
logger.warning(message)
@@ -468,6 +528,26 @@ def _reject_data_segments(data, reject, flat, decim, info, tstep):
return data, drop_inds
+def _get_inst_data(inst):
+ """get data from MNE object instance like Raw, Epochs or Evoked.
+ Returns a view, not a copy!"""
+ from .io.base import _BaseRaw
+ from .epochs import _BaseEpochs
+ from . import Evoked
+ from .time_frequency.tfr import _BaseTFR
+
+ if isinstance(inst, (_BaseRaw, _BaseEpochs)):
+ if not inst.preload:
+ inst.load_data()
+ return inst._data
+ elif isinstance(inst, (Evoked, _BaseTFR)):
+ return inst.data
+ else:
+ raise TypeError('The argument must be an instance of Raw, Epochs, '
+ 'Evoked, EpochsTFR or AverageTFR, got {0}.'.format(
+ type(inst)))
+
+
class _FormatDict(dict):
"""Helper for pformat()"""
def __missing__(self, key):
@@ -703,6 +783,22 @@ def requires_nibabel(vox2ras_tkr=False):
'Requires nibabel%s' % extra)
+def buggy_mkl_svd(function):
+ """Decorator for tests that make calls to SVD and intermittently fail"""
+ @wraps(function)
+ def dec(*args, **kwargs):
+ try:
+ return function(*args, **kwargs)
+ except np.linalg.LinAlgError as exp:
+ if 'SVD did not converge' in str(exp):
+ from nose.plugins.skip import SkipTest
+ msg = 'Intel MKL SVD convergence error detected, skipping test'
+ warn(msg)
+ raise SkipTest(msg)
+ raise
+ return dec
+
+
def requires_version(library, min_version):
"""Helper for testing"""
return np.testing.dec.skipif(not check_version(library, min_version),
@@ -710,8 +806,9 @@ def requires_version(library, min_version):
% (library, min_version))
-def requires_module(function, name, call):
+def requires_module(function, name, call=None):
"""Decorator to skip test if package is not available"""
+ call = ('import %s' % name) if call is None else call
try:
from nose.plugins.skip import SkipTest
except ImportError:
@@ -731,6 +828,177 @@ def requires_module(function, name, call):
return dec
+def copy_doc(source):
+ """Decorator to copy the docstring from another function.
+
+ The docstring of the source function is prepepended to the docstring of the
+ function wrapped by this decorator.
+
+ This is useful when inheriting from a class and overloading a method. This
+ decorator can be used to copy the docstring of the original method.
+
+ Parameters
+ ----------
+ source : function
+ Function to copy the docstring from
+
+ Returns
+ -------
+ wrapper : function
+ The decorated function
+
+ Examples
+ --------
+ >>> class A:
+ ... def m1():
+ ... '''Docstring for m1'''
+ ... pass
+ >>> class B (A):
+ ... @copy_doc(A.m1)
+ ... def m1():
+ ... ''' this gets appended'''
+ ... pass
+ >>> print(B.m1.__doc__)
+ Docstring for m1 this gets appended
+ """
+ def wrapper(func):
+ if source.__doc__ is None or len(source.__doc__) == 0:
+ raise ValueError('Cannot copy docstring: docstring was empty.')
+ doc = source.__doc__
+ if func.__doc__ is not None:
+ doc += func.__doc__
+ func.__doc__ = doc
+ return func
+ return wrapper
+
+
+def copy_function_doc_to_method_doc(source):
+ """Use the docstring from a function as docstring for a method.
+
+ The docstring of the source function is prepepended to the docstring of the
+ function wrapped by this decorator. Additionally, the first parameter
+ specified in the docstring of the source function is removed in the new
+ docstring.
+
+ This decorator is useful when implementing a method that just calls a
+ function. This pattern is prevalent in for example the plotting functions
+ of MNE.
+
+ Parameters
+ ----------
+ source : function
+ Function to copy the docstring from
+
+ Returns
+ -------
+ wrapper : function
+ The decorated method
+
+ Examples
+ --------
+ >>> def plot_function(object, a, b):
+ ... '''Docstring for plotting function.
+ ...
+ ... Parameters
+ ... ----------
+ ... object : instance of object
+ ... The object to plot
+ ... a : int
+ ... Some parameter
+ ... b : int
+ ... Some parameter
+ ... '''
+ ... pass
+ ...
+ >>> class A:
+ ... @copy_function_doc_to_method_doc(plot_function)
+ ... def plot(self, a, b):
+ ... '''
+ ... Notes
+ ... -----
+ ... .. versionadded:: 0.13.0
+ ... '''
+ ... plot_function(self, a, b)
+ >>> print(A.plot.__doc__)
+ Docstring for plotting function.
+ <BLANKLINE>
+ Parameters
+ ----------
+ a : int
+ Some parameter
+ b : int
+ Some parameter
+ <BLANKLINE>
+ Notes
+ -----
+ .. versionadded:: 0.13.0
+ <BLANKLINE>
+
+ Notes
+ -----
+ The parsing performed is very basic and will break easily on docstrings
+ that are not formatted exactly according to the ``numpydoc`` standard.
+ Always inspect the resulting docstring when using this decorator.
+ """
+ def wrapper(func):
+ doc = source.__doc__.split('\n')
+
+ # Find parameter block
+ for line, text in enumerate(doc[:-2]):
+ if (text.strip() == 'Parameters' and
+ doc[line + 1].strip() == '----------'):
+ parameter_block = line
+ break
+ else:
+ # No parameter block found
+ raise ValueError('Cannot copy function docstring: no parameter '
+ 'block found. To simply copy the docstring, use '
+ 'the @copy_doc decorator instead.')
+
+ # Find first parameter
+ for line, text in enumerate(doc[parameter_block:], parameter_block):
+ if ':' in text:
+ first_parameter = line
+ parameter_indentation = len(text) - len(text.lstrip(' '))
+ break
+ else:
+ raise ValueError('Cannot copy function docstring: no parameters '
+ 'found. To simply copy the docstring, use the '
+ '@copy_doc decorator instead.')
+
+ # Find end of first parameter
+ for line, text in enumerate(doc[first_parameter + 1:],
+ first_parameter + 1):
+ # Ignore empty lines
+ if len(text.strip()) == 0:
+ continue
+
+ line_indentation = len(text) - len(text.lstrip(' '))
+ if line_indentation <= parameter_indentation:
+ # Reach end of first parameter
+ first_parameter_end = line
+
+ # Of only one parameter is defined, remove the Parameters
+ # heading as well
+ if ':' not in text:
+ first_parameter = parameter_block
+
+ break
+ else:
+ # End of docstring reached
+ first_parameter_end = line
+ first_parameter = parameter_block
+
+ # Copy the docstring, but remove the first parameter
+ doc = ('\n'.join(doc[:first_parameter]) + '\n' +
+ '\n'.join(doc[first_parameter_end:]))
+ if func.__doc__ is not None:
+ doc += func.__doc__
+ func.__doc__ = doc
+ return func
+ return wrapper
+
+
_pandas_call = """
import pandas
version = LooseVersion(pandas.__version__)
@@ -794,10 +1062,8 @@ requires_fs_or_nibabel = partial(requires_module, name='nibabel or Freesurfer',
requires_tvtk = partial(requires_module, name='TVTK',
call='from tvtk.api import tvtk')
-requires_statsmodels = partial(requires_module, name='statsmodels',
- call='import statsmodels')
-requires_patsy = partial(requires_module, name='patsy',
- call='import patsy')
+requires_statsmodels = partial(requires_module, name='statsmodels')
+requires_patsy = partial(requires_module, name='patsy')
requires_pysurfer = partial(requires_module, name='PySurfer',
call='from surfer import Brain')
requires_PIL = partial(requires_module, name='PIL',
@@ -810,11 +1076,10 @@ requires_ftp = partial(
requires_module, name='ftp downloading capability',
call='if int(os.environ.get("MNE_SKIP_FTP_TESTS", 0)):\n'
' raise ImportError')
-requires_nitime = partial(requires_module, name='nitime',
- call='import nitime')
-requires_traits = partial(requires_module, name='traits',
- call='import traits')
-requires_h5py = partial(requires_module, name='h5py', call='import h5py')
+requires_nitime = partial(requires_module, name='nitime')
+requires_traits = partial(requires_module, name='traits')
+requires_h5py = partial(requires_module, name='h5py')
+requires_numpydoc = partial(requires_module, name='numpydoc')
def check_version(library, min_version):
@@ -891,7 +1156,7 @@ def run_subprocess(command, verbose=None, *args, **kwargs):
Parameters
----------
- command : list of str
+ command : list of str | str
Command to run as subprocess (see subprocess.Popen documentation).
verbose : bool, str, int, or None
If not None, override default verbose level (see mne.verbose).
@@ -925,12 +1190,19 @@ def run_subprocess(command, verbose=None, *args, **kwargs):
'starting with a tilde ("~") character. Such paths are not '
'interpreted correctly from within Python. It is recommended '
'that you use "$HOME" instead of "~".')
-
- logger.info("Running subprocess: %s" % ' '.join(command))
+ if isinstance(command, string_types):
+ command_str = command
+ else:
+ command_str = ' '.join(command)
+ logger.info("Running subprocess: %s" % command_str)
try:
p = subprocess.Popen(command, *args, **kwargs)
except Exception:
- logger.error('Command not found: %s' % (command[0],))
+ if isinstance(command, string_types):
+ command_name = command.split()[0]
+ else:
+ command_name = command[0]
+ logger.error('Command not found: %s' % command_name)
raise
stdout_, stderr = p.communicate()
stdout_ = '' if stdout_ is None else stdout_.decode('utf-8')
@@ -1094,6 +1366,8 @@ def _get_extra_data_path(home_dir=None):
"""Get path to extra data (config, tables, etc.)"""
global _temp_home_dir
if home_dir is None:
+ home_dir = os.environ.get('_MNE_FAKE_HOME_DIR')
+ if home_dir is None:
# this has been checked on OSX64, Linux64, and Win32
if 'nt' == os.name.lower():
home_dir = os.getenv('APPDATA')
@@ -1158,7 +1432,7 @@ def set_cache_dir(cache_dir):
if cache_dir is not None and not op.exists(cache_dir):
raise IOError('Directory %s does not exist' % cache_dir)
- set_config('MNE_CACHE_DIR', cache_dir)
+ set_config('MNE_CACHE_DIR', cache_dir, set_env=False)
def set_memmap_min_size(memmap_min_size):
@@ -1178,7 +1452,7 @@ def set_memmap_min_size(memmap_min_size):
raise ValueError('The size has to be given in kilo-, mega-, or '
'gigabytes, e.g., 100K, 500M, 1G.')
- set_config('MNE_MEMMAP_MIN_SIZE', memmap_min_size)
+ set_config('MNE_MEMMAP_MIN_SIZE', memmap_min_size, set_env=False)
# List the known configuration values
@@ -1193,6 +1467,7 @@ known_config_types = (
'MNE_DATASETS_MISC_PATH',
'MNE_DATASETS_SAMPLE_PATH',
'MNE_DATASETS_SOMATO_PATH',
+ 'MNE_DATASETS_MULTIMODAL_PATH',
'MNE_DATASETS_SPM_FACE_DATASETS_TESTS',
'MNE_DATASETS_SPM_FACE_PATH',
'MNE_DATASETS_TESTING_PATH',
@@ -1204,6 +1479,7 @@ known_config_types = (
'MNE_SKIP_TESTING_DATASET_TESTS',
'MNE_STIM_CHANNEL',
'MNE_USE_CUDA',
+ 'MNE_SKIP_FS_FLASH_CALL',
'SUBJECTS_DIR',
)
@@ -1213,6 +1489,22 @@ known_config_wildcards = (
)
+def _load_config(config_path, raise_error=False):
+ """Helper to safely load a config file"""
+ with open(config_path, 'r') as fid:
+ try:
+ config = json.load(fid)
+ except ValueError:
+ # No JSON object could be decoded --> corrupt file?
+ msg = ('The MNE-Python config file (%s) is not a valid JSON '
+ 'file and might be corrupted' % config_path)
+ if raise_error:
+ raise RuntimeError(msg)
+ warn(msg)
+ config = dict()
+ return config
+
+
def get_config(key=None, default=None, raise_error=False, home_dir=None):
"""Read mne(-python) preference from env, then mne-python config
@@ -1254,16 +1546,14 @@ def get_config(key=None, default=None, raise_error=False, home_dir=None):
key_found = False
val = default
else:
- with open(config_path, 'r') as fid:
- config = json.load(fid)
- if key is None:
- return config
+ config = _load_config(config_path)
+ if key is None:
+ return config
key_found = key in config
val = config.get(key, default)
-
if not key_found and raise_error is True:
meth_1 = 'os.environ["%s"] = VALUE' % key
- meth_2 = 'mne.utils.set_config("%s", VALUE)' % key
+ meth_2 = 'mne.utils.set_config("%s", VALUE, set_env=True)' % key
raise KeyError('Key "%s" not found in environment or in the '
'mne-python config file: %s '
'Try either:'
@@ -1275,7 +1565,7 @@ def get_config(key=None, default=None, raise_error=False, home_dir=None):
return val
-def set_config(key, value, home_dir=None):
+def set_config(key, value, home_dir=None, set_env=None):
"""Set mne-python preference in config
Parameters
@@ -1289,6 +1579,9 @@ def set_config(key, value, home_dir=None):
home_dir : str | None
The folder that contains the .mne config folder.
If None, it is found automatically.
+ set_env : bool
+ If True, update :data:`os.environ` in addition to updating the
+ MNE-Python config file.
See Also
--------
@@ -1305,20 +1598,28 @@ def set_config(key, value, home_dir=None):
if key not in known_config_types and not \
any(k in key for k in known_config_wildcards):
warn('Setting non-standard config type: "%s"' % key)
+ if set_env is None:
+ warnings.warn('set_env defaults to False in 0.13 but will change '
+ 'to True in 0.14, set it explicitly to avoid this '
+ 'warning', DeprecationWarning)
+ set_env = False
# Read all previous values
config_path = get_config_path(home_dir=home_dir)
if op.isfile(config_path):
- with open(config_path, 'r') as fid:
- config = json.load(fid)
+ config = _load_config(config_path, raise_error=True)
else:
config = dict()
logger.info('Attempting to create new mne-python configuration '
'file:\n%s' % config_path)
if value is None:
config.pop(key, None)
+ if set_env and key in os.environ:
+ del os.environ[key]
else:
config[key] = value
+ if set_env:
+ os.environ[key] = value
# Write all values. This may fail if the default directory is not
# writeable.
@@ -1643,6 +1944,44 @@ def sizeof_fmt(num):
return '1 byte'
+class SizeMixin(object):
+ """Class to estimate MNE object sizes"""
+ @property
+ def _size(self):
+ """Estimate of the object size"""
+ try:
+ size = object_size(self.info)
+ except Exception:
+ warn('Could not get size for self.info')
+ return -1
+ if hasattr(self, 'data'):
+ size += object_size(self.data)
+ elif hasattr(self, '_data'):
+ size += object_size(self._data)
+ return size
+
+ def __hash__(self):
+ """Hash the object
+
+ Returns
+ -------
+ hash : int
+ The hash
+ """
+ from .evoked import Evoked
+ from .epochs import _BaseEpochs
+ from .io.base import _BaseRaw
+ if isinstance(self, Evoked):
+ return object_hash(dict(info=self.info, data=self.data))
+ elif isinstance(self, (_BaseEpochs, _BaseRaw)):
+ if not self.preload:
+ raise RuntimeError('Cannot hash %s unless data are loaded'
+ % self.__class__.__name__)
+ return object_hash(dict(info=self.info, data=self._data))
+ else:
+ raise RuntimeError('Hashing unknown object type: %s' % type(self))
+
+
def _url_to_local_path(url, path):
"""Mirror a url path in a local destination (keeping folder structure)"""
destination = urllib.parse.urlparse(url).path
@@ -1654,7 +1993,7 @@ def _url_to_local_path(url, path):
return destination
-def _get_stim_channel(stim_channel, info):
+def _get_stim_channel(stim_channel, info, raise_error=True):
"""Helper to determine the appropriate stim_channel
First, 'MNE_STIM_CHANNEL', 'MNE_STIM_CHANNEL_1', 'MNE_STIM_CHANNEL_2', etc.
@@ -1692,17 +2031,19 @@ def _get_stim_channel(stim_channel, info):
if ch_count > 0:
return stim_channel
- if 'STI 014' in info['ch_names']:
+ if 'STI101' in info['ch_names']: # combination channel for newer systems
+ return ['STI101']
+ if 'STI 014' in info['ch_names']: # for older systems
return ['STI 014']
from .io.pick import pick_types
stim_channel = pick_types(info, meg=False, ref_meg=False, stim=True)
if len(stim_channel) > 0:
stim_channel = [info['ch_names'][ch_] for ch_ in stim_channel]
- return stim_channel
-
- raise ValueError("No stim channels found. Consider specifying them "
- "manually using the 'stim_channel' parameter.")
+ elif raise_error:
+ raise ValueError("No stim channels found. Consider specifying them "
+ "manually using the 'stim_channel' parameter.")
+ return stim_channel
def _check_fname(fname, overwrite=False, must_exist=False):
@@ -1790,24 +2131,6 @@ def _clean_names(names, remove_whitespace=False, before_dash=True):
return cleaned
-def clean_warning_registry():
- """Safe way to reset warnings """
- warnings.resetwarnings()
- reg = "__warningregistry__"
- bad_names = ['MovedModule'] # this is in six.py, and causes bad things
- for mod in list(sys.modules.values()):
- if mod.__class__.__name__ not in bad_names and hasattr(mod, reg):
- getattr(mod, reg).clear()
- # hack to deal with old scipy/numpy in tests
- if os.getenv('TRAVIS') == 'true' and sys.version.startswith('2.6'):
- warnings.simplefilter('default')
- try:
- np.rank([])
- except Exception:
- pass
- warnings.simplefilter('always')
-
-
def _check_type_picks(picks):
"""helper to guarantee type integrity of picks"""
err_msg = 'picks must be None, a list or an array of integers'
@@ -2091,8 +2414,7 @@ def grand_average(all_inst, interpolate_bads=True, drop_bads=True):
from .evoked import Evoked
from .time_frequency import AverageTFR
from .channels.channels import equalize_channels
- if not any([(all(isinstance(inst, t) for inst in all_inst)
- for t in (Evoked, AverageTFR))]):
+ if not all(isinstance(inst, (Evoked, AverageTFR)) for inst in all_inst):
raise ValueError("Not all input elements are Evoked or AverageTFR")
# Copy channels to leave the original evoked datasets intact.
@@ -2105,7 +2427,7 @@ def grand_average(all_inst, interpolate_bads=True, drop_bads=True):
else inst for inst in all_inst]
equalize_channels(all_inst) # apply equalize_channels
from .evoked import combine_evoked as combine
- elif isinstance(all_inst[0], AverageTFR):
+ else: # isinstance(all_inst[0], AverageTFR):
from .time_frequency.tfr import combine_tfr as combine
if drop_bads:
@@ -2211,3 +2533,16 @@ def sys_info(fid=None, show_paths=False):
extra = ' {%s}%s' % (libs, extra)
out += '%s%s\n' % (version, extra)
print(out, end='', file=fid)
+
+
+class ETSContext(object):
+ """Add more meaningful message to errors generated by ETS Toolkit"""
+ def __enter__(self):
+ pass
+
+ def __exit__(self, type, value, traceback):
+ if isinstance(value, SystemExit) and value.code.\
+ startswith("This program needs access to the screen"):
+ value.code += ("\nThis can probably be solved by setting "
+ "ETS_TOOLKIT=qt4. On bash, type\n\n $ export "
+ "ETS_TOOLKIT=qt4\n\nand run the command again.")
diff --git a/mne/viz/_3d.py b/mne/viz/_3d.py
index 4068918..42cfd43 100644
--- a/mne/viz/_3d.py
+++ b/mne/viz/_3d.py
@@ -12,6 +12,7 @@ from __future__ import print_function
# License: Simplified BSD
import base64
+from distutils.version import LooseVersion
from itertools import cycle
import os.path as op
import warnings
@@ -29,8 +30,6 @@ from ..transforms import (read_trans, _find_trans, apply_trans,
combine_transforms, _get_trans, _ensure_trans,
invert_transform, Transform)
from ..utils import get_subjects_dir, logger, _check_subject, verbose, warn
-from ..fixes import _get_args
-from ..defaults import _handle_default
from .utils import mne_analyze_colormap, _prepare_trellis, COLORS, plt_show
from ..externals.six import BytesIO
@@ -361,6 +360,7 @@ def plot_trans(info, trans='auto', subject=None, subjects_dir=None,
# determine points
meg_rrs, meg_tris = list(), list()
+ hpi_loc = list()
ext_loc = list()
car_loc = list()
eeg_loc = list()
@@ -408,15 +408,19 @@ def plot_trans(info, trans='auto', subject=None, subjects_dir=None,
meg_rrs = np.concatenate(meg_rrs, axis=0)
meg_tris = np.concatenate(meg_tris, axis=0)
if dig:
+ hpi_loc = np.array([d['r'] for d in info['dig']
+ if d['kind'] == FIFF.FIFFV_POINT_HPI])
ext_loc = np.array([d['r'] for d in info['dig']
if d['kind'] == FIFF.FIFFV_POINT_EXTRA])
car_loc = np.array([d['r'] for d in info['dig']
if d['kind'] == FIFF.FIFFV_POINT_CARDINAL])
if coord_frame == 'meg':
t = invert_transform(info['dev_head_t'])
+ hpi_loc = apply_trans(t, hpi_loc)
ext_loc = apply_trans(t, ext_loc)
car_loc = apply_trans(t, car_loc)
elif coord_frame == 'mri':
+ hpi_loc = apply_trans(head_mri_t, hpi_loc)
ext_loc = apply_trans(head_mri_t, ext_loc)
car_loc = apply_trans(head_mri_t, car_loc)
if len(car_loc) == len(ext_loc) == 0:
@@ -441,10 +445,10 @@ def plot_trans(info, trans='auto', subject=None, subjects_dir=None,
mesh.data.cell_data.normals = None
mlab.pipeline.surface(mesh, color=colors[key], opacity=alphas[key])
- datas = (eeg_loc, car_loc, ext_loc)
- colors = ((1., 0., 0.), (1., 1., 0.), (1., 0.5, 0.))
- alphas = (1.0, 0.5, 0.25)
- scales = (0.005, 0.015, 0.0075)
+ datas = (eeg_loc, hpi_loc, car_loc, ext_loc)
+ colors = ((1., 0., 0.), (0., 1., 0.), (1., 1., 0.), (1., 0.5, 0.))
+ alphas = (1.0, 0.5, 0.5, 0.25)
+ scales = (0.005, 0.015, 0.015, 0.0075)
for data, color, alpha, scale in zip(datas, colors, alphas, scales):
if len(data) > 0:
with warnings.catch_warnings(record=True): # traits
@@ -593,11 +597,13 @@ def _limits_to_control_points(clim, stc_data, colormap):
def plot_source_estimates(stc, subject=None, surface='inflated', hemi='lh',
- colormap='auto', time_label='time=%0.2f ms',
+ colormap='auto', time_label='auto',
smoothing_steps=10, transparent=None, alpha=1.0,
time_viewer=False, config_opts=None,
subjects_dir=None, figure=None, views='lat',
- colorbar=True, clim='auto'):
+ colorbar=True, clim='auto', cortex="classic",
+ size=800, background="black", foreground="white",
+ initial_time=None, time_unit=None):
"""Plot SourceEstimates with PySurfer
Note: PySurfer currently needs the SUBJECTS_DIR environment variable,
@@ -625,8 +631,10 @@ def plot_source_estimates(stc, subject=None, surface='inflated', hemi='lh',
be (n x 3) or (n x 4) array for with RGB or RGBA values between
0 and 255. If 'auto', either 'hot' or 'mne' will be chosen
based on whether 'lims' or 'pos_lims' are specified in `clim`.
- time_label : str
- How to print info about the time instant visualized.
+ time_label : str | callable | None
+ Format of the time label (a format string, a function that maps
+ floating point time values to strings, or None for no label). The
+ default is ``time=%0.2f ms``.
smoothing_steps : int
The amount of smoothing
transparent : bool | None
@@ -637,8 +645,7 @@ def plot_source_estimates(stc, subject=None, surface='inflated', hemi='lh',
time_viewer : bool
Display time viewer GUI.
config_opts : dict
- Keyword arguments for Brain initialization.
- See pysurfer.viz.Brain.
+ Deprecated parameter.
subjects_dir : str
The path to the freesurfer subjects reconstructions.
It corresponds to Freesurfer environment variable SUBJECTS_DIR.
@@ -666,21 +673,73 @@ def plot_source_estimates(stc, subject=None, surface='inflated', hemi='lh',
will be mirrored directly across zero during colormap
construction to obtain negative control points.
+ cortex : str or tuple
+ specifies how binarized curvature values are rendered.
+ either the name of a preset PySurfer cortex colorscheme (one of
+ 'classic', 'bone', 'low_contrast', or 'high_contrast'), or the
+ name of mayavi colormap, or a tuple with values (colormap, min,
+ max, reverse) to fully specify the curvature colors.
+ size : float or pair of floats
+ The size of the window, in pixels. can be one number to specify
+ a square window, or the (width, height) of a rectangular window.
+ background : matplotlib color
+ Color of the background of the display window.
+ foreground : matplotlib color
+ Color of the foreground of the display window.
+ initial_time : float | None
+ The time to display on the plot initially. ``None`` to display the
+ first time sample (default).
+ time_unit : 's' | 'ms'
+ Whether time is represented in seconds (expected by PySurfer) or
+ milliseconds. The current default is 'ms', but will change to 's'
+ in MNE 0.14. To avoid a deprecation warning specify ``time_unit``
+ explicitly.
+
Returns
-------
brain : Brain
A instance of surfer.viz.Brain from PySurfer.
"""
+ import surfer
from surfer import Brain, TimeViewer
- config_opts = _handle_default('config_opts', config_opts)
-
import mayavi
- from mayavi import mlab
# import here to avoid circular import problem
from ..source_estimate import SourceEstimate
+ surfer_version = LooseVersion(surfer.__version__)
+ v06 = LooseVersion('0.6')
+ if surfer_version < v06:
+ raise ImportError("This function requires PySurfer 0.6 (you are "
+ "running version %s). You can update PySurfer "
+ "using:\n\n $ pip install -U pysurfer" %
+ surfer.__version__)
+
+ if time_unit is None:
+ if initial_time is not None:
+ warn("The time_unit parameter default will change from 'ms' to "
+ "'s' in MNE 0.14 and be removed in 0.15. To avoid this "
+ "warning specify the parameter explicitly.",
+ DeprecationWarning)
+ time_unit = 'ms'
+ elif time_unit not in ('s', 'ms'):
+ raise ValueError("time_unit needs to be 's' or 'ms', got %r" %
+ (time_unit,))
+
+ if initial_time is not None and surfer_version > v06:
+ kwargs = {'initial_time': initial_time}
+ initial_time = None # don't set it twice
+ else:
+ kwargs = {}
+
+ if time_label == 'auto':
+ if time_unit == 'ms':
+ time_label = 'time=%0.2f ms'
+ else:
+ def time_label(t):
+ return 'time=%0.2f ms' % (t * 1e3)
+
if not isinstance(stc, SourceEstimate):
raise ValueError('stc has to be a surface source estimate')
@@ -688,23 +747,16 @@ def plot_source_estimates(stc, subject=None, surface='inflated', hemi='lh',
raise ValueError('hemi has to be either "lh", "rh", "split", '
'or "both"')
- n_split = 2 if hemi == 'split' else 1
- n_views = 1 if isinstance(views, string_types) else len(views)
+ # check `figure` parameter (This will be performed by PySurfer > 0.6)
if figure is not None:
- # use figure with specified id or create new figure
if isinstance(figure, int):
- figure = mlab.figure(figure, size=(600, 600))
- # make sure it is of the correct type
- if not isinstance(figure, list):
+ # use figure with specified id
+ size_ = size if isinstance(size, (tuple, list)) else (size, size)
+ figure = [mayavi.mlab.figure(figure, size=size_)]
+ elif not isinstance(figure, (list, tuple)):
figure = [figure]
if not all(isinstance(f, mayavi.core.scene.Scene) for f in figure):
raise TypeError('figure must be a mayavi scene or list of scenes')
- # make sure we have the right number of figures
- n_fig = len(figure)
- if not n_fig == n_split * n_views:
- raise RuntimeError('`figure` must be a list with the same '
- 'number of elements as PySurfer plots that '
- 'will be created (%s)' % n_split * n_views)
# convert control points to locations in colormap
ctrl_pts, colormap = _limits_to_control_points(clim, stc.data, colormap)
@@ -728,13 +780,18 @@ def plot_source_estimates(stc, subject=None, surface='inflated', hemi='lh',
hemis = [hemi]
title = subject if len(hemis) > 1 else '%s - %s' % (subject, hemis[0])
- args = _get_args(Brain.__init__)
- kwargs = dict(title=title, figure=figure, config_opts=config_opts,
- subjects_dir=subjects_dir)
- if 'views' in args:
- kwargs['views'] = views
with warnings.catch_warnings(record=True): # traits warnings
- brain = Brain(subject, hemi, surface, **kwargs)
+ brain = Brain(subject, hemi=hemi, surf=surface, curv=True,
+ title=title, cortex=cortex, size=size,
+ background=background, foreground=foreground,
+ figure=figure, subjects_dir=subjects_dir,
+ views=views, config_opts=config_opts)
+
+ if time_unit == 's':
+ times = stc.times
+ else: # time_unit == 'ms'
+ times = 1e3 * stc.times
+
for hemi in hemis:
hemi_idx = 0 if hemi == 'lh' else 1
if hemi_idx == 0:
@@ -742,17 +799,18 @@ def plot_source_estimates(stc, subject=None, surface='inflated', hemi='lh',
else:
data = stc.data[len(stc.vertices[0]):]
vertices = stc.vertices[hemi_idx]
- time = 1e3 * stc.times
with warnings.catch_warnings(record=True): # traits warnings
brain.add_data(data, colormap=colormap, vertices=vertices,
- smoothing_steps=smoothing_steps, time=time,
+ smoothing_steps=smoothing_steps, time=times,
time_label=time_label, alpha=alpha, hemi=hemi,
- colorbar=colorbar)
+ colorbar=colorbar, **kwargs)
# scale colormap and set time (index) to display
brain.scale_data_colormap(fmin=scale_pts[0], fmid=scale_pts[1],
fmax=scale_pts[2], transparent=transparent)
+ if initial_time is not None:
+ brain.set_time(initial_time)
if time_viewer:
TimeViewer(brain)
return brain
diff --git a/mne/viz/__init__.py b/mne/viz/__init__.py
index f0b4c32..e8cb319 100644
--- a/mne/viz/__init__.py
+++ b/mne/viz/__init__.py
@@ -13,12 +13,12 @@ from .misc import (plot_cov, plot_bem, plot_events, plot_source_spectrogram,
_get_presser, plot_dipole_amplitudes)
from .evoked import (plot_evoked, plot_evoked_image, plot_evoked_white,
plot_snr_estimate, plot_evoked_topo,
- plot_evoked_joint)
+ plot_evoked_joint, plot_compare_evokeds)
from .circle import plot_connectivity_circle, circular_layout
from .epochs import (plot_drop_log, plot_epochs, plot_epochs_psd,
plot_epochs_image)
from .raw import plot_raw, plot_raw_psd, plot_raw_psd_topo
-from .ica import plot_ica_scores, plot_ica_sources, plot_ica_overlay
-from .ica import _plot_sources_raw, _plot_sources_epochs
+from .ica import (plot_ica_scores, plot_ica_sources, plot_ica_overlay,
+ _plot_sources_raw, _plot_sources_epochs, plot_ica_properties)
from .montage import plot_montage
from .decoding import plot_gat_matrix, plot_gat_times
diff --git a/mne/viz/circle.py b/mne/viz/circle.py
index ffa5e73..f7362b4 100644
--- a/mne/viz/circle.py
+++ b/mne/viz/circle.py
@@ -16,7 +16,6 @@ import numpy as np
from .utils import plt_show
from ..externals.six import string_types
-from ..fixes import tril_indices, normalize_colors
def circular_layout(node_names, node_order, start_pos=90, start_between=True,
@@ -253,7 +252,7 @@ def plot_connectivity_circle(con, node_names, indices=None, n_lines=None,
if con.shape[0] != n_nodes or con.shape[1] != n_nodes:
raise ValueError('con has to be 1D or a square matrix')
# we use the lower-triangular part
- indices = tril_indices(n_nodes, -1)
+ indices = np.tril_indices(n_nodes, -1)
con = con[indices]
else:
raise ValueError('con has to be 1D or a square matrix')
@@ -392,8 +391,8 @@ def plot_connectivity_circle(con, node_names, indices=None, n_lines=None,
axes=axes)
if colorbar:
- norm = normalize_colors(vmin=vmin, vmax=vmax)
- sm = plt.cm.ScalarMappable(cmap=colormap, norm=norm)
+ sm = plt.cm.ScalarMappable(cmap=colormap,
+ norm=plt.Normalize(vmin, vmax))
sm.set_array(np.linspace(vmin, vmax))
cb = plt.colorbar(sm, ax=axes, use_gridspec=False,
shrink=colorbar_size,
diff --git a/mne/viz/epochs.py b/mne/viz/epochs.py
index 2585458..be4ffde 100644
--- a/mne/viz/epochs.py
+++ b/mne/viz/epochs.py
@@ -9,6 +9,7 @@
#
# License: Simplified BSD
+from collections import Counter
from functools import partial
import copy
@@ -17,19 +18,18 @@ import numpy as np
from ..utils import verbose, get_config, set_config, logger, warn
from ..io.pick import pick_types, channel_type
from ..io.proj import setup_proj
-from ..fixes import Counter, _in1d
from ..time_frequency import psd_multitaper
from .utils import (tight_layout, figure_nobar, _toggle_proj, _toggle_options,
_layout_figure, _setup_vmin_vmax, _channels_changed,
_plot_raw_onscroll, _onclick_help, plt_show,
- _compute_scalings)
+ _compute_scalings, DraggableColorbar)
from ..defaults import _handle_default
def plot_epochs_image(epochs, picks=None, sigma=0., vmin=None,
vmax=None, colorbar=True, order=None, show=True,
units=None, scalings=None, cmap='RdBu_r',
- fig=None, overlay_times=None):
+ fig=None, axes=None, overlay_times=None):
"""Plot Event Related Potential / Fields image
Parameters
@@ -65,12 +65,24 @@ def plot_epochs_image(epochs, picks=None, sigma=0., vmin=None,
The scalings of the channel types to be applied for plotting.
If None, defaults to `scalings=dict(eeg=1e6, grad=1e13, mag=1e15,
eog=1e6)`.
- cmap : matplotlib colormap
- Colormap.
+ cmap : matplotlib colormap | (colormap, bool) | 'interactive'
+ Colormap. If tuple, the first value indicates the colormap to use and
+ the second value is a boolean defining interactivity. In interactive
+ mode the colors are adjustable by clicking and dragging the colorbar
+ with left and right mouse button. Left mouse button moves the scale up
+ and down and right mouse button adjusts the range. Hitting space bar
+ resets the scale. Up and down arrows can be used to change the
+ colormap. If 'interactive', translates to ('RdBu_r', True). Defaults to
+ 'RdBu_r'.
fig : matplotlib figure | None
Figure instance to draw the image to. Figure must contain two axes for
drawing the single trials and evoked responses. If None a new figure is
created. Defaults to None.
+ axes : list of matplotlib axes | None
+ List of axes instances to draw the image, erp and colorbar to.
+ Must be of length three if colorbar is True (with the last list element
+ being the colorbar axes) or two if colorbar is False. If both fig and
+ axes are passed an error is raised. Defaults to None.
overlay_times : array-like, shape (n_epochs,) | None
If not None the parameter is interpreted as time instants in seconds
and is added to the image. It is typically useful to display reaction
@@ -95,18 +107,37 @@ def plot_epochs_image(epochs, picks=None, sigma=0., vmin=None,
raise ValueError('Scalings and units must have the same keys.')
picks = np.atleast_1d(picks)
- if fig is not None and len(picks) > 1:
+ if (fig is not None or axes is not None) and len(picks) > 1:
raise ValueError('Only single pick can be drawn to a figure.')
+ if axes is not None:
+ if fig is not None:
+ raise ValueError('Both figure and axes were passed, please'
+ 'decide between the two.')
+ from .utils import _validate_if_list_of_axes
+ oblig_len = 3 if colorbar else 2
+ _validate_if_list_of_axes(axes, obligatory_len=oblig_len)
+ ax1, ax2 = axes[:2]
+ # if axes were passed - we ignore fig param and get figure from axes
+ fig = ax1.get_figure()
+ if colorbar:
+ ax3 = axes[-1]
evoked = epochs.average(picks)
data = epochs.get_data()[:, picks, :]
+ n_epochs = len(data)
+ data = np.swapaxes(data, 0, 1)
+ if sigma > 0.:
+ for k in range(len(picks)):
+ data[k, :] = ndimage.gaussian_filter1d(
+ data[k, :], sigma=sigma, axis=0)
+
scale_vmin = True if vmin is None else False
scale_vmax = True if vmax is None else False
vmin, vmax = _setup_vmin_vmax(data, vmin, vmax)
- if overlay_times is not None and len(overlay_times) != len(data):
+ if overlay_times is not None and len(overlay_times) != n_epochs:
raise ValueError('size of overlay_times parameter (%s) do not '
'match the number of epochs (%s).'
- % (len(overlay_times), len(data)))
+ % (len(overlay_times), n_epochs))
if overlay_times is not None:
overlay_times = np.array(overlay_times)
@@ -118,7 +149,7 @@ def plot_epochs_image(epochs, picks=None, sigma=0., vmin=None,
% (epochs.tmin, epochs.tmax))
figs = list()
- for i, (this_data, idx) in enumerate(zip(np.swapaxes(data, 0, 1), picks)):
+ for i, (this_data, idx) in enumerate(zip(data, picks)):
if fig is None:
this_fig = plt.figure()
else:
@@ -150,26 +181,28 @@ def plot_epochs_image(epochs, picks=None, sigma=0., vmin=None,
if this_overlay_times is not None:
this_overlay_times = this_overlay_times[this_order]
- if sigma > 0.:
- this_data = ndimage.gaussian_filter1d(this_data, sigma=sigma,
- axis=0)
plt.figure(this_fig.number)
- ax1 = plt.subplot2grid((3, 10), (0, 0), colspan=9, rowspan=2)
- if scale_vmin:
- vmin *= scalings[ch_type]
- if scale_vmax:
- vmax *= scalings[ch_type]
+ if axes is None:
+ ax1 = plt.subplot2grid((3, 10), (0, 0), colspan=9, rowspan=2)
+ ax2 = plt.subplot2grid((3, 10), (2, 0), colspan=9, rowspan=1)
+ if colorbar:
+ ax3 = plt.subplot2grid((3, 10), (0, 9), colspan=1, rowspan=3)
+
+ this_vmin = vmin * scalings[ch_type] if scale_vmin else vmin
+ this_vmax = vmax * scalings[ch_type] if scale_vmax else vmax
+
+ if cmap == 'interactive':
+ cmap = ('RdBu_r', True)
+ elif not isinstance(cmap, tuple):
+ cmap = (cmap, True)
im = ax1.imshow(this_data,
extent=[1e3 * epochs.times[0], 1e3 * epochs.times[-1],
- 0, len(data)],
+ 0, n_epochs],
aspect='auto', origin='lower', interpolation='nearest',
- vmin=vmin, vmax=vmax, cmap=cmap)
+ vmin=this_vmin, vmax=this_vmax, cmap=cmap[0])
if this_overlay_times is not None:
plt.plot(1e3 * this_overlay_times, 0.5 + np.arange(len(this_data)),
'k', linewidth=2)
- ax2 = plt.subplot2grid((3, 10), (2, 0), colspan=9, rowspan=1)
- if colorbar:
- ax3 = plt.subplot2grid((3, 10), (0, 9), colspan=1, rowspan=3)
ax1.set_title(epochs.ch_names[idx])
ax1.set_ylabel('Epochs')
ax1.axis('auto')
@@ -188,10 +221,12 @@ def plot_epochs_image(epochs, picks=None, sigma=0., vmin=None,
ax2.set_ylim([evoked_vmin, evoked_vmax])
ax2.axvline(0, color='m', linewidth=3, linestyle='--')
if colorbar:
- plt.colorbar(im, cax=ax3)
+ cbar = plt.colorbar(im, cax=ax3)
+ if cmap[1]:
+ ax1.CB = DraggableColorbar(cbar, im)
tight_layout(fig=this_fig)
-
plt_show(show)
+
return figs
@@ -403,6 +438,8 @@ def plot_epochs(epochs, picks=None, scalings=None, n_epochs=20,
with f11 key. The amount of epochs and channels per view can be adjusted
with home/end and page down/page up keys. Butterfly plot can be toggled
with ``b`` key. Right mouse click adds a vertical line to the plot.
+
+ .. versionadded:: 0.10.0
"""
epochs.drop_bad()
scalings = _compute_scalings(scalings, epochs)
@@ -563,7 +600,15 @@ def _prepare_mne_browse_epochs(params, projs, n_channels, n_epochs, scalings,
idxs = pick_types(params['info'], meg=t, ref_meg=False, exclude=[])
if len(idxs) < 1:
continue
- mask = _in1d(idxs, picks, assume_unique=True)
+ mask = np.in1d(idxs, picks, assume_unique=True)
+ inds.append(idxs[mask])
+ types += [t] * len(inds[-1])
+ for t in ['hbo', 'hbr']:
+ idxs = pick_types(params['info'], meg=False, ref_meg=False, fnirs=t,
+ exclude=[])
+ if len(idxs) < 1:
+ continue
+ mask = np.in1d(idxs, picks, assume_unique=True)
inds.append(idxs[mask])
types += [t] * len(inds[-1])
pick_kwargs = dict(meg=False, ref_meg=False, exclude=[])
@@ -575,7 +620,7 @@ def _prepare_mne_browse_epochs(params, projs, n_channels, n_epochs, scalings,
idxs = pick_types(params['info'], **pick_kwargs)
if len(idxs) < 1:
continue
- mask = _in1d(idxs, picks, assume_unique=True)
+ mask = np.in1d(idxs, picks, assume_unique=True)
inds.append(idxs[mask])
types += [ch_type] * len(inds[-1])
pick_kwargs[ch_type] = False
@@ -949,7 +994,8 @@ def _handle_picks(epochs):
exclude=[])
else:
picks = pick_types(epochs.info, meg=True, eeg=True, eog=True, ecg=True,
- seeg=True, ecog=True, ref_meg=False, exclude=[])
+ seeg=True, ecog=True, ref_meg=False, fnirs=True,
+ exclude=[])
return picks
@@ -1358,7 +1404,7 @@ def _close_event(event, params):
def _resize_event(event, params):
"""Function to handle resize event"""
size = ','.join([str(s) for s in params['fig'].get_size_inches()])
- set_config('MNE_BROWSE_RAW_SIZE', size)
+ set_config('MNE_BROWSE_RAW_SIZE', size, set_env=False)
_layout_figure(params)
diff --git a/mne/viz/evoked.py b/mne/viz/evoked.py
index dec4e84..37a5c32 100644
--- a/mne/viz/evoked.py
+++ b/mne/viz/evoked.py
@@ -11,27 +11,30 @@ from __future__ import print_function
#
# License: Simplified BSD
+from functools import partial
+
import numpy as np
from ..io.pick import (channel_type, pick_types, _picks_by_type,
- _pick_data_channels)
+ _pick_data_channels, _DATA_CH_TYPES_SPLIT)
from ..externals.six import string_types
from ..defaults import _handle_default
from .utils import (_draw_proj_checkbox, tight_layout, _check_delayed_ssp,
- plt_show, _process_times)
+ plt_show, _process_times, DraggableColorbar)
from ..utils import logger, _clean_names, warn
-from ..fixes import partial
from ..io.pick import pick_info
from .topo import _plot_evoked_topo
from .topomap import (_prepare_topo_plot, plot_topomap, _check_outlines,
_draw_outlines, _prepare_topomap, _topomap_animation)
from ..channels import find_layout
+from ..channels.layout import (_pair_grad_sensors, generate_2d_layout,
+ _auto_topomap_coords)
def _butterfly_onpick(event, params):
"""Helper to add a channel name on click"""
params['need_draw'] = True
- ax = event.artist.get_axes()
+ ax = event.artist.axes
ax_idx = np.where([ax is a for a in params['axes']])[0]
if len(ax_idx) == 0: # this can happen if ax param is used
return # let the other axes handle it
@@ -47,6 +50,7 @@ def _butterfly_onpick(event, params):
text.set_text(ch_name)
text.set_color(event.artist.get_color())
text.set_alpha(1.)
+ text.set_zorder(len(ax.lines)) # to make sure it goes on top of the lines
text.set_path_effects(params['path_effects'])
# do NOT redraw here, since for butterfly plots hundreds of lines could
# potentially be picked -- use on_button_press (happens once per click)
@@ -70,7 +74,14 @@ def _butterfly_on_button_press(event, params):
def _butterfly_onselect(xmin, xmax, ch_types, evoked, text=None):
"""Function for drawing topomaps from the selected area."""
import matplotlib.pyplot as plt
- ch_types = [type for type in ch_types if type in ('eeg', 'grad', 'mag')]
+ ch_types = [type_ for type_ in ch_types if type_ in ('eeg', 'grad', 'mag')]
+ if ('grad' in ch_types and
+ len(_pair_grad_sensors(evoked.info, topomap_coords=False,
+ raise_error=False)) < 2):
+ ch_types.remove('grad')
+ if len(ch_types) == 0:
+ return
+
vert_lines = list()
if text is not None:
text.set_visible(True)
@@ -91,9 +102,8 @@ def _butterfly_onselect(xmin, xmax, ch_types, evoked, text=None):
fig, axarr = plt.subplots(1, len(ch_types), squeeze=False,
figsize=(3 * len(ch_types), 3))
for idx, ch_type in enumerate(ch_types):
- picks, pos, merge_grads, _, ch_type = _prepare_topo_plot(evoked,
- ch_type,
- layout=None)
+ picks, pos, merge_grads, _, ch_type = _prepare_topo_plot(
+ evoked, ch_type, layout=None)
data = evoked.data[picks, minidx:maxidx]
if merge_grads:
from ..channels.layout import _merge_grad_data
@@ -160,7 +170,7 @@ def _plot_evoked(evoked, picks, exclude, unit, show,
scalings, titles, axes, plot_type,
cmap=None, gfp=False, window_title=None,
spatial_colors=False, set_tight_layout=True,
- selectable=True):
+ selectable=True, zorder='unsorted'):
"""Aux function for plot_evoked and plot_evoked_image (cf. docstrings)
Extra param is:
@@ -181,13 +191,16 @@ def _plot_evoked(evoked, picks, exclude, unit, show,
' for interactive SSP selection.')
if isinstance(gfp, string_types) and gfp != 'only':
raise ValueError('gfp must be boolean or "only". Got %s' % gfp)
-
+ if cmap == 'interactive':
+ cmap = (None, True)
+ elif not isinstance(cmap, tuple):
+ cmap = (cmap, True)
scalings = _handle_default('scalings', scalings)
titles = _handle_default('titles', titles)
units = _handle_default('units', units)
# Valid data types ordered for consistency
valid_channel_types = ['eeg', 'grad', 'mag', 'seeg', 'eog', 'ecg', 'emg',
- 'dipole', 'gof', 'bio', 'ecog']
+ 'dipole', 'gof', 'bio', 'ecog', 'hbo', 'hbr']
if picks is None:
picks = list(range(info['nchan']))
@@ -290,31 +303,54 @@ def _plot_evoked(evoked, picks, exclude, unit, show,
else:
layout = find_layout(info, None, exclude=[])
# drop channels that are not in the data
-
used_nm = np.array(_clean_names(info['ch_names']))[idx]
names = np.asarray([name for name in used_nm
if name in layout.names])
name_idx = [layout.names.index(name) for name in names]
if len(name_idx) < len(chs):
warn('Could not find layout for all the channels. '
- 'Legend for spatial colors not drawn.')
- else:
- # find indices for bads
- bads = [np.where(names == bad)[0][0] for bad in
- info['bads'] if bad in names]
- pos, outlines = _check_outlines(layout.pos[:, :2],
- 'skirt', None)
- pos = pos[name_idx]
- _plot_legend(pos, colors, ax, bads, outlines)
+ 'Generating custom layout from channel '
+ 'positions.')
+ xy = _auto_topomap_coords(info, idx, True)
+ layout = generate_2d_layout(
+ xy[idx], ch_names=list(used_nm), name='custom')
+ names = used_nm
+ name_idx = [layout.names.index(name) for name in
+ names]
+
+ # find indices for bads
+ bads = [np.where(names == bad)[0][0] for bad in
+ info['bads'] if bad in names]
+ pos, outlines = _check_outlines(layout.pos[:, :2],
+ 'skirt', None)
+ pos = pos[name_idx]
+ _plot_legend(pos, colors, ax, bads, outlines)
else:
colors = ['k'] * len(idx)
for i in bad_ch_idx:
if i in idx:
colors[idx.index(i)] = 'r'
- for ch_idx in range(len(D)):
- line_list.append(ax.plot(times, D[ch_idx], picker=3.,
- zorder=1,
- color=colors[ch_idx])[0])
+
+ if zorder == 'std':
+ # find the channels with the least activity
+ # to map them in front of the more active ones
+ z_ord = D.std(axis=1).argsort()
+ elif zorder == 'unsorted':
+ z_ord = list(range(D.shape[0]))
+ elif not callable(zorder):
+ error = ('`zorder` must be a function, "std" '
+ 'or "unsorted", not {0}.')
+ raise TypeError(error.format(type(zorder)))
+ else:
+ z_ord = zorder(D)
+
+ # plot channels
+ for ch_idx, z in enumerate(z_ord):
+ line_list.append(
+ ax.plot(times, D[ch_idx], picker=3.,
+ zorder=z + 1 if spatial_colors else 1,
+ color=colors[ch_idx])[0])
+
if gfp: # 'only' or boolean True
gfp_color = 3 * (0.,) if spatial_colors else (0., 1., 0.)
this_gfp = np.sqrt((D * D).mean(axis=0))
@@ -349,9 +385,11 @@ def _plot_evoked(evoked, picks, exclude, unit, show,
elif plot_type == 'image':
im = ax.imshow(D, interpolation='nearest', origin='lower',
extent=[times[0], times[-1], 0, D.shape[0]],
- aspect='auto', cmap=cmap)
+ aspect='auto', cmap=cmap[0])
cbar = plt.colorbar(im, ax=ax)
cbar.ax.set_title(ch_unit)
+ if cmap[1]:
+ ax.CB = DraggableColorbar(cbar, im)
ax.set_ylabel('channels (%s)' % 'index')
else:
raise ValueError("plot_type has to be 'butterfly' or 'image'."
@@ -407,7 +445,8 @@ def _plot_evoked(evoked, picks, exclude, unit, show,
def plot_evoked(evoked, picks=None, exclude='bads', unit=True, show=True,
ylim=None, xlim='tight', proj=False, hline=None, units=None,
scalings=None, titles=None, axes=None, gfp=False,
- window_title=None, spatial_colors=False):
+ window_title=None, spatial_colors=False, zorder='unsorted',
+ selectable=True):
"""Plot evoked data using butteryfly plots
Left click to a line shows the channel name. Selecting an area by clicking
@@ -464,7 +503,27 @@ def plot_evoked(evoked, picks=None, exclude='bads', unit=True, show=True,
coordinates into color values. Spatially similar channels will have
similar colors. Bad channels will be dotted. If False, the good
channels are plotted black and bad channels red. Defaults to False.
-
+ zorder : str | callable
+ Which channels to put in the front or back. Only matters if
+ `spatial_colors` is used.
+ If str, must be `std` or `unsorted` (defaults to `unsorted`). If
+ `std`, data with the lowest standard deviation (weakest effects) will
+ be put in front so that they are not obscured by those with stronger
+ effects. If `unsorted`, channels are z-sorted as in the evoked
+ instance.
+ If callable, must take one argument: a numpy array of the same
+ dimensionality as the evoked raw data; and return a list of
+ unique integers corresponding to the number of channels.
+
+ .. versionadded:: 0.13.0
+
+ selectable : bool
+ Whether to use interactive features. If True (default), it is possible
+ to paint an area to draw topomaps. When False, the interactive features
+ are disabled. Disabling interactive features reduces memory consumption
+ and is useful when using ``axes`` parameter to draw multiaxes figures.
+
+ .. versionadded:: 0.13.0
Returns
-------
@@ -476,7 +535,8 @@ def plot_evoked(evoked, picks=None, exclude='bads', unit=True, show=True,
hline=hline, units=units, scalings=scalings,
titles=titles, axes=axes, plot_type="butterfly",
gfp=gfp, window_title=window_title,
- spatial_colors=spatial_colors, selectable=True)
+ spatial_colors=spatial_colors, zorder=zorder,
+ selectable=selectable)
def plot_evoked_topo(evoked, layout=None, layout_scale=0.945, color=None,
@@ -642,8 +702,15 @@ def plot_evoked_image(evoked, picks=None, exclude='bads', unit=True, show=True,
The axes to plot to. If list, the list must be a list of Axes of
the same length as the number of channel types. If instance of
Axes, there must be only one channel type plotted.
- cmap : matplotlib colormap
- Colormap.
+ cmap : matplotlib colormap | (colormap, bool) | 'interactive'
+ Colormap. If tuple, the first value indicates the colormap to use and
+ the second value is a boolean defining interactivity. In interactive
+ mode the colors are adjustable by clicking and dragging the colorbar
+ with left and right mouse button. Left mouse button moves the scale up
+ and down and right mouse button adjusts the range. Hitting space bar
+ resets the scale. Up and down arrows can be used to change the
+ colormap. If 'interactive', translates to ('RdBu_r', True). Defaults to
+ 'RdBu_r'.
Returns
-------
@@ -995,9 +1062,9 @@ def plot_evoked_joint(evoked, times="peaks", title='', picks=None,
ts_args : None | dict
A dict of `kwargs` that are forwarded to `evoked.plot` to
style the butterfly plot. `axes` and `show` are ignored.
- If `spatial_colors` is not in this dict, `spatial_colors=True`
- will be passed. Beyond that, if ``None``, no customizable arguments
- will be passed. Defaults to ``None``.
+ If `spatial_colors` is not in this dict, `spatial_colors=True`,
+ and (if it is not in the dict) `zorder='std'` will be passed.
+ Defaults to ``None``.
topomap_args : None | dict
A dict of `kwargs` that are forwarded to `evoked.plot_topomap`
to style the topomaps. `axes` and `show` are ignored. If `times`
@@ -1041,7 +1108,7 @@ def plot_evoked_joint(evoked, times="peaks", title='', picks=None,
evoked.drop_channels(exclude)
info = evoked.info
- data_types = ['eeg', 'grad', 'mag', 'seeg', 'ecog']
+ data_types = ['eeg', 'grad', 'mag', 'seeg', 'ecog', 'hbo', 'hbr']
ch_types = set(ch_type for ch_type in data_types if ch_type in evoked)
# if multiple sensor types: one plot per channel type, recursive call
@@ -1076,7 +1143,7 @@ def plot_evoked_joint(evoked, times="peaks", title='', picks=None,
ts_args_def = dict(picks=None, unit=True, ylim=None, xlim='tight',
proj=False, hline=None, units=None, scalings=None,
titles=None, gfp=False, window_title=None,
- spatial_colors=True)
+ spatial_colors=True, zorder='std')
for key in ts_args_def:
if key not in ts_args:
ts_args_pass[key] = ts_args_def[key]
@@ -1134,3 +1201,447 @@ def plot_evoked_joint(evoked, times="peaks", title='', picks=None,
# show and return it
plt_show(show)
return fig
+
+
+def _ci(arr, ci):
+ """Calculate the `ci`% parametric confidence interval for `arr`.
+ Aux function for plot_compare_evokeds."""
+ from scipy import stats
+ mean, sigma = arr.mean(0), stats.sem(arr, 0)
+ # This is highly convoluted to support 17th century Scipy
+ # XXX Fix when Scipy 0.12 support is dropped!
+ # then it becomes just:
+ # return stats.t.interval(ci, loc=mean, scale=sigma, df=arr.shape[0])
+ return np.asarray([stats.t.interval(ci, arr.shape[0],
+ loc=mean_, scale=sigma_)
+ for mean_, sigma_ in zip(mean, sigma)]).T
+
+
+def _setup_styles(conditions, style_dict, style, default):
+ """Aux function for plot_compare_evokeds to set linestyles and colors"""
+
+ # check user-supplied style to condition matching
+ tags = set([tag for cond in conditions for tag in cond.split("/")])
+ msg = ("Can't map between conditions and the provided {0}. Make sure "
+ "you have provided keys in the format of '/'-separated tags, "
+ "and that these correspond to '/'-separated tags for the condition "
+ "names (e.g., conditions like 'Visual/Right', and styles like "
+ "'colors=dict(Visual='red'))'. The offending tag was '{1}'.")
+ for key in style_dict:
+ for tag in key.split("/"):
+ if tag not in tags:
+ raise ValueError(msg.format(style, tag))
+
+ # check condition to style matching, and fill in defaults
+ condition_warning = "Condition {0} could not be mapped to a " + style
+ style_warning = ". Using the default of {0}.".format(default)
+ for condition in conditions:
+ if condition not in style_dict:
+ if "/" not in condition:
+ warn(condition_warning.format(condition) + style_warning)
+ style_dict[condition] = default
+ for style_ in style_dict:
+ if style_ in condition.split("/"):
+ style_dict[condition] = style_dict[style_]
+ break
+
+ return style_dict
+
+
+def _truncate_yaxis(axes, ymin, ymax, orig_ymin, orig_ymax, fraction,
+ any_positive, any_negative):
+ """Aux function for truncating the y axis in plot_compare_evokeds"""
+ abs_lims = (orig_ymax if orig_ymax > np.abs(orig_ymin)
+ else np.abs(orig_ymin))
+ ymin_, ymax_ = (-(abs_lims // fraction), abs_lims // fraction)
+ # user supplied ymin and ymax overwrite everything
+ if ymin is not None and ymin > ymin_:
+ ymin_ = ymin
+ if ymax is not None and ymax < ymax_:
+ ymax_ = ymax
+ yticks = (ymin_ if any_negative else 0, ymax_ if any_positive else 0)
+ axes.set_yticks(yticks)
+ ymin_bound, ymax_bound = (-(abs_lims // fraction), abs_lims // fraction)
+ # user supplied ymin and ymax still overwrite everything
+ if ymin is not None and ymin > ymin_bound:
+ ymin_bound = ymin
+ if ymax is not None and ymax < ymax_bound:
+ ymax_bound = ymax
+ precision = 0.25 # round to .25
+ if ymin is None:
+ ymin_bound = round(ymin_bound / precision) * precision
+ if ymin is None:
+ ymax_bound = round(ymax_bound / precision) * precision
+ axes.spines['left'].set_bounds(ymin_bound, ymax_bound)
+ return ymin_bound, ymax_bound
+
+
+def plot_compare_evokeds(evokeds, picks=list(), gfp=False, colors=None,
+ linestyles=['-'], styles=None, vlines=[0.], ci=0.95,
+ truncate_yaxis=True, ylim=dict(), invert_y=False,
+ axes=None, title=None, show=True):
+ """Plot evoked time courses for one or multiple channels and conditions
+
+ This function is useful for comparing ER[P/F]s at a specific location. It
+ plots Evoked data or, if supplied with a list/dict of lists of evoked
+ instances, grand averages plus confidence intervals.
+
+ Parameters
+ ----------
+ evokeds : instance of mne.Evoked | list | dict
+ If a single evoked instance, it is plotted as a time series.
+ If a dict whose values are Evoked objects, the contents are plotted as
+ single time series each and the keys are used as condition labels.
+ If a list of Evokeds, the contents are plotted with indices as labels.
+ If a [dict/list] of lists, the unweighted mean is plotted as a time
+ series and the parametric confidence interval is plotted as a shaded
+ area. All instances must have the same shape - channel numbers, time
+ points etc.
+ picks : int | list of int
+ If int or list of int, the indices of the sensors to average and plot.
+ Must all be of the same channel type.
+ If the selected channels are gradiometers, the corresponding pairs
+ will be selected.
+ If multiple channel types are selected, one figure will be returned for
+ each channel type.
+ If an empty list, `gfp` will be set to True, and the Global Field
+ Power plotted.
+ gfp : bool
+ If True, the channel type wise GFP is plotted.
+ If `picks` is an empty list (default), this is set to True.
+ colors : list | dict | None
+ If a list, will be sequentially used for line colors.
+ If a dict, can map evoked keys or '/'-separated (HED) tags to
+ conditions.
+ For example, if `evokeds` is a dict with the keys "Aud/L", "Aud/R",
+ "Vis/L", "Vis/R", `colors` can be `dict(Aud='r', Vis='b')` to map both
+ Aud/L and Aud/R to the color red and both Visual conditions to blue.
+ If None (default), a sequence of desaturated colors is used.
+ linestyles : list | dict
+ If a list, will be sequentially and repeatedly used for evoked plot
+ linestyles.
+ If a dict, can map the `evoked` keys or '/'-separated (HED) tags to
+ conditions.
+ For example, if evokeds is a dict with the keys "Aud/L", "Aud/R",
+ "Vis/L", "Vis/R", `linestyles` can be `dict(L='--', R='-')` to map both
+ Aud/L and Vis/L to dashed lines and both Right-side conditions to
+ straight lines.
+ styles : dict | None
+ If a dict, keys must map to evoked keys or conditions, and values must
+ be a dict of legal inputs to `matplotlib.pyplot.plot`. These
+ parameters will be passed to the line plot call of the corresponding
+ condition, overriding defaults.
+ E.g., if evokeds is a dict with the keys "Aud/L", "Aud/R",
+ "Vis/L", "Vis/R", `styles` can be `{"Aud/L":{"linewidth":1}}` to set
+ the linewidth for "Aud/L" to 1. Note that HED ('/'-separated) tags are
+ not supported.
+ vlines : list of int
+ A list of integers corresponding to the positions, in seconds,
+ at which to plot dashed vertical lines.
+ ci : float | None
+ If not None and `evokeds` is a [list/dict] of lists, a confidence
+ interval is drawn around the individual time series. This value
+ determines the CI width. E.g., if this value is .95 (the default),
+ the 95% parametric confidence interval is drawn.
+ If None, no shaded confidence band is plotted.
+ truncate_yaxis : bool
+ If True, the left y axis is truncated to half the max value and
+ rounded to .25 to reduce visual clutter. Defaults to True.
+ ylim : dict | None
+ ylim for plots (after scaling has been applied). e.g.
+ ylim = dict(eeg=[-20, 20])
+ Valid keys are eeg, mag, grad, misc. If None, the ylim parameter
+ for each channel equals the pyplot default.
+ invert_y : bool
+ If True, negative values are plotted up (as is sometimes done
+ for ERPs out of tradition). Defaults to False.
+ axes : None | `matplotlib.pyplot.axes` instance | list of `axes`
+ What axes to plot to. If None, a new axes is created.
+ When plotting multiple channel types, can also be a list of axes, one
+ per channel type.
+ title : None | str
+ If str, will be plotted as figure title. If None, the channel
+ names will be shown.
+ show : bool
+ If True, show the figure.
+
+ Returns
+ -------
+ fig : Figure | list of Figures
+ The figure(s) in which the plot is drawn.
+ """
+ import matplotlib.pyplot as plt
+ from ..evoked import Evoked, combine_evoked
+
+ # set up labels and instances
+ if isinstance(evokeds, Evoked):
+ evokeds = dict(Evoked=evokeds) # title becomes 'Evoked'
+ elif not isinstance(evokeds, dict):
+ evokeds = dict((str(ii + 1), evoked)
+ for ii, evoked in enumerate(evokeds))
+ conditions = sorted(list(evokeds.keys()))
+
+ # get and set a few limits and variables (times, channels, units)
+ example = (evokeds[conditions[0]]
+ if isinstance(evokeds[conditions[0]], Evoked)
+ else evokeds[conditions[0]][0])
+ if not isinstance(example, Evoked):
+ raise ValueError("evokeds must be an instance of mne.Evoked "
+ "or a collection of mne.Evoked's")
+ times = example.times
+ tmin, tmax = times[0], times[-1]
+
+ if isinstance(picks, int):
+ picks = [picks]
+ elif len(picks) == 0:
+ warn("No picks, plotting the GFP ...")
+ gfp = True
+ picks = _pick_data_channels(example.info)
+
+ # deal with picks: infer indices and names
+ if gfp is True:
+ ch_names = ['Global Field Power']
+ if len(picks) < 2:
+ raise ValueError("A GFP with less than 2 channels doesn't work, "
+ "please pick more channels.")
+ else:
+ if not isinstance(picks[0], int):
+ msg = "'picks' must be int or a list of int, not {0}."
+ raise ValueError(msg.format(type(picks)))
+ ch_names = [example.ch_names[pick] for pick in picks]
+ ch_types = list(set(channel_type(example.info, pick_)
+ for pick_ in picks))
+ # XXX: could possibly be refactored; plot_joint is doing a similar thing
+ if any([type_ not in _DATA_CH_TYPES_SPLIT for type_ in ch_types]):
+ raise ValueError("Non-data channel picked.")
+ if len(ch_types) > 1:
+ warn("Multiple channel types selected, returning one figure per type.")
+ if axes is not None:
+ from .utils import _validate_if_list_of_axes
+ _validate_if_list_of_axes(axes, obligatory_len=len(ch_types))
+ figs = list()
+ for ii, t in enumerate(ch_types):
+ picks_ = [idx for idx in picks
+ if channel_type(example.info, idx) == t]
+ title_ = "GFP, " + t if not title and gfp is True else title
+ ax_ = axes[ii] if axes is not None else None
+ figs.append(
+ plot_compare_evokeds(
+ evokeds, picks=picks_, gfp=gfp, colors=colors,
+ linestyles=linestyles, styles=styles, vlines=vlines, ci=ci,
+ truncate_yaxis=truncate_yaxis, ylim=ylim,
+ invert_y=invert_y, axes=ax_, title=title_, show=show))
+ return figs
+ else:
+ ch_type = ch_types[0]
+ ymin, ymax = ylim.get(ch_type, [None, None])
+
+ scaling = _handle_default("scalings")[ch_type]
+
+ if ch_type == 'grad' and gfp is not True: # deal with grad pairs
+ from ..channels.layout import _merge_grad_data, _pair_grad_sensors
+ picked_chans = list()
+ pairpicks = _pair_grad_sensors(example.info, topomap_coords=False)
+ for ii in np.arange(0, len(pairpicks), 2):
+ first, second = pairpicks[ii], pairpicks[ii + 1]
+ if first in picks or second in picks:
+ picked_chans.append(first)
+ picked_chans.append(second)
+ picks = list(sorted(set(picked_chans)))
+ ch_names = [example.ch_names[pick] for pick in picks]
+
+ if ymin is None and (gfp is True or ch_type == 'grad'):
+ ymin = 0 # 'grad' and GFP are plotted as all-positive
+
+ # deal with dict/list of lists and the CI
+ if not isinstance(ci, np.float):
+ msg = '"ci" must be float, got {0} instead.'
+ raise TypeError(msg.format(type(ci)))
+
+ # if we have a dict/list of lists, we compute the grand average and the CI
+ if not all([isinstance(evoked_, Evoked) for evoked_ in evokeds.values()]):
+ if ci is not None and gfp is not True:
+ # calculate the CI
+ sem_array = dict()
+ for condition in conditions:
+ # this will fail if evokeds do not have the same structure
+ # (e.g. channel count)
+ if ch_type == 'grad' and gfp is not True:
+ data = np.asarray([
+ _merge_grad_data(
+ evoked_.data[picks, :]).mean(0)
+ for evoked_ in evokeds[condition]])
+ else:
+ data = np.asarray([evoked_.data[picks, :].mean(0)
+ for evoked_ in evokeds[condition]])
+ sem_array[condition] = _ci(data, ci)
+
+ # get the grand mean
+ evokeds = dict((cond, combine_evoked(evokeds[cond], weights='equal'))
+ for cond in conditions)
+
+ if gfp is True and ci is not None:
+ warn("Confidence Interval not drawn when plotting GFP.")
+ else:
+ ci = False
+ combine_evoked(list(evokeds.values())) # check if they are compatible
+ # we now have dicts for data ('evokeds' - grand averaged Evoked's)
+ # and the CI ('sem_array') with cond name labels
+
+ # let's plot!
+ if axes is None:
+ fig, axes = plt.subplots(1, 1)
+ fig.set_size_inches(8, 6)
+ else:
+ fig = axes.figure
+
+ # style the individual condition time series
+ # Styles (especially color and linestyle) are pulled from a dict 'styles'.
+ # This dict has one entry per condition. Its color and linestyle entries
+ # are pulled from the 'colors' and 'linestyles' dicts via '/'-tag matching
+ # unless they are overwritten by entries from a user-provided 'styles'.
+
+ # first, check if input is valid
+ if isinstance(styles, dict):
+ for style_ in styles:
+ if style_ not in conditions:
+ raise ValueError("Could not map between 'styles' and "
+ "conditions. Condition " + style_ +
+ " was not found in the supplied data.")
+
+ # second, color
+ # check: is color a list?
+ if (colors is not None and not isinstance(colors, string_types) and
+ not isinstance(colors, dict) and len(colors) > 1):
+ colors = dict((condition, color) for condition, color
+ in zip(conditions, colors))
+
+ if not isinstance(colors, dict): # default colors from M Waskom's Seaborn
+ # XXX should put a good list of default colors into defaults.py
+ colors_ = ['#e41a1c', '#377eb8', '#4daf4a', '#984ea3', '#ff7f00',
+ '#1b9e77', '#d95f02', '#7570b3', '#e7298a', '#66a61e']
+ if len(conditions) > len(colors_):
+ msg = ("Trying to plot more than {0} conditions. We provide"
+ "only {0} default colors. Please supply colors manually.")
+ raise ValueError(msg.format(len(colors_)))
+ colors = dict((condition, color) for condition, color
+ in zip(conditions, colors_))
+ else:
+ colors = _setup_styles(conditions, colors, "color", "grey")
+
+ # third, linestyles
+ if not isinstance(linestyles, dict):
+ linestyles = dict((condition, linestyle) for condition, linestyle in
+ zip(conditions, ['-'] * len(conditions)))
+ else:
+ linestyles = _setup_styles(conditions, linestyles, "linestyle", "-")
+
+ # fourth, put it all together
+ if styles is None:
+ styles = dict()
+ for condition, color, linestyle in zip(conditions, colors, linestyles):
+ styles[condition] = styles.get(condition, dict())
+ styles[condition]['c'] = styles[condition].get('c', colors[condition])
+ styles[condition]['linestyle'] = styles[condition].get(
+ 'linestyle', linestyles[condition])
+ # We now have a 'styles' dict with one entry per condition, specifying at
+ # least color and linestyles.
+
+ # the actual plot
+ any_negative, any_positive = False, False
+ for condition in conditions:
+ # plot the actual data ('d') as a line
+ if ch_type == 'grad' and gfp is False:
+ d = ((_merge_grad_data(evokeds[condition]
+ .data[picks, :])).T * scaling).mean(-1)
+ else:
+ func = np.std if gfp is True else np.mean
+ d = func((evokeds[condition].data[picks, :].T * scaling), -1)
+ axes.plot(times, d, zorder=1000, label=condition, **styles[condition])
+ if any(d > 0):
+ any_positive = True
+ if any(d < 0):
+ any_negative = True
+
+ # plot the confidence interval (standard error of the mean/'sem_')
+ if ci and gfp is not True:
+ sem_ = sem_array[condition]
+ axes.fill_between(times, sem_[0].flatten() * scaling,
+ sem_[1].flatten() * scaling, zorder=100,
+ color=styles[condition]['c'], alpha=.333)
+
+ # truncate the y axis
+ orig_ymin, orig_ymax = axes.get_ylim()[0], axes.get_ylim()[-1]
+ if not any_positive:
+ orig_ymax = 0
+ if not any_negative:
+ orig_ymin = 0
+
+ axes.set_ylim(orig_ymin if ymin is None else ymin,
+ orig_ymax if ymax is None else ymax)
+
+ fraction = 2 if axes.get_ylim()[0] >= 0 else 3
+
+ if truncate_yaxis and ymin is not None and not (ymin > 0):
+ ymin_bound, ymax_bound = _truncate_yaxis(
+ axes, ymin, ymax, orig_ymin, orig_ymax, fraction,
+ any_positive, any_negative)
+ else:
+ if ymin is not None and ymin > 0:
+ warn("ymin is positive, not truncating yaxis")
+ ymax_bound = axes.get_ylim()[-1]
+ y_range = -np.subtract(*axes.get_ylim())
+
+ title = ", ".join(ch_names[:6]) if title is None else title
+ if len(ch_names) > 6 and gfp is False:
+ warn("More than 6 channels, truncating title ...")
+ title += ", ..."
+ axes.set_title(title)
+
+ # style the spines/axes
+ axes.spines["top"].set_position('zero')
+ axes.spines["top"].set_smart_bounds(True)
+
+ axes.tick_params(direction='out')
+ axes.tick_params(right="off")
+
+ current_ymin = axes.get_ylim()[0]
+
+ # plot v lines
+ if invert_y is True and current_ymin < 0:
+ upper_v, lower_v = -ymax_bound, axes.get_ylim()[-1]
+ else:
+ upper_v, lower_v = axes.get_ylim()[0], ymax_bound
+ axes.vlines(vlines, upper_v, lower_v, linestyles='--', colors='k',
+ linewidth=1., zorder=1)
+
+ # set x label
+ axes.set_xlabel('Time (s)')
+ axes.xaxis.get_label().set_verticalalignment('center')
+
+ # set y label and ylabel position
+ axes.set_ylabel(_handle_default("units")[ch_type], rotation=0)
+ ylabel_height = (-(current_ymin / y_range)
+ if 0 > current_ymin # ... if we have negative values
+ else (axes.get_yticks()[-1] / 2 / y_range))
+ axes.yaxis.set_label_coords(-0.05, 1 - ylabel_height
+ if invert_y else ylabel_height)
+ xticks = sorted(list(set([x for x in axes.get_xticks()] + vlines)))
+ axes.set_xticks(xticks)
+ axes.set_xticklabels(xticks)
+ x_extrema = [t for t in xticks if tmax >= t >= tmin]
+ axes.spines['bottom'].set_bounds(x_extrema[0], x_extrema[-1])
+ axes.spines["left"].set_zorder(0)
+
+ # finishing touches
+ if invert_y:
+ axes.invert_yaxis()
+ axes.patch.set_alpha(0)
+ axes.spines['right'].set_color('none')
+ axes.set_xlim(tmin, tmax)
+
+ if len(conditions) > 1:
+ plt.legend(loc='best', ncol=1 + (len(conditions) // 5), frameon=True)
+
+ plt_show(show)
+ return fig
diff --git a/mne/viz/ica.py b/mne/viz/ica.py
index 5de65a0..90cbb7a 100644
--- a/mne/viz/ica.py
+++ b/mne/viz/ica.py
@@ -15,43 +15,21 @@ import numpy as np
from .utils import (tight_layout, _prepare_trellis, _select_bads,
_layout_figure, _plot_raw_onscroll, _mouse_click,
_helper_raw_resize, _plot_raw_onkey, plt_show)
+from .topomap import (_prepare_topo_plot, plot_topomap, _hide_frame,
+ _plot_ica_topomap)
from .raw import _prepare_mne_browse_raw, _plot_raw_traces
-from .epochs import _prepare_mne_browse_epochs
+from .epochs import _prepare_mne_browse_epochs, plot_epochs_image
from .evoked import _butterfly_on_button_press, _butterfly_onpick
-from .topomap import _prepare_topo_plot, plot_topomap, _hide_frame
from ..utils import warn
from ..defaults import _handle_default
from ..io.meas_info import create_info
from ..io.pick import pick_types
from ..externals.six import string_types
-
-
-def _ica_plot_sources_onpick_(event, sources=None, ylims=None):
- """Onpick callback for plot_ica_panel"""
-
- # make sure that the swipe gesture in OS-X doesn't open many figures
- if event.mouseevent.inaxes is None or event.mouseevent.button != 1:
- return
-
- artist = event.artist
- try:
- import matplotlib.pyplot as plt
- plt.figure()
- src_idx = artist._mne_src_idx
- component = artist._mne_component
- plt.plot(sources[src_idx], 'r' if artist._mne_is_bad else 'k')
- plt.ylim(ylims)
- plt.grid(linestyle='-', color='gray', linewidth=.25)
- plt.title('ICA #%i' % component)
- except Exception as err:
- # matplotlib silently ignores exceptions in event handlers, so we print
- # it here to know what went wrong
- print(err)
- raise err
+from ..time_frequency.psd import psd_multitaper
def plot_ica_sources(ica, inst, picks=None, exclude=None, start=None,
- stop=None, show=True, title=None, block=False):
+ stop=None, title=None, show=True, block=False):
"""Plot estimated latent sources given the unmixing matrix.
Typical usecases:
@@ -78,10 +56,10 @@ def plot_ica_sources(ica, inst, picks=None, exclude=None, start=None,
stop : int
X-axis stop index. If None, next 20 are shown, in case of evoked to the
end.
- show : bool
- Show figure if True.
title : str | None
The figure title. If None a default is provided.
+ show : bool
+ Show figure if True.
block : bool
Whether to halt program execution until the figure is closed.
Useful for interactive selection of components in raw and epoch
@@ -118,9 +96,9 @@ def plot_ica_sources(ica, inst, picks=None, exclude=None, start=None,
stop=stop, show=show, title=title,
block=block)
elif isinstance(inst, Evoked):
- sources = ica.get_sources(inst)
if start is not None or stop is not None:
inst = inst.copy().crop(start, stop)
+ sources = ica.get_sources(inst)
fig = _plot_ica_sources_evoked(
evoked=sources, picks=picks, exclude=exclude, title=title,
labels=getattr(ica, 'labels_', None), show=show)
@@ -130,72 +108,253 @@ def plot_ica_sources(ica, inst, picks=None, exclude=None, start=None,
return fig
-def _plot_ica_grid(sources, start, stop,
- source_idx, ncol, exclude,
- title, show):
- """Create panel plots of ICA sources
-
- Clicking on the plot of an individual source opens a new figure showing
- the source.
+def _create_properties_layout(figsize=None):
+ """creates main figure and axes layout used by plot_ica_properties"""
+ import matplotlib.pyplot as plt
+ if figsize is None:
+ figsize = [7., 6.]
+ fig = plt.figure(figsize=figsize, facecolor=[0.95] * 3)
+ ax = list()
+ ax.append(fig.add_axes([0.08, 0.5, 0.3, 0.45], label='topomap'))
+ ax.append(fig.add_axes([0.5, 0.6, 0.45, 0.35], label='image'))
+ ax.append(fig.add_axes([0.5, 0.5, 0.45, 0.1], label='erp'))
+ ax.append(fig.add_axes([0.08, 0.1, 0.32, 0.3], label='spectrum'))
+ ax.append(fig.add_axes([0.5, 0.1, 0.45, 0.25], label='variance'))
+ return fig, ax
+
+
+def plot_ica_properties(ica, inst, picks=None, axes=None, dB=True,
+ plot_std=True, topomap_args=None, image_args=None,
+ psd_args=None, figsize=None, show=True):
+ """Display component properties: topography, epochs image, ERP/ERF,
+ power spectrum and epoch variance.
Parameters
----------
- sources : ndarray
- Sources as drawn from ica.get_sources.
- start : int
- x-axis start index. If None from the beginning.
- stop : int
- x-axis stop index. If None to the end.
- n_components : int
- Number of components fitted.
- source_idx : array-like
- Indices for subsetting the sources.
- ncol : int
- Number of panel-columns.
- title : str
- The figure title. If None a default is provided.
+ ica : instance of mne.preprocessing.ICA
+ The ICA solution.
+ inst: instance of Epochs or Raw
+ The data to use in plotting properties.
+ picks : int | array-like of int | None
+ The components to be displayed. If None, plot will show the first
+ five sources. If more than one components were chosen in the picks,
+ each one will be plotted in a separate figure. Defaults to None.
+ axes: list of matplotlib axes | None
+ List of five matplotlib axes to use in plotting: [topomap_axis,
+ image_axis, erp_axis, spectrum_axis, variance_axis]. If None a new
+ figure with relevant axes is created. Defaults to None.
+ dB: bool
+ Whether to plot spectrum in dB. Defaults to True.
+ plot_std: bool | float
+ Whether to plot standard deviation in ERP/ERF and spectrum plots.
+ Defaults to True, which plots one standard deviation above/below.
+ If set to float allows to control how many standard deviations are
+ plotted. For example 2.5 will plot 2.5 standard deviation above/below.
+ topomap_args : dict | None
+ Dictionary of arguments to ``plot_topomap``. If None, doesn't pass any
+ additional arguments. Defaults to None.
+ image_args : dict | None
+ Dictionary of arguments to ``plot_epochs_image``. If None, doesn't pass
+ any additional arguments. Defaults to None.
+ psd_args : dict | None
+ Dictionary of arguments to ``psd_multitaper``. If None, doesn't pass
+ any additional arguments. Defaults to None.
+ figsize : array-like of size (2,) | None
+ Allows to control size of the figure. If None, the figure size
+ defauls to [7., 6.].
show : bool
- If True, all open plots will be shown.
- """
- import matplotlib.pyplot as plt
+ Show figure if True.
- if source_idx is None:
- source_idx = np.arange(len(sources))
- elif isinstance(source_idx, list):
- source_idx = np.array(source_idx)
- if exclude is None:
- exclude = []
+ Returns
+ -------
+ fig : list
+ List of matplotlib figures.
+
+ Notes
+ -----
+ .. versionadded:: 0.13
+ """
+ from ..io.base import _BaseRaw
+ from ..epochs import _BaseEpochs
+ from ..preprocessing import ICA
+
+ if not isinstance(inst, (_BaseRaw, _BaseEpochs)):
+ raise ValueError('inst should be an instance of Raw or Epochs,'
+ ' got %s instead.' % type(inst))
+ if not isinstance(ica, ICA):
+ raise ValueError('ica has to be an instance of ICA, '
+ 'got %s instead' % type(ica))
+ if isinstance(plot_std, bool):
+ num_std = 1. if plot_std else 0.
+ elif isinstance(plot_std, (float, int)):
+ num_std = plot_std
+ plot_std = True
+ else:
+ raise ValueError('plot_std has to be a bool, int or float, '
+ 'got %s instead' % type(plot_std))
+
+ # if no picks given - plot the first 5 components
+ picks = list(range(min(5, ica.n_components_))) if picks is None else picks
+ picks = [picks] if isinstance(picks, int) else picks
+ if axes is None:
+ fig, axes = _create_properties_layout(figsize=figsize)
+ else:
+ if len(picks) > 1:
+ raise ValueError('Only a single pick can be drawn '
+ 'to a set of axes.')
+ from .utils import _validate_if_list_of_axes
+ _validate_if_list_of_axes(axes, obligatory_len=5)
+ fig = axes[0].get_figure()
+ psd_args = dict() if psd_args is None else psd_args
+ topomap_args = dict() if topomap_args is None else topomap_args
+ image_args = dict() if image_args is None else image_args
+ for d in (psd_args, topomap_args, image_args):
+ if not isinstance(d, dict):
+ raise ValueError('topomap_args, image_args and psd_args have to be'
+ ' dictionaries, got %s instead.' % type(d))
+ if dB is not None and isinstance(dB, bool) is False:
+ raise ValueError('dB should be bool, got %s instead' %
+ type(dB))
+
+ # calculations
+ # ------------
+ plot_line_at_zero = False
+ if isinstance(inst, _BaseRaw):
+ # break up continuous signal into segments
+ from ..epochs import _segment_raw
+ inst = _segment_raw(inst, segment_length=2., verbose=False,
+ preload=True)
+ if inst.times[0] < 0. and inst.times[-1] > 0.:
+ plot_line_at_zero = True
+
+ epochs_src = ica.get_sources(inst)
+ ica_data = np.swapaxes(epochs_src.get_data()[:, picks, :], 0, 1)
+
+ # spectrum
+ Nyquist = inst.info['sfreq'] / 2.
+ if 'fmax' not in psd_args:
+ psd_args['fmax'] = min(inst.info['lowpass'] * 1.25, Nyquist)
+ plot_lowpass_edge = inst.info['lowpass'] < Nyquist and (
+ psd_args['fmax'] > inst.info['lowpass'])
+ psds, freqs = psd_multitaper(epochs_src, picks=picks, **psd_args)
+
+ def set_title_and_labels(ax, title, xlab, ylab):
+ if title:
+ ax.set_title(title)
+ if xlab:
+ ax.set_xlabel(xlab)
+ if ylab:
+ ax.set_ylabel(ylab)
+ ax.axis('auto')
+ ax.tick_params('both', labelsize=8)
+ ax.axis('tight')
+
+ all_fig = list()
+ # the rest is component-specific
+ for idx, pick in enumerate(picks):
+ if idx > 0:
+ fig, axes = _create_properties_layout(figsize=figsize)
+
+ # spectrum
+ this_psd = psds[:, idx, :]
+ if dB:
+ this_psd = 10 * np.log10(this_psd)
+ psds_mean = this_psd.mean(axis=0)
+ diffs = this_psd - psds_mean
+ # the distribution of power for each frequency bin is highly
+ # skewed so we calculate std for values below and above average
+ # separately - this is used for fill_between shade
+ spectrum_std = [
+ [np.sqrt((d[d < 0] ** 2).mean(axis=0)) for d in diffs.T],
+ [np.sqrt((d[d > 0] ** 2).mean(axis=0)) for d in diffs.T]]
+ spectrum_std = np.array(spectrum_std) * num_std
+
+ # erp std
+ if plot_std:
+ erp = ica_data[idx].mean(axis=0)
+ diffs = ica_data[idx] - erp
+ erp_std = [
+ [np.sqrt((d[d < 0] ** 2).mean(axis=0)) for d in diffs.T],
+ [np.sqrt((d[d > 0] ** 2).mean(axis=0)) for d in diffs.T]]
+ erp_std = np.array(erp_std) * num_std
+
+ # epoch variance
+ epoch_var = np.var(ica_data[idx], axis=1)
+
+ # plotting
+ # --------
+ # component topomap
+ _plot_ica_topomap(ica, pick, show=False, axes=axes[0], **topomap_args)
+
+ # image and erp
+ plot_epochs_image(epochs_src, picks=pick, axes=axes[1:3],
+ colorbar=False, show=False, **image_args)
+
+ # spectrum
+ axes[3].plot(freqs, psds_mean, color='k')
+ if plot_std:
+ axes[3].fill_between(freqs, psds_mean - spectrum_std[0],
+ psds_mean + spectrum_std[1],
+ color='k', alpha=.15)
+ if plot_lowpass_edge:
+ axes[3].axvline(inst.info['lowpass'], lw=2, linestyle='--',
+ color='k', alpha=0.15)
+
+ # epoch variance
+ axes[4].scatter(range(len(epoch_var)), epoch_var, alpha=0.5,
+ facecolor=[0, 0, 0], lw=0)
+
+ # aesthetics
+ # ----------
+ axes[0].set_title('IC #{0:0>3}'.format(pick))
+
+ set_title_and_labels(axes[1], 'epochs image and ERP/ERF', [], 'Epochs')
+
+ # erp
+ set_title_and_labels(axes[2], [], 'time', 'AU')
+ # line color and std
+ axes[2].lines[0].set_color('k')
+ if plot_std:
+ erp_xdata = axes[2].lines[0].get_data()[0]
+ axes[2].fill_between(erp_xdata, erp - erp_std[0],
+ erp + erp_std[1], color='k', alpha=.15)
+ axes[2].autoscale(enable=True, axis='y')
+ axes[2].axis('auto')
+ axes[2].set_xlim(erp_xdata[[0, -1]])
+ # remove half of yticks if more than 5
+ yt = axes[2].get_yticks()
+ if len(yt) > 5:
+ yt = yt[::2]
+ axes[2].yaxis.set_ticks(yt)
+
+ if not plot_line_at_zero:
+ xlims = [1e3 * inst.times[0], 1e3 * inst.times[-1]]
+ for k, ax in enumerate(axes[1:3]):
+ ax.lines[k].remove()
+ ax.set_xlim(xlims)
+
+ # remove xticks - erp plot shows xticks for both image and erp plot
+ axes[1].xaxis.set_ticks([])
+ yt = axes[1].get_yticks()
+ axes[1].yaxis.set_ticks(yt[1:])
+ axes[1].set_ylim([-0.5, ica_data.shape[1] + 0.5])
+
+ # spectrum
+ ylabel = 'dB' if dB else 'power'
+ set_title_and_labels(axes[3], 'spectrum', 'frequency', ylabel)
+ axes[3].yaxis.labelpad = 0
+ axes[3].set_xlim(freqs[[0, -1]])
+ ylim = axes[3].get_ylim()
+ air = np.diff(ylim)[0] * 0.1
+ axes[3].set_ylim(ylim[0] - air, ylim[1] + air)
+
+ # epoch variance
+ set_title_and_labels(axes[4], 'epochs variance', 'epoch', 'AU')
+
+ all_fig.append(fig)
- n_components = len(sources)
- ylims = sources.min(), sources.max()
- xlims = np.arange(sources.shape[-1])[[0, -1]]
- fig, axes = _prepare_trellis(n_components, ncol)
- if title is None:
- fig.suptitle('Reconstructed latent sources', size=16)
- elif title:
- fig.suptitle(title, size=16)
-
- plt.subplots_adjust(wspace=0.05, hspace=0.05)
- my_iter = enumerate(zip(source_idx, axes, sources))
- for i_source, (i_selection, ax, source) in my_iter:
- component = '[%i]' % i_selection
- # plot+ emebed idx and comp. name to use in callback
- color = 'r' if i_selection in exclude else 'k'
- line = ax.plot(source, linewidth=0.5, color=color, picker=1e9)[0]
- vars(line)['_mne_src_idx'] = i_source
- vars(line)['_mne_component'] = i_selection
- vars(line)['_mne_is_bad'] = i_selection in exclude
- ax.set_xlim(xlims)
- ax.set_ylim(ylims)
- ax.text(0.05, .95, component, transform=ax.transAxes,
- verticalalignment='top')
- plt.setp(ax.get_xticklabels(), visible=False)
- plt.setp(ax.get_yticklabels(), visible=False)
- # register callback
- callback = partial(_ica_plot_sources_onpick_, sources=sources, ylims=ylims)
- fig.canvas.mpl_connect('pick_event', callback)
plt_show(show)
- return fig
+ return all_fig
def _plot_ica_sources_evoked(evoked, picks, exclude, title, show, labels=None):
@@ -225,7 +384,6 @@ def _plot_ica_sources_evoked(evoked, picks, exclude, title, show, labels=None):
fig, axes = plt.subplots(1)
ax = axes
axes = [axes]
- idxs = [0]
times = evoked.times * 1e3
# plot unclassified sources and label excluded ones
@@ -235,7 +393,6 @@ def _plot_ica_sources_evoked(evoked, picks, exclude, title, show, labels=None):
picks = np.arange(evoked.data.shape[0])
picks = np.sort(picks)
idxs = [picks]
- color = 'r'
if labels is not None:
labels_used = [k for k in labels if '/' not in k]
@@ -243,7 +400,7 @@ def _plot_ica_sources_evoked(evoked, picks, exclude, title, show, labels=None):
exclude_labels = list()
for ii in picks:
if ii in exclude:
- line_label = 'ICA %03d' % (ii + 1)
+ line_label = 'IC #%03d' % ii
if labels is not None:
annot = list()
for this_label in labels_used:
@@ -314,11 +471,8 @@ def _plot_ica_sources_evoked(evoked, picks, exclude, title, show, labels=None):
return fig
-def plot_ica_scores(ica, scores,
- exclude=None, labels=None,
- axhline=None,
- title='ICA component scores',
- figsize=(12, 6), show=True):
+def plot_ica_scores(ica, scores, exclude=None, labels=None, axhline=None,
+ title='ICA component scores', figsize=(12, 6), show=True):
"""Plot scores related to detected components.
Use this function to asses how well your score describes outlier
@@ -595,7 +749,7 @@ def _plot_sources_raw(ica, raw, picks, exclude, start, stop, show, title,
eog_chs = pick_types(raw.info, meg=False, eog=True, ref_meg=False)
ecg_chs = pick_types(raw.info, meg=False, ecg=True, ref_meg=False)
data = [orig_data[pick] for pick in picks]
- c_names = ['ICA %03d' % x for x in range(len(orig_data))]
+ c_names = ['IC #%03d' % x for x in range(len(orig_data))]
for eog_idx in eog_chs:
c_names.append(raw.ch_names[eog_idx])
types.append('eog')
@@ -633,15 +787,15 @@ def _plot_sources_raw(ica, raw, picks, exclude, start, stop, show, title,
inds = list(range(len(picks)))
data = np.array(data)
n_channels = min([20, len(picks)])
- params = dict(raw=raw, orig_data=data, data=data[:, 0:t_end],
+ params = dict(raw=raw, orig_data=data, data=data[:, 0:t_end], inds=inds,
ch_start=0, t_start=start, info=info, duration=duration,
ica=ica, n_channels=n_channels, times=times, types=types,
n_times=raw.n_times, bad_color=bad_color, picks=picks)
_prepare_mne_browse_raw(params, title, 'w', color, bad_color, inds,
n_channels)
params['scale_factor'] = 1.0
- params['plot_fun'] = partial(_plot_raw_traces, params=params, inds=inds,
- color=color, bad_color=bad_color)
+ params['plot_fun'] = partial(_plot_raw_traces, params=params, color=color,
+ bad_color=bad_color)
params['update_fun'] = partial(_update_data, params)
params['pick_bads_fun'] = partial(_pick_bads, params=params)
params['label_click_fun'] = partial(_label_clicked, params=params)
@@ -688,8 +842,8 @@ def _pick_bads(event, params):
def _close_event(events, params):
"""Function for excluding the selected components on close."""
info = params['info']
- c_names = ['ICA %03d' % x for x in range(params['ica'].n_components_)]
- exclude = [c_names.index(x) for x in info['bads'] if x.startswith('ICA')]
+ c_names = ['IC #%03d' % x for x in range(params['ica'].n_components_)]
+ exclude = [c_names.index(x) for x in info['bads'] if x.startswith('IC')]
params['ica'].exclude = exclude
@@ -699,7 +853,7 @@ def _plot_sources_epochs(ica, epochs, picks, exclude, start, stop, show,
data = ica._transform_epochs(epochs, concatenate=True)
eog_chs = pick_types(epochs.info, meg=False, eog=True, ref_meg=False)
ecg_chs = pick_types(epochs.info, meg=False, ecg=True, ref_meg=False)
- c_names = ['ICA %03d' % x for x in range(ica.n_components_)]
+ c_names = ['IC #%03d' % x for x in range(ica.n_components_)]
ch_types = np.repeat('misc', ica.n_components_)
for eog_idx in eog_chs:
c_names.append(epochs.ch_names[eog_idx])
@@ -772,7 +926,7 @@ def _close_epochs_event(events, params):
"""Function for excluding the selected components on close."""
info = params['info']
exclude = [info['ch_names'].index(x) for x in info['bads']
- if x.startswith('ICA')]
+ if x.startswith('IC')]
params['ica'].exclude = exclude
@@ -784,6 +938,9 @@ def _label_clicked(pos, params):
if line_idx >= len(params['picks']):
return
ic_idx = [params['picks'][line_idx]]
+ if params['types'][ic_idx[0]] != 'misc':
+ warn('Can only plot ICA components.')
+ return
types = list()
info = params['ica'].info
if len(pick_types(info, meg=False, eeg=True, ref_meg=False)) > 0:
diff --git a/mne/viz/misc.py b/mne/viz/misc.py
index 0c22881..560f3c9 100644
--- a/mne/viz/misc.py
+++ b/mne/viz/misc.py
@@ -20,7 +20,9 @@ import numpy as np
from scipy import linalg
from ..surface import read_surface
+from ..externals.six import string_types
from ..io.proj import make_projector
+from ..source_space import read_source_spaces, SourceSpaces
from ..utils import logger, verbose, get_subjects_dir, warn
from ..io.pick import pick_types
from .utils import tight_layout, COLORS, _prepare_trellis, plt_show
@@ -237,7 +239,7 @@ def plot_source_spectrogram(stcs, freq_bins, tmin=None, tmax=None,
return fig
-def _plot_mri_contours(mri_fname, surf_fnames, orientation='coronal',
+def _plot_mri_contours(mri_fname, surfaces, src, orientation='coronal',
slices=None, show=True):
"""Plot BEM contours on anatomical slices.
@@ -245,9 +247,11 @@ def _plot_mri_contours(mri_fname, surf_fnames, orientation='coronal',
----------
mri_fname : str
The name of the file containing anatomical data.
- surf_fnames : list of str
- The filenames for the BEM surfaces in the format
- ['inner_skull.surf', 'outer_skull.surf', 'outer_skin.surf'].
+ surfaces : list of (str, str) tuples
+ A list containing the BEM surfaces to plot as (filename, color) tuples.
+ Colors should be matplotlib-compatible.
+ src : None | SourceSpaces
+ SourceSpaces object for plotting individual sources.
orientation : str
'coronal' or 'axial' or 'sagittal'
slices : list of int
@@ -263,14 +267,24 @@ def _plot_mri_contours(mri_fname, surf_fnames, orientation='coronal',
import matplotlib.pyplot as plt
import nibabel as nib
- if orientation not in ['coronal', 'axial', 'sagittal']:
+ # plot axes (x, y, z) as data axes (0, 1, 2)
+ if orientation == 'coronal':
+ x, y, z = 0, 1, 2
+ elif orientation == 'axial':
+ x, y, z = 2, 0, 1
+ elif orientation == 'sagittal':
+ x, y, z = 2, 1, 0
+ else:
raise ValueError("Orientation must be 'coronal', 'axial' or "
"'sagittal'. Got %s." % orientation)
# Load the T1 data
nim = nib.load(mri_fname)
data = nim.get_data()
- affine = nim.get_affine()
+ try:
+ affine = nim.affine
+ except AttributeError: # older nibabel
+ affine = nim.get_affine()
n_sag, n_axi, n_cor = data.shape
orientation_name2axis = dict(sagittal=0, axial=1, coronal=2)
@@ -287,12 +301,21 @@ def _plot_mri_contours(mri_fname, surf_fnames, orientation='coronal',
# XXX : next line is a hack don't ask why
trans[:3, -1] = [n_sag // 2, n_axi // 2, n_cor // 2]
- for surf_fname in surf_fnames:
+ for file_name, color in surfaces:
surf = dict()
- surf['rr'], surf['tris'] = read_surface(surf_fname)
+ surf['rr'], surf['tris'] = read_surface(file_name)
# move back surface to MRI coordinate system
surf['rr'] = nib.affines.apply_affine(trans, surf['rr'])
- surfs.append(surf)
+ surfs.append((surf, color))
+
+ src_points = list()
+ if isinstance(src, SourceSpaces):
+ for src_ in src:
+ points = src_['rr'][src_['inuse'].astype(bool)] * 1e3
+ src_points.append(nib.affines.apply_affine(trans, points))
+ elif src is not None:
+ raise TypeError("src needs to be None or SourceSpaces instance, not "
+ "%s" % repr(src))
fig, axs = _prepare_trellis(len(slices), 4)
@@ -308,22 +331,21 @@ def _plot_mri_contours(mri_fname, surf_fnames, orientation='coronal',
# First plot the anatomical data
ax.imshow(dat, cmap=plt.cm.gray)
+ ax.set_autoscale_on(False)
ax.axis('off')
# and then plot the contours on top
- for surf in surfs:
- if orientation == 'coronal':
- ax.tricontour(surf['rr'][:, 0], surf['rr'][:, 1],
- surf['tris'], surf['rr'][:, 2],
- levels=[sl], colors='yellow', linewidths=2.0)
- elif orientation == 'axial':
- ax.tricontour(surf['rr'][:, 2], surf['rr'][:, 0],
- surf['tris'], surf['rr'][:, 1],
- levels=[sl], colors='yellow', linewidths=2.0)
- elif orientation == 'sagittal':
- ax.tricontour(surf['rr'][:, 2], surf['rr'][:, 1],
- surf['tris'], surf['rr'][:, 0],
- levels=[sl], colors='yellow', linewidths=2.0)
+ for surf, color in surfs:
+ ax.tricontour(surf['rr'][:, x], surf['rr'][:, y],
+ surf['tris'], surf['rr'][:, z],
+ levels=[sl], colors=color, linewidths=2.0,
+ zorder=1)
+
+ for sources in src_points:
+ in_slice = np.logical_and(sources[:, z] > sl - 0.5,
+ sources[:, z] < sl + 0.5)
+ ax.scatter(sources[in_slice, x], sources[in_slice, y], marker='.',
+ color='#FF00FF', s=1, zorder=2)
plt.subplots_adjust(left=0., bottom=0., right=1., top=1., wspace=0.,
hspace=0.)
@@ -332,7 +354,7 @@ def _plot_mri_contours(mri_fname, surf_fnames, orientation='coronal',
def plot_bem(subject=None, subjects_dir=None, orientation='coronal',
- slices=None, show=True):
+ slices=None, brain_surfaces=None, src=None, show=True):
"""Plot BEM contours on anatomical slices.
Parameters
@@ -346,6 +368,14 @@ def plot_bem(subject=None, subjects_dir=None, orientation='coronal',
'coronal' or 'axial' or 'sagittal'.
slices : list of int
Slice indices.
+ brain_surfaces : None | str | list of str
+ One or more brain surface to plot (optional). Entries should correspond
+ to files in the subject's ``surf`` directory (e.g. ``"white"``).
+ src : None | SourceSpaces | str
+ SourceSpaces instance or path to a source space to plot individual
+ sources as scatter-plot. Only sources lying in the shown slices will be
+ visible, sources that lie between visible slices are not shown. Path
+ can be absolute or relative to the subject's ``bem`` folder.
show : bool
Show figure if True.
@@ -367,21 +397,47 @@ def plot_bem(subject=None, subjects_dir=None, orientation='coronal',
if not op.isdir(bem_path):
raise IOError('Subject bem directory "%s" does not exist' % bem_path)
- surf_fnames = []
- for surf_name in ['*inner_skull', '*outer_skull', '*outer_skin']:
+ surfaces = []
+ for surf_name, color in (('*inner_skull', '#FF0000'),
+ ('*outer_skull', '#FFFF00'),
+ ('*outer_skin', '#FFAA80')):
surf_fname = glob(op.join(bem_path, surf_name + '.surf'))
if len(surf_fname) > 0:
surf_fname = surf_fname[0]
logger.info("Using surface: %s" % surf_fname)
- surf_fnames.append(surf_fname)
-
- if len(surf_fnames) == 0:
+ surfaces.append((surf_fname, color))
+
+ if brain_surfaces is not None:
+ if isinstance(brain_surfaces, string_types):
+ brain_surfaces = (brain_surfaces,)
+ for surf_name in brain_surfaces:
+ for hemi in ('lh', 'rh'):
+ surf_fname = op.join(subjects_dir, subject, 'surf',
+ hemi + '.' + surf_name)
+ if op.exists(surf_fname):
+ surfaces.append((surf_fname, '#00DD00'))
+ else:
+ raise IOError("Surface %s does not exist." % surf_fname)
+
+ if isinstance(src, string_types):
+ if not op.exists(src):
+ src_ = op.join(subjects_dir, subject, 'bem', src)
+ if op.exists(src_):
+ src = src_
+ else:
+ raise IOError("%s does not exist" % src)
+ src = read_source_spaces(src)
+ elif src is not None and not isinstance(src, SourceSpaces):
+ raise TypeError("src needs to be None, str or SourceSpaces instance, "
+ "not %s" % repr(src))
+
+ if len(surfaces) == 0:
raise IOError('No surface files found. Surface files must end with '
'inner_skull.surf, outer_skull.surf or outer_skin.surf')
# Plot the contours
- return _plot_mri_contours(mri_fname, surf_fnames, orientation=orientation,
- slices=slices, show=show)
+ return _plot_mri_contours(mri_fname, surfaces, src, orientation, slices,
+ show)
def plot_events(events, sfreq=None, first_samp=0, color=None, event_id=None,
diff --git a/mne/viz/raw.py b/mne/viz/raw.py
index 09724e6..47a4a92 100644
--- a/mne/viz/raw.py
+++ b/mne/viz/raw.py
@@ -14,16 +14,18 @@ import numpy as np
from ..externals.six import string_types
from ..io.pick import (pick_types, _pick_data_channels, pick_info,
- _PICK_TYPES_KEYS)
+ _PICK_TYPES_KEYS, pick_channels)
from ..io.proj import setup_proj
from ..utils import verbose, get_config
from ..time_frequency import psd_welch
from .topo import _plot_topo, _plot_timeseries, _plot_timeseries_unified
from .utils import (_toggle_options, _toggle_proj, tight_layout,
- _layout_figure, _plot_raw_onkey, figure_nobar,
- _plot_raw_onscroll, _mouse_click, plt_show,
+ _layout_figure, _plot_raw_onkey, figure_nobar, plt_show,
+ _plot_raw_onscroll, _mouse_click, _find_channel_idx,
_helper_raw_resize, _select_bads, _onclick_help,
- _setup_browser_offsets, _compute_scalings)
+ _setup_browser_offsets, _compute_scalings, plot_sensors,
+ _radio_clicked, _set_radio_button, _handle_topomap_bads,
+ _change_channel_group)
from ..defaults import _handle_default
from ..annotations import _onset_to_seconds
@@ -100,7 +102,8 @@ def plot_raw(raw, events=None, duration=10.0, start=0.0, n_channels=20,
start : float
Initial time to show (can be changed dynamically once plotted).
n_channels : int
- Number of channels to plot at once. Defaults to 20.
+ Number of channels to plot at once. Defaults to 20. Has no effect if
+ ``order`` is 'position' or 'selection'.
bgcolor : color object
Color of the background.
color : dict | color object | None
@@ -129,10 +132,15 @@ def plot_raw(raw, events=None, duration=10.0, start=0.0, n_channels=20,
remove_dc : bool
If True remove DC component when plotting data.
- order : 'type' | 'original' | array
- Order in which to plot data. 'type' groups by channel type,
- 'original' plots in the order of ch_names, array gives the
- indices to use in plotting.
+ order : str | array of int
+ Order in which to plot data. 'type' groups by channel type, 'original'
+ plots in the order of ch_names, 'selection' uses Elekta's channel
+ groupings (only works for Neuromag data), 'position' groups the
+ channels by the positions of the sensors. 'selection' and 'position'
+ modes allow custom selections by using lasso selector on the topomap.
+ Pressing ``ctrl`` key while selecting allows appending to the current
+ selection. If array, only the channels in the array are plotted in the
+ given order. Defaults to 'type'.
show_options : bool
If True, a dialog for options related to projection is shown.
title : str | None
@@ -243,9 +251,13 @@ def plot_raw(raw, events=None, duration=10.0, start=0.0, n_channels=20,
for t in ['grad', 'mag']:
inds += [pick_types(info, meg=t, ref_meg=False, exclude=[])]
types += [t] * len(inds[-1])
+ for t in ['hbo', 'hbr']:
+ inds += [pick_types(info, meg=False, ref_meg=False, fnirs=t,
+ exclude=[])]
+ types += [t] * len(inds[-1])
pick_kwargs = dict(meg=False, ref_meg=False, exclude=[])
for key in _PICK_TYPES_KEYS:
- if key != 'meg':
+ if key not in ['meg', 'fnirs']:
pick_kwargs[key] = True
inds += [pick_types(raw.info, **pick_kwargs)]
types += [key] * len(inds[-1])
@@ -258,16 +270,14 @@ def plot_raw(raw, events=None, duration=10.0, start=0.0, n_channels=20,
# put them back to original or modified order for natral plotting
reord = np.argsort(inds)
types = [types[ri] for ri in reord]
- if isinstance(order, str):
+ if isinstance(order, string_types):
if order == 'original':
inds = inds[reord]
+ elif order in ['selection', 'position']:
+ selections, fig_selection = _setup_browser_selection(raw, order)
elif order != 'type':
raise ValueError('Unknown order type %s' % order)
- elif isinstance(order, np.ndarray):
- if not np.array_equal(np.sort(order),
- np.arange(len(info['ch_names']))):
- raise ValueError('order, if array, must have integers from '
- '0 to n_channels - 1')
+ elif isinstance(order, (np.ndarray, list)):
# put back to original order first, then use new order
inds = inds[reord][order]
@@ -288,9 +298,17 @@ def plot_raw(raw, events=None, duration=10.0, start=0.0, n_channels=20,
params = dict(raw=raw, ch_start=0, t_start=start, duration=duration,
info=info, projs=projs, remove_dc=remove_dc, ba=ba,
n_channels=n_channels, scalings=scalings, types=types,
- n_times=n_times, event_times=event_times,
+ n_times=n_times, event_times=event_times, inds=inds,
event_nums=event_nums, clipping=clipping, fig_proj=None)
+ if order in ['selection', 'position']:
+ params['fig_selection'] = fig_selection
+ params['selections'] = selections
+ params['radio_clicked'] = partial(_radio_clicked, params=params)
+ fig_selection.radio.on_clicked(params['radio_clicked'])
+ lasso_callback = partial(_set_custom_selection, params=params)
+ fig_selection.canvas.mpl_connect('lasso_event', lasso_callback)
+
_prepare_mne_browse_raw(params, title, bgcolor, color, bad_color, inds,
n_channels)
@@ -298,17 +316,16 @@ def plot_raw(raw, events=None, duration=10.0, start=0.0, n_channels=20,
event_lines = [params['ax'].plot([np.nan], color=event_color[ev_num])[0]
for ev_num in sorted(event_color.keys())]
- params['plot_fun'] = partial(_plot_raw_traces, params=params, inds=inds,
- color=color, bad_color=bad_color,
- event_lines=event_lines,
+ params['plot_fun'] = partial(_plot_raw_traces, params=params, color=color,
+ bad_color=bad_color, event_lines=event_lines,
event_color=event_color)
if raw.annotations is not None:
segments = list()
segment_colors = dict()
# sort the segments by start time
- order = raw.annotations.onset.argsort(axis=0)
- descriptions = raw.annotations.description[order]
+ ann_order = raw.annotations.onset.argsort(axis=0)
+ descriptions = raw.annotations.description[ann_order]
color_keys = set(descriptions)
color_vals = np.linspace(0, 1, len(color_keys))
for idx, key in enumerate(color_keys):
@@ -317,9 +334,9 @@ def plot_raw(raw, events=None, duration=10.0, start=0.0, n_channels=20,
else:
segment_colors[key] = plt.cm.summer(color_vals[idx])
params['segment_colors'] = segment_colors
- for idx, onset in enumerate(raw.annotations.onset[order]):
+ for idx, onset in enumerate(raw.annotations.onset[ann_order]):
annot_start = _onset_to_seconds(raw, onset)
- annot_end = annot_start + raw.annotations.duration[order][idx]
+ annot_end = annot_start + raw.annotations.duration[ann_order][idx]
segments.append([annot_start, annot_end])
ylim = params['ax_hscroll'].get_ylim()
dscr = descriptions[idx]
@@ -368,6 +385,19 @@ def plot_raw(raw, events=None, duration=10.0, start=0.0, n_channels=20,
# deal with projectors
if show_options is True:
_toggle_options(None, params)
+ # initialize the first selection set
+ if order in ['selection', 'position']:
+ _radio_clicked(fig_selection.radio.labels[0]._text, params)
+ callback_selection_key = partial(_selection_key_press, params=params)
+ callback_selection_scroll = partial(_selection_scroll, params=params)
+ callback_close = partial(_close_event, params=params)
+ params['fig'].canvas.mpl_connect('close_event', callback_close)
+ params['fig_selection'].canvas.mpl_connect('close_event',
+ callback_close)
+ params['fig_selection'].canvas.mpl_connect('key_press_event',
+ callback_selection_key)
+ params['fig_selection'].canvas.mpl_connect('scroll_event',
+ callback_selection_scroll)
try:
plt_show(show, block=block)
@@ -377,6 +407,31 @@ def plot_raw(raw, events=None, duration=10.0, start=0.0, n_channels=20,
return params['fig']
+def _selection_scroll(event, params):
+ """Callback for scroll in selection dialog."""
+ if event.step < 0:
+ _change_channel_group(-1, params)
+ elif event.step > 0:
+ _change_channel_group(1, params)
+
+
+def _selection_key_press(event, params):
+ """Callback for keys in selection dialog."""
+ if event.key == 'down':
+ _change_channel_group(-1, params)
+ elif event.key == 'up':
+ _change_channel_group(1, params)
+ elif event.key == 'escape':
+ _close_event(event, params)
+
+
+def _close_event(event, params):
+ """Callback for closing of raw browser with selections."""
+ import matplotlib.pyplot as plt
+ plt.close(params['fig_selection'])
+ plt.close(params['fig'])
+
+
def _label_clicked(pos, params):
"""Helper function for selecting bad channels."""
labels = params['ax'].yaxis.get_ticklabels()
@@ -385,17 +440,23 @@ def _label_clicked(pos, params):
text = labels[line_idx].get_text()
if len(text) == 0:
return
- ch_idx = params['ch_start'] + line_idx
+ if 'fig_selection' in params:
+ ch_idx = _find_channel_idx(text, params)
+ _handle_topomap_bads(text, params)
+ else:
+ ch_idx = [params['ch_start'] + line_idx]
bads = params['info']['bads']
if text in bads:
while text in bads: # to make sure duplicates are removed
bads.remove(text)
color = vars(params['lines'][line_idx])['def_color']
- params['ax_vscroll'].patches[ch_idx].set_color(color)
+ for idx in ch_idx:
+ params['ax_vscroll'].patches[idx].set_color(color)
else:
bads.append(text)
color = params['bad_color']
- params['ax_vscroll'].patches[ch_idx].set_color(color)
+ for idx in ch_idx:
+ params['ax_vscroll'].patches[idx].set_color(color)
params['raw'].info['bads'] = bads
_plot_update_raw_proj(params, None)
@@ -521,6 +582,12 @@ def plot_raw_psd(raw, tmin=0., tmax=np.inf, fmin=0, fmax=np.inf, proj=False,
# Convert PSDs to dB
if dB:
psds = 10 * np.log10(psds)
+ if np.any(np.isinf(psds)):
+ where = np.flatnonzero(np.isinf(psds.min(1)))
+ chs = [raw.ch_names[i] for i in picks[where]]
+ raise ValueError("Infinite value in PSD for channel(s) %s. "
+ "These channels might be dead." %
+ ', '.join(chs))
unit = 'dB'
else:
unit = 'power'
@@ -582,18 +649,40 @@ def _prepare_mne_browse_raw(params, title, bgcolor, color, bad_color, inds,
# populate vertical and horizontal scrollbars
info = params['info']
- for ci in range(len(info['ch_names'])):
- this_color = (bad_color if info['ch_names'][inds[ci]] in info['bads']
- else color)
- if isinstance(this_color, dict):
- this_color = this_color[params['types'][inds[ci]]]
- ax_vscroll.add_patch(mpl.patches.Rectangle((0, ci), 1, 1,
- facecolor=this_color,
- edgecolor=this_color))
+ n_ch = len(inds)
+
+ if 'fig_selection' in params:
+ selections = params['selections']
+ labels = [l._text for l in params['fig_selection'].radio.labels]
+ # Flatten the selections dict to a list.
+ cis = [item for sublist in [selections[l] for l in labels] for item
+ in sublist]
+
+ for idx, ci in enumerate(cis):
+ this_color = (bad_color if info['ch_names'][ci] in
+ info['bads'] else color)
+ if isinstance(this_color, dict):
+ this_color = this_color[params['types'][ci]]
+ ax_vscroll.add_patch(mpl.patches.Rectangle((0, idx), 1, 1,
+ facecolor=this_color,
+ edgecolor=this_color))
+ ax_vscroll.set_ylim(len(cis), 0)
+ n_channels = max([len(selections[labels[0]]), n_channels])
+ else:
+ for ci in range(len(inds)):
+ this_color = (bad_color if info['ch_names'][inds[ci]] in
+ info['bads'] else color)
+ if isinstance(this_color, dict):
+ this_color = this_color[params['types'][inds[ci]]]
+ ax_vscroll.add_patch(mpl.patches.Rectangle((0, ci), 1, 1,
+ facecolor=this_color,
+ edgecolor=this_color))
+ ax_vscroll.set_ylim(n_ch, 0)
vsel_patch = mpl.patches.Rectangle((0, 0), 1, n_channels, alpha=0.5,
facecolor='w', edgecolor='w')
ax_vscroll.add_patch(vsel_patch)
params['vsel_patch'] = vsel_patch
+
hsel_patch = mpl.patches.Rectangle((params['t_start'], 0),
params['duration'], 1, edgecolor='k',
facecolor=(0.75, 0.75, 0.75),
@@ -601,18 +690,9 @@ def _prepare_mne_browse_raw(params, title, bgcolor, color, bad_color, inds,
ax_hscroll.add_patch(hsel_patch)
params['hsel_patch'] = hsel_patch
ax_hscroll.set_xlim(0, params['n_times'] / float(info['sfreq']))
- n_ch = len(info['ch_names'])
- ax_vscroll.set_ylim(n_ch, 0)
- ax_vscroll.set_title('Ch.')
- # make shells for plotting traces
- _setup_browser_offsets(params, n_channels)
- ax.set_xlim(params['t_start'], params['t_start'] + params['duration'],
- False)
+ ax_vscroll.set_title('Ch.')
- params['lines'] = [ax.plot([np.nan], antialiased=False, linewidth=0.5)[0]
- for _ in range(n_ch)]
- ax.set_yticklabels(['X' * max([len(ch) for ch in info['ch_names']])])
vertline_color = (0., 0.75, 0.)
params['ax_vertline'] = ax.plot([0, 0], ax.get_ylim(),
color=vertline_color, zorder=-1)[0]
@@ -622,13 +702,22 @@ def _prepare_mne_browse_raw(params, title, bgcolor, color, bad_color, inds,
params['ax_hscroll_vertline'] = ax_hscroll.plot([0, 0], [0, 1],
color=vertline_color,
zorder=2)[0]
+ # make shells for plotting traces
+ _setup_browser_offsets(params, n_channels)
+ ax.set_xlim(params['t_start'], params['t_start'] + params['duration'],
+ False)
+
+ params['lines'] = [ax.plot([np.nan], antialiased=False, linewidth=0.5)[0]
+ for _ in range(n_ch)]
+ ax.set_yticklabels(['X' * max([len(ch) for ch in info['ch_names']])])
-def _plot_raw_traces(params, inds, color, bad_color, event_lines=None,
+def _plot_raw_traces(params, color, bad_color, event_lines=None,
event_color=None):
"""Helper for plotting raw"""
lines = params['lines']
info = params['info']
+ inds = params['inds']
n_channels = params['n_channels']
params['bad_color'] = bad_color
labels = params['ax'].yaxis.get_ticklabels()
@@ -640,7 +729,7 @@ def _plot_raw_traces(params, inds, color, bad_color, event_lines=None,
# n_channels per view >= the number of traces available
if ii >= len(lines):
break
- elif ch_ind < len(info['ch_names']):
+ elif ch_ind < len(inds):
# scale to fit
ch_name = info['ch_names'][inds[ch_ind]]
tick_list += [ch_name]
@@ -722,7 +811,8 @@ def _plot_raw_traces(params, inds, color, bad_color, event_lines=None,
params['ax'].set_xlim(params['times'][0],
params['times'][0] + params['duration'], False)
params['ax'].set_yticklabels(tick_list)
- params['vsel_patch'].set_y(params['ch_start'])
+ if 'fig_selection' not in params:
+ params['vsel_patch'].set_y(params['ch_start'])
params['fig'].canvas.draw()
# XXX This is a hack to make sure this figure gets drawn last
# so that when matplotlib goes to calculate bounds we don't get a
@@ -815,3 +905,70 @@ def plot_raw_psd_topo(raw, tmin=0., tmax=None, fmin=0., fmax=100., proj=False,
except TypeError: # not all versions have this
plt_show(show)
return fig
+
+
+def _set_custom_selection(params):
+ """Callback for setting custom selection by lasso selector."""
+ chs = params['fig_selection'].lasso.selection
+ if len(chs) == 0:
+ return
+ labels = [l._text for l in params['fig_selection'].radio.labels]
+ inds = np.in1d(params['raw'].ch_names, chs)
+ params['selections']['Custom'] = np.where(inds)[0]
+
+ _set_radio_button(labels.index('Custom'), params=params)
+
+
+def _setup_browser_selection(raw, kind):
+ """Helper for organizing browser selections."""
+ import matplotlib.pyplot as plt
+ from matplotlib.widgets import RadioButtons
+ from ..selection import (read_selection, _SELECTIONS, _EEG_SELECTIONS,
+ _divide_to_regions)
+ from ..utils import _get_stim_channel
+ if kind in ('position'):
+ order = _divide_to_regions(raw.info)
+ keys = _SELECTIONS[1:] # no 'Vertex'
+ elif 'selection':
+ from ..io import RawFIF, RawArray
+ if not isinstance(raw, (RawFIF, RawArray)):
+ raise ValueError("order='selection' only works for Neuromag data. "
+ "Use order='position' instead.")
+ order = dict()
+ try:
+ stim_ch = _get_stim_channel(None, raw.info)
+ except ValueError:
+ stim_ch = ['']
+ keys = np.concatenate([_SELECTIONS, _EEG_SELECTIONS])
+ stim_ch = pick_channels(raw.ch_names, stim_ch)
+ for key in keys:
+ channels = read_selection(key, info=raw.info)
+ picks = pick_channels(raw.ch_names, channels)
+ if len(picks) == 0:
+ continue # omit empty selections
+ order[key] = np.concatenate([picks, stim_ch])
+
+ misc = pick_types(raw.info, meg=False, eeg=False, stim=True, eog=True,
+ ecg=True, emg=True, ref_meg=False, misc=True, resp=True,
+ chpi=True, exci=True, ias=True, syst=True, seeg=False,
+ bio=True, ecog=False, fnirs=False, exclude=())
+ if len(misc) > 0:
+ order['Misc'] = misc
+ keys = np.concatenate([keys, ['Misc']])
+ fig_selection = figure_nobar(figsize=(2, 6), dpi=80)
+ fig_selection.canvas.set_window_title('Selection')
+ rax = plt.subplot2grid((6, 1), (2, 0), rowspan=4, colspan=1)
+ topo_ax = plt.subplot2grid((6, 1), (0, 0), rowspan=2, colspan=1)
+ keys = np.concatenate([keys, ['Custom']])
+ order.update({'Custom': list()}) # custom selection with lasso
+
+ plot_sensors(raw.info, kind='select', ch_type='all', axes=topo_ax,
+ ch_groups=kind, title='', show=False)
+ fig_selection.radio = RadioButtons(rax, [key for key in keys
+ if key in order.keys()])
+
+ for circle in fig_selection.radio.circles:
+ circle.set_radius(0.02) # make them smaller to prevent overlap
+ circle.set_edgecolor('gray') # make sure the buttons are visible
+
+ return order, fig_selection
diff --git a/mne/viz/tests/test_3d.py b/mne/viz/tests/test_3d.py
index 1062d23..4add583 100644
--- a/mne/viz/tests/test_3d.py
+++ b/mne/viz/tests/test_3d.py
@@ -69,11 +69,12 @@ def test_plot_sparse_source_estimates():
stc = SourceEstimate(stc_data, vertices, 1, 1)
colormap = 'mne_analyze'
plot_source_estimates(stc, 'sample', colormap=colormap,
- config_opts={'background': (1, 1, 0)},
+ background=(1, 1, 0),
subjects_dir=subjects_dir, colorbar=True,
clim='auto')
assert_raises(TypeError, plot_source_estimates, stc, 'sample',
- figure='foo', hemi='both', clim='auto')
+ figure='foo', hemi='both', clim='auto',
+ subjects_dir=subjects_dir)
# now do sparse version
vertices = sample_src[0]['vertno']
@@ -122,9 +123,8 @@ def test_plot_trans():
ref_meg = False if system == 'KIT' else True
plot_trans(info, trans_fname, subject='sample', meg_sensors=True,
subjects_dir=subjects_dir, ref_meg=ref_meg)
- # KIT ref sensor coil def not defined
- assert_raises(RuntimeError, plot_trans, infos['KIT'], None,
- meg_sensors=True, ref_meg=True)
+ # KIT ref sensor coil def is defined
+ plot_trans(infos['KIT'], None, meg_sensors=True, ref_meg=True)
info = infos['Neuromag']
assert_raises(ValueError, plot_trans, info, trans_fname,
subject='sample', subjects_dir=subjects_dir,
@@ -164,7 +164,7 @@ def test_limits_to_control_points():
stc.plot(colormap='hot', clim='auto', subjects_dir=subjects_dir)
stc.plot(colormap='mne', clim='auto', subjects_dir=subjects_dir)
figs = [mlab.figure(), mlab.figure()]
- assert_raises(RuntimeError, stc.plot, clim='auto', figure=figs,
+ assert_raises(ValueError, stc.plot, clim='auto', figure=figs,
subjects_dir=subjects_dir)
# Test both types of incorrect limits key (lims/pos_lims)
@@ -201,13 +201,13 @@ def test_limits_to_control_points():
warnings.simplefilter('always')
# thresholded maps
stc._data.fill(1.)
- plot_source_estimates(stc, subjects_dir=subjects_dir)
+ plot_source_estimates(stc, subjects_dir=subjects_dir, time_unit='s')
assert_equal(len(w), 0)
stc._data[0].fill(0.)
- plot_source_estimates(stc, subjects_dir=subjects_dir)
+ plot_source_estimates(stc, subjects_dir=subjects_dir, time_unit='s')
assert_equal(len(w), 0)
stc._data.fill(0.)
- plot_source_estimates(stc, subjects_dir=subjects_dir)
+ plot_source_estimates(stc, subjects_dir=subjects_dir, time_unit='s')
assert_equal(len(w), 1)
mlab.close()
diff --git a/mne/viz/tests/test_decoding.py b/mne/viz/tests/test_decoding.py
index 5095b18..401f2da 100644
--- a/mne/viz/tests/test_decoding.py
+++ b/mne/viz/tests/test_decoding.py
@@ -12,7 +12,8 @@ import numpy as np
from mne.epochs import equalize_epoch_counts, concatenate_epochs
from mne.decoding import GeneralizationAcrossTime
-from mne import io, Epochs, read_events, pick_types
+from mne import Epochs, read_events, pick_types
+from mne.io import read_raw_fif
from mne.utils import requires_sklearn, run_tests_if_main
import matplotlib
matplotlib.use('Agg') # for testing don't use X server
@@ -30,7 +31,7 @@ def _get_data(tmin=-0.2, tmax=0.5, event_id=dict(aud_l=1, vis_l=3),
event_id_gen=dict(aud_l=2, vis_l=4), test_times=None):
"""Aux function for testing GAT viz"""
gat = GeneralizationAcrossTime()
- raw = io.read_raw_fif(raw_fname, preload=False)
+ raw = read_raw_fif(raw_fname, preload=False, add_eeg_ref=False)
raw.add_proj([], remove_existing=True)
events = read_events(event_name)
picks = pick_types(raw.info, meg='mag', stim=False, ecg=False,
@@ -40,7 +41,8 @@ def _get_data(tmin=-0.2, tmax=0.5, event_id=dict(aud_l=1, vis_l=3),
# Test on time generalization within one condition
with warnings.catch_warnings(record=True):
epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks,
- baseline=(None, 0), preload=True, decim=decim)
+ baseline=(None, 0), preload=True, decim=decim,
+ add_eeg_ref=False)
epochs_list = [epochs[k] for k in event_id]
equalize_epoch_counts(epochs_list)
epochs = concatenate_epochs(epochs_list)
diff --git a/mne/viz/tests/test_epochs.py b/mne/viz/tests/test_epochs.py
index 6adbbfe..ce36a5e 100644
--- a/mne/viz/tests/test_epochs.py
+++ b/mne/viz/tests/test_epochs.py
@@ -13,11 +13,10 @@ from nose.tools import assert_raises
import numpy as np
from numpy.testing import assert_equal
-from mne import io, read_events, Epochs
-from mne import pick_types
-from mne.utils import run_tests_if_main, requires_version
+from mne import read_events, Epochs, pick_types
from mne.channels import read_layout
-
+from mne.io import read_raw_fif
+from mne.utils import run_tests_if_main, requires_version
from mne.viz import plot_drop_log
from mne.viz.utils import _fake_click
@@ -39,19 +38,23 @@ layout = read_layout('Vectorview-all')
def _get_raw():
- return io.read_raw_fif(raw_fname, preload=False)
+ """Get raw data."""
+ return read_raw_fif(raw_fname, preload=False, add_eeg_ref=False)
def _get_events():
+ """Get events."""
return read_events(event_name)
def _get_picks(raw):
+ """Get picks."""
return pick_types(raw.info, meg=True, eeg=False, stim=False,
ecg=False, eog=False, exclude='bads')
def _get_epochs():
+ """Get epochs."""
raw = _get_raw()
events = _get_events()
picks = _get_picks(raw)
@@ -59,18 +62,19 @@ def _get_epochs():
picks = np.round(np.linspace(0, len(picks) + 1, n_chan)).astype(int)
with warnings.catch_warnings(record=True): # bad proj
epochs = Epochs(raw, events[:5], event_id, tmin, tmax, picks=picks,
- baseline=(None, 0))
+ baseline=(None, 0), add_eeg_ref=False)
return epochs
def _get_epochs_delayed_ssp():
+ """Get epochs with delayed SSP."""
raw = _get_raw()
events = _get_events()
picks = _get_picks(raw)
reject = dict(mag=4e-12)
- epochs_delayed_ssp = Epochs(raw, events[:10], event_id, tmin, tmax,
- picks=picks, baseline=(None, 0),
- proj='delayed', reject=reject)
+ epochs_delayed_ssp = Epochs(
+ raw, events[:10], event_id, tmin, tmax, picks=picks,
+ baseline=(None, 0), proj='delayed', reject=reject, add_eeg_ref=False)
return epochs_delayed_ssp
@@ -136,8 +140,8 @@ def test_plot_epochs_image():
epochs = _get_epochs()
epochs.plot_image(picks=[1, 2])
overlay_times = [0.1]
- epochs.plot_image(order=[0], overlay_times=overlay_times)
- epochs.plot_image(overlay_times=overlay_times)
+ epochs.plot_image(order=[0], overlay_times=overlay_times, vmin=0.01)
+ epochs.plot_image(overlay_times=overlay_times, vmin=-0.001, vmax=0.001)
assert_raises(ValueError, epochs.plot_image,
overlay_times=[0.1, 0.2])
assert_raises(ValueError, epochs.plot_image,
diff --git a/mne/viz/tests/test_evoked.py b/mne/viz/tests/test_evoked.py
index 4fd4638..9ce246d 100644
--- a/mne/viz/tests/test_evoked.py
+++ b/mne/viz/tests/test_evoked.py
@@ -4,6 +4,7 @@
# Eric Larson <larson.eric.d at gmail.com>
# Cathy Nangini <cnangini at gmail.com>
# Mainak Jas <mainak at neuro.hut.fi>
+# Jona Sassenhagen <jona.sassenhagen at gmail.com>
#
# License: Simplified BSD
@@ -14,10 +15,11 @@ import numpy as np
from numpy.testing import assert_raises
-from mne import io, read_events, Epochs, pick_types, read_cov
+from mne import read_events, Epochs, pick_types, read_cov
from mne.channels import read_layout
+from mne.io import read_raw_fif
from mne.utils import slow_test, run_tests_if_main
-from mne.viz.evoked import _butterfly_onselect
+from mne.viz.evoked import _butterfly_onselect, plot_compare_evokeds
from mne.viz.utils import _fake_click
# Set our plotters to test mode
@@ -38,19 +40,23 @@ layout = read_layout('Vectorview-all')
def _get_raw():
- return io.read_raw_fif(raw_fname, preload=False)
+ """Get raw data."""
+ return read_raw_fif(raw_fname, preload=False, add_eeg_ref=False)
def _get_events():
+ """Get events."""
return read_events(event_name)
def _get_picks(raw):
+ """Get picks."""
return pick_types(raw.info, meg=True, eeg=False, stim=False,
ecg=False, eog=False, exclude='bads')
def _get_epochs():
+ """Get epochs."""
raw = _get_raw()
raw.add_proj([], remove_existing=True)
events = _get_events()
@@ -58,28 +64,28 @@ def _get_epochs():
# Use a subset of channels for plotting speed
picks = picks[np.round(np.linspace(0, len(picks) - 1, n_chan)).astype(int)]
picks[0] = 2 # make sure we have a magnetometer
- with warnings.catch_warnings(record=True): # proj
- epochs = Epochs(raw, events[:5], event_id, tmin, tmax, picks=picks,
- baseline=(None, 0))
+ epochs = Epochs(raw, events[:5], event_id, tmin, tmax, picks=picks,
+ baseline=(None, 0), add_eeg_ref=False)
epochs.info['bads'] = [epochs.ch_names[-1]]
return epochs
def _get_epochs_delayed_ssp():
+ """Get epochs with delayed SSP."""
raw = _get_raw()
events = _get_events()
picks = _get_picks(raw)
reject = dict(mag=4e-12)
epochs_delayed_ssp = Epochs(raw, events[:10], event_id, tmin, tmax,
picks=picks, baseline=(None, 0),
- proj='delayed', reject=reject)
+ proj='delayed', reject=reject,
+ add_eeg_ref=False)
return epochs_delayed_ssp
@slow_test
def test_plot_evoked():
- """Test plotting of evoked
- """
+ """Test plotting of evoked."""
import matplotlib.pyplot as plt
evoked = _get_epochs().average()
with warnings.catch_warnings(record=True):
@@ -91,9 +97,10 @@ def test_plot_evoked():
[line.get_xdata()[0], line.get_ydata()[0]], 'data')
_fake_click(fig, ax,
[ax.get_xlim()[0], ax.get_ylim()[1]], 'data')
- # plot with bad channels excluded & spatial_colors
+ # plot with bad channels excluded & spatial_colors & zorder
evoked.plot(exclude='bads')
- evoked.plot(exclude=evoked.info['bads'], spatial_colors=True, gfp=True)
+ evoked.plot(exclude=evoked.info['bads'], spatial_colors=True, gfp=True,
+ zorder='std')
# test selective updating of dict keys is working.
evoked.plot(hline=[1], units=dict(mag='femto foo'))
@@ -115,12 +122,12 @@ def test_plot_evoked():
evoked.plot_image(proj=True)
# plot with bad channels excluded
- evoked.plot_image(exclude='bads')
+ evoked.plot_image(exclude='bads', cmap='interactive')
evoked.plot_image(exclude=evoked.info['bads']) # does the same thing
plt.close('all')
evoked.plot_topo() # should auto-find layout
- _butterfly_onselect(0, 200, ['mag'], evoked) # test averaged topomap
+ _butterfly_onselect(0, 200, ['mag', 'grad'], evoked)
plt.close('all')
cov = read_cov(cov_fname)
@@ -128,6 +135,46 @@ def test_plot_evoked():
evoked.plot_white(cov)
evoked.plot_white([cov, cov])
+ # plot_compare_evokeds: test condition contrast, CI, color assignment
+ plot_compare_evokeds(evoked.copy().pick_types(meg='mag'))
+ evoked.rename_channels({'MEG 2142': "MEG 1642"})
+ assert len(plot_compare_evokeds(evoked)) == 2
+ colors = dict(red='r', blue='b')
+ linestyles = dict(red='--', blue='-')
+ red, blue = evoked.copy(), evoked.copy()
+ red.data *= 1.1
+ blue.data *= 0.9
+ plot_compare_evokeds([red, blue], picks=3) # list of evokeds
+ plot_compare_evokeds([[red, evoked], [blue, evoked]],
+ picks=3) # list of lists
+ # test picking & plotting grads
+ contrast = dict()
+ contrast["red/stim"] = list((evoked.copy(), red))
+ contrast["blue/stim"] = list((evoked.copy(), blue))
+ # test a bunch of params at once
+ plot_compare_evokeds(contrast, colors=colors, linestyles=linestyles,
+ picks=[0, 2], vlines=[.01, -.04], invert_y=True,
+ truncate_yaxis=False, ylim=dict(mag=(-10, 10)),
+ styles={"red/stim": {"linewidth": 1}})
+ assert_raises(ValueError, plot_compare_evokeds,
+ contrast, picks='str') # bad picks: not int
+ assert_raises(ValueError, plot_compare_evokeds, evoked, picks=3,
+ colors=dict(fake=1)) # 'fake' not in conds
+ assert_raises(ValueError, plot_compare_evokeds, evoked, picks=3,
+ styles=dict(fake=1)) # 'fake' not in conds
+ assert_raises(ValueError, plot_compare_evokeds, [[1, 2], [3, 4]],
+ picks=3) # evoked must contain Evokeds
+ assert_raises(ValueError, plot_compare_evokeds, evoked, picks=3,
+ styles=dict(err=1)) # bad styles dict
+ assert_raises(ValueError, plot_compare_evokeds, evoked, picks=3,
+ gfp=True) # no single-channel GFP
+ assert_raises(TypeError, plot_compare_evokeds, evoked, picks=3,
+ ci='fake') # ci must be float or None
+ contrast["red/stim"] = red
+ contrast["blue/stim"] = blue
+ plot_compare_evokeds(contrast, picks=[0], colors=['r', 'b'],
+ ylim=dict(mag=(1, 10)))
+
# Hack to test plotting of maxfiltered data
evoked_sss = evoked.copy()
evoked_sss.info['proc_history'] = [dict(max_info=None)]
diff --git a/mne/viz/tests/test_ica.py b/mne/viz/tests/test_ica.py
index 3149fc6..181ff1b 100644
--- a/mne/viz/tests/test_ica.py
+++ b/mne/viz/tests/test_ica.py
@@ -6,13 +6,15 @@
import os.path as op
import warnings
-from numpy.testing import assert_raises
+from numpy.testing import assert_raises, assert_equal, assert_array_equal
+from nose.tools import assert_true
-from mne import io, read_events, Epochs, read_cov
-from mne import pick_types
+from mne import read_events, Epochs, read_cov, pick_types
+from mne.io import read_raw_fif
+from mne.preprocessing import ICA, create_ecg_epochs, create_eog_epochs
from mne.utils import run_tests_if_main, requires_sklearn
+from mne.viz.ica import _create_properties_layout, plot_ica_properties
from mne.viz.utils import _fake_click
-from mne.preprocessing import ICA, create_ecg_epochs, create_eog_epochs
# Set our plotters to test mode
import matplotlib
@@ -29,31 +31,34 @@ event_id, tmin, tmax = 1, -0.1, 0.2
def _get_raw(preload=False):
- return io.read_raw_fif(raw_fname, preload=preload)
+ """Get raw data."""
+ return read_raw_fif(raw_fname, preload=preload, add_eeg_ref=False)
def _get_events():
+ """Get events."""
return read_events(event_name)
def _get_picks(raw):
+ """Get picks."""
return [0, 1, 2, 6, 7, 8, 12, 13, 14] # take a only few channels
def _get_epochs():
+ """Get epochs."""
raw = _get_raw()
events = _get_events()
picks = _get_picks(raw)
with warnings.catch_warnings(record=True): # bad proj
epochs = Epochs(raw, events[:10], event_id, tmin, tmax, picks=picks,
- baseline=(None, 0))
+ baseline=(None, 0), add_eeg_ref=False)
return epochs
@requires_sklearn
def test_plot_ica_components():
- """Test plotting of ICA solutions
- """
+ """Test plotting of ICA solutions."""
import matplotlib.pyplot as plt
raw = _get_raw()
ica = ICA(noise_cov=read_cov(cov_fname), n_components=2,
@@ -64,7 +69,29 @@ def test_plot_ica_components():
warnings.simplefilter('always', UserWarning)
with warnings.catch_warnings(record=True):
for components in [0, [0], [0, 1], [0, 1] * 2, None]:
- ica.plot_components(components, image_interp='bilinear', res=16)
+ ica.plot_components(components, image_interp='bilinear', res=16,
+ colorbar=True)
+
+ # test interactive mode (passing 'inst' arg)
+ plt.close('all')
+ ica.plot_components([0, 1], image_interp='bilinear', res=16, inst=raw)
+
+ fig = plt.gcf()
+ ax = [a for a in fig.get_children() if isinstance(a, plt.Axes)]
+ lbl = ax[1].get_label()
+ _fake_click(fig, ax[1], (0., 0.), xform='data')
+
+ c_fig = plt.gcf()
+ ax = [a for a in c_fig.get_children() if isinstance(a, plt.Axes)]
+ labels = [a.get_label() for a in ax]
+
+ for l in ['topomap', 'image', 'erp', 'spectrum', 'variance']:
+ assert_true(l in labels)
+
+ topomap_ax = ax[labels.index('topomap')]
+ title = topomap_ax.get_title()
+ assert_true(lbl == title)
+
ica.info = None
assert_raises(ValueError, ica.plot_components, 1)
assert_raises(RuntimeError, ica.plot_components, 1, ch_type='mag')
@@ -72,12 +99,64 @@ def test_plot_ica_components():
@requires_sklearn
+def test_plot_ica_properties():
+ """Test plotting of ICA properties."""
+ import matplotlib.pyplot as plt
+
+ raw = _get_raw(preload=True)
+ raw.add_proj([], remove_existing=True)
+ events = _get_events()
+ picks = _get_picks(raw)[:6]
+ pick_names = [raw.ch_names[k] for k in picks]
+ raw.pick_channels(pick_names)
+
+ with warnings.catch_warnings(record=True): # bad proj
+ epochs = Epochs(raw, events[:10], event_id, tmin, tmax,
+ baseline=(None, 0), preload=True)
+
+ ica = ICA(noise_cov=read_cov(cov_fname), n_components=2,
+ max_pca_components=2, n_pca_components=2)
+ with warnings.catch_warnings(record=True): # bad proj
+ ica.fit(raw)
+
+ # test _create_properties_layout
+ fig, ax = _create_properties_layout()
+ assert_equal(len(ax), 5)
+
+ topoargs = dict(topomap_args={'res': 10})
+ ica.plot_properties(raw, picks=0, **topoargs)
+ ica.plot_properties(epochs, picks=1, dB=False, plot_std=1.5, **topoargs)
+ ica.plot_properties(epochs, picks=1, image_args={'sigma': 1.5},
+ topomap_args={'res': 10, 'colorbar': True},
+ psd_args={'fmax': 65.}, plot_std=False,
+ figsize=[4.5, 4.5])
+ plt.close('all')
+
+ assert_raises(ValueError, ica.plot_properties, epochs, dB=list('abc'))
+ assert_raises(ValueError, ica.plot_properties, epochs, plot_std=[])
+ assert_raises(ValueError, ica.plot_properties, ica)
+ assert_raises(ValueError, ica.plot_properties, [0.2])
+ assert_raises(ValueError, plot_ica_properties, epochs, epochs)
+ assert_raises(ValueError, ica.plot_properties, epochs,
+ psd_args='not dict')
+
+ fig, ax = plt.subplots(2, 3)
+ ax = ax.ravel()[:-1]
+ ica.plot_properties(epochs, picks=1, axes=ax)
+ fig = ica.plot_properties(raw, picks=[0, 1], **topoargs)
+ assert_equal(len(fig), 2)
+ assert_raises(ValueError, plot_ica_properties, epochs, ica, picks=[0, 1],
+ axes=ax)
+ assert_raises(ValueError, ica.plot_properties, epochs, axes='not axes')
+ plt.close('all')
+
+
+ at requires_sklearn
def test_plot_ica_sources():
- """Test plotting of ICA panel
- """
+ """Test plotting of ICA panel."""
import matplotlib.pyplot as plt
- raw = io.read_raw_fif(raw_fname,
- preload=False).crop(0, 1, copy=False).load_data()
+ raw = read_raw_fif(raw_fname, preload=False, add_eeg_ref=False)
+ raw.crop(0, 1, copy=False).load_data()
picks = _get_picks(raw)
epochs = _get_epochs()
raw.pick_channels([raw.ch_names[k] for k in picks])
@@ -85,6 +164,12 @@ def test_plot_ica_sources():
ecg=False, eog=False, exclude='bads')
ica = ICA(n_components=2, max_pca_components=3, n_pca_components=3)
ica.fit(raw, picks=ica_picks)
+ ica.exclude = [1]
+ fig = ica.plot_sources(raw)
+ fig.canvas.key_press_event('escape')
+ # Sadly close_event isn't called on Agg backend and the test always passes.
+ assert_array_equal(ica.exclude, [1])
+
raw.info['bads'] = ['MEG 0113']
assert_raises(RuntimeError, ica.plot_sources, inst=raw)
ica.plot_sources(epochs)
@@ -115,8 +200,7 @@ def test_plot_ica_sources():
@requires_sklearn
def test_plot_ica_overlay():
- """Test plotting of ICA cleaning
- """
+ """Test plotting of ICA cleaning."""
import matplotlib.pyplot as plt
raw = _get_raw(preload=True)
picks = _get_picks(raw)
@@ -140,8 +224,7 @@ def test_plot_ica_overlay():
@requires_sklearn
def test_plot_ica_scores():
- """Test plotting of ICA scores
- """
+ """Test plotting of ICA scores."""
import matplotlib.pyplot as plt
raw = _get_raw()
picks = _get_picks(raw)
diff --git a/mne/viz/tests/test_misc.py b/mne/viz/tests/test_misc.py
index 79d5899..6834924 100644
--- a/mne/viz/tests/test_misc.py
+++ b/mne/viz/tests/test_misc.py
@@ -13,9 +13,10 @@ import warnings
import numpy as np
from numpy.testing import assert_raises
-from mne import (io, read_events, read_cov, read_source_spaces, read_evokeds,
+from mne import (read_events, read_cov, read_source_spaces, read_evokeds,
read_dipole, SourceEstimate)
from mne.datasets import testing
+from mne.io import read_raw_fif
from mne.minimum_norm import read_inverse_operator
from mne.viz import (plot_bem, plot_events, plot_source_spectrogram,
plot_snr_estimate)
@@ -29,6 +30,7 @@ warnings.simplefilter('always') # enable b/c these tests throw warnings
data_path = testing.data_path(download=False)
subjects_dir = op.join(data_path, 'subjects')
+src_fname = op.join(subjects_dir, 'sample', 'bem', 'sample-oct-6-src.fif')
inv_fname = op.join(data_path, 'MEG', 'sample',
'sample_audvis_trunc-meg-eeg-oct-4-meg-inv.fif')
evoked_fname = op.join(data_path, 'MEG', 'sample', 'sample_audvis-ave.fif')
@@ -41,16 +43,17 @@ event_fname = op.join(base_dir, 'test-eve.fif')
def _get_raw():
- return io.read_raw_fif(raw_fname, preload=True)
+ """Get raw data."""
+ return read_raw_fif(raw_fname, preload=True, add_eeg_ref=False)
def _get_events():
+ """Get events."""
return read_events(event_fname)
def test_plot_cov():
- """Test plotting of covariances
- """
+ """Test plotting of covariances."""
raw = _get_raw()
cov = read_cov(cov_fname)
with warnings.catch_warnings(record=True): # bad proj
@@ -60,19 +63,22 @@ def test_plot_cov():
@testing.requires_testing_data
@requires_nibabel()
def test_plot_bem():
- """Test plotting of BEM contours
- """
+ """Test plotting of BEM contours."""
assert_raises(IOError, plot_bem, subject='bad-subject',
subjects_dir=subjects_dir)
assert_raises(ValueError, plot_bem, subject='sample',
subjects_dir=subjects_dir, orientation='bad-ori')
plot_bem(subject='sample', subjects_dir=subjects_dir,
orientation='sagittal', slices=[25, 50])
+ plot_bem(subject='sample', subjects_dir=subjects_dir,
+ orientation='coronal', slices=[25, 50],
+ brain_surfaces='white')
+ plot_bem(subject='sample', subjects_dir=subjects_dir,
+ orientation='coronal', slices=[25, 50], src=src_fname)
def test_plot_events():
- """Test plotting events
- """
+ """Test plotting events."""
event_labels = {'aud_l': 1, 'aud_r': 2, 'vis_l': 3, 'vis_r': 4}
color = {1: 'green', 2: 'yellow', 3: 'red', 4: 'c'}
raw = _get_raw()
@@ -97,8 +103,7 @@ def test_plot_events():
@testing.requires_testing_data
def test_plot_source_spectrogram():
- """Test plotting of source spectrogram
- """
+ """Test plotting of source spectrogram."""
sample_src = read_source_spaces(op.join(subjects_dir, 'sample',
'bem', 'sample-oct-6-src.fif'))
@@ -119,8 +124,7 @@ def test_plot_source_spectrogram():
@slow_test
@testing.requires_testing_data
def test_plot_snr():
- """Test plotting SNR estimate
- """
+ """Test plotting SNR estimate."""
inv = read_inverse_operator(inv_fname)
evoked = read_evokeds(evoked_fname, baseline=(None, 0))[0]
plot_snr_estimate(evoked, inv)
@@ -128,8 +132,7 @@ def test_plot_snr():
@testing.requires_testing_data
def test_plot_dipole_amplitudes():
- """Test plotting dipole amplitudes
- """
+ """Test plotting dipole amplitudes."""
dipoles = read_dipole(dip_fname)
dipoles.plot_amplitudes(show=False)
diff --git a/mne/viz/tests/test_raw.py b/mne/viz/tests/test_raw.py
index f6db8d7..7c4641e 100644
--- a/mne/viz/tests/test_raw.py
+++ b/mne/viz/tests/test_raw.py
@@ -2,12 +2,14 @@
#
# License: Simplified BSD
+import numpy as np
import os.path as op
import warnings
-from numpy.testing import assert_raises
+from numpy.testing import assert_raises, assert_equal
-from mne import io, read_events, pick_types, Annotations
+from mne import read_events, pick_types, Annotations
+from mne.io import read_raw_fif
from mne.utils import requires_version, run_tests_if_main
from mne.viz.utils import _fake_click
from mne.viz import plot_raw, plot_sensors
@@ -24,7 +26,8 @@ event_name = op.join(base_dir, 'test-eve.fif')
def _get_raw():
- raw = io.read_raw_fif(raw_fname, preload=True)
+ """Get raw data."""
+ raw = read_raw_fif(raw_fname, preload=True, add_eeg_ref=False)
# Throws a warning about a changed unit.
with warnings.catch_warnings(record=True):
raw.set_channel_types({raw.ch_names[0]: 'ias'})
@@ -34,9 +37,11 @@ def _get_raw():
def _get_events():
+ """Get events."""
return read_events(event_name)
+ at requires_version('matplotlib', '1.2')
def test_plot_raw():
"""Test plotting of raw data."""
import matplotlib.pyplot as plt
@@ -94,6 +99,39 @@ def test_plot_raw():
raw.annotations = annot
fig = plot_raw(raw, events=events, event_color={-1: 'r', 998: 'b'})
plt.close('all')
+ for order in ['position', 'selection', range(len(raw.ch_names))[::-4],
+ [1, 2, 4, 6]]:
+ fig = raw.plot(order=order)
+ x = fig.get_axes()[0].lines[1].get_xdata()[10]
+ y = fig.get_axes()[0].lines[1].get_ydata()[10]
+ _fake_click(fig, data_ax, [x, y], xform='data') # mark bad
+ fig.canvas.key_press_event('down') # change selection
+ _fake_click(fig, fig.get_axes()[2], [0.5, 0.5]) # change channels
+ if order in ('position', 'selection'):
+ sel_fig = plt.figure(1)
+ topo_ax = sel_fig.axes[1]
+ _fake_click(sel_fig, topo_ax, [-0.425, 0.20223853],
+ xform='data')
+ fig.canvas.key_press_event('down')
+ fig.canvas.key_press_event('up')
+ fig.canvas.scroll_event(0.5, 0.5, -1) # scroll down
+ fig.canvas.scroll_event(0.5, 0.5, 1) # scroll up
+ _fake_click(sel_fig, topo_ax, [-0.5, 0.], xform='data')
+ _fake_click(sel_fig, topo_ax, [0.5, 0.], xform='data',
+ kind='motion')
+ _fake_click(sel_fig, topo_ax, [0.5, 0.5], xform='data',
+ kind='motion')
+ _fake_click(sel_fig, topo_ax, [-0.5, 0.5], xform='data',
+ kind='release')
+
+ plt.close('all')
+ # test if meas_date has only one element
+ raw.info['meas_date'] = np.array([raw.info['meas_date'][0]],
+ dtype=np.int32)
+ raw.annotations = Annotations([1 + raw.first_samp / raw.info['sfreq']],
+ [5], ['bad'])
+ raw.plot()
+ plt.close('all')
@requires_version('scipy', '0.10')
@@ -133,16 +171,48 @@ def test_plot_raw_psd():
# topo psd
raw.plot_psd_topo()
plt.close('all')
+ # with a flat channel
+ raw[5, :] = 0
+ assert_raises(ValueError, raw.plot_psd)
+ at requires_version('matplotlib', '1.2')
def test_plot_sensors():
"""Test plotting of sensor array."""
import matplotlib.pyplot as plt
raw = _get_raw()
fig = raw.plot_sensors('3d')
_fake_click(fig, fig.gca(), (-0.08, 0.67))
- raw.plot_sensors('topomap')
+ raw.plot_sensors('topomap', ch_type='mag')
+ ax = plt.subplot(111)
+ raw.plot_sensors(ch_groups='position', axes=ax)
+ raw.plot_sensors(ch_groups='selection')
+ raw.plot_sensors(ch_groups=[[0, 1, 2], [3, 4]])
+ assert_raises(ValueError, raw.plot_sensors, ch_groups='asd')
assert_raises(TypeError, plot_sensors, raw) # needs to be info
+ assert_raises(ValueError, plot_sensors, raw.info, kind='sasaasd')
+ plt.close('all')
+ fig, sels = raw.plot_sensors('select', show_names=True)
+ ax = fig.axes[0]
+
+ # Click with no sensors
+ _fake_click(fig, ax, (0., 0.), xform='data')
+ _fake_click(fig, ax, (0, 0.), xform='data', kind='release')
+ assert_equal(len(fig.lasso.selection), 0)
+
+ # Lasso with 1 sensor
+ _fake_click(fig, ax, (-0.5, 0.5), xform='data')
+ plt.draw()
+ _fake_click(fig, ax, (0., 0.5), xform='data', kind='motion')
+ _fake_click(fig, ax, (0., 0.), xform='data', kind='motion')
+ fig.canvas.key_press_event('control')
+ _fake_click(fig, ax, (-0.5, 0.), xform='data', kind='release')
+ assert_equal(len(fig.lasso.selection), 1)
+
+ _fake_click(fig, ax, (-0.09, -0.43), xform='data') # single selection
+ assert_equal(len(fig.lasso.selection), 2)
+ _fake_click(fig, ax, (-0.09, -0.43), xform='data') # deselect
+ assert_equal(len(fig.lasso.selection), 1)
plt.close('all')
run_tests_if_main()
diff --git a/mne/viz/tests/test_topo.py b/mne/viz/tests/test_topo.py
index b84c2a4..8c60a38 100644
--- a/mne/viz/tests/test_topo.py
+++ b/mne/viz/tests/test_topo.py
@@ -12,9 +12,9 @@ from collections import namedtuple
import numpy as np
from numpy.testing import assert_raises
-from mne import io, read_events, Epochs
-from mne import pick_channels_evoked
+from mne import read_events, Epochs, pick_channels_evoked
from mne.channels import read_layout
+from mne.io import read_raw_fif
from mne.time_frequency.tfr import AverageTFR
from mne.utils import run_tests_if_main
@@ -39,48 +39,58 @@ layout = read_layout('Vectorview-all')
def _get_raw():
- return io.read_raw_fif(raw_fname, preload=False)
+ """Get raw data."""
+ return read_raw_fif(raw_fname, preload=False, add_eeg_ref=False)
def _get_events():
+ """Get events."""
return read_events(event_name)
def _get_picks(raw):
+ """Get picks."""
return [0, 1, 2, 6, 7, 8, 306, 340, 341, 342] # take a only few channels
def _get_epochs():
+ """Get epochs."""
raw = _get_raw()
+ raw.add_proj([], remove_existing=True)
events = _get_events()
picks = _get_picks(raw)
- with warnings.catch_warnings(record=True): # bad proj
- epochs = Epochs(raw, events[:10], event_id, tmin, tmax, picks=picks,
- baseline=(None, 0), verbose='error')
+ # bad proj warning
+ epochs = Epochs(raw, events[:10], event_id, tmin, tmax, picks=picks,
+ baseline=(None, 0), add_eeg_ref=False)
return epochs
def _get_epochs_delayed_ssp():
+ """Get epochs with delayed SSP."""
raw = _get_raw()
events = _get_events()
picks = _get_picks(raw)
reject = dict(mag=4e-12)
- epochs_delayed_ssp = Epochs(raw, events[:10], event_id, tmin, tmax,
- picks=picks, baseline=(None, 0),
- proj='delayed', reject=reject)
+ epochs_delayed_ssp = Epochs(
+ raw, events[:10], event_id, tmin, tmax, picks=picks,
+ baseline=(None, 0), proj='delayed', reject=reject, add_eeg_ref=False)
return epochs_delayed_ssp
def test_plot_topo():
- """Test plotting of ERP topography
- """
+ """Test plotting of ERP topography."""
import matplotlib.pyplot as plt
# Show topography
evoked = _get_epochs().average()
- plot_evoked_topo(evoked) # should auto-find layout
+ # should auto-find layout
+ plot_evoked_topo([evoked, evoked], merge_grads=True)
# Test jointplot
evoked.plot_joint()
- evoked.plot_joint(title='test', ts_args=dict(spatial_colors=True),
+
+ def return_inds(d): # to test function kwarg to zorder arg of evoked.plot
+ return list(range(d.shape[0]))
+ ts_args = dict(spatial_colors=True, zorder=return_inds)
+ evoked.plot_joint(title='test', ts_args=ts_args,
topomap_args=dict(colorbar=True, times=[0.]))
warnings.simplefilter('always', UserWarning)
@@ -124,8 +134,7 @@ def test_plot_topo():
def test_plot_topo_image_epochs():
- """Test plotting of epochs image topography
- """
+ """Test plotting of epochs image topography."""
import matplotlib.pyplot as plt
title = 'ERF images - MNE sample data'
epochs = _get_epochs()
@@ -137,8 +146,7 @@ def test_plot_topo_image_epochs():
def test_plot_tfr_topo():
- """Test plotting of TFR data
- """
+ """Test plotting of TFR data."""
epochs = _get_epochs()
n_freqs = 3
nave = 1
diff --git a/mne/viz/tests/test_topomap.py b/mne/viz/tests/test_topomap.py
index 54c462b..d7ead26 100644
--- a/mne/viz/tests/test_topomap.py
+++ b/mne/viz/tests/test_topomap.py
@@ -14,16 +14,20 @@ from numpy.testing import assert_raises, assert_array_equal
from nose.tools import assert_true, assert_equal
-from mne import io, read_evokeds, read_proj
+from mne import read_evokeds, read_proj
+from mne.io import read_raw_fif
from mne.io.constants import FIFF
+from mne.io.pick import pick_info, channel_indices_by_type
from mne.channels import read_layout, make_eeg_layout
from mne.datasets import testing
from mne.time_frequency.tfr import AverageTFR
-from mne.utils import slow_test
+from mne.utils import slow_test, run_tests_if_main
from mne.viz import plot_evoked_topomap, plot_projs_topomap
-from mne.viz.topomap import (_check_outlines, _onselect, plot_topomap)
-from mne.viz.utils import _find_peaks
+from mne.viz.topomap import (_check_outlines, _onselect, plot_topomap,
+ plot_psds_topomap)
+from mne.viz.utils import _find_peaks, _fake_click
+
# Set our plotters to test mode
import matplotlib
@@ -45,14 +49,14 @@ layout = read_layout('Vectorview-all')
def _get_raw():
- return io.read_raw_fif(raw_fname, preload=False)
+ """Get raw data."""
+ return read_raw_fif(raw_fname, preload=False, add_eeg_ref=False)
@slow_test
@testing.requires_testing_data
def test_plot_topomap():
- """Test topomap plotting
- """
+ """Test topomap plotting."""
import matplotlib.pyplot as plt
from matplotlib.patches import Circle
# evoked
@@ -151,7 +155,7 @@ def test_plot_topomap():
warnings.simplefilter('always')
projs = read_proj(ecg_fname)
projs = [pp for pp in projs if pp['desc'].lower().find('eeg') < 0]
- plot_projs_topomap(projs, res=res)
+ plot_projs_topomap(projs, res=res, colorbar=True)
plt.close('all')
ax = plt.subplot(111)
plot_projs_topomap([projs[0]], res=res, axes=ax) # test axes param
@@ -196,6 +200,27 @@ def test_plot_topomap():
evoked.plot_topomap(times, ch_type='eeg', outlines=outlines)
plt.close('all')
+ # Test interactive cmap
+ fig = plot_evoked_topomap(evoked, times=[0., 0.1], ch_type='eeg',
+ cmap=('Reds', True), title='title')
+ fig.canvas.key_press_event('up')
+ fig.canvas.key_press_event(' ')
+ fig.canvas.key_press_event('down')
+ cbar = fig.get_axes()[0].CB # Fake dragging with mouse.
+ ax = cbar.cbar.ax
+ _fake_click(fig, ax, (0.1, 0.1))
+ _fake_click(fig, ax, (0.1, 0.2), kind='motion')
+ _fake_click(fig, ax, (0.1, 0.3), kind='release')
+
+ _fake_click(fig, ax, (0.1, 0.1), button=3)
+ _fake_click(fig, ax, (0.1, 0.2), button=3, kind='motion')
+ _fake_click(fig, ax, (0.1, 0.3), kind='release')
+
+ fig.canvas.scroll_event(0.5, 0.5, -0.5) # scroll down
+ fig.canvas.scroll_event(0.5, 0.5, 0.5) # scroll up
+
+ plt.close('all')
+
# Pass custom outlines with patch callable
def patch():
return Circle((0.5, 0.4687), radius=.46,
@@ -248,8 +273,7 @@ def test_plot_topomap():
def test_plot_tfr_topomap():
- """Test plotting of TFR data
- """
+ """Test plotting of TFR data."""
import matplotlib as mpl
import matplotlib.pyplot as plt
raw = _get_raw()
@@ -264,14 +288,29 @@ def test_plot_tfr_topomap():
eclick = mpl.backend_bases.MouseEvent('button_press_event',
plt.gcf().canvas, 0, 0, 1)
- eclick.xdata = 0.1
- eclick.ydata = 0.1
+ eclick.xdata = eclick.ydata = 0.1
eclick.inaxes = plt.gca()
erelease = mpl.backend_bases.MouseEvent('button_release_event',
plt.gcf().canvas, 0.9, 0.9, 1)
erelease.xdata = 0.3
erelease.ydata = 0.2
pos = [[0.11, 0.11], [0.25, 0.5], [0.0, 0.2], [0.2, 0.39]]
+ _onselect(eclick, erelease, tfr, pos, 'grad', 1, 3, 1, 3, 'RdBu_r', list())
_onselect(eclick, erelease, tfr, pos, 'mag', 1, 3, 1, 3, 'RdBu_r', list())
+ eclick.xdata = eclick.ydata = 0.
+ erelease.xdata = erelease.ydata = 0.9
tfr._onselect(eclick, erelease, None, 'mean', None)
plt.close('all')
+
+ # test plot_psds_topomap
+ info = raw.info.copy()
+ chan_inds = channel_indices_by_type(info)
+ info = pick_info(info, chan_inds['grad'][:4])
+
+ fig, axes = plt.subplots()
+ freqs = np.arange(3., 9.5)
+ bands = [(4, 8, 'Theta')]
+ psd = np.random.rand(len(info['ch_names']), freqs.shape[0])
+ plot_psds_topomap(psd, freqs, info, bands=bands, axes=[axes])
+
+run_tests_if_main()
diff --git a/mne/viz/tests/test_utils.py b/mne/viz/tests/test_utils.py
index 5f8932c..336661e 100644
--- a/mne/viz/tests/test_utils.py
+++ b/mne/viz/tests/test_utils.py
@@ -8,7 +8,8 @@ import numpy as np
from nose.tools import assert_true, assert_raises
from numpy.testing import assert_allclose
-from mne.viz.utils import compare_fiff, _fake_click, _compute_scalings
+from mne.viz.utils import (compare_fiff, _fake_click, _compute_scalings,
+ _validate_if_list_of_axes)
from mne.viz import ClickableImage, add_background_image, mne_analyze_colormap
from mne.utils import run_tests_if_main
from mne.io import read_raw_fif
@@ -28,8 +29,7 @@ ev_fname = op.join(base_dir, 'test_raw-eve.fif')
def test_mne_analyze_colormap():
- """Test mne_analyze_colormap
- """
+ """Test mne_analyze_colormap."""
assert_raises(ValueError, mne_analyze_colormap, [0])
assert_raises(ValueError, mne_analyze_colormap, [-1, 1, 2])
assert_raises(ValueError, mne_analyze_colormap, [0, 2, 1])
@@ -88,12 +88,15 @@ def test_add_background_image():
for ax in axs:
assert_true(ax.get_aspect() == 'auto')
+ # Make sure passing None as image returns None
+ assert_true(add_background_image(f, None) is None)
+
def test_auto_scale():
"""Test auto-scaling of channels for quick plotting."""
- raw = read_raw_fif(raw_fname, preload=False)
+ raw = read_raw_fif(raw_fname, preload=False, add_eeg_ref=False)
ev = read_events(ev_fname)
- epochs = Epochs(raw, ev)
+ epochs = Epochs(raw, ev, add_eeg_ref=False)
rand_data = np.random.randn(10, 100)
for inst in [raw, epochs]:
@@ -112,9 +115,29 @@ def test_auto_scale():
assert_raises(ValueError, _compute_scalings, scalings_def, rand_data)
epochs = epochs[0].load_data()
- epochs.pick_types(eeg=True, meg=False, copy=False)
+ epochs.pick_types(eeg=True, meg=False)
assert_raises(ValueError, _compute_scalings,
dict(grad='auto'), epochs)
+def test_validate_if_list_of_axes():
+ """Test validation of axes."""
+ import matplotlib.pyplot as plt
+ fig, ax = plt.subplots(2, 2)
+ assert_raises(ValueError, _validate_if_list_of_axes, ax)
+ ax_flat = ax.ravel()
+ ax = ax.ravel().tolist()
+ _validate_if_list_of_axes(ax_flat)
+ _validate_if_list_of_axes(ax_flat, 4)
+ assert_raises(ValueError, _validate_if_list_of_axes, ax_flat, 5)
+ assert_raises(ValueError, _validate_if_list_of_axes, ax, 3)
+ assert_raises(ValueError, _validate_if_list_of_axes, 'error')
+ assert_raises(ValueError, _validate_if_list_of_axes, ['error'] * 2)
+ assert_raises(ValueError, _validate_if_list_of_axes, ax[0])
+ assert_raises(ValueError, _validate_if_list_of_axes, ax, 3)
+ ax_flat[2] = 23
+ assert_raises(ValueError, _validate_if_list_of_axes, ax_flat)
+ _validate_if_list_of_axes(ax, 4)
+
+
run_tests_if_main()
diff --git a/mne/viz/topo.py b/mne/viz/topo.py
index 3bcad97..5da18bf 100644
--- a/mne/viz/topo.py
+++ b/mne/viz/topo.py
@@ -16,12 +16,12 @@ import numpy as np
from ..io.constants import Bunch
from ..io.pick import channel_type, pick_types
-from ..fixes import normalize_colors
from ..utils import _clean_names, warn
from ..channels.layout import _merge_grad_data, _pair_grad_sensors, find_layout
from ..defaults import _handle_default
from .utils import (_check_delayed_ssp, COLORS, _draw_proj_checkbox,
- add_background_image, plt_show, _setup_vmin_vmax)
+ add_background_image, plt_show, _setup_vmin_vmax,
+ DraggableColorbar)
def iter_topography(info, layout=None, on_pick=None, fig=None,
@@ -162,8 +162,7 @@ def _plot_topo(info, times, show_func, click_func=None, layout=None,
fig = plt.figure()
if colorbar:
- norm = normalize_colors(vmin=vmin, vmax=vmax)
- sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
+ sm = plt.cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(vmin, vmax))
sm.set_array(np.linspace(vmin, vmax))
ax = plt.axes([0.015, 0.025, 1.05, .8], axisbg=fig_facecolor)
cb = fig.colorbar(sm, ax=ax)
@@ -235,6 +234,8 @@ def _plot_topo_onpick(event, show_func):
def _compute_scalings(bn, xlim, ylim):
"""Compute scale factors for a unified plot"""
+ if isinstance(ylim[0], (tuple, list, np.ndarray)):
+ ylim = (ylim[0][0], ylim[1][0])
pos = bn.pos
bn.x_s = pos[2] / (xlim[1] - xlim[0])
bn.x_t = pos[0] - bn.x_s * xlim[0]
@@ -249,13 +250,15 @@ def _check_vlim(vlim):
def _imshow_tfr(ax, ch_idx, tmin, tmax, vmin, vmax, onselect, ylim=None,
tfr=None, freq=None, vline=None, x_label=None, y_label=None,
- colorbar=False, picker=True, cmap='RdBu_r', title=None,
+ colorbar=False, picker=True, cmap=('RdBu_r', True), title=None,
hline=None):
""" Aux function to show time-freq map on topo """
import matplotlib.pyplot as plt
from matplotlib.widgets import RectangleSelector
extent = (tmin, tmax, freq[0], freq[-1])
+ cmap, interactive_cmap = cmap
+
img = ax.imshow(tfr[ch_idx], extent=extent, aspect="auto", origin="lower",
vmin=vmin, vmax=vmax, picker=picker, cmap=cmap)
if isinstance(ax, plt.Axes):
@@ -269,7 +272,12 @@ def _imshow_tfr(ax, ch_idx, tmin, tmax, vmin, vmax, onselect, ylim=None,
if y_label is not None:
plt.ylabel(y_label)
if colorbar:
- plt.colorbar(mappable=img)
+ if isinstance(colorbar, DraggableColorbar):
+ cbar = colorbar.cbar # this happens with multiaxes case
+ else:
+ cbar = plt.colorbar(mappable=img)
+ if interactive_cmap:
+ ax.CB = DraggableColorbar(cbar, img)
if title:
plt.title(title)
if not isinstance(ax, plt.Axes):
@@ -464,9 +472,9 @@ def _plot_evoked_topo(evoked, layout=None, layout_scale=0.945, color=None,
The values at which to show a horizontal line.
fig_facecolor : str | obj
The figure face color. Defaults to black.
- fig_background : None | numpy ndarray
- A background image for the figure. This must work with a call to
- plt.imshow. Defaults to None.
+ fig_background : None | array
+ A background image for the figure. This must be a valid input to
+ `matplotlib.pyplot.imshow`. Defaults to None.
axis_facecolor : str | obj
The face color to be used for each sensor plot. Defaults to black.
font_color : str | obj
@@ -582,13 +590,13 @@ def _plot_evoked_topo(evoked, layout=None, layout_scale=0.945, color=None,
else:
ylim_ = zip(*[np.array(yl) for yl in ylim_])
else:
- raise ValueError('ylim must be None ore a dict')
+ raise TypeError('ylim must be None or a dict. Got %s.' % type(ylim))
data = [e.data for e in evoked]
- show_func = partial(_plot_timeseries_unified, data=data,
- color=color, times=times, vline=vline, hline=hline)
- click_func = partial(_plot_timeseries, data=data,
- color=color, times=times, vline=vline, hline=hline)
+ show_func = partial(_plot_timeseries_unified, data=data, color=color,
+ times=times, vline=vline, hline=hline)
+ click_func = partial(_plot_timeseries, data=data, color=color, times=times,
+ vline=vline, hline=hline)
fig = _plot_topo(info=info, times=times, show_func=show_func,
click_func=click_func, layout=layout,
@@ -598,8 +606,7 @@ def _plot_evoked_topo(evoked, layout=None, layout_scale=0.945, color=None,
axis_facecolor=axis_facecolor, title=title,
x_label='Time (s)', y_label=y_label, unified=True)
- if fig_background is not None:
- add_background_image(fig, fig_background)
+ add_background_image(fig, fig_background)
if proj == 'interactive':
for e in evoked:
@@ -634,8 +641,8 @@ def _plot_update_evoked_topo_proj(params, bools):
def plot_topo_image_epochs(epochs, layout=None, sigma=0., vmin=None,
vmax=None, colorbar=True, order=None, cmap='RdBu_r',
layout_scale=.95, title=None, scalings=None,
- border='none', fig_facecolor='k', font_color='w',
- show=True):
+ border='none', fig_facecolor='k',
+ fig_background=None, font_color='w', show=True):
"""Plot Event Related Potential / Fields image on topographies
Parameters
@@ -675,6 +682,9 @@ def plot_topo_image_epochs(epochs, layout=None, sigma=0., vmin=None,
matplotlib borders style to be used for each sensor plot.
fig_facecolor : str | obj
The figure face color. Defaults to black.
+ fig_background : None | array
+ A background image for the figure. This must be a valid input to
+ `matplotlib.pyplot.imshow`. Defaults to None.
font_color : str | obj
The color of tick labels in the colorbar. Defaults to white.
show : bool
@@ -711,5 +721,6 @@ def plot_topo_image_epochs(epochs, layout=None, sigma=0., vmin=None,
fig_facecolor=fig_facecolor, font_color=font_color,
border=border, x_label='Time (s)', y_label='Epoch',
unified=True, img=True)
+ add_background_image(fig, fig_background)
plt_show(show)
return fig
diff --git a/mne/viz/topomap.py b/mne/viz/topomap.py
index 06da62d..ae5225a 100644
--- a/mne/viz/topomap.py
+++ b/mne/viz/topomap.py
@@ -23,7 +23,8 @@ from ..io.pick import (pick_types, _picks_by_type, channel_type, pick_info,
from ..utils import _clean_names, _time_mask, verbose, logger, warn
from .utils import (tight_layout, _setup_vmin_vmax, _prepare_trellis,
_check_delayed_ssp, _draw_proj_checkbox, figure_nobar,
- plt_show, _process_times)
+ plt_show, _process_times, DraggableColorbar,
+ _validate_if_list_of_axes)
from ..time_frequency import psd_multitaper
from ..defaults import _handle_default
from ..channels.layout import _find_topomap_coords
@@ -135,9 +136,16 @@ def plot_projs_topomap(projs, layout=None, cmap=None, sensors=True,
Layout instance specifying sensor positions (does not need to be
specified for Neuromag data). Or a list of Layout if projections
are from different sensor types.
- cmap : matplotlib colormap | None
- Colormap to use. If None, 'Reds' is used for all positive data,
- otherwise defaults to 'RdBu_r'.
+ cmap : matplotlib colormap | (colormap, bool) | 'interactive' | None
+ Colormap to use. If tuple, the first value indicates the colormap to
+ use and the second value is a boolean defining interactivity. In
+ interactive mode (only works if ``colorbar=True``) the colors are
+ adjustable by clicking and dragging the colorbar with left and right
+ mouse button. Left mouse button moves the scale up and down and right
+ mouse button adjusts the range. Hitting space bar resets the range. Up
+ and down arrows can be used to change the colormap. If None (default),
+ 'Reds' is used for all positive data, otherwise defaults to 'RdBu_r'.
+ If 'interactive', translates to (None, True).
sensors : bool | str
Add markers for sensor locations to the plot. Accepts matplotlib plot
format string (e.g., 'r+' for red plusses). If True, a circle will be
@@ -182,7 +190,7 @@ def plot_projs_topomap(projs, layout=None, cmap=None, sensors=True,
.. versionadded:: 0.9.0
"""
import matplotlib.pyplot as plt
-
+ from mpl_toolkits.axes_grid1 import make_axes_locatable
if layout is None:
from ..channels import read_layout
layout = read_layout('Vectorview-all')
@@ -194,6 +202,10 @@ def plot_projs_topomap(projs, layout=None, cmap=None, sensors=True,
nrows = math.floor(math.sqrt(n_projs))
ncols = math.ceil(n_projs / nrows)
+ if cmap == 'interactive':
+ cmap = (None, True)
+ elif not isinstance(cmap, tuple):
+ cmap = (cmap, True)
if axes is None:
plt.figure()
axes = list()
@@ -232,12 +244,16 @@ def plot_projs_topomap(projs, layout=None, cmap=None, sensors=True,
break
if len(idx):
- plot_topomap(data, pos[:, :2], vmax=None, cmap=cmap,
- sensors=sensors, res=res, axes=axes[proj_idx],
- outlines=outlines, contours=contours,
- image_interp=image_interp, show=False)
+ im = plot_topomap(data, pos[:, :2], vmax=None, cmap=cmap[0],
+ sensors=sensors, res=res, axes=axes[proj_idx],
+ outlines=outlines, contours=contours,
+ image_interp=image_interp, show=False)[0]
if colorbar:
- plt.colorbar()
+ divider = make_axes_locatable(axes[proj_idx])
+ cax = divider.append_axes("right", size="5%", pad=0.05)
+ cbar = plt.colorbar(im, cax=cax, cmap=cmap)
+ if cmap[1]:
+ axes[proj_idx].CB = DraggableColorbar(cbar, im)
else:
raise RuntimeError('Cannot find a proper layout for projection %s'
% proj['desc'])
@@ -381,7 +397,7 @@ def plot_topomap(data, pos, vmin=None, vmax=None, cmap=None, sensors=True,
res=64, axes=None, names=None, show_names=False, mask=None,
mask_params=None, outlines='head', image_mask=None,
contours=6, image_interp='bilinear', show=True,
- head_pos=None, onselect=None, axis=None):
+ head_pos=None, onselect=None):
"""Plot a topographic map as image
Parameters
@@ -462,8 +478,6 @@ def plot_topomap(data, pos, vmin=None, vmax=None, cmap=None, sensors=True,
Handle for a function that is called when the user selects a set of
channels by rectangle selection (matplotlib ``RectangleSelector``). If
None interactive selection is disabled. Defaults to None.
- axis : instance of Axes | None
- Deprecated. Will be removed in 0.13. Use ``axes`` instead.
Returns
-------
@@ -540,10 +554,6 @@ def plot_topomap(data, pos, vmin=None, vmax=None, cmap=None, sensors=True,
pos, outlines = _check_outlines(pos, outlines, head_pos)
- if axis is not None:
- axes = axis
- warn('axis parameter is deprecated and will be removed in 0.13. '
- 'Use axes instead.', DeprecationWarning)
ax = axes if axes else plt.gca()
pos_x, pos_y = _prepare_topomap(pos, ax)
if outlines is None:
@@ -709,11 +719,62 @@ def _inside_contour(pos, contour):
return check_mask
+def _plot_ica_topomap(ica, idx=0, ch_type=None, res=64, layout=None,
+ vmin=None, vmax=None, cmap='RdBu_r', colorbar=False,
+ title=None, show=True, outlines='head', contours=6,
+ image_interp='bilinear', head_pos=None, axes=None):
+ """plot single ica map to axes"""
+ import matplotlib as mpl
+ from ..channels import _get_ch_type
+ from ..preprocessing.ica import _get_ica_map
+
+ if ica.info is None:
+ raise RuntimeError('The ICA\'s measurement info is missing. Please '
+ 'fit the ICA or add the corresponding info object.')
+ if not isinstance(axes, mpl.axes.Axes):
+ raise ValueError('axis has to be an instance of matplotlib Axes, '
+ 'got %s instead.' % type(axes))
+ ch_type = _get_ch_type(ica, ch_type)
+
+ data = _get_ica_map(ica, components=idx)
+ data_picks, pos, merge_grads, names, _ = _prepare_topo_plot(
+ ica, ch_type, layout)
+ pos, outlines = _check_outlines(pos, outlines, head_pos)
+ if outlines not in (None, 'head'):
+ image_mask, pos = _make_image_mask(outlines, pos, res)
+ else:
+ image_mask = None
+
+ data = np.atleast_2d(data)
+ data = data[:, data_picks]
+
+ if merge_grads:
+ from ..channels.layout import _merge_grad_data
+ data = _merge_grad_data(data)
+ axes.set_title('IC #%03d' % idx, fontsize=12)
+ vmin_, vmax_ = _setup_vmin_vmax(data, vmin, vmax)
+ im = plot_topomap(data.ravel(), pos, vmin=vmin_, vmax=vmax_,
+ res=res, axes=axes, cmap=cmap, outlines=outlines,
+ image_mask=image_mask, contours=contours,
+ image_interp=image_interp, show=show)[0]
+ if colorbar:
+ import matplotlib.pyplot as plt
+ from mpl_toolkits.axes_grid import make_axes_locatable
+ divider = make_axes_locatable(axes)
+ cax = divider.append_axes("right", size="5%", pad=0.05)
+ cbar = plt.colorbar(im, cax=cax, format='%3.2f', cmap=cmap)
+ cbar.ax.tick_params(labelsize=12)
+ cbar.set_ticks((vmin_, vmax_))
+ cbar.ax.set_title('AU', fontsize=10)
+ _hide_frame(axes)
+
+
def plot_ica_components(ica, picks=None, ch_type=None, res=64,
layout=None, vmin=None, vmax=None, cmap='RdBu_r',
sensors=True, colorbar=False, title=None,
show=True, outlines='head', contours=6,
- image_interp='bilinear', head_pos=None):
+ image_interp='bilinear', head_pos=None,
+ inst=None):
"""Project unmixing matrix on interpolated sensor topogrpahy.
Parameters
@@ -741,8 +802,20 @@ def plot_ica_components(ica, picks=None, ch_type=None, res=64,
The value specifying the upper bound of the color range.
If None, the maximum absolute value is used. If callable, the output
equals vmax(data). Defaults to None.
- cmap : matplotlib colormap
- Colormap.
+ cmap : matplotlib colormap | (colormap, bool) | 'interactive' | None
+ Colormap to use. If tuple, the first value indicates the colormap to
+ use and the second value is a boolean defining interactivity. In
+ interactive mode the colors are adjustable by clicking and dragging the
+ colorbar with left and right mouse button. Left mouse button moves the
+ scale up and down and right mouse button adjusts the range. Hitting
+ space bar resets the range. Up and down arrows can be used to change
+ the colormap. If None, 'Reds' is used for all positive data,
+ otherwise defaults to 'RdBu_r'. If 'interactive', translates to
+ (None, True). Defaults to 'RdBu_r'.
+
+ .. warning:: Interactive mode works smoothly only for a small amount
+ of topomaps.
+
sensors : bool | str
Add markers for sensor locations to the plot. Accepts matplotlib
plot format string (e.g., 'r+' for red plusses). If True, a circle
@@ -774,12 +847,19 @@ def plot_ica_components(ica, picks=None, ch_type=None, res=64,
the head circle. If dict, can have entries 'center' (tuple) and
'scale' (tuple) for what the center and scale of the head should be
relative to the electrode locations.
+ inst : Raw | Epochs | None
+ To be able to see component properties after clicking on component
+ topomap you need to pass relevant data - instances of Raw or Epochs
+ (for example the data that ICA was trained on). This takes effect
+ only when running matplotlib in interactive mode.
Returns
-------
fig : instance of matplotlib.pyplot.Figure or list
The figure object(s).
"""
+ from ..io import _BaseRaw
+ from ..epochs import _BaseEpochs
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid import make_axes_locatable
from ..channels import _get_ch_type
@@ -791,19 +871,24 @@ def plot_ica_components(ica, picks=None, ch_type=None, res=64,
figs = []
for k in range(0, n_components, p):
picks = range(k, min(k + p, n_components))
- fig = plot_ica_components(ica, picks=picks,
- ch_type=ch_type, res=res, layout=layout,
- vmax=vmax, cmap=cmap, sensors=sensors,
+ fig = plot_ica_components(ica, picks=picks, ch_type=ch_type,
+ res=res, layout=layout, vmax=vmax,
+ cmap=cmap, sensors=sensors,
colorbar=colorbar, title=title,
show=show, outlines=outlines,
contours=contours,
- image_interp=image_interp)
+ image_interp=image_interp,
+ head_pos=head_pos, inst=inst)
figs.append(fig)
return figs
elif np.isscalar(picks):
picks = [picks]
ch_type = _get_ch_type(ica, ch_type)
+ if cmap == 'interactive':
+ cmap = ('RdBu_r', True)
+ elif not isinstance(cmap, tuple):
+ cmap = (cmap, False if len(picks) > 2 else True)
data = np.dot(ica.mixing_matrix_[:, picks].T,
ica.pca_components_[:ica.n_components_])
@@ -835,9 +920,10 @@ def plot_ica_components(ica, picks=None, ch_type=None, res=64,
data_ = _merge_grad_data(data_) if merge_grads else data_
vmin_, vmax_ = _setup_vmin_vmax(data_, vmin, vmax)
im = plot_topomap(data_.flatten(), pos, vmin=vmin_, vmax=vmax_,
- res=res, axes=ax, cmap=cmap, outlines=outlines,
+ res=res, axes=ax, cmap=cmap[0], outlines=outlines,
image_mask=image_mask, contours=contours,
image_interp=image_interp, show=False)[0]
+ im.axes.set_label('IC #%03d' % ii)
if colorbar:
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)
@@ -845,10 +931,21 @@ def plot_ica_components(ica, picks=None, ch_type=None, res=64,
cbar.ax.tick_params(labelsize=12)
cbar.set_ticks((vmin_, vmax_))
cbar.ax.set_title('AU', fontsize=10)
+ if cmap[1]:
+ ax.CB = DraggableColorbar(cbar, im)
_hide_frame(ax)
tight_layout(fig=fig)
fig.subplots_adjust(top=0.95)
fig.canvas.draw()
+ if isinstance(inst, (_BaseRaw, _BaseEpochs)):
+ def onclick(event, ica=ica, inst=inst):
+ # check which component to plot
+ label = event.inaxes.get_label()
+ if 'IC #' in label:
+ ic = int(label[4:])
+ ica.plot_properties(inst, picks=ic, show=True)
+ fig.canvas.mpl_connect('button_press_event', onclick)
+
plt_show(show)
return fig
@@ -910,10 +1007,16 @@ def plot_tfr_topomap(tfr, tmin=None, tmax=None, fmin=None, fmax=None,
The value specifying the upper bound of the color range. If None, the
maximum value is used. If callable, the output equals vmax(data).
Defaults to None.
- cmap : matplotlib colormap | None
- Colormap. If None and the plotted data is all positive, defaults to
- 'Reds'. If None and data contains also negative values, defaults to
- 'RdBu_r'. Defaults to None.
+ cmap : matplotlib colormap | (colormap, bool) | 'interactive' | None
+ Colormap to use. If tuple, the first value indicates the colormap to
+ use and the second value is a boolean defining interactivity. In
+ interactive mode the colors are adjustable by clicking and dragging the
+ colorbar with left and right mouse button. Left mouse button moves the
+ scale up and down and right mouse button adjusts the range. Hitting
+ space bar resets the range. Up and down arrows can be used to change
+ the colormap. If None (default), 'Reds' is used for all positive data,
+ otherwise defaults to 'RdBu_r'. If 'interactive', translates to
+ (None, True).
sensors : bool | str
Add markers for sensor locations to the plot. Accepts matplotlib
plot format string (e.g., 'r+' for red plusses). If True, a circle will
@@ -925,7 +1028,8 @@ def plot_tfr_topomap(tfr, tmin=None, tmax=None, fmin=None, fmax=None,
res : int
The resolution of the topomap image (n pixels along each side).
size : float
- Side length per topomap in inches.
+ Side length per topomap in inches (only applies when plotting multiple
+ topomaps at a time).
cbar_fmt : str
String format for colorbar values.
show_names : bool | callable
@@ -1000,8 +1104,10 @@ def plot_tfr_topomap(tfr, tmin=None, tmax=None, fmin=None, fmax=None,
norm = False if np.min(data) < 0 else True
vmin, vmax = _setup_vmin_vmax(data, vmin, vmax, norm)
- if cmap is None:
- cmap = 'Reds' if norm else 'RdBu_r'
+ if cmap is None or cmap == 'interactive':
+ cmap = ('Reds', True) if norm else ('RdBu_r', True)
+ elif not isinstance(cmap, tuple):
+ cmap = (cmap, True)
if axes is None:
fig = plt.figure()
@@ -1017,21 +1123,23 @@ def plot_tfr_topomap(tfr, tmin=None, tmax=None, fmin=None, fmax=None,
fig_wrapper = list()
selection_callback = partial(_onselect, tfr=tfr, pos=pos, ch_type=ch_type,
itmin=itmin, itmax=itmax, ifmin=ifmin,
- ifmax=ifmax, cmap=cmap, fig=fig_wrapper,
+ ifmax=ifmax, cmap=cmap[0], fig=fig_wrapper,
layout=layout)
im, _ = plot_topomap(data[:, 0], pos, vmin=vmin, vmax=vmax,
- axes=ax, cmap=cmap, image_interp='bilinear',
+ axes=ax, cmap=cmap[0], image_interp='bilinear',
contours=False, names=names, show_names=show_names,
show=False, onselect=selection_callback)
if colorbar:
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)
- cbar = plt.colorbar(im, cax=cax, format=cbar_fmt, cmap=cmap)
+ cbar = plt.colorbar(im, cax=cax, format=cbar_fmt, cmap=cmap[0])
cbar.set_ticks((vmin, vmax))
cbar.ax.tick_params(labelsize=12)
cbar.ax.set_title('AU')
+ if cmap[1]:
+ ax.CB = DraggableColorbar(cbar, im)
plt_show(show)
return fig
@@ -1075,9 +1183,20 @@ def plot_evoked_topomap(evoked, times="auto", ch_type=None, layout=None,
The value specifying the upper bound of the color range.
If None, the maximum absolute value is used. If callable, the output
equals vmax(data). Defaults to None.
- cmap : matplotlib colormap | None
- Colormap to use. If None, 'Reds' is used for all positive data,
- otherwise defaults to 'RdBu_r'.
+ cmap : matplotlib colormap | (colormap, bool) | 'interactive' | None
+ Colormap to use. If tuple, the first value indicates the colormap to
+ use and the second value is a boolean defining interactivity. In
+ interactive mode the colors are adjustable by clicking and dragging the
+ colorbar with left and right mouse button. Left mouse button moves the
+ scale up and down and right mouse button adjusts the range. Hitting
+ space bar resets the range. Up and down arrows can be used to change
+ the colormap. If None (default), 'Reds' is used for all positive data,
+ otherwise defaults to 'RdBu_r'. If 'interactive', translates to
+ (None, True).
+
+ .. warning:: Interactive mode works smoothly only for a small amount
+ of topomaps.
+
sensors : bool | str
Add markers for sensor locations to the plot. Accepts matplotlib plot
format string (e.g., 'r+' for red plusses). If True, a circle will be
@@ -1269,10 +1388,18 @@ def plot_evoked_topomap(evoked, times="auto", ch_type=None, layout=None,
else:
image_mask = None
+ vlims = [_setup_vmin_vmax(data[:, i], vmin, vmax, norm=merge_grads)
+ for i in range(len(times))]
+ vmin = np.min(vlims)
+ vmax = np.max(vlims)
+ if cmap == 'interactive':
+ cmap = (None, True)
+ elif not isinstance(cmap, tuple):
+ cmap = (cmap, False if len(times) > 2 else True)
for idx, time in enumerate(times):
tp, cn = plot_topomap(data[:, idx], pos, vmin=vmin, vmax=vmax,
sensors=sensors, res=res, names=names,
- show_names=show_names, cmap=cmap,
+ show_names=show_names, cmap=cmap[0],
mask=mask_[:, idx] if mask is not None else None,
mask_params=mask_params, axes=axes[idx],
outlines=outlines, image_mask=image_mask,
@@ -1304,6 +1431,9 @@ def plot_evoked_topomap(evoked, times="auto", ch_type=None, layout=None,
cax.set_title(unit)
cbar = fig.colorbar(images[-1], ax=cax, cax=cax, format=cbar_fmt)
cbar.set_ticks([cbar.vmin, 0, cbar.vmax])
+ if cmap[1]:
+ for im in images:
+ im.axes.CB = DraggableColorbar(cbar, im)
if proj == 'interactive':
_check_delayed_ssp(evoked)
@@ -1318,8 +1448,8 @@ def plot_evoked_topomap(evoked, times="auto", ch_type=None, layout=None,
return fig
-def _plot_topomap_multi_cbar(data, pos, ax, title=None, unit=None,
- vmin=None, vmax=None, cmap='RdBu_r',
+def _plot_topomap_multi_cbar(data, pos, ax, title=None, unit=None, vmin=None,
+ vmax=None, cmap=None, outlines='head',
colorbar=False, cbar_fmt='%3.3f'):
"""Aux Function"""
import matplotlib.pyplot as plt
@@ -1329,11 +1459,15 @@ def _plot_topomap_multi_cbar(data, pos, ax, title=None, unit=None,
vmin = np.min(data) if vmin is None else vmin
vmax = np.max(data) if vmax is None else vmax
+ if cmap == 'interactive':
+ cmap = (None, True)
+ elif not isinstance(cmap, tuple):
+ cmap = (cmap, True)
if title is not None:
ax.set_title(title, fontsize=10)
im, _ = plot_topomap(data, pos, vmin=vmin, vmax=vmax, axes=ax,
- cmap=cmap, image_interp='bilinear', contours=False,
- show=False)
+ cmap=cmap[0], image_interp='bilinear', contours=False,
+ outlines=outlines, show=False)
if colorbar is True:
divider = make_axes_locatable(ax)
@@ -1343,6 +1477,8 @@ def _plot_topomap_multi_cbar(data, pos, ax, title=None, unit=None,
if unit is not None:
cbar.ax.set_title(unit, fontsize=8)
cbar.ax.tick_params(labelsize=8)
+ if cmap[1]:
+ ax.CB = DraggableColorbar(cbar, im)
@verbose
@@ -1352,7 +1488,8 @@ def plot_epochs_psd_topomap(epochs, bands=None, vmin=None, vmax=None,
normalization='length', ch_type=None, layout=None,
cmap='RdBu_r', agg_fun=None, dB=False, n_jobs=1,
normalize=False, cbar_fmt='%0.3f',
- outlines='head', show=True, verbose=None):
+ outlines='head', axes=None, show=True,
+ verbose=None):
"""Plot the topomap of the power spectral density across epochs
Parameters
@@ -1404,9 +1541,16 @@ def plot_epochs_psd_topomap(epochs, bands=None, vmin=None, vmax=None,
file is inferred from the data; if no appropriate layout file was
found, the layout is automatically generated from the sensor
locations.
- cmap : matplotlib colormap
- Colormap. For magnetometers and eeg defaults to 'RdBu_r', else
- 'Reds'.
+ cmap : matplotlib colormap | (colormap, bool) | 'interactive' | None
+ Colormap to use. If tuple, the first value indicates the colormap to
+ use and the second value is a boolean defining interactivity. In
+ interactive mode the colors are adjustable by clicking and dragging the
+ colorbar with left and right mouse button. Left mouse button moves the
+ scale up and down and right mouse button adjusts the range. Hitting
+ space bar resets the range. Up and down arrows can be used to change
+ the colormap. If None (default), 'Reds' is used for all positive data,
+ otherwise defaults to 'RdBu_r'. If 'interactive', translates to
+ (None, True).
agg_fun : callable
The function used to aggregate over frequencies.
Defaults to np.sum. if normalize is True, else np.mean.
@@ -1432,6 +1576,9 @@ def plot_epochs_psd_topomap(epochs, bands=None, vmin=None, vmax=None,
masking options, either directly or as a function that returns patches
(required for multi-axis plots). If None, nothing will be drawn.
Defaults to 'head'.
+ axes : list of axes | None
+ List of axes to plot consecutive topographies to. If None the axes
+ will be created automatically. Defaults to None.
show : bool
Show figure if True.
verbose : bool, str, int, or None
@@ -1462,13 +1609,13 @@ def plot_epochs_psd_topomap(epochs, bands=None, vmin=None, vmax=None,
return plot_psds_topomap(
psds=psds, freqs=freqs, pos=pos, agg_fun=agg_fun, vmin=vmin,
vmax=vmax, bands=bands, cmap=cmap, dB=dB, normalize=normalize,
- cbar_fmt=cbar_fmt, outlines=outlines, show=show)
+ cbar_fmt=cbar_fmt, outlines=outlines, axes=axes, show=show)
def plot_psds_topomap(
psds, freqs, pos, agg_fun=None, vmin=None, vmax=None, bands=None,
- cmap='RdBu_r', dB=True, normalize=False, cbar_fmt='%0.3f',
- outlines='head', show=True):
+ cmap=None, dB=True, normalize=False, cbar_fmt='%0.3f', outlines='head',
+ axes=None, show=True):
"""Plot spatial maps of PSDs
Parameters
@@ -1497,9 +1644,16 @@ def plot_psds_topomap(
bands = [(0, 4, 'Delta'), (4, 8, 'Theta'), (8, 12, 'Alpha'),
(12, 30, 'Beta'), (30, 45, 'Gamma')]
- cmap : matplotlib colormap
- Colormap. For magnetometers and eeg defaults to 'RdBu_r', else
- 'Reds'.
+ cmap : matplotlib colormap | (colormap, bool) | 'interactive' | None
+ Colormap to use. If tuple, the first value indicates the colormap to
+ use and the second value is a boolean defining interactivity. In
+ interactive mode the colors are adjustable by clicking and dragging the
+ colorbar with left and right mouse button. Left mouse button moves the
+ scale up and down and right mouse button adjusts the range. Hitting
+ space bar resets the range. Up and down arrows can be used to change
+ the colormap. If None (default), 'Reds' is used for all positive data,
+ otherwise defaults to 'RdBu_r'. If 'interactive', translates to
+ (None, True).
dB : bool
If True, transform data to decibels (with ``10 * np.log10(data)``)
following the application of `agg_fun`. Only valid if normalize is
@@ -1520,6 +1674,9 @@ def plot_psds_topomap(
masking options, either directly or as a function that returns patches
(required for multi-axis plots). If None, nothing will be drawn.
Defaults to 'head'.
+ axes : list of axes | None
+ List of axes to plot consecutive topographies to. If None the axes
+ will be created automatically. Defaults to None.
show : bool
Show figure if True.
@@ -1543,9 +1700,13 @@ def plot_psds_topomap(
assert np.allclose(psds.sum(axis=-1), 1.)
n_axes = len(bands)
- fig, axes = plt.subplots(1, n_axes, figsize=(2 * n_axes, 1.5))
- if n_axes == 1:
- axes = [axes]
+ if axes is not None:
+ _validate_if_list_of_axes(axes, n_axes)
+ fig = axes[0].figure
+ else:
+ fig, axes = plt.subplots(1, n_axes, figsize=(2 * n_axes, 1.5))
+ if n_axes == 1:
+ axes = [axes]
for ax, (fmin, fmax, title) in zip(axes, bands):
freq_mask = (fmin < freqs) & (freqs < fmax)
@@ -1559,8 +1720,8 @@ def plot_psds_topomap(
else:
unit = 'power'
- _plot_topomap_multi_cbar(data, pos, ax, title=title,
- vmin=vmin, vmax=vmax, cmap=cmap,
+ _plot_topomap_multi_cbar(data, pos, ax, title=title, vmin=vmin,
+ vmax=vmax, cmap=cmap, outlines=outlines,
colorbar=True, unit=unit, cbar_fmt=cbar_fmt)
tight_layout(fig=fig)
fig.canvas.draw()
@@ -1585,9 +1746,7 @@ def plot_layout(layout, show=True):
Notes
-----
-
.. versionadded:: 0.12.0
-
"""
import matplotlib.pyplot as plt
fig = plt.figure()
@@ -1632,7 +1791,6 @@ def _onselect(eclick, erelease, tfr, pos, ch_type, itmin, itmax, ifmin, ifmax,
data = np.mean(data[indices, ifmin:ifmax, itmin:itmax], axis=0)
chs = [tfr.ch_names[picks[x]] for x in indices]
elif ch_type == 'grad':
- picks = pick_types(tfr.info, meg=ch_type, ref_meg=False)
from ..channels.layout import _pair_grad_sensors
grads = _pair_grad_sensors(tfr.info, layout=layout,
topomap_coords=False)
@@ -1652,8 +1810,14 @@ def _onselect(eclick, erelease, tfr, pos, ch_type, itmin, itmax, ifmin, ifmax,
if not plt.fignum_exists(fig[0].number):
fig[0] = figure_nobar()
ax = fig[0].add_subplot(111)
- itmax = min(itmax, len(tfr.times) - 1)
- ifmax = min(ifmax, len(tfr.freqs) - 1)
+ itmax = len(tfr.times) - 1 if itmax is None else min(itmax,
+ len(tfr.times) - 1)
+ ifmax = len(tfr.freqs) - 1 if ifmax is None else min(ifmax,
+ len(tfr.freqs) - 1)
+ if itmin is None:
+ itmin = 0
+ if ifmin is None:
+ ifmin = 0
extent = (tfr.times[itmin] * 1e3, tfr.times[itmax] * 1e3, tfr.freqs[ifmin],
tfr.freqs[ifmax])
@@ -1685,8 +1849,9 @@ def _prepare_topomap(pos, ax):
def _hide_frame(ax):
"""Helper to hide axis frame for topomaps."""
- ax.set_xticks([])
- ax.set_yticks([])
+ ax.get_yticks()
+ ax.xaxis.set_ticks([])
+ ax.yaxis.set_ticks([])
ax.set_frame_on(False)
@@ -1894,7 +2059,7 @@ def _topomap_animation(evoked, ch_type='mag', times=None, frame_rate=None,
raise ValueError('All times must be inside the evoked time series.')
frames = [np.abs(evoked.times - time).argmin() for time in times]
- blit = False if plt.get_backend() == 'MacOSX' else True
+ blit = False if plt.get_backend() == 'MacOSX' else blit
picks, pos, merge_grads, _, ch_type = _prepare_topo_plot(evoked,
ch_type=ch_type,
layout=None)
@@ -1932,8 +2097,7 @@ def _topomap_animation(evoked, ch_type='mag', times=None, frame_rate=None,
frames=len(frames), interval=interval,
blit=blit)
fig.mne_animation = anim # to make sure anim is not garbage collected
- if show:
- plt.show()
+ plt_show(show, block=False)
if 'line' in params:
# Finally remove the vertical line so it does not appear in saved fig.
params['line'].remove()
diff --git a/mne/viz/utils.py b/mne/viz/utils.py
index 8eeb7a1..1e740dc 100644
--- a/mne/viz/utils.py
+++ b/mne/viz/utils.py
@@ -17,15 +17,17 @@ import webbrowser
import tempfile
import numpy as np
from copy import deepcopy
+from distutils.version import LooseVersion
from ..channels.layout import _auto_topomap_coords
from ..channels.channels import _contains_ch_type
from ..defaults import _handle_default
from ..io import show_fiff, Info
-from ..io.pick import channel_type, channel_indices_by_type
+from ..io.pick import channel_type, channel_indices_by_type, pick_channels
from ..utils import verbose, set_config, warn
from ..externals.six import string_types
-from ..fixes import _get_argrelmax
+from ..selection import (read_selection, _SELECTIONS, _EEG_SELECTIONS,
+ _divide_to_regions)
COLORS = ['b', 'g', 'r', 'c', 'm', 'y', 'k', '#473C8B', '#458B74',
@@ -107,6 +109,32 @@ def _check_delayed_ssp(container):
raise RuntimeError('No projs found in evoked.')
+def _validate_if_list_of_axes(axes, obligatory_len=None):
+ """ Helper function that validates whether input is a list/array of axes"""
+ import matplotlib as mpl
+ if obligatory_len is not None and not isinstance(obligatory_len, int):
+ raise ValueError('obligatory_len must be None or int, got %d',
+ 'instead' % type(obligatory_len))
+ if not isinstance(axes, (list, np.ndarray)):
+ raise ValueError('axes must be a list or numpy array of matplotlib '
+ 'axes objects, got %s instead.' % type(axes))
+ if isinstance(axes, np.ndarray) and axes.ndim > 1:
+ raise ValueError('if input is a numpy array, it must be '
+ 'one-dimensional. The received numpy array has %d '
+ 'dimensions however. Try using ravel or flatten '
+ 'method of the array.' % axes.ndim)
+ is_correct_type = np.array([isinstance(x, mpl.axes.Axes)
+ for x in axes])
+ if not np.all(is_correct_type):
+ first_bad = np.where(np.logical_not(is_correct_type))[0][0]
+ raise ValueError('axes must be a list or numpy array of matplotlib '
+ 'axes objects while one of the list elements is '
+ '%s.' % type(axes[first_bad]))
+ if obligatory_len is not None and not len(axes) == obligatory_len:
+ raise ValueError('axes must be a list/array of length %d, while the'
+ ' length is %d' % (obligatory_len, len(axes)))
+
+
def mne_analyze_colormap(limits=[5, 10, 15], format='mayavi'):
"""Return a colormap similar to that used by mne_analyze
@@ -243,8 +271,9 @@ def _get_help_text(params):
text.append(u'+ or = : \n')
text.append(u'Home : \n')
text.append(u'End : \n')
- text.append(u'Page down : \n')
- text.append(u'Page up : \n')
+ if 'fig_selection' not in params:
+ text.append(u'Page down : \n')
+ text.append(u'Page up : \n')
text.append(u'F11 : \n')
text.append(u'? : \n')
@@ -278,8 +307,9 @@ def _get_help_text(params):
text.append(u'click channel name :\n')
text2.insert(2, 'Navigate channels down\n')
text2.insert(3, 'Navigate channels up\n')
- text2.insert(8, 'Reduce the number of channels per view\n')
- text2.insert(9, 'Increase the number of channels per view\n')
+ if 'fig_selection' not in params:
+ text2.insert(8, 'Reduce the number of channels per view\n')
+ text2.insert(9, 'Increase the number of channels per view\n')
text2.append('Mark bad channel\n')
text2.append('Vertical line at a time instant\n')
text2.append('Mark bad channel\n')
@@ -503,14 +533,17 @@ def figure_nobar(*args, **kwargs):
def _helper_raw_resize(event, params):
"""Helper for resizing"""
size = ','.join([str(s) for s in params['fig'].get_size_inches()])
- set_config('MNE_BROWSE_RAW_SIZE', size)
+ set_config('MNE_BROWSE_RAW_SIZE', size, set_env=False)
_layout_figure(params)
def _plot_raw_onscroll(event, params, len_channels=None):
"""Interpret scroll events"""
+ if 'fig_selection' in params:
+ _change_channel_group(event.step, params)
+ return
if len_channels is None:
- len_channels = len(params['info']['ch_names'])
+ len_channels = len(params['inds'])
orig_start = params['ch_start']
if event.step < 0:
params['ch_start'] = min(params['ch_start'] + params['n_channels'],
@@ -543,17 +576,96 @@ def _plot_raw_time(value, params):
params['hsel_patch'].set_x(value)
+def _radio_clicked(label, params):
+ """Callback for radio buttons in selection dialog."""
+ from .evoked import _rgb
+ labels = [l._text for l in params['fig_selection'].radio.labels]
+ idx = labels.index(label)
+ params['fig_selection'].radio._active_idx = idx
+ channels = params['selections'][label]
+ ax_topo = params['fig_selection'].get_axes()[1]
+ types = np.array([], dtype=int)
+ for this_type in ('mag', 'grad', 'eeg', 'seeg', 'ecog', 'hbo', 'hbr'):
+ if this_type in params['types']:
+ types = np.concatenate(
+ [types, np.where(np.array(params['types']) == this_type)[0]])
+ colors = np.zeros((len(types), 4)) # alpha = 0 by default
+ locs3d = np.array([ch['loc'][:3] for ch in params['info']['chs']])
+ x, y, z = locs3d.T
+ color_vals = _rgb(params['info'], x, y, z)
+ for color_idx, pick in enumerate(types):
+ if pick in channels: # set color and alpha = 1
+ colors[color_idx] = np.append(color_vals[pick], 1.)
+ ax_topo.collections[0]._facecolors = colors
+ params['fig_selection'].canvas.draw()
+
+ nchan = sum([len(params['selections'][l]) for l in labels[:idx]])
+ params['vsel_patch'].set_y(nchan)
+ n_channels = len(channels)
+ params['n_channels'] = n_channels
+ params['inds'] = channels
+ for line in params['lines'][n_channels:]: # To remove lines from view.
+ line.set_xdata([])
+ line.set_ydata([])
+ if n_channels > 0: # Can be 0 with lasso selector.
+ _setup_browser_offsets(params, n_channels)
+ params['plot_fun']()
+
+
+def _set_radio_button(idx, params):
+ """Helper for setting radio button."""
+ # XXX: New version of matplotlib has this implemented for radio buttons,
+ # This function is for compatibility with old versions of mpl.
+ radio = params['fig_selection'].radio
+ radio.circles[radio._active_idx].set_facecolor((1., 1., 1., 1.))
+ radio.circles[idx].set_facecolor((0., 0., 1., 1.))
+ _radio_clicked(radio.labels[idx]._text, params)
+
+
+def _change_channel_group(step, params):
+ """Deal with change of channel group."""
+ radio = params['fig_selection'].radio
+ idx = radio._active_idx
+ if step < 0:
+ if idx < len(radio.labels) - 1:
+ _set_radio_button(idx + 1, params)
+ else:
+ if idx > 0:
+ _set_radio_button(idx - 1, params)
+ return
+
+
+def _handle_change_selection(event, params):
+ """Helper for handling clicks on vertical scrollbar using selections."""
+ radio = params['fig_selection'].radio
+ ydata = event.ydata
+ labels = [label._text for label in radio.labels]
+ offset = 0
+ for idx, label in enumerate(labels):
+ nchans = len(params['selections'][label])
+ offset += nchans
+ if ydata < offset:
+ _set_radio_button(idx, params)
+ return
+
+
def _plot_raw_onkey(event, params):
"""Interpret key presses"""
import matplotlib.pyplot as plt
if event.key == 'escape':
plt.close(params['fig'])
elif event.key == 'down':
+ if 'fig_selection' in params.keys():
+ _change_channel_group(-1, params)
+ return
params['ch_start'] += params['n_channels']
- _channels_changed(params, len(params['info']['ch_names']))
+ _channels_changed(params, len(params['inds']))
elif event.key == 'up':
+ if 'fig_selection' in params.keys():
+ _change_channel_group(1, params)
+ return
params['ch_start'] -= params['n_channels']
- _channels_changed(params, len(params['info']['ch_names']))
+ _channels_changed(params, len(params['inds']))
elif event.key == 'right':
value = params['t_start'] + params['duration']
_plot_raw_time(value, params)
@@ -570,11 +682,11 @@ def _plot_raw_onkey(event, params):
elif event.key == '-':
params['scale_factor'] /= 1.1
params['plot_fun']()
- elif event.key == 'pageup':
+ elif event.key == 'pageup' and 'fig_selection' not in params:
n_channels = params['n_channels'] + 1
_setup_browser_offsets(params, n_channels)
- _channels_changed(params, len(params['info']['ch_names']))
- elif event.key == 'pagedown':
+ _channels_changed(params, len(params['inds']))
+ elif event.key == 'pagedown' and 'fig_selection' not in params:
n_channels = params['n_channels'] - 1
if n_channels == 0:
return
@@ -582,7 +694,7 @@ def _plot_raw_onkey(event, params):
if len(params['lines']) > n_channels: # remove line from view
params['lines'][n_channels].set_xdata([])
params['lines'][n_channels].set_ydata([])
- _channels_changed(params, len(params['info']['ch_names']))
+ _channels_changed(params, len(params['inds']))
elif event.key == 'home':
duration = params['duration'] - 1.0
if duration <= 0:
@@ -621,10 +733,13 @@ def _mouse_click(event, params):
params['label_click_fun'](pos)
# vertical scrollbar changed
if event.inaxes == params['ax_vscroll']:
- ch_start = max(int(event.ydata) - params['n_channels'] // 2, 0)
- if params['ch_start'] != ch_start:
- params['ch_start'] = ch_start
- params['plot_fun']()
+ if 'fig_selection' in params.keys():
+ _handle_change_selection(event, params)
+ else:
+ ch_start = max(int(event.ydata) - params['n_channels'] // 2, 0)
+ if params['ch_start'] != ch_start:
+ params['ch_start'] = ch_start
+ params['plot_fun']()
# horizontal scrollbar changed
elif event.inaxes == params['ax_hscroll']:
_plot_raw_time(event.xdata - params['duration'] / 2, params)
@@ -635,6 +750,40 @@ def _mouse_click(event, params):
params['pick_bads_fun'](event)
+def _handle_topomap_bads(ch_name, params):
+ """Helper for coloring channels in selection topomap when selecting bads"""
+ for type in ('mag', 'grad', 'eeg', 'seeg', 'hbo', 'hbr'):
+ if type in params['types']:
+ types = np.where(np.array(params['types']) == type)[0]
+ break
+ color_ind = np.where(np.array(
+ params['info']['ch_names'])[types] == ch_name)[0]
+ if len(color_ind) > 0:
+ sensors = params['fig_selection'].axes[1].collections[0]
+ this_color = sensors._edgecolors[color_ind][0]
+ if all(this_color == [1., 0., 0., 1.]): # is red
+ sensors._edgecolors[color_ind] = [0., 0., 0., 1.]
+ else: # is black
+ sensors._edgecolors[color_ind] = [1., 0., 0., 1.]
+ params['fig_selection'].canvas.draw()
+
+
+def _find_channel_idx(ch_name, params):
+ """Helper for finding all indices when using selections."""
+ indices = list()
+ offset = 0
+ labels = [l._text for l in params['fig_selection'].radio.labels]
+ for label in labels:
+ if label == 'Custom':
+ continue # Custom selection not included as it shifts the indices.
+ selection = params['selections'][label]
+ hits = np.where(np.array(params['raw'].ch_names)[selection] == ch_name)
+ for idx in hits[0]:
+ indices.append(offset + idx)
+ offset += len(selection)
+ return indices
+
+
def _select_bads(event, params, bads):
"""Helper for selecting bad channels onpick. Returns updated bads list."""
# trade-off, avoid selecting more than one channel when drifts are present
@@ -649,7 +798,12 @@ def _select_bads(event, params, bads):
if ymin <= event.ydata <= ymax:
this_chan = vars(line)['ch_name']
if this_chan in params['info']['ch_names']:
- ch_idx = params['ch_start'] + lines.index(line)
+ if 'fig_selection' in params:
+ ch_idx = _find_channel_idx(this_chan, params)
+ _handle_topomap_bads(this_chan, params)
+ else:
+ ch_idx = [params['ch_start'] + lines.index(line)]
+
if this_chan not in bads:
bads.append(this_chan)
color = params['bad_color']
@@ -660,13 +814,15 @@ def _select_bads(event, params, bads):
color = vars(line)['def_color']
line.set_zorder(0)
line.set_color(color)
- params['ax_vscroll'].patches[ch_idx].set_color(color)
+ for idx in ch_idx:
+ params['ax_vscroll'].patches[idx].set_color(color)
break
else:
x = np.array([event.xdata] * 2)
params['ax_vertline'].set_data(x, np.array(params['ax'].get_ylim()))
params['ax_hscroll_vertline'].set_data(x, np.array([0., 1.]))
params['vertline_t'].set_text('%0.3f' % x[0])
+
return bads
@@ -712,6 +868,8 @@ def _setup_browser_offsets(params, n_channels):
params['ax'].set_yticks(params['offsets'])
params['ax'].set_ylim(ylim)
params['vsel_patch'].set_height(n_channels)
+ line = params['ax_vertline']
+ line.set_data(line._x, np.array(params['ax'].get_ylim()))
class ClickableImage(object):
@@ -729,7 +887,7 @@ class ClickableImage(object):
Parameters
----------
- imdata: ndarray
+ imdata : ndarray
The image that you wish to click on for 2-d points.
**kwargs : dict
Keyword arguments. Passed to ax.imshow.
@@ -761,7 +919,7 @@ class ClickableImage(object):
Parameters
----------
- event: matplotlib event object
+ event : matplotlib event object
The matplotlib object that we use to get x/y position.
"""
mouseevent = event.mouseevent
@@ -804,7 +962,7 @@ class ClickableImage(object):
return lt
-def _fake_click(fig, ax, point, xform='ax', button=1):
+def _fake_click(fig, ax, point, xform='ax', button=1, kind='press'):
"""Helper to fake a click at a relative point within axes."""
if xform == 'ax':
x, y = ax.transAxes.transform_point(point)
@@ -812,10 +970,17 @@ def _fake_click(fig, ax, point, xform='ax', button=1):
x, y = ax.transData.transform_point(point)
else:
raise ValueError('unknown transform')
+ if kind == 'press':
+ func = partial(fig.canvas.button_press_event, x=x, y=y, button=button)
+ elif kind == 'release':
+ func = partial(fig.canvas.button_release_event, x=x, y=y,
+ button=button)
+ elif kind == 'motion':
+ func = partial(fig.canvas.motion_notify_event, x=x, y=y)
try:
- fig.canvas.button_press_event(x, y, button, False, None)
+ func(guiEvent=None)
except Exception: # for old MPL
- fig.canvas.button_press_event(x, y, button, False)
+ func()
def add_background_image(fig, im, set_ratios=None):
@@ -831,20 +996,19 @@ def add_background_image(fig, im, set_ratios=None):
Parameters
----------
- fig: plt.figure
+ fig : plt.figure
The figure you wish to add a bg image to.
- im: ndarray
- A numpy array that works with a call to
- plt.imshow(im). This will be plotted
- as the background of the figure.
- set_ratios: None | str
+ im : array, shape (M, N, {3, 4})
+ A background image for the figure. This must be a valid input to
+ `matplotlib.pyplot.imshow`. Defaults to None.
+ set_ratios : None | str
Set the aspect ratio of any axes in fig
to the value in set_ratios. Defaults to None,
which does nothing to axes.
Returns
-------
- ax_im: instance of the create matplotlib axis object
+ ax_im : instance of the created matplotlib axis object
corresponding to the image you added.
Notes
@@ -852,11 +1016,14 @@ def add_background_image(fig, im, set_ratios=None):
.. versionadded:: 0.9.0
"""
+ if im is None:
+ # Don't do anything and return nothing
+ return None
if set_ratios is not None:
for ax in fig.axes:
ax.set_aspect(set_ratios)
- ax_im = fig.add_axes([0, 0, 1, 1])
+ ax_im = fig.add_axes([0, 0, 1, 1], label='background')
ax_im.imshow(im, aspect='auto')
ax_im.set_zorder(-1)
return ax_im
@@ -866,7 +1033,7 @@ def _find_peaks(evoked, npeaks):
"""Helper function for finding peaks from evoked data
Returns ``npeaks`` biggest peaks as a list of time points.
"""
- argrelmax = _get_argrelmax()
+ from scipy.signal import argrelmax
gfp = evoked.data.std(axis=0)
order = len(evoked.times) // 30
if order < 1:
@@ -911,7 +1078,8 @@ def _process_times(inst, times, n_peaks=None, few=False):
def plot_sensors(info, kind='topomap', ch_type=None, title=None,
- show_names=False, show=True):
+ show_names=False, ch_groups=None, axes=None, block=False,
+ show=True):
"""Plot sensors positions.
Parameters
@@ -919,16 +1087,42 @@ def plot_sensors(info, kind='topomap', ch_type=None, title=None,
info : Instance of Info
Info structure containing the channel locations.
kind : str
- Whether to plot the sensors as 3d or as topomap. Available options
- 'topomap', '3d'. Defaults to 'topomap'.
- ch_type : 'mag' | 'grad' | 'eeg' | 'seeg' | None
- The channel type to plot. If None, then channels are chosen in the
- order given above.
+ Whether to plot the sensors as 3d, topomap or as an interactive
+ sensor selection dialog. Available options 'topomap', '3d', 'select'.
+ If 'select', a set of channels can be selected interactively by using
+ lasso selector or clicking while holding control key. The selected
+ channels are returned along with the figure instance. Defaults to
+ 'topomap'.
+ ch_type : None | str
+ The channel type to plot. Available options 'mag', 'grad', 'eeg',
+ 'seeg', 'ecog', 'all'. If ``'all'``, all the available mag, grad, eeg,
+ seeg and ecog channels are plotted. If None (default), then channels
+ are chosen in the order given above.
title : str | None
Title for the figure. If None (default), equals to
``'Sensor positions (%s)' % ch_type``.
show_names : bool
Whether to display all channel names. Defaults to False.
+ ch_groups : 'position' | array of shape (ch_groups, picks) | None
+ Channel groups for coloring the sensors. If None (default), default
+ coloring scheme is used. If 'position', the sensors are divided
+ into 8 regions. See ``order`` kwarg of :func:`mne.viz.plot_raw`. If
+ array, the channels are divided by picks given in the array.
+
+ .. versionadded:: 0.13.0
+
+ axes : instance of Axes | instance of Axes3D | None
+ Axes to draw the sensors to. If ``kind='3d'``, axes must be an instance
+ of Axes3D. If None (default), a new axes will be created.
+
+ .. versionadded:: 0.13.0
+
+ block : bool
+ Whether to halt program execution until the figure is closed. Defaults
+ to False.
+
+ .. versionadded:: 0.13.0
+
show : bool
Show figure if True. Defaults to True.
@@ -936,6 +1130,8 @@ def plot_sensors(info, kind='topomap', ch_type=None, title=None,
-------
fig : instance of matplotlib figure
Figure containing the sensor topography.
+ selection : list
+ A list of selected channels. Only returned if ``kind=='select'``.
See Also
--------
@@ -950,40 +1146,96 @@ def plot_sensors(info, kind='topomap', ch_type=None, title=None,
.. versionadded:: 0.12.0
"""
- if kind not in ['topomap', '3d']:
- raise ValueError("Kind must be 'topomap' or '3d'.")
+ from .evoked import _rgb
+ if kind not in ['topomap', '3d', 'select']:
+ raise ValueError("Kind must be 'topomap', '3d' or 'select'. Got %s." %
+ kind)
if not isinstance(info, Info):
raise TypeError('info must be an instance of Info not %s' % type(info))
ch_indices = channel_indices_by_type(info)
- allowed_types = ['mag', 'grad', 'eeg', 'seeg']
+ allowed_types = ['mag', 'grad', 'eeg', 'seeg', 'ecog']
if ch_type is None:
for this_type in allowed_types:
if _contains_ch_type(info, this_type):
ch_type = this_type
break
- elif ch_type not in allowed_types:
+ picks = ch_indices[ch_type]
+ elif ch_type == 'all':
+ picks = list()
+ for this_type in allowed_types:
+ picks += ch_indices[this_type]
+ elif ch_type in allowed_types:
+ picks = ch_indices[ch_type]
+ else:
raise ValueError("ch_type must be one of %s not %s!" % (allowed_types,
ch_type))
- picks = ch_indices[ch_type]
- if kind == 'topomap':
- pos = _auto_topomap_coords(info, picks, True)
- else:
- pos = np.asarray([ch['loc'][:3] for ch in info['chs']])[picks]
- def_colors = _handle_default('color')
+
+ if len(picks) == 0:
+ raise ValueError('Could not find any channels of type %s.' % ch_type)
+
+ pos = np.asarray([ch['loc'][:3] for ch in info['chs']])[picks]
ch_names = np.array(info['ch_names'])[picks]
bads = [idx for idx, name in enumerate(ch_names) if name in info['bads']]
- colors = ['red' if i in bads else def_colors[channel_type(info, pick)]
- for i, pick in enumerate(picks)]
- title = 'Sensor positions (%s)' % ch_type if title is None else title
- fig = _plot_sensors(pos, colors, ch_names, title, show_names, show)
+ if ch_groups is None:
+ def_colors = _handle_default('color')
+ colors = ['red' if i in bads else def_colors[channel_type(info, pick)]
+ for i, pick in enumerate(picks)]
+ else:
+ if ch_groups in ['position', 'selection']:
+ if ch_groups == 'position':
+ ch_groups = _divide_to_regions(info, add_stim=False)
+ ch_groups = list(ch_groups.values())
+ else:
+ ch_groups, color_vals = list(), list()
+ for selection in _SELECTIONS + _EEG_SELECTIONS:
+ channels = pick_channels(
+ info['ch_names'], read_selection(selection, info=info))
+ ch_groups.append(channels)
+ color_vals = np.ones((len(ch_groups), 4))
+ for idx, ch_group in enumerate(ch_groups):
+ color_picks = [np.where(picks == ch)[0][0] for ch in ch_group
+ if ch in picks]
+ if len(color_picks) == 0:
+ continue
+ x, y, z = pos[color_picks].T
+ color = np.mean(_rgb(info, x, y, z), axis=0)
+ color_vals[idx, :3] = color # mean of spatial color
+ else:
+ import matplotlib.pyplot as plt
+ colors = np.linspace(0, 1, len(ch_groups))
+ color_vals = [plt.cm.jet(colors[i]) for i in range(len(ch_groups))]
+ if not isinstance(ch_groups, (np.ndarray, list)):
+ raise ValueError("ch_groups must be None, 'position', "
+ "'selection', or an array. Got %s." % ch_groups)
+ colors = np.zeros((len(picks), 4))
+ for pick_idx, pick in enumerate(picks):
+ for ind, value in enumerate(ch_groups):
+ if pick in value:
+ colors[pick_idx] = color_vals[ind]
+ break
+ if kind in ('topomap', 'select'):
+ pos = _auto_topomap_coords(info, picks, True)
+ title = 'Sensor positions (%s)' % ch_type if title is None else title
+ fig = _plot_sensors(pos, colors, bads, ch_names, title, show_names, axes,
+ show, kind == 'select', block=block)
+ if kind == 'select':
+ return fig, fig.lasso.selection
return fig
-def _onpick_sensor(event, fig, ax, pos, ch_names):
+def _onpick_sensor(event, fig, ax, pos, ch_names, show_names):
"""Callback for picked channel in plot_sensors."""
+ if event.mouseevent.key == 'control' and fig.lasso is not None:
+ for ind in event.ind:
+ fig.lasso.select_one(ind)
+
+ return
+ if show_names:
+ return # channel names already visible
ind = event.ind[0] # Just take the first sensor.
ch_name = ch_names[ind]
+
this_pos = pos[ind]
# XXX: Bug in matplotlib won't allow setting the position of existing
@@ -996,22 +1248,36 @@ def _onpick_sensor(event, fig, ax, pos, ch_names):
fig.canvas.draw()
-def _plot_sensors(pos, colors, ch_names, title, show_names, show):
+def _close_event(event, fig):
+ fig.lasso.disconnect()
+
+
+def _plot_sensors(pos, colors, bads, ch_names, title, show_names, ax, show,
+ select, block):
"""Helper function for plotting sensors."""
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from .topomap import _check_outlines, _draw_outlines
- fig = plt.figure()
+ edgecolors = np.repeat('black', len(colors))
+ edgecolors[bads] = 'red'
+ if ax is None:
+ fig = plt.figure()
+ if pos.shape[1] == 3:
+ Axes3D(fig)
+ ax = fig.gca(projection='3d')
+ else:
+ ax = fig.add_subplot(111)
+ else:
+ fig = ax.get_figure()
if pos.shape[1] == 3:
- ax = Axes3D(fig)
- ax = fig.gca(projection='3d')
ax.text(0, 0, 0, '', zorder=1)
- ax.scatter(pos[:, 0], pos[:, 1], pos[:, 2], picker=True, c=colors)
+ ax.scatter(pos[:, 0], pos[:, 1], pos[:, 2], picker=True, c=colors,
+ s=75, edgecolor=edgecolors, linewidth=2)
+
ax.azim = 90
ax.elev = 0
else:
- ax = fig.add_subplot(111)
ax.text(0, 0, '', zorder=1)
ax.set_xticks([])
ax.set_yticks([])
@@ -1019,8 +1285,13 @@ def _plot_sensors(pos, colors, ch_names, title, show_names, show):
hspace=None)
pos, outlines = _check_outlines(pos, 'head')
_draw_outlines(ax, outlines)
- ax.scatter(pos[:, 0], pos[:, 1], picker=True, c=colors)
+ pts = ax.scatter(pos[:, 0], pos[:, 1], picker=True, c=colors, s=75,
+ edgecolor=edgecolors, linewidth=2)
+ if select:
+ fig.lasso = SelectFromCollection(ax, pts, ch_names)
+
+ connect_picker = True
if show_names:
for idx in range(len(pos)):
this_pos = pos[idx]
@@ -1028,12 +1299,16 @@ def _plot_sensors(pos, colors, ch_names, title, show_names, show):
ax.text(this_pos[0], this_pos[1], this_pos[2], ch_names[idx])
else:
ax.text(this_pos[0], this_pos[1], ch_names[idx])
- else:
+ connect_picker = select
+ if connect_picker:
picker = partial(_onpick_sensor, fig=fig, ax=ax, pos=pos,
- ch_names=ch_names)
+ ch_names=ch_names, show_names=show_names)
fig.canvas.mpl_connect('pick_event', picker)
+
fig.suptitle(title)
- plt_show(show)
+ closed = partial(_close_event, fig=fig)
+ fig.canvas.mpl_connect('close_event', closed)
+ plt_show(show, block=block)
return fig
@@ -1107,3 +1382,192 @@ def _compute_scalings(scalings, inst):
scale_factor = np.max(np.abs(scale_factor))
scalings[key] = scale_factor
return scalings
+
+
+class DraggableColorbar(object):
+ """Class for enabling interactive colorbar.
+ See http://www.ster.kuleuven.be/~pieterd/python/html/plotting/interactive_colorbar.html
+ """ # noqa
+ def __init__(self, cbar, mappable):
+ import matplotlib.pyplot as plt
+ self.cbar = cbar
+ self.mappable = mappable
+ self.press = None
+ self.cycle = sorted([i for i in dir(plt.cm) if
+ hasattr(getattr(plt.cm, i), 'N')])
+ self.index = self.cycle.index(cbar.get_cmap().name)
+ self.lims = (self.cbar.norm.vmin, self.cbar.norm.vmax)
+ self.connect()
+
+ def connect(self):
+ """Connect to all the events we need."""
+ self.cidpress = self.cbar.patch.figure.canvas.mpl_connect(
+ 'button_press_event', self.on_press)
+ self.cidrelease = self.cbar.patch.figure.canvas.mpl_connect(
+ 'button_release_event', self.on_release)
+ self.cidmotion = self.cbar.patch.figure.canvas.mpl_connect(
+ 'motion_notify_event', self.on_motion)
+ self.keypress = self.cbar.patch.figure.canvas.mpl_connect(
+ 'key_press_event', self.key_press)
+ self.scroll = self.cbar.patch.figure.canvas.mpl_connect(
+ 'scroll_event', self.on_scroll)
+
+ def on_press(self, event):
+ """Callback for button press."""
+ if event.inaxes != self.cbar.ax:
+ return
+ self.press = event.y
+
+ def key_press(self, event):
+ """Callback for key press."""
+ if event.key == 'down':
+ self.index += 1
+ elif event.key == 'up':
+ self.index -= 1
+ elif event.key == ' ': # space key resets scale
+ self.cbar.norm.vmin = self.lims[0]
+ self.cbar.norm.vmax = self.lims[1]
+ else:
+ return
+ if self.index < 0:
+ self.index = len(self.cycle) - 1
+ elif self.index >= len(self.cycle):
+ self.index = 0
+ cmap = self.cycle[self.index]
+ self.cbar.set_cmap(cmap)
+ self.cbar.draw_all()
+ self.mappable.set_cmap(cmap)
+ self.cbar.patch.figure.canvas.draw()
+
+ def on_motion(self, event):
+ """Callback for mouse movements."""
+ if self.press is None:
+ return
+ if event.inaxes != self.cbar.ax:
+ return
+ yprev = self.press
+ dy = event.y - yprev
+ self.press = event.y
+ scale = self.cbar.norm.vmax - self.cbar.norm.vmin
+ perc = 0.03
+ if event.button == 1:
+ self.cbar.norm.vmin -= (perc * scale) * np.sign(dy)
+ self.cbar.norm.vmax -= (perc * scale) * np.sign(dy)
+ elif event.button == 3:
+ self.cbar.norm.vmin -= (perc * scale) * np.sign(dy)
+ self.cbar.norm.vmax += (perc * scale) * np.sign(dy)
+ self.cbar.draw_all()
+ self.mappable.set_norm(self.cbar.norm)
+ self.cbar.patch.figure.canvas.draw()
+
+ def on_release(self, event):
+ """Callback for release."""
+ self.press = None
+ self.mappable.set_norm(self.cbar.norm)
+ self.cbar.patch.figure.canvas.draw()
+
+ def on_scroll(self, event):
+ """Callback for scroll."""
+ scale = 1.1 if event.step < 0 else 1. / 1.1
+ self.cbar.norm.vmin *= scale
+ self.cbar.norm.vmax *= scale
+ self.cbar.draw_all()
+ self.mappable.set_norm(self.cbar.norm)
+ self.cbar.patch.figure.canvas.draw()
+
+
+class SelectFromCollection(object):
+ """Select channels from a matplotlib collection using `LassoSelector`.
+
+ Selected channels are saved in the ``selection`` attribute. This tool
+ highlights selected points by fading other points out (i.e., reducing their
+ alpha values).
+
+ Notes:
+ This tool selects collection objects based on their *origins*
+ (i.e., `offsets`). Emits mpl event 'lasso_event' when selection is ready.
+
+ Parameters
+ ----------
+ ax : Instance of Axes
+ Axes to interact with.
+
+ collection : Instance of matplotlib collection
+ Collection you want to select from.
+
+ alpha_other : 0 <= float <= 1
+ To highlight a selection, this tool sets all selected points to an
+ alpha value of 1 and non-selected points to `alpha_other`.
+ Defaults to 0.3.
+ """
+
+ def __init__(self, ax, collection, ch_names, alpha_other=0.3):
+ import matplotlib as mpl
+ if LooseVersion(mpl.__version__) < LooseVersion('1.2.1'):
+ raise ImportError('Interactive selection not possible for '
+ 'matplotlib versions < 1.2.1. Upgrade '
+ 'matplotlib.')
+ from matplotlib.widgets import LassoSelector
+ self.canvas = ax.figure.canvas
+ self.collection = collection
+ self.ch_names = ch_names
+ self.alpha_other = alpha_other
+
+ self.xys = collection.get_offsets()
+ self.Npts = len(self.xys)
+
+ # Ensure that we have separate colors for each object
+ self.fc = collection.get_facecolors()
+ if len(self.fc) == 0:
+ raise ValueError('Collection must have a facecolor')
+ elif len(self.fc) == 1:
+ self.fc = np.tile(self.fc, self.Npts).reshape(self.Npts, -1)
+ self.fc[:, -1] = self.alpha_other # deselect in the beginning
+
+ self.lasso = LassoSelector(ax, onselect=self.on_select,
+ lineprops={'color': 'red', 'linewidth': .5})
+ self.selection = list()
+
+ def on_select(self, verts):
+ """Callback for selecting a subset from the collection."""
+ from matplotlib.path import Path
+ if len(verts) <= 3: # Seems to be a good way to exclude single clicks.
+ return
+
+ path = Path(verts)
+ inds = np.nonzero([path.contains_point(xy) for xy in self.xys])[0]
+ if self.canvas._key == 'control': # Appending selection.
+ sels = [np.where(self.ch_names == c)[0][0] for c in self.selection]
+ inters = set(inds) - set(sels)
+ inds = list(inters.union(set(sels) - set(inds)))
+
+ while len(self.selection) > 0:
+ self.selection.pop(0)
+ self.selection.extend(self.ch_names[inds])
+ self.fc[:, -1] = self.alpha_other
+ self.fc[inds, -1] = 1
+ self.collection.set_facecolors(self.fc)
+ self.canvas.draw_idle()
+ self.canvas.callbacks.process('lasso_event')
+
+ def select_one(self, ind):
+ """Helper for selecting/deselecting one sensor."""
+ ch_name = self.ch_names[ind]
+ if ch_name in self.selection:
+ sel_ind = self.selection.index(ch_name)
+ self.selection.pop(sel_ind)
+ this_alpha = self.alpha_other
+ else:
+ self.selection.append(ch_name)
+ this_alpha = 1
+ self.fc[ind, -1] = this_alpha
+ self.collection.set_facecolors(self.fc)
+ self.canvas.draw_idle()
+ self.canvas.callbacks.process('lasso_event')
+
+ def disconnect(self):
+ """Method for disconnecting the lasso selector."""
+ self.lasso.disconnect_events()
+ self.fc[:, -1] = 1
+ self.collection.set_facecolors(self.fc)
+ self.canvas.draw_idle()
diff --git a/setup.cfg b/setup.cfg
index b5ba987..34231e7 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -19,6 +19,7 @@ doc-files = doc
# cover-html = 1
# cover-html-dir = coverage
cover-package = mne
+ignore-files = (?:^\.|^_,|^conf\.py$)
detailed-errors = 1
with-doctest = 1
diff --git a/setup.py b/setup.py
index a60cfd0..9f61e5d 100755
--- a/setup.py
+++ b/setup.py
@@ -75,6 +75,7 @@ if __name__ == "__main__":
'mne.datasets.brainstorm',
'mne.datasets.testing',
'mne.datasets.tests',
+ 'mne.datasets.multimodal',
'mne.externals',
'mne.externals.h5io',
'mne.externals.tempita',
diff --git a/tutorials/plot_artifacts_correction_filtering.py b/tutorials/plot_artifacts_correction_filtering.py
index af7dda7..f886737 100644
--- a/tutorials/plot_artifacts_correction_filtering.py
+++ b/tutorials/plot_artifacts_correction_filtering.py
@@ -1,8 +1,8 @@
"""
.. _tut_artifacts_filter:
-Filtering and Resampling
-========================
+Filtering and resampling data
+=============================
Certain artifacts are restricted to certain frequencies and can therefore
be fixed by filtering. An artifact that typically affects only some
@@ -13,6 +13,10 @@ It is composed of sharp peaks at 50Hz (or 60Hz depending on your
geographical location). Some peaks may also be present at the harmonic
frequencies, i.e. the integer multiples of
the power-line frequency, e.g. 100Hz, 150Hz, ... (or 120Hz, 180Hz, ...).
+
+This tutorial covers some basics of how to filter data in MNE-Python.
+For more in-depth information about filter design in general and in
+MNE-Python in particular, check out :ref:`tut_background_filtering`.
"""
import numpy as np
@@ -27,7 +31,8 @@ tmin, tmax = 0, 20 # use the first 20s of data
# Setup for reading the raw data (save memory by cropping the raw data
# before loading it)
-raw = mne.io.read_raw_fif(raw_fname).crop(tmin, tmax).load_data()
+raw = mne.io.read_raw_fif(raw_fname, add_eeg_ref=False)
+raw.crop(tmin, tmax).load_data()
raw.info['bads'] = ['MEG 2443', 'EEG 053'] # bads + 2 more
fmin, fmax = 2, 300 # look at frequencies between 2 and 300Hz
@@ -48,44 +53,62 @@ raw.plot_psd(area_mode='range', tmax=10.0, picks=picks)
# Removing power-line noise can be done with a Notch filter, directly on the
# Raw object, specifying an array of frequency to be cut off:
-raw.notch_filter(np.arange(60, 241, 60), picks=picks)
+raw.notch_filter(np.arange(60, 241, 60), picks=picks, filter_length='auto',
+ phase='zero')
raw.plot_psd(area_mode='range', tmax=10.0, picks=picks)
###############################################################################
-# Removing power-line noise with low-pas filtering
+# Removing power-line noise with low-pass filtering
# -------------------------------------------------
#
# If you're only interested in low frequencies, below the peaks of power-line
# noise you can simply low pass filter the data.
-raw.filter(None, 50.) # low pass filtering below 50 Hz
+# low pass filtering below 50 Hz
+raw.filter(None, 50., h_trans_bandwidth='auto', filter_length='auto',
+ phase='zero')
raw.plot_psd(area_mode='range', tmax=10.0, picks=picks)
###############################################################################
# High-pass filtering to remove slow drifts
# -----------------------------------------
#
-# If you're only interested in low frequencies, below the peaks of power-line
-# noise you can simply high pass filter the data.
+# To remove slow drifts, you can high pass.
+#
+# .. warning:: There can be issues using high-passes greater than 0.1 Hz
+# (see examples in :ref:`tut_filtering_hp_problems`),
+# so apply high-pass filters with caution.
-raw.filter(1., None) # low pass filtering above 1 Hz
+raw.filter(1., None, l_trans_bandwidth='auto', filter_length='auto',
+ phase='zero')
raw.plot_psd(area_mode='range', tmax=10.0, picks=picks)
+
###############################################################################
# To do the low-pass and high-pass filtering in one step you can do
-# a so-called *band-pass* filter by running
+# a so-called *band-pass* filter by running the following:
-raw.filter(1., 50.) # band-pass filtering in the range 1 Hz - 50 Hz
+# band-pass filtering in the range 1 Hz - 50 Hz
+raw.filter(1, 50., l_trans_bandwidth='auto', h_trans_bandwidth='auto',
+ filter_length='auto', phase='zero')
###############################################################################
-# Down-sampling (for performance reasons)
-# ---------------------------------------
+# Downsampling and decimation
+# ---------------------------
#
# When performing experiments where timing is critical, a signal with a high
# sampling rate is desired. However, having a signal with a much higher
# sampling rate than necessary needlessly consumes memory and slows down
-# computations operating on the data. To avoid that, you can down-sample
-# your time series.
+# computations operating on the data. To avoid that, you can downsample
+# your time series. Since downsampling raw data reduces the timing precision
+# of events, it is recommended only for use in procedures that do not require
+# optimal precision, e.g. computing EOG or ECG projectors on long recordings.
+#
+# .. note:: A *downsampling* operation performs a low-pass (to prevent
+# aliasing) followed by *decimation*, which selects every
+# :math:`N^{th}` sample from the signal. See
+# :func:`scipy.signal.resample` and
+# :func:`scipy.signal.resample_poly` for examples.
#
# Data resampling can be done with *resample* methods.
@@ -93,6 +116,16 @@ raw.resample(100, npad="auto") # set sampling frequency to 100Hz
raw.plot_psd(area_mode='range', tmax=10.0, picks=picks)
###############################################################################
-# Since down-sampling reduces the timing precision of events, you might want to
-# first extract epochs and down-sampling the Epochs object. You can do this
-# using the :func:`mne.Epochs.resample` method.
+# To avoid this reduction in precision, the suggested pipeline for
+# processing final data to be analyzed is:
+#
+# 1. low-pass the data with :meth:`mne.io.Raw.filter`.
+# 2. Extract epochs with :class:`mne.Epochs`.
+# 3. Decimate the Epochs object using :meth:`mne.Epochs.decimate` or the
+# ``decim`` argument to the :class:`mne.Epochs` object.
+#
+# We also provide the convenience methods :meth:`mne.Epochs.resample` and
+# :meth:`mne.Evoked.resample` to downsample or upsample data, but these are
+# less optimal because they will introduce edge artifacts into every epoch,
+# whereas filtering the raw data will only introduce edge artifacts only at
+# the start and end of the recording.
diff --git a/tutorials/plot_artifacts_correction_ica.py b/tutorials/plot_artifacts_correction_ica.py
index 0d10177..2d900e0 100644
--- a/tutorials/plot_artifacts_correction_ica.py
+++ b/tutorials/plot_artifacts_correction_ica.py
@@ -23,13 +23,13 @@ import mne
from mne.datasets import sample
from mne.preprocessing import ICA
-from mne.preprocessing import create_eog_epochs
+from mne.preprocessing import create_eog_epochs, create_ecg_epochs
# getting some data ready
data_path = sample.data_path()
raw_fname = data_path + '/MEG/sample/sample_audvis_filt-0-40_raw.fif'
-raw = mne.io.read_raw_fif(raw_fname, preload=True)
+raw = mne.io.read_raw_fif(raw_fname, preload=True, add_eeg_ref=False)
raw.filter(1, 40, n_jobs=2) # 1Hz high pass is often helpful for fitting ICA
picks_meg = mne.pick_types(raw.info, meg=True, eeg=False, eog=False,
@@ -47,11 +47,16 @@ picks_meg = mne.pick_types(raw.info, meg=True, eeg=False, eog=False,
n_components = 25 # if float, select n_components by explained variance of PCA
method = 'fastica' # for comparison with EEGLAB try "extended-infomax" here
-decim = 3 # we need sufficient statistics, not all time points -> save time
+decim = 3 # we need sufficient statistics, not all time points -> saves time
+
+# we will also set state of the random number generator - ICA is a
+# non-deterministic algorithm, but we want to have the same decomposition
+# and the same order of components each time this tutorial is run
+random_state = 23
###############################################################################
# Define the ICA object instance
-ica = ICA(n_components=n_components, method=method)
+ica = ICA(n_components=n_components, method=method, random_state=random_state)
print(ica)
###############################################################################
@@ -64,8 +69,36 @@ print(ica)
###############################################################################
# Plot ICA components
-ica.plot_components() # can you see some potential bad guys?
+ica.plot_components() # can you spot some potential bad guys?
+
+###############################################################################
+# Component properties
+# --------------------
+#
+# Let's take a closer look at properties of first three independent components.
+
+# first, component 0:
+ica.plot_properties(raw, picks=0)
+
+###############################################################################
+# we can see that the data were filtered so the spectrum plot is not
+# very informative, let's change that:
+ica.plot_properties(raw, picks=0, psd_args={'fmax': 35.})
+
+###############################################################################
+# we can also take a look at multiple different components at once:
+ica.plot_properties(raw, picks=[1, 2], psd_args={'fmax': 35.})
+
+###############################################################################
+# Instead of opening individual figures with component properties, we can
+# also pass an instance of Raw or Epochs in ``inst`` arument to
+# ``ica.plot_components``. This would allow us to open component properties
+# interactively by clicking on individual component topomaps. In the notebook
+# this woks only when running matplotlib in interactive mode (``%matplotlib``).
+
+# uncomment the code below to test the inteactive mode of plot_components:
+# ica.plot_components(picks=range(10), inst=raw)
###############################################################################
# Advanced artifact detection
@@ -88,10 +121,18 @@ ica.plot_scores(scores, exclude=eog_inds) # look at r scores of components
ica.plot_sources(eog_average, exclude=eog_inds) # look at source time course
###############################################################################
-# That component is also showing a prototypical average vertical EOG time
-# course.
+# We can take a look at the properties of that component, now using the
+# data epoched with respect to EOG events.
+# We will also use a little bit of smoothing along the trials axis in the
+# epochs image:
+ica.plot_properties(eog_epochs, picks=eog_inds, psd_args={'fmax': 35.},
+ image_args={'sigma': 1.})
+
+###############################################################################
+# That component is showing a prototypical average vertical EOG time course.
#
-# Pay attention to the labels, a customized read-out of the ica.labels_
+# Pay attention to the labels, a customized read-out of the
+# :attr:`ica.labels_ <mne.preprocessing.ICA.labels_>`
print(ica.labels_)
###############################################################################
@@ -99,14 +140,13 @@ print(ica.labels_)
# by artifact detection functions. You can also manually edit them to annotate
# components.
#
-# Now let's see how we would modify our signals if we would remove this
-# component from the data
+# Now let's see how we would modify our signals if we removed this component
+# from the data
ica.plot_overlay(eog_average, exclude=eog_inds, show=False)
# red -> before, black -> after. Yes! We remove quite a lot!
# to definitely register this component as a bad one to be removed
# there is the ``ica.exclude`` attribute, a simple Python list
-
ica.exclude.extend(eog_inds)
# from now on the ICA will reject this component even if no exclude
@@ -119,16 +159,24 @@ ica.exclude.extend(eog_inds)
###############################################################################
# Exercise: find and remove ECG artifacts using ICA!
-#
+ecg_epochs = create_ecg_epochs(raw, tmin=-.5, tmax=.5)
+ecg_inds, scores = ica.find_bads_ecg(ecg_epochs, method='ctps')
+ica.plot_properties(ecg_epochs, picks=ecg_inds, psd_args={'fmax': 35.})
+
+
+###############################################################################
# What if we don't have an EOG channel?
# -------------------------------------
#
-# 1) make a bipolar reference from frontal EEG sensors and use as virtual EOG
-# channel. This can be tricky though as you can only hope that the frontal
-# EEG channels only reflect EOG and not brain dynamics in the prefrontal
-# cortex.
-# 2) Go for a semi-automated approach, using template matching.
-# In MNE-Python option 2 is easily achievable and it might be better,
+# We could either:
+#
+# 1. make a bipolar reference from frontal EEG sensors and use as virtual EOG
+# channel. This can be tricky though as you can only hope that the frontal
+# EEG channels only reflect EOG and not brain dynamics in the prefrontal
+# cortex.
+# 2. go for a semi-automated approach, using template matching.
+#
+# In MNE-Python option 2 is easily achievable and it might give better results,
# so let's have a look at it.
from mne.preprocessing.ica import corrmap # noqa
@@ -137,13 +185,19 @@ from mne.preprocessing.ica import corrmap # noqa
# The idea behind corrmap is that artefact patterns are similar across subjects
# and can thus be identified by correlating the different patterns resulting
# from each solution with a template. The procedure is therefore
-# semi-automatic. Corrmap hence takes at least a list of ICA solutions and a
-# template, that can be an index or an array. As we don't have different
-# subjects or runs available today, here we will fit ICA models to different
-# parts of the recording and then use as a user-defined template the ICA
-# that we just fitted for detecting corresponding components in the three "new"
-# ICAs. The following block of code addresses this point and should not be
-# copied, ok?
+# semi-automatic. :func:`mne.preprocessing.corrmap` hence takes a list of
+# ICA solutions and a template, that can be an index or an array.
+#
+# As we don't have different subjects or runs available today, here we will
+# simulate ICA solutions from different subjects by fitting ICA models to
+# different parts of the same recording. Then we will use one of the components
+# from our original ICA as a template in order to detect sufficiently similar
+# components in the simulated ICAs.
+#
+# The following block of code simulates having ICA solutions from different
+# runs/subjects so it should not be used in real analysis - use independent
+# data sets instead.
+
# We'll start by simulating a group of subjects or runs from a subject
start, stop = [0, len(raw.times) - 1]
intervals = np.linspace(start, stop, 4, dtype=int)
@@ -158,44 +212,46 @@ for ii, start in enumerate(intervals):
icas_from_other_data.append(this_ica)
###############################################################################
-# Do not copy this at home! You start by reading in a collections of ICA
-# solutions, something like
+# Remember, don't do this at home! Start by reading in a collection of ICA
+# solutions instead. Something like:
#
# ``icas = [mne.preprocessing.read_ica(fname) for fname in ica_fnames]``
print(icas_from_other_data)
###############################################################################
-# use our previous ICA as reference.
+# We use our original ICA as reference.
reference_ica = ica
###############################################################################
-# Investigate our reference ICA, here we use the previous fit from above.
+# Investigate our reference ICA:
reference_ica.plot_components()
###############################################################################
# Which one is the bad EOG component?
-# Here we rely on our previous detection algorithm. You will need to decide
-# yourself in that situation where no other detection is available.
-
+# Here we rely on our previous detection algorithm. You would need to decide
+# yourself if no automatic detection was available.
reference_ica.plot_sources(eog_average, exclude=eog_inds)
###############################################################################
# Indeed it looks like an EOG, also in the average time course.
#
-# So our template shall be a tuple like (reference_run_index, component_index):
+# We construct a list where our reference run is the first element. Then we
+# can detect similar components from the other runs using
+# :func:`mne.preprocessing.corrmap`. So our template must be a tuple like
+# (reference_run_index, component_index):
+icas = [reference_ica] + icas_from_other_data
template = (0, eog_inds[0])
###############################################################################
# Now we can do the corrmap.
-fig_template, fig_detected = corrmap(
- icas_from_other_data, template=template, label="blinks", show=True,
- threshold=.8, ch_type='mag')
+fig_template, fig_detected = corrmap(icas, template=template, label="blinks",
+ show=True, threshold=.8, ch_type='mag')
###############################################################################
-# Nice, we have found similar ICs from the other runs!
+# Nice, we have found similar ICs from the other (simulated) runs!
# This is even nicer if we have 20 or 100 ICA solutions in a list.
#
-# You can also use SSP for correcting for artifacts. It is a bit simpler,
-# faster but is less precise than ICA. And it requires that you
-# know the event timing of your artifact.
+# You can also use SSP to correct for artifacts. It is a bit simpler and
+# faster but also less precise than ICA and requires that you know the event
+# timing of your artifact.
# See :ref:`tut_artifacts_correct_ssp`.
diff --git a/tutorials/plot_artifacts_correction_maxwell_filtering.py b/tutorials/plot_artifacts_correction_maxwell_filtering.py
index 84b19c7..eff9419 100644
--- a/tutorials/plot_artifacts_correction_maxwell_filtering.py
+++ b/tutorials/plot_artifacts_correction_maxwell_filtering.py
@@ -24,7 +24,7 @@ fine_cal_fname = data_path + '/SSS/sss_cal_mgh.dat'
###############################################################################
# Preprocess with Maxwell filtering
-raw = mne.io.read_raw_fif(raw_fname)
+raw = mne.io.read_raw_fif(raw_fname, add_eeg_ref=False)
raw.info['bads'] = ['MEG 2443', 'EEG 053', 'MEG 1032', 'MEG 2313'] # set bads
# Here we don't use tSSS (set st_duration) because MGH data is very clean
raw_sss = maxwell_filter(raw, cross_talk=ctc_fname, calibration=fine_cal_fname)
diff --git a/tutorials/plot_artifacts_correction_rejection.py b/tutorials/plot_artifacts_correction_rejection.py
index 04accb9..affdfef 100644
--- a/tutorials/plot_artifacts_correction_rejection.py
+++ b/tutorials/plot_artifacts_correction_rejection.py
@@ -12,7 +12,8 @@ from mne.datasets import sample
data_path = sample.data_path()
raw_fname = data_path + '/MEG/sample/sample_audvis_filt-0-40_raw.fif'
-raw = mne.io.read_raw_fif(raw_fname)
+raw = mne.io.read_raw_fif(raw_fname, add_eeg_ref=False)
+raw.set_eeg_reference()
###############################################################################
# .. _marking_bad_channels:
@@ -127,7 +128,8 @@ n_blinks = len(eog_events)
# Center to cover the whole blink with full duration of 0.5s:
onset = eog_events[:, 0] / raw.info['sfreq'] - 0.25
duration = np.repeat(0.5, n_blinks)
-raw.annotations = mne.Annotations(onset, duration, ['bad blink'] * n_blinks)
+raw.annotations = mne.Annotations(onset, duration, ['bad blink'] * n_blinks,
+ orig_time=raw.info['meas_date'])
raw.plot(events=eog_events) # To see the annotated segments.
###############################################################################
@@ -176,7 +178,7 @@ picks_meg = mne.pick_types(raw.info, meg=True, eeg=False, eog=True,
stim=False, exclude='bads')
epochs = mne.Epochs(raw, events, event_id, tmin, tmax, proj=True,
picks=picks_meg, baseline=baseline, reject=reject,
- reject_by_annotation=True)
+ reject_by_annotation=True, add_eeg_ref=False)
###############################################################################
# We then drop/reject the bad epochs
diff --git a/tutorials/plot_artifacts_correction_ssp.py b/tutorials/plot_artifacts_correction_ssp.py
index 642d196..bff1cc5 100644
--- a/tutorials/plot_artifacts_correction_ssp.py
+++ b/tutorials/plot_artifacts_correction_ssp.py
@@ -16,7 +16,8 @@ from mne.preprocessing import compute_proj_ecg, compute_proj_eog
data_path = sample.data_path()
raw_fname = data_path + '/MEG/sample/sample_audvis_filt-0-40_raw.fif'
-raw = mne.io.read_raw_fif(raw_fname, preload=True)
+raw = mne.io.read_raw_fif(raw_fname, preload=True, add_eeg_ref=False)
+raw.set_eeg_reference()
raw.pick_types(meg=True, ecg=True, eog=True, stim=True)
##############################################################################
diff --git a/tutorials/plot_background_filtering.py b/tutorials/plot_background_filtering.py
new file mode 100644
index 0000000..0db34d7
--- /dev/null
+++ b/tutorials/plot_background_filtering.py
@@ -0,0 +1,951 @@
+# -*- coding: utf-8 -*-
+r"""
+.. _tut_background_filtering:
+
+===================================
+Background information on filtering
+===================================
+
+Here we give some background information on filtering in general,
+and how it is done in MNE-Python in particular.
+Recommended reading for practical applications of digital
+filter design can be found in Parks & Burrus [1]_ and
+Ifeachor and Jervis [2]_, and for filtering in an
+M/EEG context we recommend reading Widmann *et al.* 2015 [7]_.
+To see how to use the default filters in MNE-Python on actual data, see
+the :ref:`tut_artifacts_filter` tutorial.
+
+.. contents::
+ :local:
+
+Problem statement
+=================
+
+The practical issues with filtering electrophysiological data are covered
+well by Widmann *et al.* in [7]_, in a follow-up to an article where they
+conclude with this statement:
+
+ Filtering can result in considerable distortions of the time course
+ (and amplitude) of a signal as demonstrated by VanRullen (2011) [[3]_].
+ Thus, filtering should not be used lightly. However, if effects of
+ filtering are cautiously considered and filter artifacts are minimized,
+ a valid interpretation of the temporal dynamics of filtered
+ electrophysiological data is possible and signals missed otherwise
+ can be detected with filtering.
+
+In other words, filtering can increase SNR, but if it is not used carefully,
+it can distort data. Here we hope to cover some filtering basics so
+users can better understand filtering tradeoffs, and why MNE-Python has
+chosen particular defaults.
+
+.. _tut_filtering_basics:
+
+Filtering basics
+================
+
+Let's get some of the basic math down. In the frequency domain, digital
+filters have a transfer function that is given by:
+
+.. math::
+
+ H(z) &= \frac{b_0 + b_1 z^{-1} + b_2 z^{-2} + ... + b_M z^{-M}}
+ {1 + a_1 z^{-1} + a_2 z^{-2} + ... + a_N z^{-M}} \\
+ &= \frac{\sum_0^Mb_kz^{-k}}{\sum_1^Na_kz^{-k}}
+
+In the time domain, the numerator coefficients :math:`b_k` and denominator
+coefficients :math:`a_k` can be used to obtain our output data
+:math:`y(n)` in terms of our input data :math:`x(n)` as:
+
+.. math::
+ :label: summations
+
+ y(n) &= b_0 x(n) + b_1 x(n-1) + ... + b_M x(n-M)
+ - a_1 y(n-1) - a_2 y(n - 2) - ... - a_N y(n - N)\\
+ &= \sum_0^M b_k x(n-k) - \sum_1^N a_k y(n-k)
+
+In other words, the output at time :math:`n` is determined by a sum over:
+
+ 1. The numerator coefficients :math:`b_k`, which get multiplied by
+ the previous input :math:`x(n-k)` values, and
+ 2. The denominator coefficients :math:`a_k`, which get multiplied by
+ the previous output :math:`y(n-k)` values.
+
+Note that these summations in :eq:`summations` correspond nicely to
+(1) a weighted `moving average`_ and (2) an autoregression_.
+
+Filters are broken into two classes: FIR_ (finite impulse response) and
+IIR_ (infinite impulse response) based on these coefficients.
+FIR filters use a finite number of numerator
+coefficients :math:`b_k` (:math:`\forall k, a_k=0`), and thus each output
+value of :math:`y(n)` depends only on the :math:`M` previous input values.
+IIR filters depend on the previous input and output values, and thus can have
+effectively infinite impulse responses.
+
+As outlined in [1]_, FIR and IIR have different tradeoffs:
+
+ * A causal FIR filter can be linear-phase -- i.e., the same time delay
+ across all frequencies -- whereas a causal IIR filter cannot. The phase
+ and group delay characteristics are also usually better for FIR filters.
+ * IIR filters can generally have a steeper cutoff than an FIR filter of
+ equivalent order.
+ * IIR filters are generally less numerically stable, in part due to
+ accumulating error (due to its recursive calculations).
+
+In MNE-Python we default to using FIR filtering. As noted in Widmann *et al.*
+2015 [7]_:
+
+ Despite IIR filters often being considered as computationally more
+ efficient, they are recommended only when high throughput and sharp
+ cutoffs are required (Ifeachor and Jervis, 2002[2]_, p. 321),
+ ...FIR filters are easier to control, are always stable, have a
+ well-defined passband, can be corrected to zero-phase without
+ additional computations, and can be converted to minimum-phase.
+ We therefore recommend FIR filters for most purposes in
+ electrophysiological data analysis.
+
+When designing a filter (FIR or IIR), there are always tradeoffs that
+need to be considered, including but not limited to:
+
+ 1. Ripple in the pass-band
+ 2. Attenuation of the stop-band
+ 3. Steepness of roll-off
+ 4. Filter order (i.e., length for FIR filters)
+ 5. Time-domain ringing
+
+In general, the sharper something is in frequency, the broader it is in time,
+and vice-versa. This is a fundamental time-frequency tradeoff, and it will
+show up below.
+
+FIR Filters
+===========
+
+First we will focus first on FIR filters, which are the default filters used by
+MNE-Python.
+"""
+
+###############################################################################
+# Designing FIR filters
+# ---------------------
+# Here we'll try designing a low-pass filter, and look at trade-offs in terms
+# of time- and frequency-domain filter characteristics. Later, in
+# :ref:`tut_effect_on_signals`, we'll look at how such filters can affect
+# signals when they are used.
+#
+# First let's import some useful tools for filtering, and set some default
+# values for our data that are reasonable for M/EEG data.
+
+import numpy as np
+from scipy import signal, fftpack
+import matplotlib.pyplot as plt
+
+from mne.time_frequency.tfr import morlet
+
+import mne
+
+sfreq = 1000.
+f_p = 40.
+ylim = [-60, 10] # for dB plots
+xlim = [2, sfreq / 2.]
+blue = '#1f77b4'
+
+###############################################################################
+# Take for example an ideal low-pass filter, which would give a value of 1 in
+# the pass-band (up to frequency :math:`f_p`) and a value of 0 in the stop-band
+# (down to frequency :math:`f_s`) such that :math:`f_p=f_s=40` Hz here
+# (shown to a lower limit of -60 dB for simplicity):
+
+nyq = sfreq / 2. # the Nyquist frequency is half our sample rate
+freq = [0, f_p, f_p, nyq]
+gain = [1, 1, 0, 0]
+
+
+def box_off(ax):
+ ax.grid(zorder=0)
+ for key in ('top', 'right'):
+ ax.spines[key].set_visible(False)
+
+
+def plot_ideal(freq, gain, ax):
+ freq = np.maximum(freq, xlim[0])
+ xs, ys = list(), list()
+ my_freq, my_gain = list(), list()
+ for ii in range(len(freq)):
+ xs.append(freq[ii])
+ ys.append(ylim[0])
+ if ii < len(freq) - 1 and gain[ii] != gain[ii + 1]:
+ xs += [freq[ii], freq[ii + 1]]
+ ys += [ylim[1]] * 2
+ my_freq += np.linspace(freq[ii], freq[ii + 1], 20,
+ endpoint=False).tolist()
+ my_gain += np.linspace(gain[ii], gain[ii + 1], 20,
+ endpoint=False).tolist()
+ else:
+ my_freq.append(freq[ii])
+ my_gain.append(gain[ii])
+ my_gain = 10 * np.log10(np.maximum(my_gain, 10 ** (ylim[0] / 10.)))
+ ax.fill_between(xs, ylim[0], ys, color='r', alpha=0.1)
+ ax.semilogx(my_freq, my_gain, 'r--', alpha=0.5, linewidth=4, zorder=3)
+ xticks = [1, 2, 4, 10, 20, 40, 100, 200, 400]
+ ax.set(xlim=xlim, ylim=ylim, xticks=xticks, xlabel='Frequency (Hz)',
+ ylabel='Amplitude (dB)')
+ ax.set(xticklabels=xticks)
+ box_off(ax)
+
+half_height = np.array(plt.rcParams['figure.figsize']) * [1, 0.5]
+ax = plt.subplots(1, figsize=half_height)[1]
+plot_ideal(freq, gain, ax)
+ax.set(title='Ideal %s Hz lowpass' % f_p)
+mne.viz.tight_layout()
+plt.show()
+
+###############################################################################
+# This filter hypothetically achieves zero ripple in the frequency domain,
+# perfect attenuation, and perfect steepness. However, due to the discontunity
+# in the frequency response, the filter would require infinite ringing in the
+# time domain (i.e., infinite order) to be realized. Another way to think of
+# this is that a rectangular window in frequency is actually sinc_ function
+# in time, which requires an infinite number of samples, and thus infinite
+# time, to represent. So although this filter has ideal frequency suppression,
+# it has poor time-domain characteristics.
+#
+# Let's try to naïvely make a brick-wall filter of length 0.1 sec, and look
+# at the filter itself in the time domain and the frequency domain:
+
+n = int(round(0.1 * sfreq)) + 1
+t = np.arange(-n // 2, n // 2) / sfreq # center our sinc
+h = np.sinc(2 * f_p * t) / (4 * np.pi)
+
+
+def plot_filter(h, title, freq, gain, show=True):
+ if h.ndim == 2: # second-order sections
+ sos = h
+ n = mne.filter.estimate_ringing_samples(sos)
+ h = np.zeros(n)
+ h[0] = 1
+ h = signal.sosfilt(sos, h)
+ H = np.ones(512, np.complex128)
+ for section in sos:
+ f, this_H = signal.freqz(section[:3], section[3:])
+ H *= this_H
+ else:
+ f, H = signal.freqz(h)
+ fig, axs = plt.subplots(2)
+ t = np.arange(len(h)) / sfreq
+ axs[0].plot(t, h, color=blue)
+ axs[0].set(xlim=t[[0, -1]], xlabel='Time (sec)',
+ ylabel='Amplitude h(n)', title=title)
+ box_off(axs[0])
+ f *= sfreq / (2 * np.pi)
+ axs[1].semilogx(f, 10 * np.log10((H * H.conj()).real), color=blue,
+ linewidth=2, zorder=4)
+ plot_ideal(freq, gain, axs[1])
+ mne.viz.tight_layout()
+ if show:
+ plt.show()
+
+plot_filter(h, 'Sinc (0.1 sec)', freq, gain)
+
+###############################################################################
+# This is not so good! Making the filter 10 times longer (1 sec) gets us a
+# bit better stop-band suppression, but still has a lot of ringing in
+# the time domain. Note the x-axis is an order of magnitude longer here:
+
+n = int(round(1. * sfreq)) + 1
+t = np.arange(-n // 2, n // 2) / sfreq
+h = np.sinc(2 * f_p * t) / (4 * np.pi)
+plot_filter(h, 'Sinc (1.0 sec)', freq, gain)
+
+###############################################################################
+# Let's make the stop-band tighter still with a longer filter (10 sec),
+# with a resulting larger x-axis:
+
+n = int(round(10. * sfreq)) + 1
+t = np.arange(-n // 2, n // 2) / sfreq
+h = np.sinc(2 * f_p * t) / (4 * np.pi)
+plot_filter(h, 'Sinc (10.0 sec)', freq, gain)
+
+###############################################################################
+# Now we have very sharp frequency suppression, but our filter rings for the
+# entire second. So this naïve method is probably not a good way to build
+# our low-pass filter.
+#
+# Fortunately, there are multiple established methods to design FIR filters
+# based on desired response characteristics. These include:
+#
+# 1. The Remez_ algorithm (:func:`scipy.signal.remez`, `MATLAB firpm`_)
+# 2. Windowed FIR design (:func:`scipy.signal.firwin2`, `MATLAB fir2`_)
+# 3. Least squares designs (:func:`scipy.signal.firls`, `MATLAB firls`_)
+# 4. Frequency-domain design (construct filter in Fourier
+# domain and use an :func:`IFFT <scipy.fftpack.ifft>` to invert it)
+#
+# .. note:: Remez and least squares designs have advantages when there are
+# "do not care" regions in our frequency response. However, we want
+# well controlled responses in all frequency regions.
+# Frequency-domain construction is good when an arbitrary response
+# is desired, but generally less clean (due to sampling issues) than
+# a windowed approach for more straightfroward filter applications.
+# Since our filters (low-pass, high-pass, band-pass, band-stop)
+# are fairly simple and we require precisel control of all frequency
+# regions, here we will use and explore primarily windowed FIR
+# design.
+#
+# If we relax our frequency-domain filter requirements a little bit, we can
+# use these functions to construct a lowpass filter that instead has a
+# *transition band*, or a region between the pass frequency :math:`f_p`
+# and stop frequency :math:`f_s`, e.g.:
+
+trans_bandwidth = 10 # 10 Hz transition band
+f_s = f_p + trans_bandwidth # = 50 Hz
+
+freq = [0., f_p, f_s, nyq]
+gain = [1., 1., 0., 0.]
+ax = plt.subplots(1, figsize=half_height)[1]
+plot_ideal(freq, gain, ax)
+ax.set(title='%s Hz lowpass with a %s Hz transition' % (f_p, trans_bandwidth))
+mne.viz.tight_layout()
+plt.show()
+
+###############################################################################
+# Accepting a shallower roll-off of the filter in the frequency domain makes
+# our time-domain response potentially much better. We end up with a
+# smoother slope through the transition region, but a *much* cleaner time
+# domain signal. Here again for the 1 sec filter:
+
+h = signal.firwin2(n, freq, gain, nyq=nyq)
+plot_filter(h, 'Windowed 10-Hz transition (1.0 sec)', freq, gain)
+
+###############################################################################
+# Since our lowpass is around 40 Hz with a 10 Hz transition, we can actually
+# use a shorter filter (5 cycles at 10 Hz = 0.5 sec) and still get okay
+# stop-band attenuation:
+
+n = int(round(sfreq * 0.5)) + 1
+h = signal.firwin2(n, freq, gain, nyq=nyq)
+plot_filter(h, 'Windowed 10-Hz transition (0.5 sec)', freq, gain)
+
+###############################################################################
+# But then if we shorten the filter too much (2 cycles of 10 Hz = 0.2 sec),
+# our effective stop frequency gets pushed out past 60 Hz:
+
+n = int(round(sfreq * 0.2)) + 1
+h = signal.firwin2(n, freq, gain, nyq=nyq)
+plot_filter(h, 'Windowed 10-Hz transition (0.2 sec)', freq, gain)
+
+###############################################################################
+# If we want a filter that is only 0.1 seconds long, we should probably use
+# something more like a 25 Hz transition band (0.2 sec = 5 cycles @ 25 Hz):
+
+trans_bandwidth = 25
+f_s = f_p + trans_bandwidth
+freq = [0, f_p, f_s, nyq]
+h = signal.firwin2(n, freq, gain, nyq=nyq)
+plot_filter(h, 'Windowed 50-Hz transition (0.2 sec)', freq, gain)
+
+###############################################################################
+# .. _tut_effect_on_signals:
+#
+# Applying FIR filters
+# --------------------
+#
+# Now lets look at some practical effects of these filters by applying
+# them to some data.
+#
+# Let's construct a Gaussian-windowed sinusoid (i.e., Morlet imaginary part)
+# plus noise (random + line). Note that the original, clean signal contains
+# frequency content in both the pass band and transition bands of our
+# low-pass filter.
+
+dur = 10.
+center = 2.
+morlet_freq = f_p
+tlim = [center - 0.2, center + 0.2]
+tticks = [tlim[0], center, tlim[1]]
+flim = [20, 70]
+
+x = np.zeros(int(sfreq * dur))
+blip = morlet(sfreq, [morlet_freq], n_cycles=7)[0].imag / 20.
+n_onset = int(center * sfreq) - len(blip) // 2
+x[n_onset:n_onset + len(blip)] += blip
+x_orig = x.copy()
+
+rng = np.random.RandomState(0)
+x += rng.randn(len(x)) / 1000.
+x += np.sin(2. * np.pi * 60. * np.arange(len(x)) / sfreq) / 2000.
+
+###############################################################################
+# Filter it with a shallow cutoff, linear-phase FIR and compensate for
+# the delay:
+
+transition_band = 0.25 * f_p
+f_s = f_p + transition_band
+filter_dur = 6.6 / transition_band # sec
+n = int(sfreq * filter_dur)
+freq = [0., f_p, f_s, sfreq / 2.]
+gain = [1., 1., 0., 0.]
+h = signal.firwin2(n, freq, gain, nyq=sfreq / 2.)
+x_shallow = np.convolve(h, x)[len(h) // 2:]
+
+plot_filter(h, 'MNE-Python 0.14 default', freq, gain)
+
+###############################################################################
+# This is actually set to become the default type of filter used in MNE-Python
+# in 0.14 (see :ref:`tut_filtering_in_python`).
+#
+# Let's also filter with the MNE-Python 0.13 default, which is a
+# long-duration, steep cutoff FIR that gets applied twice:
+
+transition_band = 0.5 # Hz
+f_s = f_p + transition_band
+filter_dur = 10. # sec
+n = int(sfreq * filter_dur)
+freq = [0., f_p, f_s, sfreq / 2.]
+gain = [1., 1., 0., 0.]
+h = signal.firwin2(n, freq, gain, nyq=sfreq / 2.)
+x_steep = np.convolve(np.convolve(h, x)[::-1], h)[::-1][len(h) - 1:-len(h) - 1]
+
+plot_filter(h, 'MNE-Python 0.13 default', freq, gain)
+
+###############################################################################
+# Finally, Let's also filter it with the
+# MNE-C default, which is a long-duration steep-slope FIR filter designed
+# using frequency-domain techniques:
+
+h = mne.filter.design_mne_c_filter(sfreq, l_freq=None, h_freq=f_p + 2.5)
+x_mne_c = np.convolve(h, x)[len(h) // 2:]
+
+transition_band = 5 # Hz (default in MNE-C)
+f_s = f_p + transition_band
+freq = [0., f_p, f_s, sfreq / 2.]
+gain = [1., 1., 0., 0.]
+plot_filter(h, 'MNE-C default', freq, gain)
+
+###############################################################################
+# Both the MNE-Python 0.13 and MNE-C filhters have excellent frequency
+# attenuation, but it comes at a cost of potential
+# ringing (long-lasting ripples) in the time domain. Ringing can occur with
+# steep filters, especially on signals with frequency content around the
+# transition band. Our Morlet wavelet signal has power in our transition band,
+# and the time-domain ringing is thus more pronounced for the steep-slope,
+# long-duration filter than the shorter, shallower-slope filter:
+
+axs = plt.subplots(1, 2)[1]
+
+
+def plot_signal(x, offset):
+ t = np.arange(len(x)) / sfreq
+ axs[0].plot(t, x + offset)
+ axs[0].set(xlabel='Time (sec)', xlim=t[[0, -1]])
+ box_off(axs[0])
+ X = fftpack.fft(x)
+ freqs = fftpack.fftfreq(len(x), 1. / sfreq)
+ mask = freqs >= 0
+ X = X[mask]
+ freqs = freqs[mask]
+ axs[1].plot(freqs, 20 * np.log10(np.abs(X)))
+ axs[1].set(xlim=xlim)
+
+yticks = np.arange(5) / -30.
+yticklabels = ['Original', 'Noisy', 'FIR-shallow (0.14)', 'FIR-steep (0.13)',
+ 'FIR-steep (MNE-C)']
+plot_signal(x_orig, offset=yticks[0])
+plot_signal(x, offset=yticks[1])
+plot_signal(x_shallow, offset=yticks[2])
+plot_signal(x_steep, offset=yticks[3])
+plot_signal(x_mne_c, offset=yticks[4])
+axs[0].set(xlim=tlim, title='FIR, Lowpass=%d Hz' % f_p, xticks=tticks,
+ ylim=[-0.150, 0.025], yticks=yticks, yticklabels=yticklabels,)
+for text in axs[0].get_yticklabels():
+ text.set(rotation=45, size=8)
+axs[1].set(xlim=flim, ylim=ylim, xlabel='Frequency (Hz)',
+ ylabel='Magnitude (dB)')
+box_off(axs[0])
+box_off(axs[1])
+mne.viz.tight_layout()
+plt.show()
+
+###############################################################################
+# IIR filters
+# ===========
+#
+# MNE-Python also offers IIR filtering functionality that is based on the
+# methods from :mod:`scipy.signal`. Specifically, we use the general-purpose
+# functions :func:`scipy.signal.iirfilter` and :func:`scipy.signal.iirdesign`,
+# which provide unified interfaces to IIR filter design.
+#
+# Designing IIR filters
+# ---------------------
+#
+# Let's continue with our design of a 40 Hz low-pass filter, and look at
+# some trade-offs of different IIR filters.
+#
+# Often the default IIR filter is a `Butterworth filter`_, which is designed
+# to have a *maximally flat pass-band*. Let's look at a few orders of filter,
+# i.e., a few different number of coefficients used and therefore steepness
+# of the filter:
+
+sos = signal.iirfilter(2, f_p / nyq, btype='low', ftype='butter', output='sos')
+plot_filter(sos, 'Butterworth order=2', freq, gain)
+
+# Eventually this will just be from scipy signal.sosfiltfilt, but 0.18 is
+# not widely adopted yet (as of June 2016), so we use our wrapper...
+sosfiltfilt = mne.fixes.get_sosfiltfilt()
+x_shallow = sosfiltfilt(sos, x)
+
+###############################################################################
+# The falloff of this filter is not very steep.
+#
+# .. warning:: For brevity, we do not show the phase of these filters here.
+# In the FIR case, we can design linear-phase filters, and
+# compensate for the delay (making the filter acausal) if
+# necessary. This cannot be done
+# with IIR filters, as they have a non-linear phase.
+# As the filter order increases, the
+# phase distortion near and in the transition band worsens.
+# However, if acausal (forward-backward) filtering can be used,
+# e.g. with :func:`scipy.signal.filtfilt`, these phase issues
+# can be mitigated.
+#
+# .. note:: Here we have made use of second-order sections (SOS)
+# by using :func:`scipy.signal.sosfilt` and, under the
+# hood, :func:`scipy.signal.zpk2sos` when passing the
+# ``output='sos'`` keyword argument to
+# :func:`scipy.signal.iirfilter`. The filter definitions
+# given in tut_filtering_basics_ use the polynomial
+# numerator/denominator (sometimes called "tf") form ``(b, a)``,
+# which are theoretically equivalent to the SOS form used here.
+# In practice, however, the SOS form can give much better results
+# due to issues with numerical precision (see
+# :func:`scipy.signal.sosfilt` for an example), so SOS should be
+# used when possible to do IIR filtering.
+#
+# Let's increase the order, and note that now we have better attenuation,
+# with a longer impulse response:
+
+sos = signal.iirfilter(8, f_p / nyq, btype='low', ftype='butter', output='sos')
+plot_filter(sos, 'Butterworth order=8', freq, gain)
+x_steep = sosfiltfilt(sos, x)
+
+###############################################################################
+# There are other types of IIR filters that we can use. For a complete list,
+# check out the documentation for :func:`scipy.signal.iirdesign`. Let's
+# try a Chebychev (type I) filter, which trades off ripple in the pass-band
+# to get better attenuation in the stop-band:
+
+sos = signal.iirfilter(8, f_p / nyq, btype='low', ftype='cheby1', output='sos',
+ rp=1) # dB of acceptable pass-band ripple
+plot_filter(sos, 'Chebychev-1 order=8, ripple=1 dB', freq, gain)
+
+###############################################################################
+# And if we can live with even more ripple, we can get it slightly steeper,
+# but the impulse response begins to ring substantially longer (note the
+# different x-axis scale):
+
+sos = signal.iirfilter(8, f_p / nyq, btype='low', ftype='cheby1', output='sos',
+ rp=6)
+plot_filter(sos, 'Chebychev-1 order=8, ripple=6 dB', freq, gain)
+
+###############################################################################
+# Applying IIR filters
+# --------------------
+#
+# Now let's look at how our shallow and steep Butterworth IIR filters
+# perform on our Morlet signal from before:
+
+axs = plt.subplots(1, 2)[1]
+yticks = np.arange(4) / -30.
+yticklabels = ['Original', 'Noisy', 'Butterworth-2', 'Butterworth-8']
+plot_signal(x_orig, offset=yticks[0])
+plot_signal(x, offset=yticks[1])
+plot_signal(x_shallow, offset=yticks[2])
+plot_signal(x_steep, offset=yticks[3])
+axs[0].set(xlim=tlim, title='IIR, Lowpass=%d Hz' % f_p, xticks=tticks,
+ ylim=[-0.125, 0.025], yticks=yticks, yticklabels=yticklabels,)
+for text in axs[0].get_yticklabels():
+ text.set(rotation=45, size=8)
+axs[1].set(xlim=flim, ylim=ylim, xlabel='Frequency (Hz)',
+ ylabel='Magnitude (dB)')
+box_off(axs[0])
+box_off(axs[1])
+mne.viz.tight_layout()
+plt.show()
+
+###############################################################################
+# Some pitfalls of filtering
+# ==========================
+#
+# Multiple recent papers have noted potential risks of drawing
+# errant inferences due to misapplication of filters.
+#
+# Low-pass problems
+# -----------------
+#
+# Filters in general, especially those that are acausal (zero-phase), can make
+# activity appear to occur earlier or later than it truly did. As
+# mentioned in VanRullen 2011 [3]_, investigations of commonly (at the time)
+# used low-pass filters created artifacts when they were applied to smulated
+# data. However, such deleterious effects were minimal in many real-world
+# examples in Rousselet 2012 [5]_.
+#
+# Perhaps more revealing, it was noted in Widmann & Schröger 2012 [6]_ that
+# the problematic low-pass filters from VanRullen 2011 [3]_:
+#
+# 1. Used a least-squares design (like :func:`scipy.signal.firls`) that
+# included "do-not-care" transition regions, which can lead to
+# uncontrolled behavior.
+# 2. Had a filter length that was independent of the transition bandwidth,
+# which can cause excessive ringing and signal distortion.
+#
+# .. _tut_filtering_hp_problems:
+#
+# High-pass problems
+# ------------------
+#
+# When it comes to high-pass filtering, using corner frequencies above 0.1 Hz
+# were found in Acunzo *et al.* 2012 [4]_ to:
+#
+# "...generate a systematic bias easily leading to misinterpretations of
+# neural activity.”
+#
+# In a related paper, Widmann *et al.* 2015 [7]_ also came to suggest a 0.1 Hz
+# highpass. And more evidence followed in Tanner *et al.* 2015 [8]_ of such
+# distortions. Using data from language ERP studies of semantic and syntactic
+# processing (i.e., N400 and P600), using a high-pass above 0.3 Hz caused
+# significant effects to be introduced implausibly early when compared to the
+# unfiltered data. From this, the authors suggested the optimal high-pass
+# value for language processing to be 0.1 Hz.
+#
+# We can recreate a problematic simulation from Tanner *et al.* 2015 [8]_:
+#
+# "The simulated component is a single-cycle cosine wave with an amplitude
+# of 5µV, onset of 500 ms poststimulus, and duration of 800 ms. The
+# simulated component was embedded in 20 s of zero values to avoid
+# filtering edge effects... Distortions [were] caused by 2 Hz low-pass and
+# high-pass filters... No visible distortion to the original waveform
+# [occurred] with 30 Hz low-pass and 0.01 Hz high-pass filters...
+# Filter frequencies correspond to the half-amplitude (-6 dB) cutoff
+# (12 dB/octave roll-off)."
+#
+# .. note:: This simulated signal contains energy not just within the
+# pass-band, but also within the transition and stop-bands -- perhaps
+# most easily understood because the signal has a non-zero DC value,
+# but also because it is a shifted cosine that has been
+# *windowed* (here multiplied by a rectangular window), which
+# makes the cosine and DC frequencies spread to other frequencies
+# (multiplication in time is convolution in frequency, so multiplying
+# by a rectangular window in the time domain means convolving a sinc
+# function with the impulses at DC and the cosine frequency in the
+# frequency domain).
+#
+
+x = np.zeros(int(2 * sfreq))
+t = np.arange(0, len(x)) / sfreq - 0.2
+onset = np.where(t >= 0.5)[0][0]
+cos_t = np.arange(0, int(sfreq * 0.8)) / sfreq
+sig = 2.5 - 2.5 * np.cos(2 * np.pi * (1. / 0.8) * cos_t)
+x[onset:onset + len(sig)] = sig
+
+iir_lp_30 = signal.iirfilter(2, 30. / sfreq, btype='lowpass')
+iir_hp_p1 = signal.iirfilter(2, 0.1 / sfreq, btype='highpass')
+iir_lp_2 = signal.iirfilter(2, 2. / sfreq, btype='lowpass')
+iir_hp_2 = signal.iirfilter(2, 2. / sfreq, btype='highpass')
+x_lp_30 = signal.filtfilt(iir_lp_30[0], iir_lp_30[1], x, padlen=0)
+x_hp_p1 = signal.filtfilt(iir_hp_p1[0], iir_hp_p1[1], x, padlen=0)
+x_lp_2 = signal.filtfilt(iir_lp_2[0], iir_lp_2[1], x, padlen=0)
+x_hp_2 = signal.filtfilt(iir_hp_2[0], iir_hp_2[1], x, padlen=0)
+
+xlim = t[[0, -1]]
+ylim = [-2, 6]
+xlabel = 'Time (sec)'
+ylabel = 'Amplitude ($\mu$V)'
+tticks = [0, 0.5, 1.3, t[-1]]
+axs = plt.subplots(2, 2)[1].ravel()
+for ax, x_f, title in zip(axs, [x_lp_2, x_lp_30, x_hp_2, x_hp_p1],
+ ['LP$_2$', 'LP$_{30}$', 'HP$_2$', 'LP$_{0.1}$']):
+ ax.plot(t, x, color='0.5')
+ ax.plot(t, x_f, color='k', linestyle='--')
+ ax.set(ylim=ylim, xlim=xlim, xticks=tticks,
+ title=title, xlabel=xlabel, ylabel=ylabel)
+ box_off(ax)
+mne.viz.tight_layout()
+plt.show()
+
+###############################################################################
+# Similarly, in a P300 paradigm reported by Kappenman & Luck 2010 [12]_,
+# they found that applying a 1 Hz high-pass decreased the probaility of
+# finding a significant difference in the N100 response, likely because
+# the P300 response was smeared (and inverted) in time by the high-pass
+# filter such that it tended to cancel out the increased N100. However,
+# they nonetheless note that some high-passing can still be useful to deal
+# with drifts in the data.
+#
+# Even though these papers generally advise a 0.1 HZ or lower frequency for
+# a high-pass, it is important to keep in mind (as most authors note) that
+# filtering choices should depend on the frequency content of both the
+# signal(s) of interest and the noise to be suppressed. For example, in
+# some of the MNE-Python examples involving :ref:`ch_sample_data`,
+# high-pass values of around 1 Hz are used when looking at auditory
+# or visual N100 responses, because we analyze standard (not deviant) trials
+# and thus expect that contamination by later or slower components will
+# be limited.
+#
+# Baseline problems (or solutions?)
+# ---------------------------------
+#
+# In an evolving discussion, Tanner *et al.* 2015 [8]_ suggest using baseline
+# correction to remove slow drifts in data. However, Maess *et al.* 2016 [9]_
+# suggest that baseline correction, which is a form of high-passing, does
+# not offer substantial advantages over standard high-pass filtering.
+# Tanner *et al.* [10]_ rebutted that baseline correction can correct for
+# problems with filtering.
+#
+# To see what they mean, consider again our old simulated signal ``x`` from
+# before:
+
+
+def baseline_plot(x):
+ all_axs = plt.subplots(3, 2)[1]
+ for ri, (axs, freq) in enumerate(zip(all_axs, [0.1, 0.3, 0.5])):
+ for ci, ax in enumerate(axs):
+ if ci == 0:
+ iir_hp = signal.iirfilter(4, freq / sfreq, btype='highpass',
+ output='sos')
+ x_hp = sosfiltfilt(iir_hp, x, padlen=0)
+ else:
+ x_hp -= x_hp[t < 0].mean()
+ ax.plot(t, x, color='0.5')
+ ax.plot(t, x_hp, color='k', linestyle='--')
+ if ri == 0:
+ ax.set(title=('No ' if ci == 0 else '') +
+ 'Baseline Correction')
+ box_off(ax)
+ ax.set(xticks=tticks, ylim=ylim, xlim=xlim, xlabel=xlabel)
+ ax.set_ylabel('%0.1f Hz' % freq, rotation=0,
+ horizontalalignment='right')
+ mne.viz.tight_layout()
+ plt.suptitle(title)
+ plt.show()
+
+baseline_plot(x)
+
+###############################################################################
+# In respose, Maess *et al.* 2016 [11]_ note that these simulations do not
+# address cases of pre-stimulus activity that is shared across conditions, as
+# applying baseline correction will effectively copy the topology outside the
+# baseline period. We can see this if we give our signal ``x`` with some
+# consistent pre-stimulus activity, which makes everything look bad.
+#
+# .. note:: An important thing to keep in mind with these plots is that they
+# are for a single simulated sensor. In multielectrode recordings
+# the topology (i.e., spatial pattiern) of the pre-stimulus activity
+# will leak into the post-stimulus period. This will likely create a
+# spatially varying distortion of the time-domain signals, as the
+# averaged pre-stimulus spatial pattern gets subtracted from the
+# sensor time courses.
+#
+# Putting some activity in the baseline period:
+
+n_pre = (t < 0).sum()
+sig_pre = 1 - np.cos(2 * np.pi * np.arange(n_pre) / (0.5 * n_pre))
+x[:n_pre] += sig_pre
+baseline_plot(x)
+
+###############################################################################
+# Both groups seem to acknowledge that the choices of filtering cutoffs, and
+# perhaps even the application of baseline correction, depend on the
+# characteristics of the data being investigated, especially when it comes to:
+#
+# 1. The frequency content of the underlying evoked activity relative
+# to the filtering parameters.
+# 2. The validity of the assumption of no consistent evoked activity
+# in the baseline period.
+#
+# We thus recommend carefully applying baseline correction and/or high-pass
+# values based on the characteristics of the data to be analyzed.
+#
+#
+# Filtering defaults
+# ==================
+#
+# .. _tut_filtering_in_python:
+#
+# Defaults in MNE-Python
+# ----------------------
+#
+# Most often, filtering in MNE-Python is done at the :class:`mne.io.Raw` level,
+# and thus :func:`mne.io.Raw.filter` is used. This function under the hood
+# (among other things) calls :func:`mne.filter.filter_data` to actually
+# filter the data, which by default applies a zero-phase FIR filter designed
+# using :func:`scipy.signal.firwin2`. In Widmann *et al.* 2015 [7]_, they
+# suggest a specific set of parameters to use for high-pass filtering,
+# including:
+#
+# "... providing a transition bandwidth of 25% of the lower passband
+# edge but, where possible, not lower than 2 Hz and otherwise the
+# distance from the passband edge to the critical frequency.”
+#
+# In practice, this means that for each high-pass value ``l_freq`` or
+# low-pass value ``h_freq`` below, you would get this corresponding
+# ``l_trans_bandwidth`` or ``h_trans_bandwidth``, respectively,
+# if the sample rate were 100 Hz (i.e., Nyquist frequency of 50 Hz):
+#
+# +------------------+-------------------+-------------------+
+# | l_freq or h_freq | l_trans_bandwidth | h_trans_bandwidth |
+# +==================+===================+===================+
+# | 0.01 | 0.01 | 2.0 |
+# +------------------+-------------------+-------------------+
+# | 0.1 | 0.1 | 2.0 |
+# +------------------+-------------------+-------------------+
+# | 1.0 | 1.0 | 2.0 |
+# +------------------+-------------------+-------------------+
+# | 2.0 | 2.0 | 2.0 |
+# +------------------+-------------------+-------------------+
+# | 4.0 | 2.0 | 2.0 |
+# +------------------+-------------------+-------------------+
+# | 8.0 | 2.0 | 2.0 |
+# +------------------+-------------------+-------------------+
+# | 10.0 | 2.5 | 2.5 |
+# +------------------+-------------------+-------------------+
+# | 20.0 | 5.0 | 5.0 |
+# +------------------+-------------------+-------------------+
+# | 40.0 | 10.0 | 10.0 |
+# +------------------+-------------------+-------------------+
+# | 45.0 | 11.25 | 5.0 |
+# +------------------+-------------------+-------------------+
+# | 48.0 | 12.0 | 2.0 |
+# +------------------+-------------------+-------------------+
+#
+# MNE-Python has adopted this definition for its high-pass (and low-pass)
+# transition bandwidth choices when using ``l_trans_bandwidth='auto'`` and
+# ``h_trans_bandwidth='auto'``.
+#
+# To choose the filter length automatically with ``filter_length='auto'``,
+# the reciprocal of the shortest transition bandwidth is used to ensure
+# decent attenuation at the stop frequency. Specifically, the reciprocal
+# (in samples) is multiplied by 6.2, 6.6, or 11.0 for the Hann, Hamming,
+# or Blackman windows, respectively as selected by the ``fir_window``
+# argument.
+#
+# .. note:: These multiplicative factors are double what is given in
+# Ifeachor and Jervis [2]_ (p. 357). The window functions have a
+# smearing effect on the frequency response; I&J thus take the
+# approach of setting the stop frequency as
+# :math:`f_s = f_p + f_{trans} / 2.`, but our stated definitions of
+# :math:`f_s` and :math:`f_{trans}` do not
+# allow us to do this in a nice way. Instead, we increase our filter
+# length to achieve acceptable (20+ dB) attenuation by
+# :math:`f_s = f_p + f_{trans}`, and excellent (50+ dB)
+# attenuation by :math:`f_s + f_{trans}` (and usually earlier).
+#
+# In 0.14, we default to using a Hamming window in filter design, as it
+# provides up to 53 dB of stop-band attenuation with small pass-band ripple.
+#
+# .. note:: In band-pass applications, often a low-pass filter can operate
+# effectively with fewer samples than the high-pass filter, so
+# it is advisable to apply the high-pass and low-pass separately.
+#
+# For more information on how to use the
+# MNE-Python filtering functions with real data, consult the preprocessing
+# tutorial on :ref:`tut_artifacts_filter`.
+#
+# Defaults in MNE-C
+# -----------------
+# MNE-C by default uses:
+#
+# 1. 5 Hz transition band for low-pass filters.
+# 2. 3-sample transition band for high-pass filters.
+# 3. Filter length of 8197 samples.
+#
+# The filter is designed in the frequency domain, creating a linear-phase
+# filter such that the delay is compensated for as is done with the MNE-Python
+# ``phase='zero'`` filtering option.
+#
+# Squared-cosine ramps are used in the transition regions. Because these
+# are used in place of more gradual (e.g., linear) transitions,
+# a given transition width will result in more temporal ringing but also more
+# rapid attenuation than the same transition width in windowed FIR designs.
+#
+# The default filter length will generally have excellent attenuation
+# but long ringing for the sample rates typically encountered in M-EEG data
+# (e.g. 500-2000 Hz).
+#
+# Defaults in other software
+# --------------------------
+# A good but possibly outdated comparison of filtering in various software
+# packages is available in [7]_. Briefly:
+#
+# * EEGLAB
+# MNE-Python in 0.14 defaults to behavior very similar to that of EEGLAB,
+# see the `EEGLAB filtering FAQ`_ for more information.
+# * Fieldrip
+# By default FieldTrip applies a forward-backward Butterworth IIR filter
+# of order 4 (band-pass and band-stop filters) or 2 (for low-pass and
+# high-pass filters). Similar filters can be achieved in MNE-Python when
+# filtering with :meth:`raw.filter(..., method='iir') <mne.io.Raw.filter>`
+# (see also :func:`mne.filter.construct_iir_filter` for options).
+# For more inforamtion, see e.g. `FieldTrip band-pass documentation`_.
+#
+# Summary
+# =======
+#
+# When filtering, there are always tradeoffs that should be considered.
+# One important tradeoff is between time-domain characteristics (like ringing)
+# and frequency-domain attenuation characteristics (like effective transition
+# bandwidth). Filters with sharp frequency cutoffs can produce outputs that
+# ring for a long time when they operate on signals with frequency content
+# in the transition band. In general, therefore, the wider a transition band
+# that can be tolerated, the better behaved the filter will be in the time
+# domain.
+#
+# References
+# ==========
+#
+# .. [1] Parks TW, Burrus CS (1987). Digital Filter Design.
+# New York: Wiley-Interscience.
+# .. [2] Ifeachor, E. C., & Jervis, B. W. (2002). Digital Signal Processing:
+# A Practical Approach. Prentice Hall.
+# .. [3] Vanrullen, R. (2011). Four common conceptual fallacies in mapping
+# the time course of recognition. Perception Science, 2, 365.
+# .. [4] Acunzo, D. J., MacKenzie, G., & van Rossum, M. C. W. (2012).
+# Systematic biases in early ERP and ERF components as a result
+# of high-pass filtering. Journal of Neuroscience Methods,
+# 209(1), 212–218. http://doi.org/10.1016/j.jneumeth.2012.06.011
+# .. [5] Rousselet, G. A. (2012). Does filtering preclude us from studying
+# ERP time-courses? Frontiers in Psychology, 3(131)
+# .. [6] Widmann, A., & Schröger, E. (2012). Filter effects and filter
+# artifacts in the analysis of electrophysiological data.
+# Perception Science, 233.
+# .. [7] Widmann, A., Schröger, E., & Maess, B. (2015). Digital filter
+# design for electrophysiological data – a practical approach.
+# Journal of Neuroscience Methods, 250, 34–46.
+# .. [8] Tanner, D., Morgan-Short, K., & Luck, S. J. (2015).
+# How inappropriate high-pass filters can produce artifactual effects
+# and incorrect conclusions in ERP studies of language and cognition.
+# Psychophysiology, 52(8), 997–1009. http://doi.org/10.1111/psyp.12437
+# .. [9] Maess, B., Schröger, E., & Widmann, A. (2016).
+# High-pass filters and baseline correction in M/EEG analysis.
+# Commentary on: “How inappropriate high-pass filters can produce
+# artefacts and incorrect conclusions in ERP studies of language
+# and cognition.” Journal of Neuroscience Methods, 266, 164–165.
+# .. [10] Tanner, D., Norton, J. J. S., Morgan-Short, K., & Luck, S. J. (2016).
+# On high-pass filter artifacts (they’re real) and baseline correction
+# (it’s a good idea) in ERP/ERMF analysis.
+# .. [11] Maess, B., Schröger, E., & Widmann, A. (2016).
+# High-pass filters and baseline correction in M/EEG analysis-continued
+# discussion. Journal of Neuroscience Methods, 266, 171–172.
+# Journal of Neuroscience Methods, 266, 166–170.
+# .. [12] Kappenman E. & Luck, S. (2010). The effects of impedance on data
+# quality and statistical significance in ERP recordings.
+# Psychophysiology, 47, 888-904.
+#
+# .. _FIR: https://en.wikipedia.org/wiki/Finite_impulse_response
+# .. _IIR: https://en.wikipedia.org/wiki/Infinite_impulse_response
+# .. _sinc: https://en.wikipedia.org/wiki/Sinc_function
+# .. _moving average: https://en.wikipedia.org/wiki/Moving_average
+# .. _autoregression: https://en.wikipedia.org/wiki/Autoregressive_model
+# .. _Remez: https://en.wikipedia.org/wiki/Remez_algorithm
+# .. _matlab firpm: http://www.mathworks.com/help/signal/ref/firpm.html
+# .. _matlab fir2: http://www.mathworks.com/help/signal/ref/fir2.html
+# .. _matlab firls: http://www.mathworks.com/help/signal/ref/firls.html
+# .. _Butterworth filter: https://en.wikipedia.org/wiki/Butterworth_filter
+# .. _eeglab filtering faq: https://sccn.ucsd.edu/wiki/Firfilt_FAQ
+# .. _fieldtrip band-pass documentation: http://www.fieldtriptoolbox.org/reference/ft_preproc_bandpassfilter # noqa
diff --git a/tutorials/plot_brainstorm_auditory.py b/tutorials/plot_brainstorm_auditory.py
index a2adb75..ecc7196 100644
--- a/tutorials/plot_brainstorm_auditory.py
+++ b/tutorials/plot_brainstorm_auditory.py
@@ -5,18 +5,18 @@ Brainstorm auditory tutorial dataset
====================================
Here we compute the evoked from raw for the auditory Brainstorm
-tutorial dataset. For comparison, see [1]_ and
-http://neuroimage.usc.edu/brainstorm/Tutorials/Auditory
+tutorial dataset. For comparison, see [1]_ and:
+
+ http://neuroimage.usc.edu/brainstorm/Tutorials/Auditory
Experiment:
- - One subject 2 acquisition runs 6 minutes each.
+ - One subject, 2 acquisition runs 6 minutes each.
- Each run contains 200 regular beeps and 40 easy deviant beeps.
- Random ISI: between 0.7s and 1.7s seconds, uniformly distributed.
- Button pressed when detecting a deviant with the right index finger.
-The specifications of this dataset were discussed initially on the FieldTrip
-bug tracker:
-http://bugzilla.fcdonders.nl/show_bug.cgi?id=2300
+The specifications of this dataset were discussed initially on the
+`FieldTrip bug tracker <http://bugzilla.fcdonders.nl/show_bug.cgi?id=2300>`_.
References
----------
@@ -82,6 +82,7 @@ raw_erm = read_raw_ctf(erm_fname, preload=preload)
# Data channel array consisted of 274 MEG axial gradiometers, 26 MEG reference
# sensors and 2 EEG electrodes (Cz and Pz).
# In addition:
+#
# - 1 stim channel for marking presentation times for the stimuli
# - 1 audio channel for the sent signal
# - 1 response channel for recording the button presses
@@ -89,6 +90,7 @@ raw_erm = read_raw_ctf(erm_fname, preload=preload)
# - 2 EOG bipolar (vertical and horizontal)
# - 12 head tracking channels
# - 20 unused channels
+#
# The head tracking channels and the unused channels are marked as misc
# channels. Here we define the EOG and ECG channels.
raw.set_channel_types({'HEOG': 'eog', 'VEOG': 'eog', 'ECG': 'ecg'})
@@ -174,7 +176,8 @@ if not use_precomputed:
###############################################################################
# We also lowpass filter the data at 100 Hz to remove the hf components.
if not use_precomputed:
- raw.filter(None, 100.)
+ raw.filter(None, 100., h_trans_bandwidth=0.5, filter_length='10s',
+ phase='zero-double')
###############################################################################
# Epoching and averaging.
@@ -284,7 +287,7 @@ evoked_dev.plot_topomap(times=times, title='Deviant')
# We can see the MMN effect more clearly by looking at the difference between
# the two conditions. P50 and N100 are no longer visible, but MMN/P200 and
# P300 are emphasised.
-evoked_difference = combine_evoked([evoked_dev, evoked_std], weights=[1, -1])
+evoked_difference = combine_evoked([evoked_dev, -evoked_std], weights='equal')
evoked_difference.plot(window_title='Difference', gfp=True)
###############################################################################
@@ -310,7 +313,7 @@ trans = mne.read_trans(trans_fname)
# forward solution from scratch. The head surfaces for constructing a BEM
# solution are read from a file. Since the data only contains MEG channels, we
# only need the inner skull surface for making the forward solution. For more
-# information: :ref:`CHDBBCEJ`, :class:`mne.setup_source_space`,
+# information: :ref:`CHDBBCEJ`, :func:`mne.setup_source_space`,
# :ref:`create_bem_model`, :func:`mne.bem.make_watershed_bem`.
if use_precomputed:
fwd_fname = op.join(data_path, 'MEG', 'bst_auditory',
@@ -337,21 +340,21 @@ del fwd
# Standard condition.
stc_standard = mne.minimum_norm.apply_inverse(evoked_std, inv, lambda2, 'dSPM')
brain = stc_standard.plot(subjects_dir=subjects_dir, subject=subject,
- surface='inflated', time_viewer=False, hemi='lh')
-brain.set_data_time_index(120)
-del stc_standard, evoked_std, brain
+ surface='inflated', time_viewer=False, hemi='lh',
+ initial_time=0.1, time_unit='s')
+del stc_standard, brain
###############################################################################
# Deviant condition.
stc_deviant = mne.minimum_norm.apply_inverse(evoked_dev, inv, lambda2, 'dSPM')
brain = stc_deviant.plot(subjects_dir=subjects_dir, subject=subject,
- surface='inflated', time_viewer=False, hemi='lh')
-brain.set_data_time_index(120)
-del stc_deviant, evoked_dev, brain
+ surface='inflated', time_viewer=False, hemi='lh',
+ initial_time=0.1, time_unit='s')
+del stc_deviant, brain
###############################################################################
# Difference.
stc_difference = apply_inverse(evoked_difference, inv, lambda2, 'dSPM')
brain = stc_difference.plot(subjects_dir=subjects_dir, subject=subject,
- surface='inflated', time_viewer=False, hemi='lh')
-brain.set_data_time_index(150)
+ surface='inflated', time_viewer=False, hemi='lh',
+ initial_time=0.15, time_unit='s')
diff --git a/tutorials/plot_brainstorm_phantom_ctf.py b/tutorials/plot_brainstorm_phantom_ctf.py
new file mode 100644
index 0000000..70daf66
--- /dev/null
+++ b/tutorials/plot_brainstorm_phantom_ctf.py
@@ -0,0 +1,112 @@
+# -*- coding: utf-8 -*-
+"""
+=======================================
+Brainstorm CTF phantom tutorial dataset
+=======================================
+
+Here we compute the evoked from raw for the Brainstorm CTF phantom
+tutorial dataset. For comparison, see [1]_ and:
+
+ http://neuroimage.usc.edu/brainstorm/Tutorials/PhantomCtf
+
+References
+----------
+.. [1] Tadel F, Baillet S, Mosher JC, Pantazis D, Leahy RM.
+ Brainstorm: A User-Friendly Application for MEG/EEG Analysis.
+ Computational Intelligence and Neuroscience, vol. 2011, Article ID
+ 879716, 13 pages, 2011. doi:10.1155/2011/879716
+"""
+
+# Authors: Eric Larson <larson.eric.d at gmail.com>
+#
+# License: BSD (3-clause)
+
+import os.path as op
+import numpy as np
+import matplotlib.pyplot as plt
+
+import mne
+from mne import fit_dipole
+from mne.datasets.brainstorm import bst_phantom_ctf
+from mne.io import read_raw_ctf
+
+print(__doc__)
+
+###############################################################################
+# The data were collected with a CTF system at 2400 Hz.
+data_path = bst_phantom_ctf.data_path()
+
+# Switch to these to use the higher-SNR data:
+# raw_path = op.join(data_path, 'phantom_200uA_20150709_01.ds')
+# dip_freq = 7.
+raw_path = op.join(data_path, 'phantom_20uA_20150603_03.ds')
+dip_freq = 23.
+erm_path = op.join(data_path, 'emptyroom_20150709_01.ds')
+raw = read_raw_ctf(raw_path, preload=True)
+
+###############################################################################
+# The sinusoidal signal is generated on channel HDAC006, so we can use
+# that to obtain precise timing.
+
+sinusoid, times = raw[raw.ch_names.index('HDAC006-4408')]
+plt.figure()
+plt.plot(times[times < 1.], sinusoid.T[times < 1.])
+
+###############################################################################
+# Let's create some events using this signal by thresholding the sinusoid.
+
+events = np.where(np.diff(sinusoid > 0.5) > 0)[1] + raw.first_samp
+events = np.vstack((events, np.zeros_like(events), np.ones_like(events))).T
+
+###############################################################################
+# The CTF software compensation works reasonably well:
+
+raw.plot()
+
+###############################################################################
+# But here we can get slightly better noise suppression, lower localization
+# bias, and a better dipole goodness of fit with spatio-temporal (tSSS)
+# Maxwell filtering:
+
+raw.apply_gradient_compensation(0) # must un-do software compensation first
+mf_kwargs = dict(origin=(0., 0., 0.), st_duration=10.)
+raw = mne.preprocessing.maxwell_filter(raw, **mf_kwargs)
+raw.plot()
+
+###############################################################################
+# Our choice of tmin and tmax should capture exactly one cycle, so
+# we can make the unusual choice of baselining using the entire epoch
+# when creating our evoked data. We also then crop to a single time point
+# (@t=0) because this is a peak in our signal.
+
+tmin = -0.5 / dip_freq
+tmax = -tmin
+epochs = mne.Epochs(raw, events, event_id=1, tmin=tmin, tmax=tmax,
+ baseline=(None, None))
+evoked = epochs.average()
+evoked.plot()
+evoked.crop(0., 0.)
+del raw, epochs
+
+###############################################################################
+# To do a dipole fit, let's use the covariance provided by the empty room
+# recording.
+
+raw_erm = read_raw_ctf(erm_path).apply_gradient_compensation(0)
+raw_erm = mne.preprocessing.maxwell_filter(raw_erm, coord_frame='meg',
+ **mf_kwargs)
+cov = mne.compute_raw_covariance(raw_erm)
+del raw_erm
+sphere = mne.make_sphere_model(r0=(0., 0., 0.), head_radius=None)
+dip = fit_dipole(evoked, cov, sphere)[0]
+
+###############################################################################
+# Compare the actual position with the estimated one.
+
+expected_pos = np.array([18., 0., 49.])
+diff = np.sqrt(np.sum((dip.pos[0] * 1000 - expected_pos) ** 2))
+print('Actual pos: %s mm' % np.array_str(expected_pos, precision=1))
+print('Estimated pos: %s mm' % np.array_str(dip.pos[0] * 1000, precision=1))
+print('Difference: %0.1f mm' % diff)
+print('Amplitude: %0.1f nAm' % (1e9 * dip.amplitude[0]))
+print('GOF: %0.1f %%' % dip.gof[0])
diff --git a/tutorials/plot_brainstorm_phantom_elekta.py b/tutorials/plot_brainstorm_phantom_elekta.py
new file mode 100644
index 0000000..75d3713
--- /dev/null
+++ b/tutorials/plot_brainstorm_phantom_elekta.py
@@ -0,0 +1,106 @@
+# -*- coding: utf-8 -*-
+"""
+==========================================
+Brainstorm Elekta phantom tutorial dataset
+==========================================
+
+Here we compute the evoked from raw for the Brainstorm Elekta phantom
+tutorial dataset. For comparison, see [1]_ and:
+
+ http://neuroimage.usc.edu/brainstorm/Tutorials/PhantomElekta
+
+References
+----------
+.. [1] Tadel F, Baillet S, Mosher JC, Pantazis D, Leahy RM.
+ Brainstorm: A User-Friendly Application for MEG/EEG Analysis.
+ Computational Intelligence and Neuroscience, vol. 2011, Article ID
+ 879716, 13 pages, 2011. doi:10.1155/2011/879716
+"""
+
+# Authors: Eric Larson <larson.eric.d at gmail.com>
+#
+# License: BSD (3-clause)
+
+import os.path as op
+import numpy as np
+
+import mne
+from mne import find_events, fit_dipole
+from mne.datasets.brainstorm import bst_phantom_elekta
+from mne.io import read_raw_fif
+
+print(__doc__)
+
+###############################################################################
+# The data were collected with an Elekta Neuromag VectorView system at 1000 Hz
+# and low-pass filtered at 330 Hz. Here the medium-amplitude (200 nAm) data
+# are read to construct instances of :class:`mne.io.Raw`.
+data_path = bst_phantom_elekta.data_path()
+
+raw_fname = op.join(data_path, 'kojak_all_200nAm_pp_no_chpi_no_ms_raw.fif')
+raw = read_raw_fif(raw_fname, add_eeg_ref=False)
+
+###############################################################################
+# Data channel array consisted of 204 MEG planor gradiometers,
+# 102 axial magnetometers, and 3 stimulus channels. Let's get the events
+# for the phantom, where each dipole (1-32) gets its own event:
+
+events = find_events(raw, 'STI201')
+raw.plot(events=events)
+raw.info['bads'] = ['MEG2421']
+
+###############################################################################
+# The data have strong line frequency (60 Hz and harmonics) and cHPI coil
+# noise (five peaks around 300 Hz). Here we plot only out to 60 seconds
+# to save memory:
+
+raw.plot_psd(tmax=60.)
+
+###############################################################################
+# Let's use Maxwell filtering to clean the data a bit.
+# Ideally we would have the fine calibration and cross-talk information
+# for the site of interest, but we don't, so we just do:
+
+raw.fix_mag_coil_types()
+raw = mne.preprocessing.maxwell_filter(raw, origin=(0., 0., 0.))
+
+###############################################################################
+# We know our phantom produces sinusoidal bursts below 25 Hz, so let's filter.
+
+raw.filter(None, 40., h_trans_bandwidth='auto', filter_length='auto',
+ phase='zero')
+raw.plot(events=events)
+
+###############################################################################
+# Now we epoch our data, average it, and look at the first dipole response.
+# The first peak appears around 3 ms. Because we low-passed at 40 Hz,
+# we can also decimate our data to save memory.
+
+tmin, tmax = -0.1, 0.1
+event_id = list(range(1, 33))
+epochs = mne.Epochs(raw, events, event_id, tmin, tmax, baseline=(None, -0.01),
+ decim=5, preload=True, add_eeg_ref=False)
+epochs['1'].average().plot()
+
+###############################################################################
+# Let's do some dipole fits. The phantom is properly modeled by a single-shell
+# sphere with origin (0., 0., 0.). We compute covariance, then do the fits.
+
+t_peak = 60e-3 # ~60 MS at largest peak
+sphere = mne.make_sphere_model(r0=(0., 0., 0.), head_radius=None)
+cov = mne.compute_covariance(epochs, tmax=0)
+data = []
+for ii in range(1, 33):
+ evoked = epochs[str(ii)].average().crop(t_peak, t_peak)
+ data.append(evoked.data[:, 0])
+evoked = mne.EvokedArray(np.array(data).T, evoked.info, tmin=0.)
+del epochs, raw
+dip = fit_dipole(evoked, cov, sphere, n_jobs=1)[0]
+
+###############################################################################
+# Now we can compare to the actual locations, taking the difference in mm:
+
+actual_pos = mne.dipole.get_phantom_dipoles(kind='122')[0]
+diffs = 1000 * np.sqrt(np.sum((dip.pos - actual_pos) ** 2, axis=-1))
+print('Differences (mm):\n%s' % diffs[:, np.newaxis])
+print('μ = %s' % (np.mean(diffs),))
diff --git a/tutorials/plot_compute_covariance.py b/tutorials/plot_compute_covariance.py
index 55ebd7f..a928568 100644
--- a/tutorials/plot_compute_covariance.py
+++ b/tutorials/plot_compute_covariance.py
@@ -17,9 +17,10 @@ from mne.datasets import sample
data_path = sample.data_path()
raw_empty_room_fname = op.join(
data_path, 'MEG', 'sample', 'ernoise_raw.fif')
-raw_empty_room = mne.io.read_raw_fif(raw_empty_room_fname)
+raw_empty_room = mne.io.read_raw_fif(raw_empty_room_fname, add_eeg_ref=False)
raw_fname = op.join(data_path, 'MEG', 'sample', 'sample_audvis_raw.fif')
-raw = mne.io.read_raw_fif(raw_fname)
+raw = mne.io.read_raw_fif(raw_fname, add_eeg_ref=False)
+raw.set_eeg_reference()
raw.info['bads'] += ['EEG 053'] # bads + 1 more
###############################################################################
@@ -135,4 +136,4 @@ evoked.plot_white(covs)
# vol. 108, 328-342, NeuroImage.
#
# .. [2] Taulu, S., Simola, J., Kajola, M., 2005. Applications of the signal
-# space separation method. IEEE Trans. Signal Proc. 53, 3359–3372.
+# space separation method. IEEE Trans. Signal Proc. 53, 3359-3372.
diff --git a/tutorials/plot_dipole_fit.py b/tutorials/plot_dipole_fit.py
index 54d8856..811a1ab 100644
--- a/tutorials/plot_dipole_fit.py
+++ b/tutorials/plot_dipole_fit.py
@@ -71,7 +71,7 @@ pred_evoked.plot_topomap(time_format='Predicted field', axes=axes[1],
**plot_params)
# Subtract predicted from measured data (apply equal weights)
-diff = combine_evoked([evoked, pred_evoked], [1, -1])
+diff = combine_evoked([evoked, -pred_evoked], weights='equal')
plot_params['colorbar'] = True
diff.plot_topomap(time_format='Difference', axes=axes[2], **plot_params)
plt.suptitle('Comparison of measured and predicted fields '
diff --git a/tutorials/plot_eeg_erp.py b/tutorials/plot_eeg_erp.py
index c22e168..02003dc 100644
--- a/tutorials/plot_eeg_erp.py
+++ b/tutorials/plot_eeg_erp.py
@@ -7,6 +7,7 @@ EEG processing and Event Related Potentials (ERPs)
For a generic introduction to the computation of ERP and ERF
see :ref:`tut_epoching_and_averaging`. Here we cover the specifics
of EEG, namely:
+
- setting the reference
- using standard montages :func:`mne.channels.Montage`
- Evoked arithmetic (e.g. differences)
@@ -21,7 +22,8 @@ from mne.datasets import sample
data_path = sample.data_path()
raw_fname = data_path + '/MEG/sample/sample_audvis_filt-0-40_raw.fif'
event_fname = data_path + '/MEG/sample/sample_audvis_filt-0-40_raw-eve.fif'
-raw = mne.io.read_raw_fif(raw_fname, add_eeg_ref=True, preload=True)
+raw = mne.io.read_raw_fif(raw_fname, add_eeg_ref=False, preload=True)
+raw.set_eeg_reference() # set EEG average reference
###############################################################################
# Let's restrict the data to the EEG channels
@@ -97,7 +99,7 @@ reject = dict(eeg=180e-6, eog=150e-6)
event_id, tmin, tmax = {'left/auditory': 1}, -0.2, 0.5
events = mne.read_events(event_fname)
epochs_params = dict(events=events, event_id=event_id, tmin=tmin, tmax=tmax,
- reject=reject)
+ reject=reject, add_eeg_ref=False)
evoked_no_ref = mne.Epochs(raw_no_ref, **epochs_params).average()
del raw_no_ref # save memory
@@ -151,13 +153,14 @@ print(epochs)
left, right = epochs["left"].average(), epochs["right"].average()
-(left - right).plot_joint() # create and plot difference ERP
+# create and plot difference ERP
+mne.combine_evoked([left, -right], weights='equal').plot_joint()
###############################################################################
-# Note that by default, this is a trial-weighted average. If you have
-# imbalanced trial numbers, consider either equalizing the number of events per
-# condition (using ``Epochs.equalize_event_counts``), or the ``combine_evoked``
-# function.
+# This is an equal-weighting difference. If you have imbalanced trial numbers,
+# you could also consider either equalizing the number of events per
+# condition (using
+# :meth:`epochs.equalize_epochs_counts <mne.Epochs.equalize_event_counts`).
# As an example, first, we create individual ERPs for each condition.
aud_l = epochs["auditory", "left"].average()
@@ -166,12 +169,17 @@ vis_l = epochs["visual", "left"].average()
vis_r = epochs["visual", "right"].average()
all_evokeds = [aud_l, aud_r, vis_l, vis_r]
+print(all_evokeds)
-# This could have been much simplified with a list comprehension:
-# all_evokeds = [epochs[cond] for cond in event_id]
+###############################################################################
+# This can be simplified with a Python list comprehension:
+all_evokeds = [epochs[cond].average() for cond in sorted(event_id.keys())]
+print(all_evokeds)
-# Then, we construct and plot an unweighted average of left vs. right trials.
-mne.combine_evoked(all_evokeds, weights=(1, -1, 1, -1)).plot_joint()
+# Then, we construct and plot an unweighted average of left vs. right trials
+# this way, too:
+mne.combine_evoked(all_evokeds,
+ weights=(0.25, -0.25, 0.25, -0.25)).plot_joint()
###############################################################################
# Often, it makes sense to store Evoked objects in a dictionary or a list -
diff --git a/tutorials/plot_epoching_and_averaging.py b/tutorials/plot_epoching_and_averaging.py
index 22cf3d4..4248fcf 100644
--- a/tutorials/plot_epoching_and_averaging.py
+++ b/tutorials/plot_epoching_and_averaging.py
@@ -17,7 +17,8 @@ import mne
# First let's read in the raw sample data.
data_path = mne.datasets.sample.data_path()
fname = op.join(data_path, 'MEG', 'sample', 'sample_audvis_raw.fif')
-raw = mne.io.read_raw_fif(fname)
+raw = mne.io.read_raw_fif(fname, add_eeg_ref=False)
+raw.set_eeg_reference() # set EEG average reference
###############################################################################
# To create time locked epochs, we first need a set of events that contain the
@@ -109,7 +110,7 @@ picks = mne.pick_types(raw.info, meg=True, eeg=False, eog=True)
baseline = (None, 0.0)
reject = {'mag': 4e-12, 'eog': 200e-6}
epochs = mne.Epochs(raw, events=events, event_id=event_id, tmin=tmin,
- tmax=tmax, reject=reject, picks=picks)
+ tmax=tmax, reject=reject, picks=picks, add_eeg_ref=False)
###############################################################################
# Let's plot the epochs to see the results. The number at the top refers to the
@@ -150,7 +151,7 @@ evoked_right = epochs['Auditory/Right'].average(picks=picks)
epochs_left = epochs['Left']
# ... or to select a very specific subset. This is the same as above:
-evoked_left = epochs['Auditory', 'Left'].average(picks=picks)
+evoked_left = epochs['Left/Auditory'].average(picks=picks)
###############################################################################
# Finally, let's plot the evoked responses.
diff --git a/tutorials/plot_epochs_to_data_frame.py b/tutorials/plot_epochs_to_data_frame.py
index 0c8beca..e99d743 100644
--- a/tutorials/plot_epochs_to_data_frame.py
+++ b/tutorials/plot_epochs_to_data_frame.py
@@ -102,7 +102,8 @@ data_path = sample.data_path()
raw_fname = data_path + '/MEG/sample/sample_audvis_filt-0-40_raw.fif'
event_fname = data_path + '/MEG/sample/sample_audvis_filt-0-40_raw-eve.fif'
-raw = mne.io.read_raw_fif(raw_fname)
+raw = mne.io.read_raw_fif(raw_fname, add_eeg_ref=False)
+raw.set_eeg_reference() # set EEG average reference
# For simplicity we will only consider the first 10 epochs
events = mne.read_events(event_fname)[:10]
@@ -119,7 +120,8 @@ reject = dict(grad=4000e-13, eog=150e-6)
event_id = dict(auditory_l=1, auditory_r=2, visual_l=3, visual_r=4)
epochs = mne.Epochs(raw, events, event_id, tmin, tmax, proj=True, picks=picks,
- baseline=baseline, preload=True, reject=reject)
+ baseline=baseline, preload=True, reject=reject,
+ add_eeg_ref=False)
###############################################################################
# Export DataFrame
@@ -203,7 +205,7 @@ print(max_latency)
df.condition = df.condition.apply(lambda name: name + ' ')
plt.figure()
-max_latency.plot(kind='barh', title='Latency of Maximum Reponse',
+max_latency.plot(kind='barh', title='Latency of Maximum Response',
color=['steelblue'])
mne.viz.tight_layout()
diff --git a/tutorials/plot_forward.py b/tutorials/plot_forward.py
index 3f542b0..a4d2554 100644
--- a/tutorials/plot_forward.py
+++ b/tutorials/plot_forward.py
@@ -27,7 +27,8 @@ subject = 'sample'
# ------------------------------
#
# To compute a forward operator we need:
-# - a -trans.fif file that contains the coregistration info.
+#
+# - a ``-trans.fif`` file that contains the coregistration info.
# - a source space
# - the BEM surfaces
@@ -56,7 +57,7 @@ subject = 'sample'
# reconstruction the necessary files.
mne.viz.plot_bem(subject=subject, subjects_dir=subjects_dir,
- orientation='coronal')
+ brain_surfaces='white', orientation='coronal')
###############################################################################
# Visualization the coregistration
@@ -99,9 +100,15 @@ print(src)
###############################################################################
# src contains two parts, one for the left hemisphere (4098 locations) and
-# one for the right hemisphere (4098 locations).
-#
-# Let's write a few lines of mayavi to see what it contains
+# one for the right hemisphere (4098 locations). Sources can be visualized on
+# top of the BEM surfaces.
+
+mne.viz.plot_bem(subject=subject, subjects_dir=subjects_dir,
+ brain_surfaces='white', src=src, orientation='coronal')
+
+###############################################################################
+# However, only sources that lie in the plotted MRI slices are shown.
+# Let's write a few lines of mayavi to see all sources.
import numpy as np # noqa
from mayavi import mlab # noqa
@@ -171,6 +178,7 @@ print("Leadfield size : %d sensors x %d dipoles" % leadfield.shape)
# By looking at :ref:`sphx_glr_auto_examples_forward_plot_read_forward.py`
# plot the sensitivity maps for EEG and compare it with the MEG, can you
# justify the claims that:
+#
# - MEG is not sensitive to radial sources
# - EEG is more sensitive to deep sources
#
diff --git a/tutorials/plot_ica_from_raw.py b/tutorials/plot_ica_from_raw.py
index 1b61783..4b466ad 100644
--- a/tutorials/plot_ica_from_raw.py
+++ b/tutorials/plot_ica_from_raw.py
@@ -26,8 +26,9 @@ from mne.datasets import sample
data_path = sample.data_path()
raw_fname = data_path + '/MEG/sample/sample_audvis_filt-0-40_raw.fif'
-raw = mne.io.read_raw_fif(raw_fname, preload=True)
-raw.filter(1, 45, n_jobs=1)
+raw = mne.io.read_raw_fif(raw_fname, preload=True, add_eeg_ref=False)
+raw.filter(1, 45, n_jobs=1, l_trans_bandwidth=0.5, h_trans_bandwidth=0.5,
+ filter_length='10s', phase='zero-double')
###############################################################################
# 1) Fit ICA model using the FastICA algorithm
diff --git a/tutorials/plot_info.py b/tutorials/plot_info.py
index e291b70..031c9f8 100644
--- a/tutorials/plot_info.py
+++ b/tutorials/plot_info.py
@@ -15,7 +15,7 @@ import os.path as op
# when data is imported into MNE-Python and contains details such as:
#
# - date, subject information, and other recording details
-# - the samping rate
+# - the sampling rate
# - information about the data channels (name, type, position, etc.)
# - digitized points
# - sensor–head coordinate transformation matrices
diff --git a/tutorials/plot_introduction.py b/tutorials/plot_introduction.py
index 641dcb4..67dc9d9 100644
--- a/tutorials/plot_introduction.py
+++ b/tutorials/plot_introduction.py
@@ -5,6 +5,8 @@
Basic MEG and EEG data processing
=================================
+.. image:: http://mne-tools.github.io/stable/_static/mne_logo.png
+
MNE-Python reimplements most of MNE-C's (the original MNE command line utils)
functionality and offers transparent scripting.
On top of that it extends MNE-C's functionality considerably
@@ -79,12 +81,15 @@ From raw data to evoked data
.. _ipython: http://ipython.scipy.org/
-Now, launch `ipython`_ (Advanced Python shell) using the QT backend which best
-supported across systems::
+Now, launch `ipython`_ (Advanced Python shell) using the QT backend, which
+is best supported across systems::
$ ipython --matplotlib=qt
First, load the mne package:
+
+.. note:: In IPython, you can press **shift-enter** with a given cell
+ selected to execute it and advance to the next cell:
"""
import mne
@@ -103,7 +108,7 @@ mne.set_log_level('INFO')
# You can set the default level by setting the environment variable
# "MNE_LOGGING_LEVEL", or by having mne-python write preferences to a file:
-mne.set_config('MNE_LOGGING_LEVEL', 'WARNING')
+mne.set_config('MNE_LOGGING_LEVEL', 'WARNING', set_env=True)
##############################################################################
# Note that the location of the mne-python preferences file (for easier manual
@@ -129,7 +134,7 @@ print(raw_fname)
#
# Read data from file:
-raw = mne.io.read_raw_fif(raw_fname)
+raw = mne.io.read_raw_fif(raw_fname, add_eeg_ref=False)
print(raw)
print(raw.info)
@@ -170,7 +175,7 @@ print(events[:5])
# system (e.g., a newer system that uses channel 'STI101' by default), you can
# use the following to set the default stim channel to use for finding events:
-mne.set_config('MNE_STIM_CHANNEL', 'STI101')
+mne.set_config('MNE_STIM_CHANNEL', 'STI101', set_env=True)
##############################################################################
# Events are stored as 2D numpy array where the first column is the time
@@ -217,7 +222,8 @@ reject = dict(grad=4000e-13, mag=4e-12, eog=150e-6)
# Read epochs:
epochs = mne.Epochs(raw, events, event_id, tmin, tmax, proj=True, picks=picks,
- baseline=baseline, preload=False, reject=reject)
+ baseline=baseline, preload=False, reject=reject,
+ add_eeg_ref=False)
print(epochs)
##############################################################################
@@ -276,9 +282,36 @@ evoked2 = mne.read_evokeds(
evoked_fname, condition='Right Auditory', baseline=(None, 0), proj=True)
##############################################################################
-# Compute a contrast:
+# Two evoked objects can be contrasted using :func:`mne.combine_evoked`.
+# This function can use ``weights='equal'``, which provides a simple
+# element-by-element subtraction (and sets the
+# :attr:`mne.Evoked.nave` attribute properly based on the underlying number
+# of trials) using either equivalent call:
+
+contrast = mne.combine_evoked([evoked1, evoked2], weights=[0.5, -0.5])
+contrast = mne.combine_evoked([evoked1, -evoked2], weights='equal')
+print(contrast)
+
+##############################################################################
+# To do a weighted sum based on the number of averages, which will give
+# you what you would have gotten from pooling all trials together in
+# :class:`mne.Epochs` before creating the :class:`mne.Evoked` instance,
+# you can use ``weights='nave'``:
+
+average = mne.combine_evoked([evoked1, evoked2], weights='nave')
+print(contrast)
+
+##############################################################################
+# Instead of dealing with mismatches in the number of averages, we can use
+# trial-count equalization before computing a contrast, which can have some
+# benefits in inverse imaging (note that here ``weights='nave'`` will
+# give the same result as ``weights='equal'``):
-contrast = evoked1 - evoked2
+epochs_eq = epochs.copy().equalize_event_counts(['aud_l', 'aud_r'])[0]
+evoked1, evoked2 = epochs_eq['aud_l'].average(), epochs_eq['aud_r'].average()
+print(evoked1)
+print(evoked2)
+contrast = mne.combine_evoked([evoked1, -evoked2], weights='equal')
print(contrast)
##############################################################################
@@ -297,7 +330,7 @@ freqs = np.arange(7, 30, 3) # frequencies of interest
from mne.time_frequency import tfr_morlet # noqa
power, itc = tfr_morlet(epochs, freqs=freqs, n_cycles=n_cycles,
return_itc=True, decim=3, n_jobs=1)
-# power.plot()
+power.plot([power.ch_names.index('MEG 1332')])
##############################################################################
# Inverse modeling: MNE and dSPM on evoked and raw data
diff --git a/tutorials/plot_mne_dspm_source_localization.py b/tutorials/plot_mne_dspm_source_localization.py
index e9bd702..7b4ded1 100644
--- a/tutorials/plot_mne_dspm_source_localization.py
+++ b/tutorials/plot_mne_dspm_source_localization.py
@@ -22,7 +22,8 @@ from mne.minimum_norm import (make_inverse_operator, apply_inverse,
data_path = sample.data_path()
raw_fname = data_path + '/MEG/sample/sample_audvis_filt-0-40_raw.fif'
-raw = mne.io.read_raw_fif(raw_fname)
+raw = mne.io.read_raw_fif(raw_fname, add_eeg_ref=False)
+raw.set_eeg_reference() # set EEG average reference
events = mne.find_events(raw, stim_channel='STI 014')
event_id = dict(aud_r=1) # event trigger and conditions
@@ -34,8 +35,8 @@ picks = mne.pick_types(raw.info, meg=True, eeg=False, eog=True,
baseline = (None, 0) # means from the first instant to t = 0
reject = dict(grad=4000e-13, mag=4e-12, eog=150e-6)
-epochs = mne.Epochs(raw, events, event_id, tmin, tmax, proj=True,
- picks=picks, baseline=baseline, reject=reject)
+epochs = mne.Epochs(raw, events, event_id, tmin, tmax, proj=True, picks=picks,
+ baseline=baseline, reject=reject, add_eeg_ref=False)
###############################################################################
# Compute regularized noise covariance
@@ -105,27 +106,29 @@ plt.show()
# Here we use peak getter to move visualization to the time point of the peak
# and draw a marker at the maximum peak vertex.
-vertno_max, time_idx = stc.get_peak(hemi='rh', time_as_index=True)
+vertno_max, time_max = stc.get_peak(hemi='rh')
subjects_dir = data_path + '/subjects'
-brain = stc.plot(surface='inflated', hemi='rh', subjects_dir=subjects_dir)
-
-brain.set_data_time_index(time_idx)
+brain = stc.plot(surface='inflated', hemi='rh', subjects_dir=subjects_dir,
+ clim=dict(kind='value', lims=[8, 12, 15]),
+ initial_time=time_max, time_unit='s')
brain.add_foci(vertno_max, coords_as_verts=True, hemi='rh', color='blue',
scale_factor=0.6)
-brain.scale_data_colormap(fmin=8, fmid=12, fmax=15, transparent=True)
brain.show_view('lateral')
###############################################################################
# Morph data to average brain
# ---------------------------
-stc_fsaverage = stc.morph(subject_to='fsaverage', subjects_dir=subjects_dir)
-
-brain_fsaverage = stc_fsaverage.plot(surface='inflated', hemi='rh',
+fs_vertices = [np.arange(10242)] * 2
+morph_mat = mne.compute_morph_matrix('sample', 'fsaverage', stc.vertices,
+ fs_vertices, smooth=None,
subjects_dir=subjects_dir)
-brain_fsaverage.set_data_time_index(time_idx)
-brain_fsaverage.scale_data_colormap(fmin=8, fmid=12, fmax=15, transparent=True)
+stc_fsaverage = stc.morph_precomputed('fsaverage', fs_vertices, morph_mat)
+brain_fsaverage = stc_fsaverage.plot(surface='inflated', hemi='rh',
+ subjects_dir=subjects_dir,
+ clim=dict(kind='value', lims=[8, 12, 15]),
+ initial_time=time_max, time_unit='s')
brain_fsaverage.show_view('lateral')
###############################################################################
diff --git a/tutorials/plot_modifying_data_inplace.py b/tutorials/plot_modifying_data_inplace.py
index 82a02d6..ddcab02 100644
--- a/tutorials/plot_modifying_data_inplace.py
+++ b/tutorials/plot_modifying_data_inplace.py
@@ -21,8 +21,8 @@ from matplotlib import pyplot as plt
# Load an example dataset, the preload flag loads the data into memory now
data_path = op.join(mne.datasets.sample.data_path(), 'MEG',
'sample', 'sample_audvis_raw.fif')
-raw = mne.io.read_raw_fif(data_path, preload=True, verbose=False)
-raw = raw.crop(0, 2)
+raw = mne.io.read_raw_fif(data_path, preload=True, add_eeg_ref=False)
+raw = raw.crop(0, 10)
print(raw)
###############################################################################
@@ -36,29 +36,35 @@ f, (ax, ax2) = plt.subplots(2, 1, figsize=(15, 10))
_ = ax.plot(raw._data[0])
for fband in filt_bands:
raw_filt = raw.copy()
- raw_filt.filter(*fband)
- _ = ax2.plot(raw_filt._data[0])
+ raw_filt.filter(*fband, h_trans_bandwidth='auto', l_trans_bandwidth='auto',
+ filter_length='auto', phase='zero')
+ _ = ax2.plot(raw_filt[0][0][0])
ax2.legend(filt_bands)
ax.set_title('Raw data')
ax2.set_title('Band-pass filtered data')
###############################################################################
# In addition, there are functions for applying the Hilbert transform, which is
-# useful to calculate phase / amplitude of your signal
+# useful to calculate phase / amplitude of your signal.
+
+# Filter signal with a fairly steep filter, then take hilbert transform
-# Filter signal, then take hilbert transform
raw_band = raw.copy()
-raw_band.filter(12, 18)
+raw_band.filter(12, 18, l_trans_bandwidth=2., h_trans_bandwidth=2.,
+ filter_length='auto', phase='zero')
raw_hilb = raw_band.copy()
hilb_picks = mne.pick_types(raw_band.info, meg=False, eeg=True)
raw_hilb.apply_hilbert(hilb_picks)
print(raw_hilb._data.dtype)
###############################################################################
-# Finally, it is possible to apply arbitrary to your data to do what you want.
-# Here we will use this to take the amplitude and phase of the hilbert
-# transformed data. (note that you can use `amplitude=True` in the call to
-# :func:`mne.io.Raw.apply_hilbert` to do this automatically).
+# Finally, it is possible to apply arbitrary functions to your data to do
+# what you want. Here we will use this to take the amplitude and phase of
+# the hilbert transformed data.
+#
+# .. note:: You can also use ``amplitude=True`` in the call to
+# :meth:`mne.io.Raw.apply_hilbert` to do this automatically.
+#
# Take the amplitude and phase
raw_amp = raw_hilb.copy()
diff --git a/tutorials/plot_object_epochs.py b/tutorials/plot_object_epochs.py
index c73ccb3..992775c 100644
--- a/tutorials/plot_object_epochs.py
+++ b/tutorials/plot_object_epochs.py
@@ -28,8 +28,9 @@ from matplotlib import pyplot as plt
data_path = mne.datasets.sample.data_path()
# Load a dataset that contains events
-raw = mne.io.RawFIF(
- op.join(data_path, 'MEG', 'sample', 'sample_audvis_raw.fif'))
+raw = mne.io.read_raw_fif(
+ op.join(data_path, 'MEG', 'sample', 'sample_audvis_raw.fif'),
+ add_eeg_ref=False)
# If your raw object has a stim channel, you can construct an event array
# easily
@@ -57,7 +58,7 @@ event_id = {'Auditory/Left': 1, 'Auditory/Right': 2}
# Expose the raw data as epochs, cut from -0.1 s to 1.0 s relative to the event
# onsets
epochs = mne.Epochs(raw, events, event_id, tmin=-0.1, tmax=1,
- baseline=(None, 0), preload=True)
+ baseline=(None, 0), preload=True, add_eeg_ref=False)
print(epochs)
###############################################################################
diff --git a/tutorials/plot_object_raw.py b/tutorials/plot_object_raw.py
index 3e78a5f..efe6d5c 100644
--- a/tutorials/plot_object_raw.py
+++ b/tutorials/plot_object_raw.py
@@ -13,10 +13,11 @@ import os.path as op
from matplotlib import pyplot as plt
###############################################################################
-# Continuous data is stored in objects of type :class:`Raw <mne.io.RawFIF>`.
+# Continuous data is stored in objects of type :class:`Raw <mne.io.Raw>`.
# The core data structure is simply a 2D numpy array (channels × samples,
-# `._data`) combined with an :class:`Info <mne.Info>` object
-# (`.info`) (:ref:`tut_info_objects`.
+# stored in a private attribute called `._data`) combined with an
+# :class:`Info <mne.Info>` object (`.info` attribute)
+# (see :ref:`tut_info_objects`).
#
# The most common way to load continuous data is from a .fif file. For more
# information on :ref:`loading data from other formats <ch_convert>`, or
@@ -30,7 +31,8 @@ from matplotlib import pyplot as plt
# Load an example dataset, the preload flag loads the data into memory now
data_path = op.join(mne.datasets.sample.data_path(), 'MEG',
'sample', 'sample_audvis_raw.fif')
-raw = mne.io.RawFIF(data_path, preload=True, verbose=False)
+raw = mne.io.read_raw_fif(data_path, preload=True, add_eeg_ref=False)
+raw.set_eeg_reference() # set EEG average reference
# Give the sample rate
print('sample rate:', raw.info['sfreq'], 'Hz')
@@ -38,7 +40,13 @@ print('sample rate:', raw.info['sfreq'], 'Hz')
print('channels x samples:', raw._data.shape)
###############################################################################
-# Information about the channels contained in the :class:`Raw <mne.io.RawFIF>`
+# .. note:: Accessing the `._data` attribute is done here for educational
+# purposes. However this is a private attribute as its name starts
+# with an `_`. This suggests that you should **not** access this
+# variable directly but rely on indexing syntax detailed just below.
+
+###############################################################################
+# Information about the channels contained in the :class:`Raw <mne.io.Raw>`
# object is contained in the :class:`Info <mne.Info>` attribute.
# This is essentially a dictionary with a number of relevant fields (see
# :ref:`tut_info_objects`).
@@ -48,21 +56,13 @@ print('channels x samples:', raw._data.shape)
# Indexing data
# -------------
#
-# There are two ways to access the data stored within :class:`Raw
-# <mne.io.RawFIF>` objects. One is by accessing the underlying data array, and
-# the other is to index the :class:`Raw <mne.io.RawFIF>` object directly.
+# To access the data stored within :class:`Raw <mne.io.Raw>` objects,
+# it is possible to index the :class:`Raw <mne.io.Raw>` object.
#
-# To access the data array of :class:`Raw <mne.io.Raw>` objects, use the
-# `_data` attribute. Note that this is only present if `preload==True`.
-
-print('Shape of data array:', raw._data.shape)
-array_data = raw._data[0, :1000]
-_ = plt.plot(array_data)
-
-###############################################################################
-# You can also pass an index directly to the :class:`Raw <mne.io.RawFIF>`
-# object. This will return an array of times, as well as the data representing
-# those timepoints. This may be used even if the data is not preloaded:
+# Indexing a :class:`Raw <mne.io.Raw>` object will return two arrays: an array
+# of times, as well as the data representing those timepoints. This works
+# even if the data is not preloaded, in which case the data will be read from
+# disk when indexing. The syntax is as follows:
# Extract data from the first 5 channels, from 1 s to 3 s.
sfreq = raw.info['sfreq']
@@ -114,11 +114,11 @@ print('Number of channels reduced from', nchan, 'to', raw.info['nchan'])
###############################################################################
# --------------------------------------------------
-# Concatenating :class:`Raw <mne.io.RawFIF>` objects
+# Concatenating :class:`Raw <mne.io.Raw>` objects
# --------------------------------------------------
#
-# :class:`Raw <mne.io.RawFIF>` objects can be concatenated in time by using the
-# :func:`append <mne.io.RawFIF.append>` function. For this to work, they must
+# :class:`Raw <mne.io.Raw>` objects can be concatenated in time by using the
+# :func:`append <mne.io.Raw.append>` function. For this to work, they must
# have the same number of channels and their :class:`Info
# <mne.Info>` structures should be compatible.
diff --git a/tutorials/plot_point_spread.py b/tutorials/plot_point_spread.py
new file mode 100644
index 0000000..4512a32
--- /dev/null
+++ b/tutorials/plot_point_spread.py
@@ -0,0 +1,171 @@
+"""
+.. _point_spread:
+
+Corrupt known signal with point spread
+======================================
+
+The aim of this tutorial is to demonstrate how to put a known signal at a
+desired location(s) in a :class:`mne.SourceEstimate` and then corrupt the
+signal with point-spread by applying a forward and inverse solution.
+"""
+
+import os.path as op
+
+import numpy as np
+
+import mne
+from mne.datasets import sample
+
+from mne.minimum_norm import read_inverse_operator, apply_inverse
+from mne.simulation import simulate_stc, simulate_evoked
+
+###############################################################################
+# First, we set some parameters.
+
+seed = 42
+
+# parameters for inverse method
+method = 'sLORETA'
+snr = 3.
+lambda2 = 1.0 / snr ** 2
+
+# signal simulation parameters
+# do not add extra noise to the known signals
+evoked_snr = np.inf
+T = 100
+times = np.linspace(0, 1, T)
+dt = times[1] - times[0]
+
+# Paths to MEG data
+data_path = sample.data_path()
+subjects_dir = op.join(data_path, 'subjects')
+fname_fwd = op.join(data_path, 'MEG', 'sample',
+ 'sample_audvis-meg-oct-6-fwd.fif')
+fname_inv = op.join(data_path, 'MEG', 'sample',
+ 'sample_audvis-meg-oct-6-meg-fixed-inv.fif')
+
+fname_evoked = op.join(data_path, 'MEG', 'sample',
+ 'sample_audvis-ave.fif')
+
+###############################################################################
+# Load the MEG data
+# -----------------
+
+fwd = mne.read_forward_solution(fname_fwd, force_fixed=True,
+ surf_ori=True)
+fwd['info']['bads'] = []
+inv_op = read_inverse_operator(fname_inv)
+
+raw = mne.io.RawFIF(op.join(data_path, 'MEG', 'sample',
+ 'sample_audvis_raw.fif'))
+events = mne.find_events(raw)
+event_id = {'Auditory/Left': 1, 'Auditory/Right': 2}
+epochs = mne.Epochs(raw, events, event_id, baseline=(None, 0), preload=True)
+epochs.info['bads'] = []
+evoked = epochs.average()
+
+labels = mne.read_labels_from_annot('sample', subjects_dir=subjects_dir)
+label_names = [l.name for l in labels]
+n_labels = len(labels)
+
+###############################################################################
+# Estimate the background noise covariance from the baseline period
+# -----------------------------------------------------------------
+
+cov = mne.compute_covariance(epochs, tmin=None, tmax=0.)
+
+###############################################################################
+# Generate sinusoids in two spatially distant labels
+# --------------------------------------------------
+
+# The known signal is all zero-s off of the two labels of interest
+signal = np.zeros((n_labels, T))
+idx = label_names.index('inferiorparietal-lh')
+signal[idx, :] = 1e-7 * np.sin(5 * 2 * np.pi * times)
+idx = label_names.index('rostralmiddlefrontal-rh')
+signal[idx, :] = 1e-7 * np.sin(7 * 2 * np.pi * times)
+
+###############################################################################
+# Find the center vertices in source space of each label
+# ------------------------------------------------------
+#
+# We want the known signal in each label to only be active at the center. We
+# create a mask for each label that is 1 at the center vertex and 0 at all
+# other vertices in the label. This mask is then used when simulating
+# source-space data.
+
+hemi_to_ind = {'lh': 0, 'rh': 1}
+for i, label in enumerate(labels):
+ # The `center_of_mass` function needs labels to have values.
+ labels[i].values.fill(1.)
+
+ # Restrict the eligible vertices to be those on the surface under
+ # consideration and within the label.
+ surf_vertices = fwd['src'][hemi_to_ind[label.hemi]]['vertno']
+ restrict_verts = np.intersect1d(surf_vertices, label.vertices)
+ com = labels[i].center_of_mass(subject='sample',
+ subjects_dir=subjects_dir,
+ restrict_vertices=restrict_verts,
+ surf='white')
+
+ # Convert the center of vertex index from surface vertex list to Label's
+ # vertex list.
+ cent_idx = np.where(label.vertices == com)[0][0]
+
+ # Create a mask with 1 at center vertex and zeros elsewhere.
+ labels[i].values.fill(0.)
+ labels[i].values[cent_idx] = 1.
+
+###############################################################################
+# Create source-space data with known signals
+# -------------------------------------------
+#
+# Put known signals onto surface vertices using the array of signals and
+# the label masks (stored in labels[i].values).
+stc_gen = simulate_stc(fwd['src'], labels, signal, times[0], dt,
+ value_fun=lambda x: x)
+
+###############################################################################
+# Plot original signals
+# ---------------------
+#
+# Note that the original signals are highly concentrated (point) sources.
+#
+kwargs = dict(subjects_dir=subjects_dir, hemi='split', views=['lat', 'med'],
+ smoothing_steps=4, time_unit='s', initial_time=0.05)
+clim = dict(kind='value', pos_lims=[1e-9, 1e-8, 1e-7])
+brain_gen = stc_gen.plot(clim=clim, **kwargs)
+
+###############################################################################
+# Simulate sensor-space signals
+# -----------------------------
+#
+# Use the forward solution and add Gaussian noise to simulate sensor-space
+# (evoked) data from the known source-space signals. The amount of noise is
+# controlled by `evoked_snr` (higher values imply less noise).
+#
+evoked_gen = simulate_evoked(fwd, stc_gen, evoked.info, cov, evoked_snr,
+ tmin=0., tmax=1., random_state=seed)
+
+# Map the simulated sensor-space data to source-space using the inverse
+# operator.
+stc_inv = apply_inverse(evoked_gen, inv_op, lambda2, method=method)
+
+###############################################################################
+# Plot the point-spread of corrupted signal
+# -----------------------------------------
+#
+# Notice that after applying the forward- and inverse-operators to the known
+# point sources that the point sources have spread across the source-space.
+# This spread is due to the minimum norm solution so that the signal leaks to
+# nearby vertices with similar orientations so that signal ends up crossing the
+# sulci and gyri.
+brain_inv = stc_inv.plot(**kwargs)
+
+###############################################################################
+# Exercises
+# ---------
+# - Change the `method` parameter to either `dSPM` or `MNE` to explore the
+# effect of the inverse method.
+# - Try setting `evoked_snr` to a small, finite value, e.g. 3., to see the
+# effect of noise.
diff --git a/tutorials/plot_sensors_decoding.py b/tutorials/plot_sensors_decoding.py
index aa79097..cb5a62d 100644
--- a/tutorials/plot_sensors_decoding.py
+++ b/tutorials/plot_sensors_decoding.py
@@ -29,7 +29,8 @@ tmin, tmax = -0.2, 0.5
event_id = dict(aud_l=1, vis_l=3)
# Setup for reading the raw data
-raw = mne.io.read_raw_fif(raw_fname, preload=True)
+raw = mne.io.read_raw_fif(raw_fname, preload=True, add_eeg_ref=False)
+raw.set_eeg_reference() # set EEG average reference
raw.filter(2, None, method='iir') # replace baselining with high-pass
events = mne.read_events(event_fname)
@@ -41,7 +42,7 @@ picks = mne.pick_types(raw.info, meg='grad', eeg=False, stim=True, eog=True,
# Read epochs
epochs = mne.Epochs(raw, events, event_id, tmin, tmax, proj=True,
picks=picks, baseline=None, preload=True,
- reject=dict(grad=4000e-13, eog=150e-6))
+ reject=dict(grad=4000e-13, eog=150e-6), add_eeg_ref=False)
epochs_list = [epochs[k] for k in event_id]
mne.epochs.equalize_epoch_counts(epochs_list)
@@ -69,6 +70,8 @@ td.plot(title='Sensor space decoding')
# Generalization Across Time
# --------------------------
#
+# This runs the analysis used in [1]_ and further detailed in [2]_
+#
# Here we'll use a stratified cross-validation scheme.
# make response vector
@@ -98,3 +101,16 @@ gat.plot_diagonal()
#
# Have a look at the example
# :ref:`sphx_glr_auto_examples_decoding_plot_decoding_csp_space.py`
+#
+# References
+# ==========
+#
+# .. [1] Jean-Remi King, Alexandre Gramfort, Aaron Schurger, Lionel Naccache
+# and Stanislas Dehaene, "Two distinct dynamic modes subtend the
+# detection of unexpected sounds", PLOS ONE, 2013,
+# http://www.ncbi.nlm.nih.gov/pubmed/24475052
+#
+# .. [2] King & Dehaene (2014) 'Characterizing the dynamics of mental
+# representations: the temporal generalization method', Trends In
+# Cognitive Sciences, 18(4), 203-210.
+# http://www.ncbi.nlm.nih.gov/pubmed/24593982
diff --git a/tutorials/plot_sensors_time_frequency.py b/tutorials/plot_sensors_time_frequency.py
index d685257..a4f7904 100644
--- a/tutorials/plot_sensors_time_frequency.py
+++ b/tutorials/plot_sensors_time_frequency.py
@@ -27,7 +27,7 @@ data_path = somato.data_path()
raw_fname = data_path + '/MEG/somato/sef_raw_sss.fif'
# Setup for reading the raw data
-raw = mne.io.read_raw_fif(raw_fname)
+raw = mne.io.read_raw_fif(raw_fname, add_eeg_ref=False)
events = mne.find_events(raw, stim_channel='STI 014')
# picks MEG gradiometers
@@ -38,7 +38,7 @@ event_id, tmin, tmax = 1, -1., 3.
baseline = (None, 0)
epochs = mne.Epochs(raw, events, event_id, tmin, tmax, picks=picks,
baseline=baseline, reject=dict(grad=4000e-13, eog=350e-6),
- preload=True)
+ preload=True, add_eeg_ref=False)
epochs.resample(150., npad='auto') # resample to reduce computation time
@@ -59,7 +59,7 @@ epochs.plot_psd_topomap(ch_type='grad', normalize=True)
###############################################################################
# Alternatively, you can also create PSDs from Epochs objects with functions
-# that start with psd_ such as
+# that start with ``psd_`` such as
# :func:`mne.time_frequency.psd_multitaper` and
# :func:`mne.time_frequency.psd_welch`.
diff --git a/tutorials/plot_stats_cluster_1samp_test_time_frequency.py b/tutorials/plot_stats_cluster_1samp_test_time_frequency.py
index b7c0e75..8803557 100644
--- a/tutorials/plot_stats_cluster_1samp_test_time_frequency.py
+++ b/tutorials/plot_stats_cluster_1samp_test_time_frequency.py
@@ -26,8 +26,7 @@ import numpy as np
import matplotlib.pyplot as plt
import mne
-from mne import io
-from mne.time_frequency import single_trial_power
+from mne.time_frequency import tfr_morlet
from mne.stats import permutation_cluster_1samp_test
from mne.datasets import sample
@@ -38,12 +37,10 @@ print(__doc__)
# --------------
data_path = sample.data_path()
raw_fname = data_path + '/MEG/sample/sample_audvis_raw.fif'
-event_id = 1
-tmin = -0.3
-tmax = 0.6
+tmin, tmax, event_id = -0.3, 0.6, 1
# Setup for reading the raw data
-raw = io.read_raw_fif(raw_fname)
+raw = mne.io.read_raw_fif(raw_fname)
events = mne.find_events(raw, stim_channel='STI 014')
include = []
@@ -56,47 +53,34 @@ picks = mne.pick_types(raw.info, meg='grad', eeg=False, eog=True,
# Load condition 1
event_id = 1
epochs = mne.Epochs(raw, events, event_id, tmin, tmax, picks=picks,
- baseline=(None, 0), reject=dict(grad=4000e-13, eog=150e-6))
-data = epochs.get_data() # as 3D matrix
-data *= 1e13 # change unit to fT / cm
-# Time vector
-times = 1e3 * epochs.times # change unit to ms
+ baseline=(None, 0), preload=True,
+ reject=dict(grad=4000e-13, eog=150e-6))
# Take only one channel
-ch_name = raw.info['ch_names'][97]
-data = data[:, 97:98, :]
-
-evoked_data = np.mean(data, 0)
+ch_name = 'MEG 1332'
+epochs.pick_channels([ch_name])
-# data -= evoked_data[None,:,:] # remove evoked component
-# evoked_data = np.mean(data, 0)
+evoked = epochs.average()
-# Factor to down-sample the temporal dimension of the PSD computed by
-# single_trial_power. Decimation occurs after frequency decomposition and can
+# Factor to down-sample the temporal dimension of the TFR computed by
+# tfr_morlet. Decimation occurs after frequency decomposition and can
# be used to reduce memory usage (and possibly computational time of downstream
# operations such as nonparametric statistics) if you don't need high
# spectrotemporal resolution.
decim = 5
frequencies = np.arange(8, 40, 2) # define frequencies of interest
sfreq = raw.info['sfreq'] # sampling in Hz
-epochs_power = single_trial_power(data, sfreq=sfreq, frequencies=frequencies,
- n_cycles=4, n_jobs=1,
- baseline=(-100, 0), times=times,
- baseline_mode='ratio', decim=decim)
+tfr_epochs = tfr_morlet(epochs, frequencies, n_cycles=4., decim=decim,
+ average=False, return_itc=False, n_jobs=1)
-# Crop in time to keep only what is between 0 and 400 ms
-time_mask = (times > 0) & (times < 400)
-evoked_data = evoked_data[:, time_mask]
-times = times[time_mask]
+# Baseline power
+tfr_epochs.apply_baseline(mode='logratio', baseline=(-.100, 0))
-# The time vector reflects the original time points, not the decimated time
-# points returned by single trial power. Be sure to decimate the time mask
-# appropriately.
-epochs_power = epochs_power[..., time_mask[::decim]]
+# Crop in time to keep only what is between 0 and 400 ms
+evoked.crop(0., 0.4)
+tfr_epochs.crop(0., 0.4)
-epochs_power = epochs_power[:, 0, :, :]
-epochs_power = np.log10(epochs_power) # take log of ratio
-# under the null hypothesis epochs_power should be now be 0
+epochs_power = tfr_epochs.data[:, 0, :, :] # take the 1 channel
###############################################################################
# Compute statistic
@@ -109,17 +93,12 @@ T_obs, clusters, cluster_p_values, H0 = \
###############################################################################
# View time-frequency plots
# -------------------------
-plt.clf()
-plt.subplots_adjust(0.12, 0.08, 0.96, 0.94, 0.2, 0.43)
-plt.subplot(2, 1, 1)
-plt.plot(times, evoked_data.T)
-plt.title('Evoked response (%s)' % ch_name)
-plt.xlabel('time (ms)')
-plt.ylabel('Magnetic Field (fT/cm)')
-plt.xlim(times[0], times[-1])
-plt.ylim(-100, 250)
-plt.subplot(2, 1, 2)
+evoked_data = evoked.data
+times = 1e3 * evoked.times
+
+plt.figure()
+plt.subplots_adjust(0.12, 0.08, 0.96, 0.94, 0.2, 0.43)
# Create new stats image with only significant clusters
T_obs_plot = np.nan * np.ones_like(T_obs)
@@ -129,6 +108,7 @@ for c, p_val in zip(clusters, cluster_p_values):
vmax = np.max(np.abs(T_obs))
vmin = -vmax
+plt.subplot(2, 1, 1)
plt.imshow(T_obs, cmap=plt.cm.gray,
extent=[times[0], times[-1], frequencies[0], frequencies[-1]],
aspect='auto', origin='lower', vmin=vmin, vmax=vmax)
@@ -136,7 +116,10 @@ plt.imshow(T_obs_plot, cmap=plt.cm.RdBu_r,
extent=[times[0], times[-1], frequencies[0], frequencies[-1]],
aspect='auto', origin='lower', vmin=vmin, vmax=vmax)
plt.colorbar()
-plt.xlabel('time (ms)')
+plt.xlabel('Time (ms)')
plt.ylabel('Frequency (Hz)')
plt.title('Induced power (%s)' % ch_name)
+
+ax2 = plt.subplot(2, 1, 2)
+evoked.plot(axes=[ax2])
plt.show()
diff --git a/tutorials/plot_stats_cluster_methods.py b/tutorials/plot_stats_cluster_methods.py
index 03d5b5c..e80ec0f 100644
--- a/tutorials/plot_stats_cluster_methods.py
+++ b/tutorials/plot_stats_cluster_methods.py
@@ -59,7 +59,8 @@ import numpy as np
from scipy import stats
from functools import partial
import matplotlib.pyplot as plt
-from mpl_toolkits.mplot3d import Axes3D # noqa; this changes hidden mpl vars
+# this changes hidden MPL vars:
+from mpl_toolkits.mplot3d import Axes3D # noqa
from mne.stats import (spatio_temporal_cluster_1samp_test,
bonferroni_correction, ttest_1samp_no_p)
diff --git a/tutorials/plot_stats_cluster_spatio_temporal.py b/tutorials/plot_stats_cluster_spatio_temporal.py
index f56fdbb..059b34d 100644
--- a/tutorials/plot_stats_cluster_spatio_temporal.py
+++ b/tutorials/plot_stats_cluster_spatio_temporal.py
@@ -17,6 +17,7 @@ permutation test across space and time.
import os.path as op
+
import numpy as np
from numpy.random import randn
from scipy import stats as stats
@@ -179,8 +180,7 @@ stc_all_cluster_vis = summarize_clusters_stc(clu, tstep=tstep,
# shows all the clusters, weighted by duration
subjects_dir = op.join(data_path, 'subjects')
# blue blobs are for condition A < condition B, red for A > B
-brain = stc_all_cluster_vis.plot(hemi='both', subjects_dir=subjects_dir,
+brain = stc_all_cluster_vis.plot(hemi='both', views='lateral',
+ subjects_dir=subjects_dir,
time_label='Duration significant (ms)')
-brain.set_data_time_index(0)
-brain.show_view('lateral')
brain.save_image('clusters.png')
diff --git a/tutorials/plot_stats_cluster_spatio_temporal_2samp.py b/tutorials/plot_stats_cluster_spatio_temporal_2samp.py
index 959faca..bbe1da9 100644
--- a/tutorials/plot_stats_cluster_spatio_temporal_2samp.py
+++ b/tutorials/plot_stats_cluster_spatio_temporal_2samp.py
@@ -16,6 +16,7 @@ permutation test across space and time.
# License: BSD (3-clause)
import os.path as op
+
import numpy as np
from scipy import stats as stats
@@ -103,8 +104,6 @@ stc_all_cluster_vis = summarize_clusters_stc(clu, tstep=tstep,
subjects_dir = op.join(data_path, 'subjects')
# blue blobs are for condition A != condition B
brain = stc_all_cluster_vis.plot('fsaverage', hemi='both', colormap='mne',
- subjects_dir=subjects_dir,
+ views='lateral', subjects_dir=subjects_dir,
time_label='Duration significant (ms)')
-brain.set_data_time_index(0)
-brain.show_view('lateral')
brain.save_image('clusters.png')
diff --git a/tutorials/plot_stats_cluster_spatio_temporal_repeated_measures_anova.py b/tutorials/plot_stats_cluster_spatio_temporal_repeated_measures_anova.py
index ef899f7..ab3f429 100644
--- a/tutorials/plot_stats_cluster_spatio_temporal_repeated_measures_anova.py
+++ b/tutorials/plot_stats_cluster_spatio_temporal_repeated_measures_anova.py
@@ -180,7 +180,7 @@ n_conditions = 4
#
# Note. for further details on this ANOVA function consider the
# corresponding
-# :ref:`time frequency tutorial <tut_stats_cluster_sensor_rANOVA_tfr>`.
+# :ref:`time-frequency tutorial <tut_stats_cluster_sensor_rANOVA_tfr>`.
def stat_fun(*args):
@@ -239,10 +239,8 @@ subjects_dir = op.join(data_path, 'subjects')
# stimulus modality and stimulus location
brain = stc_all_cluster_vis.plot(subjects_dir=subjects_dir, colormap='mne',
+ views='lateral',
time_label='Duration significant (ms)')
-
-brain.set_data_time_index(0)
-brain.show_view('lateral')
brain.save_image('cluster-lh.png')
brain.show_view('medial')
diff --git a/tutorials/plot_stats_cluster_time_frequency.py b/tutorials/plot_stats_cluster_time_frequency.py
index 9124c91..e573cd3 100644
--- a/tutorials/plot_stats_cluster_time_frequency.py
+++ b/tutorials/plot_stats_cluster_time_frequency.py
@@ -27,8 +27,7 @@ import numpy as np
import matplotlib.pyplot as plt
import mne
-from mne import io
-from mne.time_frequency import single_trial_power
+from mne.time_frequency import tfr_morlet
from mne.stats import permutation_cluster_test
from mne.datasets import sample
@@ -39,12 +38,10 @@ print(__doc__)
data_path = sample.data_path()
raw_fname = data_path + '/MEG/sample/sample_audvis_raw.fif'
event_fname = data_path + '/MEG/sample/sample_audvis_raw-eve.fif'
-event_id = 1
-tmin = -0.2
-tmax = 0.5
+tmin, tmax = -0.2, 0.5
# Setup for reading the raw data
-raw = io.read_raw_fif(raw_fname)
+raw = mne.io.read_raw_fif(raw_fname)
events = mne.read_events(event_fname)
include = []
@@ -54,65 +51,50 @@ raw.info['bads'] += ['MEG 2443', 'EEG 053'] # bads + 2 more
picks = mne.pick_types(raw.info, meg='grad', eeg=False, eog=True,
stim=False, include=include, exclude='bads')
-ch_name = raw.info['ch_names'][picks[0]]
+ch_name = 'MEG 1332' # restrict example to one channel
# Load condition 1
reject = dict(grad=4000e-13, eog=150e-6)
event_id = 1
epochs_condition_1 = mne.Epochs(raw, events, event_id, tmin, tmax,
picks=picks, baseline=(None, 0),
- reject=reject)
-data_condition_1 = epochs_condition_1.get_data() # as 3D matrix
-data_condition_1 *= 1e13 # change unit to fT / cm
+ reject=reject, preload=True)
+epochs_condition_1.pick_channels([ch_name])
# Load condition 2
event_id = 2
epochs_condition_2 = mne.Epochs(raw, events, event_id, tmin, tmax,
picks=picks, baseline=(None, 0),
- reject=reject)
-data_condition_2 = epochs_condition_2.get_data() # as 3D matrix
-data_condition_2 *= 1e13 # change unit to fT / cm
-
-# Take only one channel
-data_condition_1 = data_condition_1[:, 97:98, :]
-data_condition_2 = data_condition_2[:, 97:98, :]
-
-# Time vector
-times = 1e3 * epochs_condition_1.times # change unit to ms
+ reject=reject, preload=True)
+epochs_condition_2.pick_channels([ch_name])
###############################################################################
-# Factor to downsample the temporal dimension of the PSD computed by
-# single_trial_power. Decimation occurs after frequency decomposition and can
+# Factor to downsample the temporal dimension of the TFR computed by
+# tfr_morlet. Decimation occurs after frequency decomposition and can
# be used to reduce memory usage (and possibly comptuational time of downstream
# operations such as nonparametric statistics) if you don't need high
# spectrotemporal resolution.
decim = 2
frequencies = np.arange(7, 30, 3) # define frequencies of interest
-sfreq = raw.info['sfreq'] # sampling in Hz
n_cycles = 1.5
-epochs_power_1 = single_trial_power(data_condition_1, sfreq=sfreq,
- frequencies=frequencies,
- n_cycles=n_cycles, decim=decim)
+tfr_epochs_1 = tfr_morlet(epochs_condition_1, frequencies,
+ n_cycles=n_cycles, decim=decim,
+ return_itc=False, average=False)
-epochs_power_2 = single_trial_power(data_condition_2, sfreq=sfreq,
- frequencies=frequencies,
- n_cycles=n_cycles, decim=decim)
+tfr_epochs_2 = tfr_morlet(epochs_condition_2, frequencies,
+ n_cycles=n_cycles, decim=decim,
+ return_itc=False, average=False)
-epochs_power_1 = epochs_power_1[:, 0, :, :] # only 1 channel to get 3D matrix
-epochs_power_2 = epochs_power_2[:, 0, :, :] # only 1 channel to get 3D matrix
+tfr_epochs_1.apply_baseline(mode='ratio', baseline=(None, 0))
+tfr_epochs_2.apply_baseline(mode='ratio', baseline=(None, 0))
-###############################################################################
-# Compute ratio with baseline power (be sure to correct time vector with
-# decimation factor)
-baseline_mask = times[::decim] < 0
-epochs_baseline_1 = np.mean(epochs_power_1[:, :, baseline_mask], axis=2)
-epochs_power_1 /= epochs_baseline_1[..., np.newaxis]
-epochs_baseline_2 = np.mean(epochs_power_2[:, :, baseline_mask], axis=2)
-epochs_power_2 /= epochs_baseline_2[..., np.newaxis]
+epochs_power_1 = tfr_epochs_1.data[:, 0, :, :] # only 1 channel as 3D matrix
+epochs_power_2 = tfr_epochs_2.data[:, 0, :, :] # only 1 channel as 3D matrix
###############################################################################
# Compute statistic
+# -----------------
threshold = 6.0
T_obs, clusters, cluster_p_values, H0 = \
permutation_cluster_test([epochs_power_1, epochs_power_2],
@@ -120,19 +102,16 @@ T_obs, clusters, cluster_p_values, H0 = \
###############################################################################
# View time-frequency plots
-plt.clf()
-plt.subplots_adjust(0.12, 0.08, 0.96, 0.94, 0.2, 0.43)
-plt.subplot(2, 1, 1)
-evoked_contrast = np.mean(data_condition_1, 0) - np.mean(data_condition_2, 0)
-plt.plot(times, evoked_contrast.T)
-plt.title('Contrast of evoked response (%s)' % ch_name)
-plt.xlabel('time (ms)')
-plt.ylabel('Magnetic Field (fT/cm)')
-plt.xlim(times[0], times[-1])
-plt.ylim(-100, 200)
+# -------------------------
-plt.subplot(2, 1, 2)
+times = 1e3 * epochs_condition_1.times # change unit to ms
+evoked_condition_1 = epochs_condition_1.average()
+evoked_condition_2 = epochs_condition_2.average()
+
+plt.figure()
+plt.subplots_adjust(0.12, 0.08, 0.96, 0.94, 0.2, 0.43)
+plt.subplot(2, 1, 1)
# Create new stats image with only significant clusters
T_obs_plot = np.nan * np.ones_like(T_obs)
for c, p_val in zip(clusters, cluster_p_values):
@@ -141,12 +120,18 @@ for c, p_val in zip(clusters, cluster_p_values):
plt.imshow(T_obs,
extent=[times[0], times[-1], frequencies[0], frequencies[-1]],
- aspect='auto', origin='lower', cmap='RdBu_r')
+ aspect='auto', origin='lower', cmap='gray')
plt.imshow(T_obs_plot,
extent=[times[0], times[-1], frequencies[0], frequencies[-1]],
aspect='auto', origin='lower', cmap='RdBu_r')
-plt.xlabel('time (ms)')
+plt.xlabel('Time (ms)')
plt.ylabel('Frequency (Hz)')
plt.title('Induced power (%s)' % ch_name)
+
+ax2 = plt.subplot(2, 1, 2)
+evoked_contrast = mne.combine_evoked([evoked_condition_1, evoked_condition_2],
+ weights=[1, -1])
+evoked_contrast.plot(axes=ax2)
+
plt.show()
diff --git a/tutorials/plot_stats_cluster_time_frequency_repeated_measures_anova.py b/tutorials/plot_stats_cluster_time_frequency_repeated_measures_anova.py
index ad1b7a2..dfff054 100644
--- a/tutorials/plot_stats_cluster_time_frequency_repeated_measures_anova.py
+++ b/tutorials/plot_stats_cluster_time_frequency_repeated_measures_anova.py
@@ -30,8 +30,7 @@ import numpy as np
import matplotlib.pyplot as plt
import mne
-from mne import io
-from mne.time_frequency import single_trial_power
+from mne.time_frequency import tfr_morlet
from mne.stats import f_threshold_mway_rm, f_mway_rm, fdr_correction
from mne.datasets import sample
@@ -43,12 +42,10 @@ print(__doc__)
data_path = sample.data_path()
raw_fname = data_path + '/MEG/sample/sample_audvis_raw.fif'
event_fname = data_path + '/MEG/sample/sample_audvis_raw-eve.fif'
-event_id = 1
-tmin = -0.2
-tmax = 0.5
+tmin, tmax = -0.2, 0.5
# Setup for reading the raw data
-raw = io.read_raw_fif(raw_fname)
+raw = mne.io.read_raw_fif(raw_fname)
events = mne.read_events(event_fname)
include = []
@@ -58,44 +55,41 @@ raw.info['bads'] += ['MEG 2443'] # bads
picks = mne.pick_types(raw.info, meg='grad', eeg=False, eog=True,
stim=False, include=include, exclude='bads')
-ch_name = raw.info['ch_names'][picks[0]]
+ch_name = 'MEG 1332'
# Load conditions
reject = dict(grad=4000e-13, eog=150e-6)
event_id = dict(aud_l=1, aud_r=2, vis_l=3, vis_r=4)
epochs = mne.Epochs(raw, events, event_id, tmin, tmax,
- picks=picks, baseline=(None, 0),
+ picks=picks, baseline=(None, 0), preload=True,
reject=reject)
+epochs.pick_channels([ch_name]) # restrict example to one channel
###############################################################################
# We have to make sure all conditions have the same counts, as the ANOVA
# expects a fully balanced data matrix and does not forgive imbalances that
# generously (risk of type-I error).
epochs.equalize_event_counts(event_id, copy=False)
-# Time vector
-times = 1e3 * epochs.times # change unit to ms
-# Factor to down-sample the temporal dimension of the PSD computed by
-# single_trial_power.
+# Factor to down-sample the temporal dimension of the TFR computed by
+# tfr_morlet.
decim = 2
frequencies = np.arange(7, 30, 3) # define frequencies of interest
-sfreq = raw.info['sfreq'] # sampling in Hz
n_cycles = frequencies / frequencies[0]
-baseline_mask = times[::decim] < 0
+zero_mean = False # don't correct morlet wavelet to be of mean zero
+# To have a true wavelet zero_mean should be True but here for illustration
+# purposes it helps to spot the evoked response.
###############################################################################
# Create TFR representations for all conditions
# ---------------------------------------------
epochs_power = list()
-for condition in [epochs[k].get_data()[:, 97:98, :] for k in event_id]:
- this_power = single_trial_power(condition, sfreq=sfreq,
- frequencies=frequencies, n_cycles=n_cycles,
- decim=decim)
- this_power = this_power[:, 0, :, :] # we only have one channel.
- # Compute ratio with baseline power (be sure to correct time vector with
- # decimation factor)
- epochs_baseline = np.mean(this_power[:, :, baseline_mask], axis=2)
- this_power /= epochs_baseline[..., np.newaxis]
+for condition in [epochs[k] for k in event_id]:
+ this_tfr = tfr_morlet(condition, frequencies, n_cycles=n_cycles,
+ decim=decim, average=False, zero_mean=zero_mean,
+ return_itc=False)
+ this_tfr.apply_baseline(mode='ratio', baseline=(None, 0))
+ this_power = this_tfr.data[:, 0, :, :] # we only have one channel.
epochs_power.append(this_power)
###############################################################################
@@ -115,7 +109,8 @@ effects = 'A*B' # this is the default signature for computing all effects
# or 'A:B' for the interaction effect only (this notation is borrowed from the
# R formula language)
n_frequencies = len(frequencies)
-n_times = len(times[::decim])
+times = 1e3 * epochs.times[::decim]
+n_times = len(times)
###############################################################################
# Now we'll assemble the data matrix and swap axes so the trial replications
@@ -132,15 +127,15 @@ print(data.shape)
# makes sure the first two dimensions are organized as expected (with A =
# modality and B = location):
#
-# .. table::
+# .. table:: Sample data layout
#
-# ===== ==== ==== ==== ====
-# trial A1B1 A1B2 A2B1 B2B2
-# ===== ==== ==== ==== ====
-# 1 1.34 2.53 0.97 1.74
-# ... .... .... .... ....
-# 56 2.45 7.90 3.09 4.76
-# ===== ==== ==== ==== ====
+# ===== ==== ==== ==== ====
+# trial A1B1 A1B2 A2B1 B2B2
+# ===== ==== ==== ==== ====
+# 1 1.34 2.53 0.97 1.74
+# ... ... ... ... ...
+# 56 2.45 7.90 3.09 4.76
+# ===== ==== ==== ==== ====
#
# Now we're ready to run our repeated measures ANOVA.
#
@@ -166,7 +161,7 @@ for effect, sig, effect_label in zip(fvals, pvals, effect_labels):
times[-1], frequencies[0], frequencies[-1]], aspect='auto',
origin='lower')
plt.colorbar()
- plt.xlabel('time (ms)')
+ plt.xlabel('Time (ms)')
plt.ylabel('Frequency (Hz)')
plt.title(r"Time-locked response for '%s' (%s)" % (effect_label, ch_name))
plt.show()
@@ -192,8 +187,8 @@ effects = 'A:B'
def stat_fun(*args):
return f_mway_rm(np.swapaxes(args, 1, 0), factor_levels=factor_levels,
effects=effects, return_pvals=False)[0]
- # The ANOVA returns a tuple f-values and p-values, we will pick the former.
+# The ANOVA returns a tuple f-values and p-values, we will pick the former.
pthresh = 0.00001 # set threshold rather high to save some time
f_thresh = f_threshold_mway_rm(n_replications, factor_levels, effects,
pthresh)
@@ -204,8 +199,8 @@ T_obs, clusters, cluster_p_values, h0 = mne.stats.permutation_cluster_test(
n_permutations=n_permutations, buffer_size=None)
###############################################################################
-# Create new stats image with only significant clusters
-# -----------------------------------------------------
+# Create new stats image with only significant clusters:
+
good_clusers = np.where(cluster_p_values < .05)[0]
T_obs_plot = np.ma.masked_array(T_obs,
np.invert(clusters[np.squeeze(good_clusers)]))
@@ -215,15 +210,15 @@ for f_image, cmap in zip([T_obs, T_obs_plot], [plt.cm.gray, 'RdBu_r']):
plt.imshow(f_image, cmap=cmap, extent=[times[0], times[-1],
frequencies[0], frequencies[-1]], aspect='auto',
origin='lower')
-plt.xlabel('time (ms)')
+plt.xlabel('Time (ms)')
plt.ylabel('Frequency (Hz)')
-plt.title('Time-locked response for \'modality by location\' (%s)\n'
- ' cluster-level corrected (p <= 0.05)' % ch_name)
+plt.title("Time-locked response for 'modality by location' (%s)\n"
+ " cluster-level corrected (p <= 0.05)" % ch_name)
plt.show()
###############################################################################
-# Now using FDR
-# -------------
+# Now using FDR:
+
mask, _ = fdr_correction(pvals[2])
T_obs_plot2 = np.ma.masked_array(T_obs, np.invert(mask))
@@ -233,11 +228,12 @@ for f_image, cmap in zip([T_obs, T_obs_plot2], [plt.cm.gray, 'RdBu_r']):
frequencies[0], frequencies[-1]], aspect='auto',
origin='lower')
-plt.xlabel('time (ms)')
+plt.xlabel('Time (ms)')
plt.ylabel('Frequency (Hz)')
-plt.title('Time-locked response for \'modality by location\' (%s)\n'
- ' FDR corrected (p <= 0.05)' % ch_name)
+plt.title("Time-locked response for 'modality by location' (%s)\n"
+ " FDR corrected (p <= 0.05)" % ch_name)
plt.show()
-# Both, cluster level and FDR correction help getting rid of
+###############################################################################
+# Both cluster level and FDR correction help get rid of
# putatively spots we saw in the naive f-images.
diff --git a/tutorials/plot_stats_spatio_temporal_cluster_sensors.py b/tutorials/plot_stats_spatio_temporal_cluster_sensors.py
index bac6a93..8faea54 100644
--- a/tutorials/plot_stats_spatio_temporal_cluster_sensors.py
+++ b/tutorials/plot_stats_spatio_temporal_cluster_sensors.py
@@ -40,7 +40,8 @@ tmax = 0.5
# Setup for reading the raw data
raw = mne.io.read_raw_fif(raw_fname, preload=True)
-raw.filter(1, 30)
+raw.filter(1, 30, l_trans_bandwidth='auto', h_trans_bandwidth='auto',
+ filter_length='auto', phase='zero')
events = mne.read_events(event_fname)
###############################################################################
diff --git a/tutorials/plot_visualize_epochs.py b/tutorials/plot_visualize_epochs.py
index eca13d4..692c62a 100644
--- a/tutorials/plot_visualize_epochs.py
+++ b/tutorials/plot_visualize_epochs.py
@@ -10,10 +10,12 @@ import os.path as op
import mne
data_path = op.join(mne.datasets.sample.data_path(), 'MEG', 'sample')
-raw = mne.io.read_raw_fif(op.join(data_path, 'sample_audvis_raw.fif'))
+raw = mne.io.read_raw_fif(op.join(data_path, 'sample_audvis_raw.fif'),
+ add_eeg_ref=False)
+raw.set_eeg_reference() # set EEG average reference
events = mne.read_events(op.join(data_path, 'sample_audvis_raw-eve.fif'))
picks = mne.pick_types(raw.info, meg='grad')
-epochs = mne.Epochs(raw, events, [1, 2], picks=picks)
+epochs = mne.Epochs(raw, events, [1, 2], picks=picks, add_eeg_ref=False)
###############################################################################
# This tutorial focuses on visualization of epoched data. All of the functions
@@ -44,8 +46,12 @@ epochs.plot(block=True)
# To plot individual channels as an image, where you see all the epochs at one
# glance, you can use function :func:`mne.Epochs.plot_image`. It shows the
# amplitude of the signal over all the epochs plus an average of the
-# activation.
-epochs.plot_image(97)
+# activation. We explicitly set interactive colorbar on (it is also on by
+# default for plotting functions with a colorbar except the topo plots). In
+# interactive mode you can scale and change the colormap with mouse scroll and
+# up/down arrow keys. You can also drag the colorbar with left/right mouse
+# button. Hitting space bar resets the scale.
+epochs.plot_image(97, cmap='interactive')
# You also have functions for plotting channelwise information arranged into a
# shape of the channel array. The image plotting uses automatic scaling by
diff --git a/tutorials/plot_visualize_evoked.py b/tutorials/plot_visualize_evoked.py
index b456f65..c66db59 100644
--- a/tutorials/plot_visualize_evoked.py
+++ b/tutorials/plot_visualize_evoked.py
@@ -21,9 +21,9 @@ evoked = mne.read_evokeds(fname, baseline=(None, 0), proj=True)
print(evoked)
###############################################################################
-# Notice that ``evoked`` is a list of evoked instances. You can read only one
-# of the categories by passing the argument ``condition`` to
-# :func:`mne.read_evokeds`. To make things more simple for this tutorial, we
+# Notice that ``evoked`` is a list of :class:`evoked <mne.Evoked>` instances.
+# You can read only one of the categories by passing the argument ``condition``
+# to :func:`mne.read_evokeds`. To make things more simple for this tutorial, we
# read each instance to a variable.
evoked_l_aud = evoked[0]
evoked_r_aud = evoked[1]
@@ -115,6 +115,33 @@ evoked_r_aud.plot_joint(title='right auditory', times=[.07, .105],
ts_args=ts_args, topomap_args=topomap_args)
###############################################################################
+# Sometimes, you may want to compare two conditions at a selection of sensors,
+# or e.g. for the Global Field Power. For this, you can use the function
+# :func:`mne.viz.plot_compare_evokeds`. The easiest way is to create a Python
+# dictionary, where the keys are condition names and the values are
+# :class:`mne.Evoked` objects. If you provide lists of :class:`mne.Evoked`
+# objects, such as those for multiple subjects, the grand average is plotted,
+# along with a confidence interval band - this can be used to contrast
+# conditions for a whole experiment.
+# First, we load in the evoked objects into a dictionary, setting the keys to
+# '/'-separated tags. Then, we plot with :func:`mne.viz.plot_compare_evokeds`.
+# The plot is styled with dictionary arguments, again using "/"-separated tags.
+# We plot a MEG channel with a strong auditory response.
+conditions = ["Left Auditory", "Right Auditory", "Left visual", "Right visual"]
+evoked_dict = dict()
+for condition in conditions:
+ evoked_dict[condition.replace(" ", "/")] = mne.read_evokeds(
+ fname, baseline=(None, 0), proj=True, condition=condition)
+print(evoked_dict)
+
+colors = dict(Left="Crimson", Right="CornFlowerBlue")
+linestyles = dict(Auditory='-', visual='--')
+pick = evoked_dict["Left/Auditory"].ch_names.index('MEG 1811')
+
+mne.viz.plot_compare_evokeds(evoked_dict, picks=pick,
+ colors=colors, linestyles=linestyles)
+
+###############################################################################
# We can also plot the activations as images. The time runs along the x-axis
# and the channels along the y-axis. The amplitudes are color coded so that
# the amplitudes from negative to positive translates to shift from blue to
diff --git a/tutorials/plot_visualize_raw.py b/tutorials/plot_visualize_raw.py
index 4d1ce89..fa01f5a 100644
--- a/tutorials/plot_visualize_raw.py
+++ b/tutorials/plot_visualize_raw.py
@@ -10,7 +10,9 @@ import os.path as op
import mne
data_path = op.join(mne.datasets.sample.data_path(), 'MEG', 'sample')
-raw = mne.io.read_raw_fif(op.join(data_path, 'sample_audvis_raw.fif'))
+raw = mne.io.read_raw_fif(op.join(data_path, 'sample_audvis_raw.fif'),
+ add_eeg_ref=False)
+raw.set_eeg_reference() # set EEG average reference
events = mne.read_events(op.join(data_path, 'sample_audvis_raw-eve.fif'))
###############################################################################
@@ -23,25 +25,38 @@ events = mne.read_events(op.join(data_path, 'sample_audvis_raw-eve.fif'))
#
# To visually inspect your raw data, you can use the python equivalent of
# ``mne_browse_raw``.
-raw.plot(block=True, events=events)
+raw.plot(block=True)
###############################################################################
# The channels are color coded by channel type. Generally MEG channels are
# colored in different shades of blue, whereas EEG channels are black. The
-# channels are also sorted by channel type by default. If you want to use a
-# custom order for the channels, you can use ``order`` parameter of
-# :func:`raw.plot`. The scrollbar on right side of the browser window also
-# tells us that two of the channels are marked as ``bad``. Bad channels are
-# color coded gray. By clicking the lines or channel names on the left, you can
-# mark or unmark a bad channel interactively. You can use +/- keys to adjust
-# the scale (also = works for magnifying the data). Note that the initial
-# scaling factors can be set with parameter ``scalings``. If you don't know the
-# scaling factor for channels, you can automatically set them by passing
-# scalings='auto'. With ``pageup/pagedown`` and ``home/end`` keys you can
-# adjust the amount of data viewed at once. To see all the interactive
-# features, hit ``?`` or click ``help`` in the lower left corner of the
-# browser window.
+# scrollbar on right side of the browser window also tells us that two of the
+# channels are marked as ``bad``. Bad channels are color coded gray. By
+# clicking the lines or channel names on the left, you can mark or unmark a bad
+# channel interactively. You can use +/- keys to adjust the scale (also = works
+# for magnifying the data). Note that the initial scaling factors can be set
+# with parameter ``scalings``. If you don't know the scaling factor for
+# channels, you can automatically set them by passing scalings='auto'. With
+# ``pageup/pagedown`` and ``home/end`` keys you can adjust the amount of data
+# viewed at once. To see all the interactive features, hit ``?`` or click
+# ``help`` in the lower left corner of the browser window.
#
+# The channels are sorted by channel type by default. You can use the ``order``
+# parameter of :func:`raw.plot <mne.io.Raw.plot>` to group the channels in a
+# different way. ``order='selection'`` uses the same channel groups as MNE-C's
+# mne_browse_raw (see :ref:`CACCJEJD`). The selections are defined in
+# ``mne-python/mne/data/mne_analyze.sel`` and by modifying the channels there,
+# you can define your own selection groups. Notice that this also affects the
+# selections returned by :func:`mne.read_selection`. By default the selections
+# only work for Neuromag data, but ``order='position'`` tries to mimic this
+# behavior for any data with sensor positions available. The channels are
+# grouped by sensor positions to 8 evenly sized regions. Notice that for this
+# to work effectively, all the data channels in the channel array must be
+# present. The ``order`` parameter can also be passed as an array of ints
+# (picks) to plot the channels in the given order.
+raw.plot(order='selection')
+
+###############################################################################
# We read the events from a file and passed it as a parameter when calling the
# method. The events are plotted as vertical lines so you can see how they
# align with the raw data.
@@ -49,9 +64,14 @@ raw.plot(block=True, events=events)
# We can check where the channels reside with ``plot_sensors``. Notice that
# this method (along with many other MNE plotting functions) is callable using
# any MNE data container where the channel information is available.
-raw.plot_sensors(kind='3d', ch_type='mag')
+raw.plot_sensors(kind='3d', ch_type='mag', ch_groups='position')
###############################################################################
+# We used ``ch_groups='position'`` to color code the different regions. It uses
+# the same algorithm for dividing the regions as ``order='position'`` of
+# :func:`raw.plot <mne.io.Raw.plot>`. You can also pass a list of picks to
+# color any channel group with different colors.
+#
# Now let's add some ssp projectors to the raw data. Here we read them from a
# file and plot them.
projs = mne.read_proj(op.join(data_path, 'sample_audvis_eog-proj.fif'))
@@ -82,7 +102,7 @@ raw.plot()
raw.plot_psd()
###############################################################################
-# Plotting channel wise power spectra is just as easy. The layout is inferred
+# Plotting channel-wise power spectra is just as easy. The layout is inferred
# from the data by default when plotting topo plots. This works for most data,
# but it is also possible to define the layouts by hand. Here we select a
# layout with only magnetometer channels and plot it. Then we plot the channel
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-med/python-mne.git
More information about the debian-med-commit
mailing list