[pymvpa] RFE + SplitClassifier

Matthias Hampel mhampel at uni-osnabrueck.de
Fri Sep 7 09:03:19 UTC 2012


Hi,

I could not figure out why the SplitClassifier doesn't work with RFE, 
but I found a solution with cross-validation (see code below).
Within the cross-validation a callback function saves the sensitivity 
map of each iteration. The sensitivity maps could not be mapped back to 
the original dataset, therefore I integrated a little mapping function 
into the callback function.

Cheers,
Matthias

    # callback function
     def store_me(data, node, result):
         # get sensitivity map
         sens = node.measure.mapper.forward(dataset)
         # merge to max values of all maps
         sensmax = sens.get_mapped(maxofabs_sample())
         # make a list of selected voxel indices
         maxid = deepcopy(sensmax.fa.voxel_indices)
         maxid_list = maxid.tolist()
         # initialise sensitivity map with zeros
         sens_mapped = np.zeros((1,dataset.nfeatures))
         ind = 0
         # search for selected voxel
         for i, selvox in enumerate(maxid_list):
             # search for index between last index and end of list
             ind = dsId_list.index(selvox, ind, len(dsId_list))
             # write selecet voxel to correct position (as in dataset)
             sens_mapped[0,ind] = abs(sensmax[0,i].samples)
         sensitivities.append(sens_mapped)

     rfesvm_split = SplitClassifier(LinearCSVMC(), OddEvenPartitioner())

     rfe = RFE(rfesvm_split.get_sensitivity_analyzer(
             # take sensitivities per each split, L2 norm, mean, abs them
             postproc=ChainMapper([ FxMapper('features', l2_normed),
                                    FxMapper('samples', np.mean),
                                    FxMapper('samples', np.abs)])),
                   # use the error stored in the confusion matrix of 
split classifier
                   ConfusionBasedError(rfesvm_split, 
confusion_state='stats'),
                   # we just extract error from confusion, so need to 
split dataset
                   Repeater(2),
                   # select 80% of the best on each step
                   fselector=FractionTailSelector(
                       0.80,
                       mode='select', tail='upper'),
                   # and stop whenever error didn't improve for up to 5 
steps
stopping_criterion=NBackHistoryStopCrit(BestDetector(), 5),
                   # we just extract it from existing confusion
                   train_pmeasure=False,
                   # but we do want to update sensitivities on each step
                   update_sensitivity=True)

     clf = FeatureSelectionClassifier(
             LinearCSVMC(),
             # on features selected via RFE
             rfe,
             # custom description
             descr='LinSVM+RFE(splits_avg)' )

     # (chunks-1) fold cross validation
     # run callback function for each iteration to store sensitivity maps
     cv = CrossValidation(clf, NFoldPartitioner(), callback=store_me)
     error = cv(dataset)

     print 'mean error: %f' % np.mean(error)
     print 'min error: %f' % np.min(error)
     print 'max error: %f' % np.max(error)


