[med-svn] [python-mne] 292/353: added bootstrap and crop function to epochs
Yaroslav Halchenko
debian at onerussian.com
Fri Nov 27 17:25:18 UTC 2015
This is an automated email from the git hooks/post-receive script.
yoh pushed a commit to tag 0.4
in repository python-mne.
commit bf7fa3447926b9962bb3cc00b5aec54be7d6d7d7
Author: Daniel Strohmeier <daniel.strohmeier at googlemail.com>
Date: Wed Jul 18 15:44:30 2012 +0200
added bootstrap and crop function to epochs
---
mne/epochs.py | 70 +++++++++++++++++++++++++++++++++++++++++++++---
mne/tests/test_epochs.py | 53 ++++++++++++++++++++++++++++++++++++
2 files changed, 119 insertions(+), 4 deletions(-)
diff --git a/mne/epochs.py b/mne/epochs.py
index a49f60d..ceb9618 100644
--- a/mne/epochs.py
+++ b/mne/epochs.py
@@ -384,10 +384,12 @@ class Epochs(object):
if isinstance(key, slice):
epochs._data = self._data[key]
else:
- #make sure data remains a 3D array
- #Note: np.atleast_3d() doesn't do what we want
- epochs._data = np.array([self._data[key]])
-
+ if isinstance(key, list):
+ key = np.array(key)
+ if np.ndim(key) == 0:
+ epochs._data = self._data[key][np.newaxis, :, :]
+ else:
+ epochs._data = self._data[key]
return epochs
def average(self, keep_only_data_channels=True):
@@ -441,6 +443,39 @@ class Epochs(object):
evoked.info['nchan'] = len(data_picks)
evoked.data = evoked.data[data_picks]
return evoked
+
+ def crop(self, tmin, tmax):
+ """Crops a time interval from epochs object.
+
+ Parameters
+ ----------
+ tmin : float
+ Start time of selection in seconds
+ tmax : float
+ End time of selection in seconds
+
+ Returns
+ -------
+ epochs : Epochs instance
+ The bootstrap samples
+ """
+ if not self.preload:
+ raise RuntimeError('Modifying data of epochs is only supported '
+ 'when preloading is used. Use preload=True '
+ 'in the constructor.')
+ if tmin < self.tmin:
+ tmin = self.tmin
+ if tmax > self.tmax:
+ tmax = self.tmax
+
+ sfreq = self.info['sfreq']
+ first_samp = int((tmin - self.tmin) * sfreq)
+ last_samp = int((tmax - self.tmax) * sfreq) - 1
+
+ self.tmin = tmin
+ self.tmax = tmax
+ self._data = self._data[:, :, first_samp:last_samp]
+ return self
def _is_good(e, ch_names, channel_type_idx, reject, flat):
@@ -477,3 +512,30 @@ def _is_good(e, ch_names, channel_type_idx, reject, flat):
return False
return True
+
+
+def bootstrap(epochs, rng):
+ """Compute average of epochs selected by bootstrapping
+
+ Parameters
+ ----------
+ epochs : Epochs instance
+ epochs data to be bootstrapped
+ rng:
+ random number generator.
+
+ Returns
+ -------
+ epochs : Epochs instance
+ The bootstrap samples
+ """
+ if not epochs.preload:
+ raise RuntimeError('Modifying data of epochs is only supported '
+ 'when preloading is used. Use preload=True '
+ 'in the constructor.')
+
+ epochs_bootstrap = copy.deepcopy(epochs)
+ n_events = len(epochs_bootstrap.events)
+ idx = rng.randint(0, n_events, n_events)
+ epochs_bootstrap = epochs_bootstrap[idx]
+ return epochs_bootstrap, idx
diff --git a/mne/tests/test_epochs.py b/mne/tests/test_epochs.py
index 29bfc1c..d2eb23c 100644
--- a/mne/tests/test_epochs.py
+++ b/mne/tests/test_epochs.py
@@ -8,6 +8,7 @@ from numpy.testing import assert_array_equal, assert_array_almost_equal
import numpy as np
from .. import fiff, Epochs, read_events, pick_events
+from ..epochs import bootstrap
raw_fname = op.join(op.dirname(__file__), '..', 'fiff', 'tests', 'data',
'test_raw.fif')
@@ -135,6 +136,23 @@ def test_indexing_slicing():
data = epochs2_sliced[pos].get_data()
assert_array_equal(data[0], data_normal[idx])
pos += 1
+
+ # using indexing with int
+ idx = np.random.randint(0, data_epochs2_sliced.shape[0], 1)
+ data = epochs2[idx].get_data()
+ assert_array_equal(data, data_normal[idx])
+
+ # using indexing with array
+ idx = np.random.randint(0, data_epochs2_sliced.shape[0], 10)
+ data = epochs2[idx].get_data()
+ assert_array_equal(data, data_normal[idx])
+
+ # using indexing with list of indices
+ #idx = list()
+ #for k in range(3):
+ # idx.append(np.random.randint(0, data_epochs2_sliced.shape[0], 1))
+ # data = epochs2[idx].get_data()
+ # assert_array_equal(data, data_normal[idx])
def test_comparision_with_c():
@@ -152,3 +170,38 @@ def test_comparision_with_c():
assert_true(evoked.nave == c_evoked.nave)
assert_array_almost_equal(evoked_data, c_evoked_data, 10)
assert_array_almost_equal(evoked.times, c_evoked.times, 12)
+
+
+def test_crop():
+ """Test of crop of epochs
+ """
+ epochs = Epochs(raw, events[:20], event_id, tmin, tmax, picks=picks,
+ baseline=(None, 0), preload=False,
+ reject=reject, flat=flat)
+ epochs2 = Epochs(raw, events[:20], event_id, tmin, tmax,
+ picks=picks, baseline=(None, 0), preload=True,
+ reject=reject, flat=flat)
+ data_normal = epochs.get_data()
+
+ # indices for slicing
+ start_tsamp = tmin + 60 * epochs.info['sfreq']
+ end_tsamp = tmax - 60 * epochs.info['sfreq']
+ tmask = (epochs.times >= start_tsamp) & (epochs.times <= end_tsamp)
+ assert((start_tsamp) > tmin)
+ assert((end_tsamp) < tmax)
+ epochs2.crop(start_tsamp, end_tsamp)
+ data = epochs2.get_data()
+ assert_array_equal(data, data_normal[:, :, tmask])
+
+
+def test_bootstrap():
+ """Test of crop of epochs
+ """
+ epochs = Epochs(raw, events[:20], event_id, tmin, tmax, picks=picks,
+ baseline=(None, 0), preload=True,
+ reject=reject, flat=flat)
+ data_normal = epochs._data
+ rng = np.random.RandomState(0)
+ epochs2, idx = bootstrap(epochs, rng)
+ n_events = len(epochs.events)
+ assert_array_equal(epochs2._data, data_normal[idx])
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-med/python-mne.git
More information about the debian-med-commit
mailing list