Mercurial > pylearn
comparison pylearn/dataset_ops/image_patches.py @ 998:8ba8b08e0442
added the image_patches dataset used in RanzatoHinton2010
modified mcRBM to use it.
author | James Bergstra <bergstrj@iro.umontreal.ca> |
---|---|
date | Tue, 24 Aug 2010 16:51:53 -0400 |
parents | 507159eea97e |
children | 976539956475 |
comparison
equal
deleted
inserted
replaced
997:71b0132b694a | 998:8ba8b08e0442 |
---|---|
1 import os, numpy | 1 import os, numpy |
2 import theano | 2 import theano |
3 | 3 |
4 from pylearn.datasets.image_patches import ( | 4 from pylearn.datasets.image_patches import ( |
5 data_root, | |
5 olshausen_field_1996_whitened_images, | 6 olshausen_field_1996_whitened_images, |
6 extract_random_patches) | 7 extract_random_patches) |
7 | 8 |
8 from .protocol import TensorFnDataset # protocol.py __init__.py | 9 from .protocol import TensorFnDataset # protocol.py __init__.py |
9 from .memo import memo | 10 from .memo import memo |
11 | |
12 import scipy.io | |
13 from pylearn.io import image_tiling | |
10 | 14 |
11 @memo | 15 @memo |
12 def get_dataset(N,R,C,dtype,center,unitvar): | 16 def get_dataset(N,R,C,dtype,center,unitvar): |
13 seed=98234 | 17 seed=98234 |
14 rng = numpy.random.RandomState(seed) | 18 rng = numpy.random.RandomState(seed) |
46 else: | 50 else: |
47 assert False, 'what happened?' | 51 assert False, 'what happened?' |
48 | 52 |
49 return x | 53 return x |
50 | 54 |
55 | |
56 | |
57 @memo | |
58 def ranzato_hinton_2010(path=None): | |
59 if path is None: | |
60 path = os.path.join(data_root(), 'image_patches', 'mcRBM', | |
61 'training_colorpatches_16x16_demo.mat') | |
62 dct = scipy.io.loadmat(path) | |
63 return dct | |
64 def ranzato_hinton_2010_whitened_patches(path=None): | |
65 """Return the pca of the data, which is 10240 x 105 | |
66 """ | |
67 dct = ranzato_hinton_2010(path) | |
68 return dct['whitendata'].astype('float32') | |
69 | |
70 def undo_pca_filters_of_ranzato_hinton_2010(X, path=None): | |
71 """Return tuple (R,G,B,None) of matrices for matrix `X` of filters (one per row) | |
72 | |
73 Return value can be passed to `image_tiling.tile_raster_images`. | |
74 """ | |
75 dct = ranzato_hinton_2010(path) | |
76 X = numpy.dot(X, dct['invpcatransf'].T) | |
77 return (X[:,:256], X[:,256:512], X[:,512:], None) | |
78 | |
79 def save_filters_of_ranzato_hinton_2010(X, fname, min_dynamic_range=1e-3, data_path=None): | |
80 _img = image_tiling.tile_raster_images( | |
81 undo_pca_filters_of_ranzato_hinton_2010(X, path=data_path), | |
82 img_shape=(16,16), | |
83 min_dynamic_range=min_dynamic_range) | |
84 image_tiling.save_tiled_raster_images(_img, fname) | |
85 | |
86 def ranzato_hinton_2010_op(s_idx, | |
87 split='train', | |
88 dtype=theano.config.floatX, rasterized=True, | |
89 center=True, | |
90 unitvar=True): | |
91 N = 10240 | |
92 | |
93 if split != 'train': | |
94 raise NotImplementedError('train/test/valid splits for randomly sampled image patches?') | |
95 | |
96 if not rasterized: | |
97 # the data is provided as PCA-sphered, so rasterizing does not make sense | |
98 # TODO: add a param to enable/disable 'PCA', and if disabled, then consider | |
99 # rasterizing or not | |
100 raise NotImplementedError('only pca data is provided') | |
101 | |
102 if dtype != 'float32': | |
103 raise NotImplementedError('dtype not float32') | |
104 | |
105 op = TensorFnDataset(dtype, | |
106 bcast=(False,), | |
107 fn=ranzato_hinton_2010_whitened_patches, | |
108 single_shape=(105,)) | |
109 x = op(s_idx%N) | |
110 return x | |
111 |