Mercurial > pylearn
comparison pylearn/dataset_ops/image_patches.py @ 1521:6397233f3ccd
autopep8
author | Frederic Bastien <nouiz@nouiz.org> |
---|---|
date | Wed, 31 Oct 2012 16:12:57 -0400 |
parents | 9ffe5d6faee3 |
children | 5972fab3cfd2 |
comparison
equal
deleted
inserted
replaced
1520:61134776e33c | 1521:6397233f3ccd |
---|---|
1 import os, numpy | 1 import os |
2 import numpy | |
2 import theano | 3 import theano |
3 | 4 |
4 from pylearn.datasets.image_patches import ( | 5 from pylearn.datasets.image_patches import ( |
5 olshausen_field_1996_whitened_images, | 6 olshausen_field_1996_whitened_images, |
6 extract_random_patches) | 7 extract_random_patches) |
7 | 8 |
8 from .protocol import TensorFnDataset # protocol.py __init__.py | 9 from .protocol import TensorFnDataset # protocol.py __init__.py |
9 from .memo import memo | 10 from .memo import memo |
10 | 11 |
11 import scipy.io | 12 import scipy.io |
12 from pylearn.io import image_tiling | 13 from pylearn.io import image_tiling |
13 from pylearn.datasets.config import get_filepath_in_roots | 14 from pylearn.datasets.config import get_filepath_in_roots |
14 | 15 |
16 | |
15 @memo | 17 @memo |
16 def get_dataset(N,R,C,dtype,center,unitvar): | 18 def get_dataset(N, R, C, dtype, center, unitvar): |
17 seed=98234 | 19 seed = 98234 |
18 rng = numpy.random.RandomState(seed) | 20 rng = numpy.random.RandomState(seed) |
19 img_stack = olshausen_field_1996_whitened_images() | 21 img_stack = olshausen_field_1996_whitened_images() |
20 patch_stack = extract_random_patches(img_stack, N,R,C,rng) | 22 patch_stack = extract_random_patches(img_stack, N, R, C, rng) |
21 rval = patch_stack.astype(dtype).reshape((N,(R*C))) | 23 rval = patch_stack.astype(dtype).reshape((N, (R * C))) |
22 | 24 |
23 if center: | 25 if center: |
24 rval -= rval.mean(axis=0) | 26 rval -= rval.mean(axis=0) |
25 if unitvar: | 27 if unitvar: |
26 rval /= numpy.max(rval.std(axis=0),1e-8) | 28 rval /= numpy.max(rval.std(axis=0), 1e-8) |
27 | 29 |
28 return rval | 30 return rval |
31 | |
29 | 32 |
30 def image_patches(s_idx, dims, | 33 def image_patches(s_idx, dims, |
31 split='train', dtype=theano.config.floatX, rasterized=False, | 34 split='train', dtype=theano.config.floatX, rasterized=False, |
32 center=True, | 35 center=True, |
33 unitvar=True, | 36 unitvar=True, |
34 fn=get_dataset): | 37 fn=get_dataset): |
35 N,R,C=dims | 38 N, R, C = dims |
36 | 39 |
37 if split != 'train': | 40 if split != 'train': |
38 raise NotImplementedError('train/test/valid splits for randomly sampled image patches?') | 41 raise NotImplementedError( |
42 'train/test/valid splits for randomly sampled image patches?') | |
39 | 43 |
40 if not rasterized: | 44 if not rasterized: |
41 raise NotImplementedError() | 45 raise NotImplementedError() |
42 | 46 |
43 op = TensorFnDataset(dtype, bcast=(False,), fn=(fn, (N,R,C,dtype,center,unitvar)), single_shape=(R*C,)) | 47 op = TensorFnDataset(dtype, bcast=(False, ), fn=(fn, (N, R, C, dtype, |
44 x = op(s_idx%N) | 48 center, unitvar)), single_shape=(R * C, )) |
49 x = op(s_idx % N) | |
45 if x.ndim == 1: | 50 if x.ndim == 1: |
46 if not rasterized: | 51 if not rasterized: |
47 x = x.reshape((20,20)) | 52 x = x.reshape((20, 20)) |
48 elif x.ndim == 2: | 53 elif x.ndim == 2: |
49 if not rasterized: | 54 if not rasterized: |
50 x = x.reshape((x.shape[0], 20,20)) | 55 x = x.reshape((x.shape[0], 20, 20)) |
51 else: | 56 else: |
52 assert False, 'what happened?' | 57 assert False, 'what happened?' |
53 | 58 |
54 return x | 59 return x |
55 | |
56 | 60 |
57 | 61 |
58 @memo | 62 @memo |
59 def ranzato_hinton_2010(path=None): | 63 def ranzato_hinton_2010(path=None): |
60 if path is None: | 64 if path is None: |
61 path = get_filepath_in_roots(os.path.join('image_patches', 'mcRBM', | 65 path = get_filepath_in_roots(os.path.join('image_patches', 'mcRBM', |
62 'training_colorpatches_16x16_demo.mat')) | 66 'training_colorpatches_16x16_demo.mat')) |
63 dct = scipy.io.loadmat(path) | 67 dct = scipy.io.loadmat(path) |
64 return dct | 68 return dct |
69 | |
70 | |
65 def ranzato_hinton_2010_whitened_patches(path=None): | 71 def ranzato_hinton_2010_whitened_patches(path=None): |
66 """Return the pca of the data, which is 10240 x 105 | 72 """Return the pca of the data, which is 10240 x 105 |
67 """ | 73 """ |
68 dct = ranzato_hinton_2010(path) | 74 dct = ranzato_hinton_2010(path) |
69 return dct['whitendata'].astype('float32') | 75 return dct['whitendata'].astype('float32') |
76 | |
70 | 77 |
71 def undo_pca_filters_of_ranzato_hinton_2010(X, path=None): | 78 def undo_pca_filters_of_ranzato_hinton_2010(X, path=None): |
72 """Return tuple (R,G,B,None) of matrices for matrix `X` of filters (one per row) | 79 """Return tuple (R,G,B,None) of matrices for matrix `X` of filters (one per row) |
73 | 80 |
74 Return value can be passed to `image_tiling.tile_raster_images`. | 81 Return value can be passed to `image_tiling.tile_raster_images`. |
75 """ | 82 """ |
76 dct = ranzato_hinton_2010(path) | 83 dct = ranzato_hinton_2010(path) |
77 X = numpy.dot(X, dct['invpcatransf'].T) | 84 X = numpy.dot(X, dct['invpcatransf'].T) |
78 return (X[:,:256], X[:,256:512], X[:,512:], None) | 85 return (X[:, :256], X[:, 256:512], X[:, 512:], None) |
79 | 86 |
80 def save_filters_of_ranzato_hinton_2010(X, fname, min_dynamic_range=1e-3, data_path=None): | 87 def save_filters_of_ranzato_hinton_2010(X, fname, min_dynamic_range=1e-3, data_path=None): |
81 _img = image_tiling.tile_raster_images( | 88 _img = image_tiling.tile_raster_images( |
82 undo_pca_filters_of_ranzato_hinton_2010(X, path=data_path), | 89 undo_pca_filters_of_ranzato_hinton_2010(X, path=data_path), |
83 img_shape=(16,16), | 90 img_shape=(16, 16), |
84 min_dynamic_range=min_dynamic_range) | 91 min_dynamic_range=min_dynamic_range) |
85 image_tiling.save_tiled_raster_images(_img, fname) | 92 image_tiling.save_tiled_raster_images(_img, fname) |
93 | |
86 | 94 |
87 def ranzato_hinton_2010_op(s_idx, | 95 def ranzato_hinton_2010_op(s_idx, |
88 split='train', | 96 split='train', |
89 dtype=theano.config.floatX, rasterized=True, | 97 dtype=theano.config.floatX, rasterized=True, |
90 center=True, | 98 center=True, |
91 unitvar=True, | 99 unitvar=True, |
92 fn=ranzato_hinton_2010_whitened_patches): | 100 fn=ranzato_hinton_2010_whitened_patches): |
93 N = 10240 | 101 N = 10240 |
94 | 102 |
95 if split != 'train': | 103 if split != 'train': |
96 raise NotImplementedError('train/test/valid splits for randomly sampled image patches?') | 104 raise NotImplementedError( |
105 'train/test/valid splits for randomly sampled image patches?') | |
97 | 106 |
98 if not rasterized: | 107 if not rasterized: |
99 # the data is provided as PCA-sphered, so rasterizing does not make sense | 108 # the data is provided as PCA-sphered, so rasterizing does not make sense |
100 # TODO: add a param to enable/disable 'PCA', and if disabled, then consider | 109 # TODO: add a param to enable/disable 'PCA', and if disabled, then consider |
101 # rasterizing or not | 110 # rasterizing or not |
106 | 115 |
107 op = TensorFnDataset(dtype, | 116 op = TensorFnDataset(dtype, |
108 bcast=(False,), | 117 bcast=(False,), |
109 fn=fn, | 118 fn=fn, |
110 single_shape=(105,)) | 119 single_shape=(105,)) |
111 x = op(s_idx%N) | 120 x = op(s_idx % N) |
112 return x | 121 return x |