[med-svn] [python-mne] 12/52: Multiple fixes: - indexing and slicing now always returns Epochs object - fixed bugs that occurs when Epochs only has one event (len(self.events) is 3 for single event) - still using shallow copy to avoid copying raw
Yaroslav Halchenko
debian at onerussian.com
Fri Nov 27 17:23:44 UTC 2015
This is an automated email from the git hooks/post-receive script.
yoh pushed a commit to annotated tag v0.2
in repository python-mne.
commit 6799065d5f906de43c2ffa2bf297a26b7aa419d5
Author: Martin Luessi <mluessi at nmr.mgh.harvard.edu>
Date: Wed Sep 28 15:26:21 2011 -0400
Multiple fixes:
- indexing and slicing now always returns Epochs object
- fixed bugs that occurs when Epochs only has one event (len(self.events) is 3 for single event)
- still using shallow copy to avoid copying raw
---
mne/epochs.py | 82 +++++++++++++++++++++++++-----------------------
mne/tests/test_epochs.py | 6 ++--
2 files changed, 45 insertions(+), 43 deletions(-)
diff --git a/mne/epochs.py b/mne/epochs.py
index f4c44e4..f61f393 100644
--- a/mne/epochs.py
+++ b/mne/epochs.py
@@ -88,9 +88,9 @@ class Epochs(object):
-------
epochs = Epochs(...)
- epochs[idx] : Return epoch with index idx (2D array, [n_channels, n_times])
-
- epochs[start:stop] : Return Epochs object with a subset of epochs
+ epochs[idx] : Epochs
+ Return Epochs object with a subset of epochs (supports single
+ index and python style slicing)
"""
def __init__(self, raw, events, event_id, tmin, tmax, baseline=(None, 0),
@@ -177,7 +177,7 @@ class Epochs(object):
# Select the desired events
selected = np.logical_and(events[:, 1] == 0, events[:, 2] == event_id)
self.events = events[selected]
- n_events = len(self.events)
+ n_events = len(self)
if n_events > 0:
print '%d matching events found' % n_events
@@ -232,7 +232,7 @@ class Epochs(object):
return
good_events = []
- n_events = len(self.events)
+ n_events = len(self)
for idx in range(n_events):
epoch = self._get_epoch_from_disk(idx)
if self._is_good_epoch(epoch):
@@ -246,7 +246,12 @@ class Epochs(object):
def _get_epoch_from_disk(self, idx):
"""Load one epoch from disk"""
sfreq = self.raw.info['sfreq']
- event_samp = self.events[idx, 0]
+
+ if self.events.ndim == 1:
+ #single event
+ event_samp = self.events[0]
+ else:
+ event_samp = self.events[idx, 0]
# Read a data segment
first_samp = self.raw.first_samp
@@ -268,15 +273,15 @@ class Epochs(object):
"""
n_channels = len(self.ch_names)
n_times = len(self.times)
- n_events = len(self.events)
+ n_events = len(self)
data = np.empty((n_events, n_channels, n_times))
cnt = 0
n_reject = 0
event_idx = []
for k in range(n_events):
- e = self._get_epoch_from_disk(k)
- if self._is_good_epoch(e):
- data[cnt] = self._get_epoch_from_disk(k)
+ epoch = self._get_epoch_from_disk(k)
+ if self._is_good_epoch(epoch):
+ data[cnt] = epoch
event_idx.append(k)
cnt += 1
else:
@@ -342,7 +347,7 @@ class Epochs(object):
epoch = self._data[self._current]
self._current += 1
else:
- if self._current >= len(self.events):
+ if self._current >= len(self):
raise StopIteration
epoch = self._get_epoch_from_disk(self._current)
self._current += 1
@@ -353,9 +358,9 @@ class Epochs(object):
def __repr__(self):
if not self.bad_dropped:
- s = "n_events : %s (good & bad)" % len(self.events)
+ s = "n_events : %s (good & bad)" % len(self)
else:
- s = "n_events : %s (all good)" % len(self.events)
+ s = "n_events : %s (all good)" % len(self)
s += ", tmin : %s (s)" % self.tmin
s += ", tmax : %s (s)" % self.tmax
s += ", baseline : %s" % str(self.baseline)
@@ -364,38 +369,35 @@ class Epochs(object):
def __len__(self):
"""Return length (number of events)
"""
- return len(self.events)
+ if self.events.ndim == 1:
+ return 1
+ else:
+ return len(self.events)
- def __getitem__(self, index):
- """Return epoch at index or an Epochs object with a slice of epochs
+ def __getitem__(self, key):
+ """Return an Epochs object with a subset of epochs
"""
- if isinstance(index, slice):
- # return Epochs object with slice of epochs
- if not self.bad_dropped:
- warnings.warn("Bad epochs have not been dropped, indexing "
- "will be inccurate. Use drop_bad_epochs() "
- "or preload=True")
-
- epoch_slice = copy.copy(self)
- epoch_slice.events = self.events[index]
-
- if self.preload:
- epoch_slice._data = self._data[index]
+ print key
+ if not self.bad_dropped:
+ warnings.warn("Bad epochs have not been dropped, indexing "
+ "will be inccurate. Use drop_bad_epochs() "
+ "or preload=True")
- return epoch_slice
+ epochs = copy.copy(self)
+ epochs.events = self.events[key]
- # return single epoch as 2D array
if self.preload:
- epoch = epoch = self._data[index]
- else:
- epoch = self._get_epoch_from_disk(index)
-
- if not self._is_good_epoch(epoch):
- warnings.warn("Bad epoch with index %d returned. "
- "Use drop_bad_epochs() or preload=True "
- "to prevent this." % (index))
+ if isinstance(key, slice):
+ epochs._data = self._data[key]
+ else:
+ #make sure data remains a 3D array
+ n_channels = len(self.ch_names)
+ n_times = len(self.times)
+ data = np.empty((1, n_channels, n_times))
+ data[0, :, :] = self._data[key]
+ epochs._data = data
- return epoch
+ return epochs
def average(self):
"""Compute average of epochs
@@ -409,7 +411,7 @@ class Epochs(object):
evoked.info = copy.deepcopy(self.info)
n_channels = len(self.ch_names)
n_times = len(self.times)
- n_events = len(self.events)
+ n_events = len(self)
if self.preload:
data = np.mean(self._data, axis=0)
else:
diff --git a/mne/tests/test_epochs.py b/mne/tests/test_epochs.py
index 545d91a..120cf87 100644
--- a/mne/tests/test_epochs.py
+++ b/mne/tests/test_epochs.py
@@ -96,10 +96,9 @@ def test_indexing_slicing():
if not preload:
epochs2.drop_bad_epochs()
- # get slice
+ # using slicing
epochs2_sliced = epochs2[start_index:end_index]
- # using get_data()
data_epochs2_sliced = epochs2_sliced.get_data()
assert_array_equal(data_epochs2_sliced, \
data_normal[start_index:end_index])
@@ -107,7 +106,8 @@ def test_indexing_slicing():
# using indexing
pos = 0
for idx in range(start_index, end_index):
- assert_array_equal(epochs2_sliced[pos], data_normal[idx])
+ data = epochs2_sliced[pos].get_data()
+ assert_array_equal(data[0], data_normal[idx])
pos += 1
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-med/python-mne.git
More information about the debian-med-commit
mailing list