comparison pylearn/dataset_ops/image_patches.py @ 1522:5972fab3cfd2

make ranzato_hinton_2010_op work with float64 for a DLT on mcrbm test.
author Frederic Bastien <nouiz@nouiz.org>
date Wed, 31 Oct 2012 16:19:51 -0400
parents 6397233f3ccd
children
comparison
equal deleted inserted replaced
1521:6397233f3ccd 1522:5972fab3cfd2
72 """Return the pca of the data, which is 10240 x 105 72 """Return the pca of the data, which is 10240 x 105
73 """ 73 """
74 dct = ranzato_hinton_2010(path) 74 dct = ranzato_hinton_2010(path)
75 return dct['whitendata'].astype('float32') 75 return dct['whitendata'].astype('float32')
76 76
77 def ranzato_hinton_2010_whitened_patches_f64(path=None):
78 """Return the pca of the data, which is 10240 x 105
79 """
80 dct = ranzato_hinton_2010(path)
81 return dct['whitendata'].astype('float64')
77 82
78 def undo_pca_filters_of_ranzato_hinton_2010(X, path=None): 83 def undo_pca_filters_of_ranzato_hinton_2010(X, path=None):
79 """Return tuple (R,G,B,None) of matrices for matrix `X` of filters (one per row) 84 """Return tuple (R,G,B,None) of matrices for matrix `X` of filters (one per row)
80 85
81 Return value can be passed to `image_tiling.tile_raster_images`. 86 Return value can be passed to `image_tiling.tile_raster_images`.
108 # the data is provided as PCA-sphered, so rasterizing does not make sense 113 # the data is provided as PCA-sphered, so rasterizing does not make sense
109 # TODO: add a param to enable/disable 'PCA', and if disabled, then consider 114 # TODO: add a param to enable/disable 'PCA', and if disabled, then consider
110 # rasterizing or not 115 # rasterizing or not
111 raise NotImplementedError('only pca data is provided') 116 raise NotImplementedError('only pca data is provided')
112 117
113 if dtype != 'float32': 118 if dtype == "float64" and fn is ranzato_hinton_2010_whitened_patches:
119 fn = ranzato_hinton_2010_whitened_patches_f64
120 elif dtype != 'float32':
114 raise NotImplementedError('dtype not float32') 121 raise NotImplementedError('dtype not float32')
115 122
116 op = TensorFnDataset(dtype, 123 op = TensorFnDataset(dtype,
117 bcast=(False,), 124 bcast=(False,),
118 fn=fn, 125 fn=fn,