[med-svn] [python-mne] 294/376: factoring joblib parallel code
Yaroslav Halchenko
debian at onerussian.com
Fri Nov 27 17:23:08 UTC 2015
This is an automated email from the git hooks/post-receive script.
yoh pushed a commit to annotated tag v0.1
in repository python-mne.
commit bda9bd53e3382d342c66ad5d6548e98e9cd82515
Author: Alexandre Gramfort <alexandre.gramfort at inria.fr>
Date: Wed Jun 8 12:16:05 2011 -0400
factoring joblib parallel code
---
examples/stats/plot_cluster_stats_evoked.py | 3 +-
mne/minimum_norm/time_frequency.py | 19 +-----
mne/parallel.py | 50 +++++++++++++++
mne/stats/cluster_level.py | 95 ++++++++++++++++++-----------
mne/stats/permutations.py | 20 +-----
mne/time_frequency/tfr.py | 27 +++-----
6 files changed, 123 insertions(+), 91 deletions(-)
diff --git a/examples/stats/plot_cluster_stats_evoked.py b/examples/stats/plot_cluster_stats_evoked.py
index 34a4ad1..63d344c 100644
--- a/examples/stats/plot_cluster_stats_evoked.py
+++ b/examples/stats/plot_cluster_stats_evoked.py
@@ -58,7 +58,8 @@ condition2 = condition2[:, 0, :] # take only one channel to get a 2D array
threshold = 6.0
T_obs, clusters, cluster_p_values, H0 = \
permutation_cluster_test([condition1, condition2],
- n_permutations=1000, threshold=threshold, tail=1)
+ n_permutations=1000, threshold=threshold, tail=1,
+ n_jobs=2)
###############################################################################
# Plot
diff --git a/mne/minimum_norm/time_frequency.py b/mne/minimum_norm/time_frequency.py
index 1262669..b36adce 100644
--- a/mne/minimum_norm/time_frequency.py
+++ b/mne/minimum_norm/time_frequency.py
@@ -10,6 +10,7 @@ from ..source_estimate import SourceEstimate
from ..time_frequency.tfr import cwt, morlet
from ..baseline import rescale
from .inverse import combine_xyz, prepare_inverse_operator
+from ..parallel import parallel_func
def _compute_power(data, K, sel, Ws, source_ori, use_fft, Vh):
@@ -89,23 +90,7 @@ def source_induced_power(epochs, inverse_operator, bands, lambda2=1.0 / 9.0,
Number of jobs to run in parallel
"""
- if n_jobs == -1:
- try:
- import multiprocessing
- n_jobs = multiprocessing.cpu_count()
- except ImportError:
- print "multiprocessing not installed. Cannot run in parallel."
- n_jobs = 1
-
- try:
- from scikits.learn.externals.joblib import Parallel, delayed
- parallel = Parallel(n_jobs)
- my_compute_power = delayed(_compute_power)
- except ImportError:
- print "joblib not installed. Cannot run in parallel."
- n_jobs = 1
- my_compute_power = _compute_power
- parallel = list
+ parallel, my_compute_power, n_jobs = parallel_func(_compute_power, n_jobs)
#
# Set up the inverse according to the parameters
diff --git a/mne/parallel.py b/mne/parallel.py
new file mode 100644
index 0000000..12e4164
--- /dev/null
+++ b/mne/parallel.py
@@ -0,0 +1,50 @@
+"""Parralle util function
+"""
+
+# Author: Alexandre Gramfort <gramfort at nmr.mgh.harvard.edu>
+#
+# License: Simplified BSD
+
+
+def parallel_func(func, n_jobs, verbose=5):
+ """Return parallel instance with delayed function
+
+ Util function to use joblib only if available
+
+ Parameters
+ ----------
+ func: callable
+ A function
+ n_jobs: int
+ Number of jobs to run in parallel
+ verbose: int
+ Verbosity level
+
+ Returns
+ -------
+ parallel: instance of joblib.Parallel or list
+ The parallel object
+ my_func: callable
+ func if not parallel or delayed(func)
+ n_jobs: int
+ Number of jobs >= 0
+ """
+ try:
+ from scikits.learn.externals.joblib import Parallel, delayed
+ parallel = Parallel(n_jobs, verbose=verbose)
+ my_func = delayed(func)
+
+ if n_jobs == -1:
+ try:
+ import multiprocessing
+ n_jobs = multiprocessing.cpu_count()
+ except ImportError:
+ print "multiprocessing not installed. Cannot run in parallel."
+ n_jobs = 1
+
+ except ImportError:
+ print "joblib not installed. Cannot run in parallel."
+ n_jobs = 1
+ my_func = func
+ parallel = list
+ return parallel, my_func, n_jobs
diff --git a/mne/stats/cluster_level.py b/mne/stats/cluster_level.py
index c1bda6c..eea5e10 100644
--- a/mne/stats/cluster_level.py
+++ b/mne/stats/cluster_level.py
@@ -10,6 +10,7 @@ import numpy as np
from scipy import stats, sparse, ndimage
from .parametric import f_oneway
+from ..parallel import parallel_func
def _get_components(x_in, connectivity):
@@ -123,9 +124,23 @@ def _pval_from_histogram(T, H0, tail):
return pval
+def _one_permutation(X_full, slices, stat_fun, tail, threshold, connectivity):
+ np.random.shuffle(X_full)
+ X_shuffle_list = [X_full[s] for s in slices]
+ T_obs_surr = stat_fun(*X_shuffle_list)
+ _, perm_clusters_sums = _find_clusters(T_obs_surr, threshold, tail,
+ connectivity)
+
+ if len(perm_clusters_sums) > 0:
+ return np.max(perm_clusters_sums)
+ else:
+ return 0
+
+
def permutation_cluster_test(X, stat_fun=f_oneway, threshold=1.67,
n_permutations=1000, tail=0,
- connectivity=None, verbose=True):
+ connectivity=None, n_jobs=1,
+ verbose=5):
"""Cluster-level statistical permutation test
For a list of 2d-arrays of data, e.g. power values, calculate some
@@ -154,8 +169,10 @@ def permutation_cluster_test(X, stat_fun=f_oneway, threshold=1.67,
Defines connectivity between features. The matrix is assumed to
be symmetric and only the upper triangular half is used.
Defaut is None, i.e, no connectivity.
- verbose: boolean
- If True print some text.
+ verbose : int
+ If > 0, print some text during computation.
+ n_jobs : int
+ Number of permutations to run in parallel (requires joblib package.)
Returns
-------
@@ -195,24 +212,16 @@ def permutation_cluster_test(X, stat_fun=f_oneway, threshold=1.67,
slices = [slice(splits_idx[k], splits_idx[k + 1])
for k in range(len(X))]
+ parallel, my_one_permutation, _ = parallel_func(_one_permutation, n_jobs,
+ verbose)
+
# Step 2: If we have some clusters, repeat process on permuted data
# -------------------------------------------------------------------
if len(clusters) > 0:
- H0 = np.zeros(n_permutations) # histogram
- for i_s in range(n_permutations):
- if verbose:
- print "Permutation %d / %d" % (i_s + 1, n_permutations)
- np.random.shuffle(X_full)
- X_shuffle_list = [X_full[s] for s in slices]
- T_obs_surr = stat_fun(*X_shuffle_list)
- _, perm_clusters_sums = _find_clusters(T_obs_surr, threshold, tail,
- connectivity)
-
- if len(perm_clusters_sums) > 0:
- H0[i_s] = np.max(perm_clusters_sums)
- else:
- H0[i_s] = 0
-
+ H0 = parallel(my_one_permutation(X_full, slices, stat_fun, tail,
+ threshold, connectivity)
+ for _ in range(n_permutations))
+ H0 = np.array(H0)
cluster_pv = _pval_from_histogram(cluster_stats, H0, tail)
return T_obs, clusters, cluster_pv, H0
else:
@@ -229,9 +238,28 @@ def ttest_1samp(X):
return T
+def _one_1samp_permutation(n_samples, shape_ones, X_copy, threshold, tail,
+ connectivity, stat_fun):
+ # new surrogate data with random sign flip
+ signs = np.sign(0.5 - np.random.rand(n_samples, *shape_ones))
+ X_copy *= signs
+
+ # Recompute statistic on randomized data
+ T_obs_surr = stat_fun(X_copy)
+ _, perm_clusters_sums = _find_clusters(T_obs_surr, threshold, tail,
+ connectivity)
+
+ if len(perm_clusters_sums) > 0:
+ idx_max = np.argmax(np.abs(perm_clusters_sums))
+ return perm_clusters_sums[idx_max] # get max with sign info
+ else:
+ return 0.0
+
+
def permutation_cluster_1samp_test(X, threshold=1.67, n_permutations=1000,
tail=0, stat_fun=ttest_1samp,
- connectivity=None):
+ connectivity=None, n_jobs=1,
+ verbose=5):
"""Non-parametric cluster-level 1 sample T-test
From a array of observations, e.g. signal amplitudes or power spectrum
@@ -259,6 +287,11 @@ def permutation_cluster_1samp_test(X, threshold=1.67, n_permutations=1000,
Defines connectivity between features. The matrix is assumed to
be symmetric and only the upper triangular half is used.
Defaut is None, i.e, no connectivity.
+ verbose : int
+ If > 0, print some text during computation.
+ n_jobs : int
+ Number of permutations to run in parallel (requires joblib package.)
+
Returns
-------
@@ -294,26 +327,16 @@ def permutation_cluster_1samp_test(X, threshold=1.67, n_permutations=1000,
clusters, cluster_stats = _find_clusters(T_obs, threshold, tail,
connectivity)
+ parallel, my_one_1samp_permutation, _ = parallel_func(_one_1samp_permutation,
+ n_jobs, verbose)
+
# Step 2: If we have some clusters, repeat process on permuted data
# -------------------------------------------------------------------
if len(clusters) > 0:
- H0 = np.empty(n_permutations) # histogram
- for i_s in range(n_permutations):
- # new surrogate data with random sign flip
- signs = np.sign(0.5 - np.random.rand(n_samples, *shape_ones))
- X_copy *= signs
-
- # Recompute statistic on randomized data
- T_obs_surr = stat_fun(X_copy)
- _, perm_clusters_sums = _find_clusters(T_obs_surr, threshold, tail,
- connectivity)
-
- if len(perm_clusters_sums) > 0:
- idx_max = np.argmax(np.abs(perm_clusters_sums))
- H0[i_s] = perm_clusters_sums[idx_max] # get max with sign info
- else:
- H0[i_s] = 0
-
+ H0 = parallel(my_one_1samp_permutation(n_samples, shape_ones, X_copy,
+ threshold, tail, connectivity, stat_fun)
+ for _ in range(n_permutations))
+ H0 = np.array(H0)
cluster_pv = _pval_from_histogram(cluster_stats, H0, tail)
return T_obs, clusters, cluster_pv, H0
diff --git a/mne/stats/permutations.py b/mne/stats/permutations.py
index 59c9e74..39e141c 100644
--- a/mne/stats/permutations.py
+++ b/mne/stats/permutations.py
@@ -9,6 +9,8 @@
from math import sqrt
import numpy as np
+from ..parallel import parallel_func
+
def bin_perm_rep(ndim, a=0, b=1):
"""bin_perm_rep(ndim) -> ndim permutations with repetitions of (a,b).
@@ -128,23 +130,7 @@ def permutation_t_test(X, n_permutations=10000, tail=0, n_jobs=1):
else:
perms = np.sign(0.5 - np.random.rand(n_permutations, n_samples))
- try:
- from scikits.learn.externals.joblib import Parallel, delayed
- parallel = Parallel(n_jobs)
- my_max_stat = delayed(_max_stat)
- except ImportError:
- print "joblib not installed. Cannot run in parallel."
- n_jobs = 1
- my_max_stat = _max_stat
- parallel = list
-
- if n_jobs == -1:
- try:
- import multiprocessing
- n_jobs = multiprocessing.cpu_count()
- except ImportError:
- print "multiprocessing not installed. Cannot run in parallel."
- n_jobs = 1
+ parallel, my_max_stat, n_jobs = parallel_func(_max_stat, n_jobs)
max_abs = np.concatenate(parallel(my_max_stat(X, X2, p, dof_scaling)
for p in np.array_split(perms, n_jobs)))
diff --git a/mne/time_frequency/tfr.py b/mne/time_frequency/tfr.py
index ff97ed4..6af02ec 100644
--- a/mne/time_frequency/tfr.py
+++ b/mne/time_frequency/tfr.py
@@ -12,6 +12,7 @@ import numpy as np
from scipy import linalg
from scipy.fftpack import fftn, ifftn
from ..baseline import rescale
+from ..parallel import parallel_func
def morlet(Fs, freqs, n_cycles=7, sigma=None):
@@ -276,15 +277,7 @@ def single_trial_power(epochs, Fs, frequencies, use_fft=True, n_cycles=7,
# Precompute wavelets for given frequency range to save time
Ws = morlet(Fs, frequencies, n_cycles=n_cycles)
- try:
- from scikits.learn.externals.joblib import Parallel, delayed
- parallel = Parallel(n_jobs)
- my_cwt = delayed(cwt)
- except ImportError:
- print "joblib not installed. Cannot run in parallel."
- n_jobs = 1
- my_cwt = cwt
- parallel = list
+ parallel, my_cwt, _ = parallel_func(cwt, n_jobs)
print "Computing time-frequency power on single epochs..."
@@ -347,13 +340,9 @@ def induced_power(epochs, Fs, frequencies, use_fft=True, n_cycles=7,
# Precompute wavelets for given frequency range to save time
Ws = morlet(Fs, frequencies, n_cycles=n_cycles)
- try:
- import joblib
- except ImportError:
- print "joblib not installed. Cannot run in parallel."
- n_jobs = 1
+ parallel, my_time_frequency, _ = parallel_func(_time_frequency, n_jobs)
- if n_jobs == 1:
+ if my_time_frequency is _time_frequency: # not parallel
psd = np.empty((n_channels, n_frequencies, n_times))
plf = np.empty((n_channels, n_frequencies, n_times), dtype=np.complex)
@@ -362,11 +351,9 @@ def induced_power(epochs, Fs, frequencies, use_fft=True, n_cycles=7,
psd[c], plf[c] = _time_frequency(X, Ws, use_fft)
else:
- from joblib import Parallel, delayed
- psd_plf = Parallel(n_jobs=n_jobs)(
- delayed(_time_frequency)(
- np.squeeze(epochs[:, c, :]), Ws, use_fft)
- for c in range(n_channels))
+ psd_plf = parallel(my_time_frequency(np.squeeze(epochs[:, c, :]),
+ Ws, use_fft)
+ for c in range(n_channels))
psd = np.zeros((n_channels, n_frequencies, n_times))
plf = np.zeros((n_channels, n_frequencies, n_times), dtype=np.complex)
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-med/python-mne.git
More information about the debian-med-commit
mailing list