[med-svn] [python-mne] 295/353: added comments on the pull request
Yaroslav Halchenko
debian at onerussian.com
Fri Nov 27 17:25:19 UTC 2015
This is an automated email from the git hooks/post-receive script.
yoh pushed a commit to tag 0.4
in repository python-mne.
commit a7331a3e2677e4579582020898c29394d87ac089
Author: Daniel Strohmeier <daniel.strohmeier at googlemail.com>
Date: Wed Jul 18 17:59:33 2012 +0200
added comments on the pull request
---
mne/epochs.py | 67 +++++++++++++++++++++++++++++-------------------
mne/tests/test_epochs.py | 47 ++++++++++++++++++---------------
2 files changed, 66 insertions(+), 48 deletions(-)
diff --git a/mne/epochs.py b/mne/epochs.py
index fdc723d..f7199ef 100644
--- a/mne/epochs.py
+++ b/mne/epochs.py
@@ -1,17 +1,20 @@
# Authors: Alexandre Gramfort <gramfort at nmr.mgh.harvard.edu>
# Matti Hamalainen <msh at nmr.mgh.harvard.edu>
+# Daniel Strohmeier <daniel.strohmeier at tu-ilmenau.de>
#
# License: BSD (3-clause)
-import copy
+import copy as cp
+import warnings
+
import numpy as np
+
import fiff
-import warnings
from .fiff import Evoked
from .fiff.pick import pick_types, channel_indices_by_type
from .fiff.proj import activate_proj, make_eeg_average_ref_proj
from .baseline import rescale
-
+from .utils import check_random_state
class Epochs(object):
"""List of Epochs
@@ -122,7 +125,7 @@ class Epochs(object):
self._bad_dropped = False
# Handle measurement info
- self.info = copy.deepcopy(raw.info)
+ self.info = cp.deepcopy(raw.info)
if picks is not None:
self.info['chs'] = [self.info['chs'][k] for k in picks]
self.info['ch_names'] = [self.info['ch_names'][k] for k in picks]
@@ -377,7 +380,7 @@ class Epochs(object):
warnings.warn("Bad epochs have not been dropped, indexing will be "
"inaccurate. Use drop_bad_epochs() or preload=True")
- epochs = copy.copy(self) # XXX : should use deepcopy but breaks ...
+ epochs = cp.copy(self) # XXX : should use deepcopy but breaks ...
epochs.events = np.atleast_2d(self.events[key])
if self.preload:
@@ -386,6 +389,8 @@ class Epochs(object):
else:
if isinstance(key, list):
key = np.array(key)
+ print key
+ print np.ndim(key)
if np.ndim(key) == 0:
epochs._data = self._data[key][np.newaxis, :, :]
else:
@@ -407,7 +412,7 @@ class Epochs(object):
The averaged epochs
"""
evoked = Evoked(None)
- evoked.info = copy.deepcopy(self.info)
+ evoked.info = cp.deepcopy(self.info)
n_channels = len(self.ch_names)
n_times = len(self.times)
if self.preload:
@@ -444,7 +449,7 @@ class Epochs(object):
evoked.data = evoked.data[data_picks]
return evoked
- def crop(self, tmin, tmax):
+ def crop(self, tmin=None, tmax=None, copy=False):
"""Crops a time interval from epochs object.
Parameters
@@ -453,7 +458,9 @@ class Epochs(object):
Start time of selection in seconds
tmax : float
End time of selection in seconds
-
+ copy : bool
+ If False epochs is cropped in place
+
Returns
-------
epochs : Epochs instance
@@ -463,19 +470,27 @@ class Epochs(object):
raise RuntimeError('Modifying data of epochs is only supported '
'when preloading is used. Use preload=True '
'in the constructor.')
- if tmin < self.tmin:
+ if tmin is None:
tmin = self.tmin
- if tmax > self.tmax:
+ elif tmin < self.tmin:
+ warnings.warn("tmin is not in epochs' time interval."
+ "tmin is set to epochs.tmin")
+ tmin = self.tmin
+ if tmax is None:
+ tmax = self.tmax
+ elif tmax > self.tmax:
+ warnings.warn("tmax is not in epochs' time interval."
+ "tmax is set to epochs.tmax")
tmax = self.tmax
- sfreq = self.info['sfreq']
- first_samp = int((tmin - self.tmin) * sfreq)
- last_samp = int((tmax - self.tmax) * sfreq) - 1
-
- self.tmin = tmin
- self.tmax = tmax
- self._data = self._data[:, :, first_samp:last_samp]
- return self
+ tmask = (self.times >= tmin) & (self.times <= tmax)
+
+ this_epochs = self if not copy else cp.deepcopy(self)
+ this_epochs.tmin = tmin
+ this_epochs.tmax = tmax
+ this_epochs.times = this_epochs.times[tmask]
+ this_epochs._data = this_epochs._data[:, :, tmask]
+ return this_epochs
def _is_good(e, ch_names, channel_type_idx, reject, flat):
@@ -514,15 +529,15 @@ def _is_good(e, ch_names, channel_type_idx, reject, flat):
return True
-def bootstrap(epochs, rng, return_idx=False):
- """Compute average of epochs selected by bootstrapping
+def bootstrap(epochs, random_state=None):
+ """Compute epochs selected by bootstrapping
Parameters
----------
epochs : Epochs instance
epochs data to be bootstrapped
- rng :
- random number generator.
+ random_state : None | int | np.random.RandomState
+ To specify the random generator state
return_idx : bool
If True the selected indices are provided as an output
@@ -536,11 +551,9 @@ def bootstrap(epochs, rng, return_idx=False):
'when preloading is used. Use preload=True '
'in the constructor.')
- epochs_bootstrap = copy.deepcopy(epochs)
+ rng = check_random_state(random_state)
+ epochs_bootstrap = cp.deepcopy(epochs)
n_events = len(epochs_bootstrap.events)
idx = rng.randint(0, n_events, n_events)
epochs_bootstrap = epochs_bootstrap[idx]
- if return_idx:
- return epochs_bootstrap, idx
- else:
- return epochs_bootstrap
+ return epochs_bootstrap
diff --git a/mne/tests/test_epochs.py b/mne/tests/test_epochs.py
index c751db3..a7f745b 100644
--- a/mne/tests/test_epochs.py
+++ b/mne/tests/test_epochs.py
@@ -137,22 +137,23 @@ def test_indexing_slicing():
assert_array_equal(data[0], data_normal[idx])
pos += 1
- # using indexing with int
+ # using indexing with an int
idx = np.random.randint(0, data_epochs2_sliced.shape[0], 1)
data = epochs2[idx].get_data()
assert_array_equal(data, data_normal[idx])
- # using indexing with array
+ # using indexing with an array
idx = np.random.randint(0, data_epochs2_sliced.shape[0], 10)
data = epochs2[idx].get_data()
assert_array_equal(data, data_normal[idx])
- # using indexing with list of indices
- #idx = list()
- #for k in range(3):
- # idx.append(np.random.randint(0, data_epochs2_sliced.shape[0], 1))
- # data = epochs2[idx].get_data()
- # assert_array_equal(data, data_normal[idx])
+ # using indexing with a list of indices
+ idx = [0]
+ data = epochs2[idx].get_data()
+ assert_array_equal(data, data_normal[idx])
+ idx = [0, 1]
+ data = epochs2[idx].get_data()
+ assert_array_equal(data, data_normal[idx])
def test_comparision_with_c():
@@ -175,33 +176,37 @@ def test_comparision_with_c():
def test_crop():
"""Test of crop of epochs
"""
- epochs = Epochs(raw, events[:20], event_id, tmin, tmax, picks=picks,
+ epochs = Epochs(raw, events[:5], event_id, tmin, tmax, picks=picks,
baseline=(None, 0), preload=False,
reject=reject, flat=flat)
- epochs2 = Epochs(raw, events[:20], event_id, tmin, tmax,
+ data_normal = epochs.get_data()
+
+ epochs2 = Epochs(raw, events[:5], event_id, tmin, tmax,
picks=picks, baseline=(None, 0), preload=True,
reject=reject, flat=flat)
- data_normal = epochs.get_data()
# indices for slicing
start_tsamp = tmin + 60 * epochs.info['sfreq']
end_tsamp = tmax - 60 * epochs.info['sfreq']
tmask = (epochs.times >= start_tsamp) & (epochs.times <= end_tsamp)
- assert((start_tsamp) > tmin)
- assert((end_tsamp) < tmax)
+ assert_true(start_tsamp > tmin)
+ assert_true(end_tsamp < tmax)
+ epochs3 = epochs2.crop(start_tsamp, end_tsamp, copy=True)
+ data3 = epochs3.get_data()
epochs2.crop(start_tsamp, end_tsamp)
- data = epochs2.get_data()
- assert_array_equal(data, data_normal[:, :, tmask])
-
+ data2 = epochs2.get_data()
+ assert_array_equal(data2, data_normal[:, :, tmask])
+ assert_array_equal(data3, data_normal[:, :, tmask])
+
def test_bootstrap():
- """Test of crop of epochs
+ """Test of bootstrapping of epochs
"""
- epochs = Epochs(raw, events[:20], event_id, tmin, tmax, picks=picks,
+ epochs = Epochs(raw, events[:5], event_id, tmin, tmax, picks=picks,
baseline=(None, 0), preload=True,
reject=reject, flat=flat)
data_normal = epochs._data
- rng = np.random.RandomState(0)
- epochs2, idx = bootstrap(epochs, rng, return_idx=True)
+ epochs2 = bootstrap(epochs, random_state=0)
n_events = len(epochs.events)
- assert_array_equal(epochs2._data, data_normal[idx])
+ assert_true(len(epochs2.events) == len(epochs.events))
+ assert_true(epochs._data.shape == epochs2._data.shape)
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-med/python-mne.git
More information about the debian-med-commit
mailing list