Mercurial > pylearn
comparison pylearn/dataset_ops/image_patches.py @ 1522:5972fab3cfd2
make ranzato_hinton_2010_op work with float64 for a DLT on mcrbm test.
author | Frederic Bastien <nouiz@nouiz.org> |
---|---|
date | Wed, 31 Oct 2012 16:19:51 -0400 |
parents | 6397233f3ccd |
children |
comparison
equal
deleted
inserted
replaced
1521:6397233f3ccd | 1522:5972fab3cfd2 |
---|---|
72 """Return the pca of the data, which is 10240 x 105 | 72 """Return the pca of the data, which is 10240 x 105 |
73 """ | 73 """ |
74 dct = ranzato_hinton_2010(path) | 74 dct = ranzato_hinton_2010(path) |
75 return dct['whitendata'].astype('float32') | 75 return dct['whitendata'].astype('float32') |
76 | 76 |
77 def ranzato_hinton_2010_whitened_patches_f64(path=None): | |
78 """Return the pca of the data, which is 10240 x 105 | |
79 """ | |
80 dct = ranzato_hinton_2010(path) | |
81 return dct['whitendata'].astype('float64') | |
77 | 82 |
78 def undo_pca_filters_of_ranzato_hinton_2010(X, path=None): | 83 def undo_pca_filters_of_ranzato_hinton_2010(X, path=None): |
79 """Return tuple (R,G,B,None) of matrices for matrix `X` of filters (one per row) | 84 """Return tuple (R,G,B,None) of matrices for matrix `X` of filters (one per row) |
80 | 85 |
81 Return value can be passed to `image_tiling.tile_raster_images`. | 86 Return value can be passed to `image_tiling.tile_raster_images`. |
108 # the data is provided as PCA-sphered, so rasterizing does not make sense | 113 # the data is provided as PCA-sphered, so rasterizing does not make sense |
109 # TODO: add a param to enable/disable 'PCA', and if disabled, then consider | 114 # TODO: add a param to enable/disable 'PCA', and if disabled, then consider |
110 # rasterizing or not | 115 # rasterizing or not |
111 raise NotImplementedError('only pca data is provided') | 116 raise NotImplementedError('only pca data is provided') |
112 | 117 |
113 if dtype != 'float32': | 118 if dtype == "float64" and fn is ranzato_hinton_2010_whitened_patches: |
119 fn = ranzato_hinton_2010_whitened_patches_f64 | |
120 elif dtype != 'float32': | |
114 raise NotImplementedError('dtype not float32') | 121 raise NotImplementedError('dtype not float32') |
115 | 122 |
116 op = TensorFnDataset(dtype, | 123 op = TensorFnDataset(dtype, |
117 bcast=(False,), | 124 bcast=(False,), |
118 fn=fn, | 125 fn=fn, |