Am 24.07.2012 10:56, schrieb Matthias Hampel:
> Am 22.07.2012 10:33, schrieb Michael Hanke:
>> Hi,
>>
>> On Fri, Jul 20, 2012 at 08:35:25AM +0200,mhampel at uni-osnabrueck.de  wrote:
>>> I'm currently working on my master thesis and using the PyMVPA toolbox for
>>> the analysis of my fMRI data. My script for Recursive Feature Elimination
>>> (RFE) is working with a CrossValidation but unfortunately not with a
>>> SplitClassifier. Could you please give me some advice on that?
>>>
>>> In my script (see below) I use the RFE example from the documentation. If
>>> I add a CrossValidation I get an error value for each validation step. But
>>> I'm also interested in the sensitivity maps of each step and I couldn't
>>> figure out if that is possible with CrossValidation. Therefore, I tried to
>>> use a SplittClassifier but I always get the same error message in
>>> self.train(ds).
>>>
>>> Could someone tell me the difference between SplitClassifier and
>>> CrossValidation? I assumed that the SplitClassifier also does a
>>> cross-validation internally. What do I have to change in my code to make
>>> it work?
>> Your assumption is correct and your approach sounds appropriate. Could
>> you please provide some more information on your dataset (``print
>> ds.summary``) and the actual error message (incl. traceback) that you
>> are observing? At first glance, and without this additional information
>> I can't see an obvious problem.
>>
>> Cheers,
>>
>> Michael
>>
>>
>
> It's good to know that I wasn't completely wrong. Here is the error 
> message and the whole output of ds.summary
>
> Best,
> Matthias
>
> *Error message:*
>
> Traceback (most recent call last):
>   File "RFE_Splitt1.py", line 114, in <module>
>     sens = cv_sensana(dataset)
>   File "/usr/lib/pymodules/python2.6/mvpa2/base/learner.py", line 229, 
> in __call__
>     self.train(ds)
>   File "/usr/lib/pymodules/python2.6/mvpa2/base/learner.py", line 119, 
> in train
>     result = self._train(ds)
>   File "/usr/lib/pymodules/python2.6/mvpa2/measures/base.py", line 
> 782, in _train
>     return clf.train(dataset)
>   File "/usr/lib/pymodules/python2.6/mvpa2/base/learner.py", line 119, 
> in train
>     result = self._train(ds)
>   File "/usr/lib/pymodules/python2.6/mvpa2/clfs/meta.py", line 1211, 
> in _train
>     clf = clf_template.clone()
>   File "/usr/lib/pymodules/python2.6/mvpa2/clfs/base.py", line 326, in 
> clone
>     return deepcopy(self)
>   File "/usr/lib/python2.6/copy.py", line 189, in deepcopy
>     y = _reconstruct(x, rv, 1, memo)
>   File "/usr/lib/python2.6/copy.py", line 338, in _reconstruct
>     state = deepcopy(state, memo)
>   File "/usr/lib/python2.6/copy.py", line 162, in deepcopy
>     y = copier(x, memo)
>   File "/usr/lib/python2.6/copy.py", line 255, in _deepcopy_dict
>     y[deepcopy(key, memo)] = deepcopy(value, memo)
>   File "/usr/lib/python2.6/copy.py", line 189, in deepcopy
>     y = _reconstruct(x, rv, 1, memo)
>   File "/usr/lib/python2.6/copy.py", line 338, in _reconstruct
>     state = deepcopy(state, memo)
>   File "/usr/lib/python2.6/copy.py", line 162, in deepcopy
>     y = copier(x, memo)
>   File "/usr/lib/python2.6/copy.py", line 255, in _deepcopy_dict
>     y[deepcopy(key, memo)] = deepcopy(value, memo)
>   File "/usr/lib/python2.6/copy.py", line 189, in deepcopy
>     y = _reconstruct(x, rv, 1, memo)
>   File "/usr/lib/python2.6/copy.py", line 338, in _reconstruct
>     state = deepcopy(state, memo)
>   File "/usr/lib/python2.6/copy.py", line 162, in deepcopy
>     y = copier(x, memo)
>   File "/usr/lib/python2.6/copy.py", line 255, in _deepcopy_dict
>     y[deepcopy(key, memo)] = deepcopy(value, memo)
>   File "/usr/lib/python2.6/copy.py", line 189, in deepcopy
>     y = _reconstruct(x, rv, 1, memo)
>   File "/usr/lib/python2.6/copy.py", line 338, in _reconstruct
>     state = deepcopy(state, memo)
>   File "/usr/lib/python2.6/copy.py", line 162, in deepcopy
>     y = copier(x, memo)
>   File "/usr/lib/python2.6/copy.py", line 255, in _deepcopy_dict
>     y[deepcopy(key, memo)] = deepcopy(value, memo)
>   File "/usr/lib/python2.6/copy.py", line 162, in deepcopy
>     y = copier(x, memo)
>   File "/usr/lib/python2.6/copy.py", line 228, in _deepcopy_list
>     y.append(deepcopy(a, memo))
>   File "/usr/lib/python2.6/copy.py", line 189, in deepcopy
>     y = _reconstruct(x, rv, 1, memo)
>   File "/usr/lib/python2.6/copy.py", line 338, in _reconstruct
>     state = deepcopy(state, memo)
>   File "/usr/lib/python2.6/copy.py", line 162, in deepcopy
>     y = copier(x, memo)
>   File "/usr/lib/python2.6/copy.py", line 255, in _deepcopy_dict
>     y[deepcopy(key, memo)] = deepcopy(value, memo)
>   File "/usr/lib/python2.6/copy.py", line 189, in deepcopy
>     y = _reconstruct(x, rv, 1, memo)
>   File "/usr/lib/python2.6/copy.py", line 323, in _reconstruct
>     y = callable(*args)
>   File "/usr/lib/python2.6/copy_reg.py", line 93, in __newobj__
>     return cls.__new__(cls, *args)
> TypeError: object.__new__(numpy.ufunc) is not safe, use 
> numpy.ufunc.__new__()
>
> *
> ds.summary:*
>
> <bound method Dataset.summary of Dataset(array([[ 1.66507124, 
> 0.98552763,  0.99950778, ..., -0.80142293,
>         -1.55295612, -0.11781561],
>        [-1.2762681 ,  1.06889722, -0.76036066, ..., -0.21888896,
>         -1.43887115, -0.54421524],
>        [ 0.06450247,  0.27004099, -0.85426304, ...,  0.79609439,
>          0.36335411,  0.14502004],
>        ...,
>        [-0.12649097,  0.53453771,  0.45816722, ..., -0.11964775,
>          0.12668331,  0.10950945],
>        [-1.07348862, -0.89344273, -0.7004796 , ...,  0.08824224,
>         -1.11599424,  1.2118099 ],
>        [ 0.75494889,  1.20532665,  1.58585258, ..., -0.63245843,
>          0.89628655, -0.26900842]]), 
> sa=SampleAttributesCollection(items=[ArrayCollectable(name='chunks', 
> doc=None, value=array([ 1.,  1.,  1.,  1.,  1.,  1.,  1.,  1., 1.,  
> 1.,  1.,  1.,  1.,
>         1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1., 1.,  1.,
>         1.,  1.,  1.,  1.,  2.,  2.,  2.,  2.,  2.,  2.,  2., 2.,  2.,
>         2.,  2.,  2.,  2.,  2.,  2.,  2.,  2.,  2.,  2.,  2., 2.,  2.,
>         2.,  2.,  2.,  2.,  2.,  2.,  2.,  2.,  3.,  3.,  3., 3.,  3.,
>         3.,  3.,  3.,  3.,  3.,  3.,  3.,  3.,  3.,  3.,  3., 3.,  3.,
>         3.,  3.,  3.,  3.,  3.,  3.,  3.,  3.,  3.,  3.,  3., 3.,  4.,
>         4.,  4.,  4.,  4.,  4.,  4.,  4.,  4.,  4.,  4.,  4., 4.,  4.,
>         4.,  4.,  4.,  4.,  4.,  4.,  4.,  4.,  4.,  4.,  4., 4.,  4.,
>         4.,  4.,  4.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  5., 5.,  5.,
>         5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  5., 5.,  5.,
>         5.,  5.,  5.,  5.,  5.,  5.,  5.,  6.,  6.,  6.,  6., 6.,  6.,
>         6.,  6.,  6.,  6.,  6.,  6.,  6.,  6.,  6.,  6.,  6., 6.,  6.,
>         6.,  6.,  6.,  6.,  6.,  6.,  6.,  6.,  6.,  6.,  6., 7.,  7.,
>         7.,  7.,  7.,  7.,  7.,  7.,  7.,  7.,  7.,  7.,  7., 7.,  7.,
>         7.,  7.,  7.,  7.,  7.,  7.,  7.,  7.,  7.,  7.,  7., 7.,  7.,
>         7.,  7.]), length=210), ArrayCollectable(name='time_indices', 
> doc=None, value=array([ 7,   12,   16,   20,   24,   29,   33,   38,   
> 43,   48,   52,
>          57,   62,   66,   70,   75,   80,   85,   89,   94, 98,  103,
>         108,  112,  118,  123,  127,  132,  136,  141,  157, 162,  166,
>         171,  175,  180,  185,  189,  194,  199,  204,  209, 213,  218,
>         223,  228,  233,  238,  243,  248,  252,  257,  261, 266,  270,
>         275,  279,  284,  289,  294,  309,  314,  319,  324, 329,  333,
>         338,  342,  347,  351,  355,  360,  364,  369,  374, 378,  383,
>         387,  392,  397,  401,  406,  411,  415,  420,  425, 429,  433,
>         438,  443,  462,  467,  472,  476,  481,  485,  490, 495,  500,
>         505,  510,  514,  519,  524,  528,  532,  538,  542, 547,  552,
>         556,  561,  566,  571,  575,  580,  585,  589,  595, 599,  612,
>         616,  621,  625,  630,  635,  640,  644,  649,  653, 658,  663,
>         668,  672,  677,  682,  686,  691,  696,  700,  705, 710,  715,
>         720,  725,  729,  734,  738,  742,  746,  759,  764, 769,  774,
>         779,  783,  788,  793,  798,  802,  807,  812,  817, 821,  826,
>         831,  836,  840,  846,  850,  855,  859,  865,  870, 874,  878,
>         883,  888,  892,  897,  908,  913,  918,  922,  927, 931,  936,
>         941,  946,  950,  955,  959,  965,  970,  974,  979, 984,  989,
>         994,  998, 1003, 1008, 1012, 1017, 1022, 1026, 1031, 1036, 1041,
>        1046]), length=210), ArrayCollectable(name='targets', doc=None, 
> value=array(['onsetNP', 'onsetNP', 'onsetNP', 'onsetP', 'onsetP', 
> 'onsetNP',
>        'onsetP', 'onsetNP', 'onsetNP', 'onsetP', 'onsetP', 'onsetP',
>        'onsetP', 'onsetNP', 'onsetP', 'onsetP', 'onsetP', 'onsetNP',
>        'onsetP', 'onsetNP', 'onsetNP', 'onsetNP', 'onsetP', 'onsetNP',
>        'onsetP', 'onsetNP', 'onsetNP', 'onsetP', 'onsetP', 'onsetNP',
>        'onsetP', 'onsetNP', 'onsetNP', 'onsetP', 'onsetNP', 'onsetP',
>        'onsetNP', 'onsetNP', 'onsetP', 'onsetNP', 'onsetNP', 'onsetNP',
>        'onsetP', 'onsetNP', 'onsetP', 'onsetNP', 'onsetNP', 'onsetP',
>        'onsetP', 'onsetNP', 'onsetP', 'onsetP', 'onsetP', 'onsetNP',
>        'onsetP', 'onsetP', 'onsetP', 'onsetNP', 'onsetNP', 'onsetP',
>        'onsetNP', 'onsetP', 'onsetP', 'onsetNP', 'onsetP', 'onsetP',
>        'onsetP', 'onsetNP', 'onsetNP', 'onsetP', 'onsetP', 'onsetNP',
>        'onsetP', 'onsetNP', 'onsetNP', 'onsetNP', 'onsetP', 'onsetP',
>        'onsetP', 'onsetNP', 'onsetNP', 'onsetNP', 'onsetP', 'onsetP',
>        'onsetP', 'onsetNP', 'onsetNP', 'onsetNP', 'onsetP', 'onsetNP',
>        'onsetNP', 'onsetP', 'onsetNP', 'onsetP', 'onsetP', 'onsetNP',
>        'onsetP', 'onsetNP', 'onsetNP', 'onsetP', 'onsetP', 'onsetNP',
>        'onsetP', 'onsetP', 'onsetP', 'onsetNP', 'onsetP', 'onsetNP',
>        'onsetNP', 'onsetNP', 'onsetP', 'onsetNP', 'onsetP', 'onsetP',
>        'onsetNP', 'onsetP', 'onsetNP', 'onsetNP', 'onsetP', 'onsetNP',
>        'onsetNP', 'onsetNP', 'onsetNP', 'onsetNP', 'onsetNP', 'onsetP',
>        'onsetP', 'onsetP', 'onsetP', 'onsetNP', 'onsetP', 'onsetP',
>        'onsetNP', 'onsetP', 'onsetP', 'onsetP', 'onsetP', 'onsetP',
>        'onsetNP', 'onsetNP', 'onsetNP', 'onsetP', 'onsetNP', 'onsetNP',
>        'onsetP', 'onsetNP', 'onsetNP', 'onsetP', 'onsetNP', 'onsetP',
>        'onsetP', 'onsetNP', 'onsetP', 'onsetNP', 'onsetNP', 'onsetP',
>        'onsetNP', 'onsetP', 'onsetNP', 'onsetP', 'onsetNP', 'onsetP',
>        'onsetNP', 'onsetP', 'onsetP', 'onsetNP', 'onsetP', 'onsetNP',
>        'onsetNP', 'onsetP', 'onsetNP', 'onsetNP', 'onsetNP', 'onsetP',
>        'onsetNP', 'onsetP', 'onsetP', 'onsetP', 'onsetNP', 'onsetP',
>        'onsetP', 'onsetNP', 'onsetP', 'onsetP', 'onsetP', 'onsetNP',
>        'onsetP', 'onsetNP', 'onsetP', 'onsetP', 'onsetP', 'onsetP',
>        'onsetP', 'onsetP', 'onsetNP', 'onsetP', 'onsetNP', 'onsetNP',
>        'onsetNP', 'onsetNP', 'onsetNP', 'onsetNP', 'onsetNP', 'onsetNP',
>        'onsetNP', 'onsetP', 'onsetP', 'onsetNP', 'onsetP', 'onsetNP'],
>       dtype='|S7'), length=210), ArrayCollectable(name='time_coords', 
> doc=None, value=array([ 0., 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  
> 0.,  0.,  0.,  0.,
>         0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0., 0.,  0.,
>         0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0., 0.,  0.,
>         0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0., 0.,  0.,
>         0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0., 0.,  0.,
>         0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0., 0.,  0.,
>         0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0., 0.,  0.,
>         0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0., 0.,  0.,
>         0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0., 0.,  0.,
>         0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0., 0.,  0.,
>         0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0., 0.,  0.,
>         0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0., 0.,  0.,
>         0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0., 0.,  0.,
>         0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0., 0.,  0.,
>         0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0., 0.,  0.,
>         0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0., 0.,  0.,
>         0.,  0.]), length=210)]), 
> fa=FeatureAttributesCollection(items=[ArrayCollectable(name='voxel_indices', 
> doc=None, value=array([[ 8, 29,  7],
>        [ 8, 31, 11],
>        [ 9, 23,  9],
>        ...,
>        [54, 35,  9],
>        [54, 35, 10],
>        [54, 36,  9]]), length=39978)]), 
> a=DatasetAttributesCollection(items=[Collectable(name='mapper', 
> doc=None, value=ChainMapper(nodes=[FlattenMapper(shape=(64, 64, 37), 
> auto_train=True, space='voxel_indices'), 
> StaticFeatureSelection(dshape=(151552,), slicearg=array([False, False, 
> False, ..., False, False, False], dtype=bool)), 
> PolyDetrendMapper(polyord=1, chunks_attr='chunks', opt_regs=None, ), 
> ZScoreMapper(param_est=('targets', ['junk']), chunks_attr='chunks', 
> dtype='float32')])), Collectable(name='imgtype', doc=None, 
> value=<class 'nibabel.nifti1.Nifti1Image'>), 
> Collectable(name='voxel_eldim', doc=None, value=(3.0, 3.0, 3.3)), 
> Collectable(name='voxel_dim', doc=None, value=(64, 64, 37)), 
> Collectable(name='imghdr', doc=None, 
> value=<nibabel.nifti1.Nifti1Header object at 0x4c48890>)]))>
>
>
> _______________________________________________
> Pkg-ExpPsy-PyMVPA mailing list
> Pkg-ExpPsy-PyMVPA at lists.alioth.debian.org
> http://lists.alioth.debian.org/cgi-bin/mailman/listinfo/pkg-exppsy-pymvpa

-------------- next part --------------
An HTML attachment was scrubbed...
URL: <http://lists.alioth.debian.org/pipermail/pkg-exppsy-pymvpa/attachments/20120907/de23c015/attachment-0001.html>


More information about the Pkg-ExpPsy-PyMVPA mailing list