[med-svn] [python-mne] 07/52: added indexing and slicing operations for epoch
Yaroslav Halchenko
debian at onerussian.com
Fri Nov 27 17:23:44 UTC 2015
This is an automated email from the git hooks/post-receive script.
yoh pushed a commit to annotated tag v0.2
in repository python-mne.
commit 866040876254090dae33d2abd2d8545f8ba7c335
Author: Martin Luessi <mluessi at nmr.mgh.harvard.edu>
Date: Tue Sep 27 17:26:02 2011 -0400
added indexing and slicing operations for epoch
---
mne/epochs.py | 85 +++++++++++++++++++++++++++++++++++++++++++++---
mne/tests/test_epochs.py | 40 +++++++++++++++++++++++
2 files changed, 120 insertions(+), 5 deletions(-)
diff --git a/mne/epochs.py b/mne/epochs.py
index c1a52da..a3832cc 100644
--- a/mne/epochs.py
+++ b/mne/epochs.py
@@ -79,6 +79,17 @@ class Epochs(object):
Return Evoked object containing averaged epochs as a
2D array [n_channels x n_times].
+ drop_bad_epochs() : None
+ Drop all epochs marked as bad. Should be used before indexing and
+ slicing operations.
+
+ Indexing and Slicing:
+ -------
+ epochs = Epochs(...)
+
+ epochs[idx] : Return epoch with index idx (2D array, [n_channels, n_times])
+
+ epochs[start:stop] : Return Epochs object with a subset of epochs
"""
def __init__(self, raw, events, event_id, tmin, tmax, baseline=(None, 0),
@@ -96,6 +107,7 @@ class Epochs(object):
self.preload = preload
self.reject = reject
self.flat = flat
+ self.bad_dropped = False
# Handle measurement info
self.info = copy.deepcopy(raw.info)
@@ -183,7 +195,9 @@ class Epochs(object):
self._reject_setup()
if self.preload:
- self._data = self._get_data_from_disk()
+ self._data, good_events = self._get_data_from_disk()
+ self.events = self.events[good_events,:]
+ self.bad_dropped = True
def drop_picks(self, bad_picks):
"""Drop some picks
@@ -206,6 +220,28 @@ class Epochs(object):
if self.preload:
self._data = self._data[:, idx, :]
+ def drop_bad_epochs(self):
+ """Drop bad epochs.
+
+ Should be used before slicing operations.
+
+ Warning: Operation is slow since all epochs have to be read from disk
+ """
+ if self.bad_dropped:
+ return
+
+ good = []
+ n_events = len(self.events)
+ for idx in range(n_events):
+ epoch = self._get_epoch_from_disk(idx)
+ if self._is_good_epoch(epoch):
+ good.append(idx)
+
+ self.events = self.events[good,:]
+ self.bad_dropped = True
+
+ print "%d bad epochs dropped" % (n_events - len(good))
+
def _get_epoch_from_disk(self, idx):
"""Load one epoch from disk"""
sfreq = self.raw.info['sfreq']
@@ -235,18 +271,20 @@ class Epochs(object):
data = np.empty((n_events, n_channels, n_times))
cnt = 0
n_reject = 0
+ event_idx = []
for k in range(n_events):
e = self._get_epoch_from_disk(k)
if self._is_good_epoch(e):
data[cnt] = self._get_epoch_from_disk(k)
+ event_idx.append(k)
cnt += 1
else:
n_reject += 1
print "Rejecting %d epochs." % n_reject
- return data[:cnt]
+ return data[:cnt], event_idx
def _is_good_epoch(self, data):
- """Determine is epoch is good
+ """Determine if epoch is good
"""
n_times = len(self.times)
if self.reject is None and self.flat is None:
@@ -268,7 +306,8 @@ class Epochs(object):
if self.preload:
return self._data
else:
- return self._get_data_from_disk()
+ data, _ = self._get_data_from_disk()
+ return data
def _reject_setup(self):
"""Setup reject process
@@ -312,12 +351,48 @@ class Epochs(object):
return epoch
def __repr__(self):
- s = "n_events : %s" % len(self.events)
+ if not self.bad_dropped:
+ s = "n_events : %s (good & bad)" % len(self.events)
+ else:
+ s = "n_events : %s (all good)" % len(self.events)
s += ", tmin : %s (s)" % self.tmin
s += ", tmax : %s (s)" % self.tmax
s += ", baseline : %s" % str(self.baseline)
return "Epochs (%s)" % s
+ def __getslice__(self, start, end):
+ """Return an Epoch object with a subset of epochs.
+ """
+ if not self.bad_dropped:
+ print "Warning: bad epochs have not been dropped, indexing will " \
+ "be inccurate. Use drop_bad_epochs() or preload=True"
+
+ epoch_slice = copy.copy(self)
+ epoch_slice.events = self.events[start:end]
+
+ if self.preload:
+ epoch_slice._data = self._data[start:end]
+
+ return epoch_slice
+
+ def __getitem__(self, index):
+ """Return epoch at index
+ """
+ if index < 0 or index >= len(self.events):
+ raise IndexError("Epoch index out of bounds")
+
+ if self.preload:
+ epoch = epoch = self._data[index]
+ else:
+ epoch = self._get_epoch_from_disk(index)
+
+ if not self._is_good_epoch(epoch):
+ print "Warning: Bad epoch with index %d returned. Use " \
+ "drop_bad_epochs() or preload=True to prevent this." \
+ % (index)
+
+ return epoch
+
def average(self):
"""Compute average of epochs
diff --git a/mne/tests/test_epochs.py b/mne/tests/test_epochs.py
index 5d97441..94abf4f 100644
--- a/mne/tests/test_epochs.py
+++ b/mne/tests/test_epochs.py
@@ -70,6 +70,46 @@ def test_preload_epochs():
data_no_preload = epochs.get_data()
assert_array_equal(data_preload, data_no_preload)
+def test_indexing_slicing():
+ """Test of indexing and slicing operations
+ """
+ epochs = Epochs(raw, events[:20], event_id, tmin, tmax, picks=picks,
+ baseline=(None, 0), preload=False,
+ reject=reject, flat=flat)
+
+ data_normal = epochs.get_data()
+
+ n_good_events = data_normal.shape[0]
+
+ # indices for slicing
+ start_index = 1
+ end_index = n_good_events - 1
+
+ assert((end_index - start_index) > 0)
+
+ for preload in [True, False]:
+ epochs2 = Epochs(raw, events[:20], event_id, tmin, tmax,
+ picks=picks, baseline=(None, 0), preload=preload,
+ reject=reject, flat=flat)
+
+ if not preload:
+ epochs2.drop_bad_epochs()
+
+ # get slice
+ epochs2_sliced = epochs2[start_index:end_index]
+
+ # using get_data()
+ data_epochs2_sliced = epochs2_sliced.get_data()
+ assert_array_equal(data_epochs2_sliced, \
+ data_normal[start_index:end_index])
+
+ # using indexing
+ pos = 0
+ for idx in range(start_index, end_index):
+ assert_array_equal(epochs2_sliced[pos], data_normal[idx])
+ pos += 1
+
+
def test_comparision_with_c():
"""Test of average obtained vs C code
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-med/python-mne.git
More information about the debian-med-commit
mailing list