[med-svn] [python-mne] 295/376: ENH : can pass seed in cluster level stats
Yaroslav Halchenko
debian at onerussian.com
Fri Nov 27 17:23:08 UTC 2015
This is an automated email from the git hooks/post-receive script.
yoh pushed a commit to annotated tag v0.1
in repository python-mne.
commit 7b19fcacd941306f6842766b0a5d22d38ea8c792
Author: Alexandre Gramfort <alexandre.gramfort at inria.fr>
Date: Wed Jun 8 14:00:50 2011 -0400
ENH : can pass seed in cluster level stats
---
mne/stats/cluster_level.py | 34 ++++++++++++++++++++++++----------
1 file changed, 24 insertions(+), 10 deletions(-)
diff --git a/mne/stats/cluster_level.py b/mne/stats/cluster_level.py
index eea5e10..88faea7 100644
--- a/mne/stats/cluster_level.py
+++ b/mne/stats/cluster_level.py
@@ -124,8 +124,9 @@ def _pval_from_histogram(T, H0, tail):
return pval
-def _one_permutation(X_full, slices, stat_fun, tail, threshold, connectivity):
- np.random.shuffle(X_full)
+def _one_permutation(X_full, slices, stat_fun, tail, threshold, connectivity,
+ rng):
+ rng.shuffle(X_full)
X_shuffle_list = [X_full[s] for s in slices]
T_obs_surr = stat_fun(*X_shuffle_list)
_, perm_clusters_sums = _find_clusters(T_obs_surr, threshold, tail,
@@ -140,7 +141,7 @@ def _one_permutation(X_full, slices, stat_fun, tail, threshold, connectivity):
def permutation_cluster_test(X, stat_fun=f_oneway, threshold=1.67,
n_permutations=1000, tail=0,
connectivity=None, n_jobs=1,
- verbose=5):
+ verbose=5, seed=None):
"""Cluster-level statistical permutation test
For a list of 2d-arrays of data, e.g. power values, calculate some
@@ -173,6 +174,8 @@ def permutation_cluster_test(X, stat_fun=f_oneway, threshold=1.67,
If > 0, print some text during computation.
n_jobs : int
Number of permutations to run in parallel (requires joblib package.)
+ seed : int or None
+ Seed the random number generator for results reproducibility.
Returns
-------
@@ -218,9 +221,13 @@ def permutation_cluster_test(X, stat_fun=f_oneway, threshold=1.67,
# Step 2: If we have some clusters, repeat process on permuted data
# -------------------------------------------------------------------
if len(clusters) > 0:
+ if seed is None:
+ seeds = [None] * n_permutations
+ else:
+ seeds = seed + np.arange(n_permutations)
H0 = parallel(my_one_permutation(X_full, slices, stat_fun, tail,
- threshold, connectivity)
- for _ in range(n_permutations))
+ threshold, connectivity, np.random.RandomState(s))
+ for s in seeds)
H0 = np.array(H0)
cluster_pv = _pval_from_histogram(cluster_stats, H0, tail)
return T_obs, clusters, cluster_pv, H0
@@ -239,9 +246,9 @@ def ttest_1samp(X):
def _one_1samp_permutation(n_samples, shape_ones, X_copy, threshold, tail,
- connectivity, stat_fun):
+ connectivity, stat_fun, rng):
# new surrogate data with random sign flip
- signs = np.sign(0.5 - np.random.rand(n_samples, *shape_ones))
+ signs = np.sign(0.5 - rng.rand(n_samples, *shape_ones))
X_copy *= signs
# Recompute statistic on randomized data
@@ -259,7 +266,7 @@ def _one_1samp_permutation(n_samples, shape_ones, X_copy, threshold, tail,
def permutation_cluster_1samp_test(X, threshold=1.67, n_permutations=1000,
tail=0, stat_fun=ttest_1samp,
connectivity=None, n_jobs=1,
- verbose=5):
+ verbose=5, seed=None):
"""Non-parametric cluster-level 1 sample T-test
From a array of observations, e.g. signal amplitudes or power spectrum
@@ -291,6 +298,8 @@ def permutation_cluster_1samp_test(X, threshold=1.67, n_permutations=1000,
If > 0, print some text during computation.
n_jobs : int
Number of permutations to run in parallel (requires joblib package.)
+ seed : int or None
+ Seed the random number generator for results reproducibility.
Returns
@@ -333,9 +342,14 @@ def permutation_cluster_1samp_test(X, threshold=1.67, n_permutations=1000,
# Step 2: If we have some clusters, repeat process on permuted data
# -------------------------------------------------------------------
if len(clusters) > 0:
+ if seed is None:
+ seeds = [None] * n_permutations
+ else:
+ seeds = seed + np.arange(n_permutations)
H0 = parallel(my_one_1samp_permutation(n_samples, shape_ones, X_copy,
- threshold, tail, connectivity, stat_fun)
- for _ in range(n_permutations))
+ threshold, tail, connectivity, stat_fun,
+ np.random.RandomState(s))
+ for s in seeds)
H0 = np.array(H0)
cluster_pv = _pval_from_histogram(cluster_stats, H0, tail)
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-med/python-mne.git
More information about the debian-med-commit
mailing list