comparison pylearn/dataset_ops/image_patches.py @ 1521:6397233f3ccd

autopep8
author Frederic Bastien <nouiz@nouiz.org>
date Wed, 31 Oct 2012 16:12:57 -0400
parents 9ffe5d6faee3
children 5972fab3cfd2
comparison
equal deleted inserted replaced
1520:61134776e33c 1521:6397233f3ccd
1 import os, numpy 1 import os
2 import numpy
2 import theano 3 import theano
3 4
4 from pylearn.datasets.image_patches import ( 5 from pylearn.datasets.image_patches import (
5 olshausen_field_1996_whitened_images, 6 olshausen_field_1996_whitened_images,
6 extract_random_patches) 7 extract_random_patches)
7 8
8 from .protocol import TensorFnDataset # protocol.py __init__.py 9 from .protocol import TensorFnDataset # protocol.py __init__.py
9 from .memo import memo 10 from .memo import memo
10 11
11 import scipy.io 12 import scipy.io
12 from pylearn.io import image_tiling 13 from pylearn.io import image_tiling
13 from pylearn.datasets.config import get_filepath_in_roots 14 from pylearn.datasets.config import get_filepath_in_roots
14 15
16
15 @memo 17 @memo
16 def get_dataset(N,R,C,dtype,center,unitvar): 18 def get_dataset(N, R, C, dtype, center, unitvar):
17 seed=98234 19 seed = 98234
18 rng = numpy.random.RandomState(seed) 20 rng = numpy.random.RandomState(seed)
19 img_stack = olshausen_field_1996_whitened_images() 21 img_stack = olshausen_field_1996_whitened_images()
20 patch_stack = extract_random_patches(img_stack, N,R,C,rng) 22 patch_stack = extract_random_patches(img_stack, N, R, C, rng)
21 rval = patch_stack.astype(dtype).reshape((N,(R*C))) 23 rval = patch_stack.astype(dtype).reshape((N, (R * C)))
22 24
23 if center: 25 if center:
24 rval -= rval.mean(axis=0) 26 rval -= rval.mean(axis=0)
25 if unitvar: 27 if unitvar:
26 rval /= numpy.max(rval.std(axis=0),1e-8) 28 rval /= numpy.max(rval.std(axis=0), 1e-8)
27 29
28 return rval 30 return rval
31
29 32
30 def image_patches(s_idx, dims, 33 def image_patches(s_idx, dims,
31 split='train', dtype=theano.config.floatX, rasterized=False, 34 split='train', dtype=theano.config.floatX, rasterized=False,
32 center=True, 35 center=True,
33 unitvar=True, 36 unitvar=True,
34 fn=get_dataset): 37 fn=get_dataset):
35 N,R,C=dims 38 N, R, C = dims
36 39
37 if split != 'train': 40 if split != 'train':
38 raise NotImplementedError('train/test/valid splits for randomly sampled image patches?') 41 raise NotImplementedError(
42 'train/test/valid splits for randomly sampled image patches?')
39 43
40 if not rasterized: 44 if not rasterized:
41 raise NotImplementedError() 45 raise NotImplementedError()
42 46
43 op = TensorFnDataset(dtype, bcast=(False,), fn=(fn, (N,R,C,dtype,center,unitvar)), single_shape=(R*C,)) 47 op = TensorFnDataset(dtype, bcast=(False, ), fn=(fn, (N, R, C, dtype,
44 x = op(s_idx%N) 48 center, unitvar)), single_shape=(R * C, ))
49 x = op(s_idx % N)
45 if x.ndim == 1: 50 if x.ndim == 1:
46 if not rasterized: 51 if not rasterized:
47 x = x.reshape((20,20)) 52 x = x.reshape((20, 20))
48 elif x.ndim == 2: 53 elif x.ndim == 2:
49 if not rasterized: 54 if not rasterized:
50 x = x.reshape((x.shape[0], 20,20)) 55 x = x.reshape((x.shape[0], 20, 20))
51 else: 56 else:
52 assert False, 'what happened?' 57 assert False, 'what happened?'
53 58
54 return x 59 return x
55
56 60
57 61
58 @memo 62 @memo
59 def ranzato_hinton_2010(path=None): 63 def ranzato_hinton_2010(path=None):
60 if path is None: 64 if path is None:
61 path = get_filepath_in_roots(os.path.join('image_patches', 'mcRBM', 65 path = get_filepath_in_roots(os.path.join('image_patches', 'mcRBM',
62 'training_colorpatches_16x16_demo.mat')) 66 'training_colorpatches_16x16_demo.mat'))
63 dct = scipy.io.loadmat(path) 67 dct = scipy.io.loadmat(path)
64 return dct 68 return dct
69
70
65 def ranzato_hinton_2010_whitened_patches(path=None): 71 def ranzato_hinton_2010_whitened_patches(path=None):
66 """Return the pca of the data, which is 10240 x 105 72 """Return the pca of the data, which is 10240 x 105
67 """ 73 """
68 dct = ranzato_hinton_2010(path) 74 dct = ranzato_hinton_2010(path)
69 return dct['whitendata'].astype('float32') 75 return dct['whitendata'].astype('float32')
76
70 77
71 def undo_pca_filters_of_ranzato_hinton_2010(X, path=None): 78 def undo_pca_filters_of_ranzato_hinton_2010(X, path=None):
72 """Return tuple (R,G,B,None) of matrices for matrix `X` of filters (one per row) 79 """Return tuple (R,G,B,None) of matrices for matrix `X` of filters (one per row)
73 80
74 Return value can be passed to `image_tiling.tile_raster_images`. 81 Return value can be passed to `image_tiling.tile_raster_images`.
75 """ 82 """
76 dct = ranzato_hinton_2010(path) 83 dct = ranzato_hinton_2010(path)
77 X = numpy.dot(X, dct['invpcatransf'].T) 84 X = numpy.dot(X, dct['invpcatransf'].T)
78 return (X[:,:256], X[:,256:512], X[:,512:], None) 85 return (X[:, :256], X[:, 256:512], X[:, 512:], None)
79 86
80 def save_filters_of_ranzato_hinton_2010(X, fname, min_dynamic_range=1e-3, data_path=None): 87 def save_filters_of_ranzato_hinton_2010(X, fname, min_dynamic_range=1e-3, data_path=None):
81 _img = image_tiling.tile_raster_images( 88 _img = image_tiling.tile_raster_images(
82 undo_pca_filters_of_ranzato_hinton_2010(X, path=data_path), 89 undo_pca_filters_of_ranzato_hinton_2010(X, path=data_path),
83 img_shape=(16,16), 90 img_shape=(16, 16),
84 min_dynamic_range=min_dynamic_range) 91 min_dynamic_range=min_dynamic_range)
85 image_tiling.save_tiled_raster_images(_img, fname) 92 image_tiling.save_tiled_raster_images(_img, fname)
93
86 94
87 def ranzato_hinton_2010_op(s_idx, 95 def ranzato_hinton_2010_op(s_idx,
88 split='train', 96 split='train',
89 dtype=theano.config.floatX, rasterized=True, 97 dtype=theano.config.floatX, rasterized=True,
90 center=True, 98 center=True,
91 unitvar=True, 99 unitvar=True,
92 fn=ranzato_hinton_2010_whitened_patches): 100 fn=ranzato_hinton_2010_whitened_patches):
93 N = 10240 101 N = 10240
94 102
95 if split != 'train': 103 if split != 'train':
96 raise NotImplementedError('train/test/valid splits for randomly sampled image patches?') 104 raise NotImplementedError(
105 'train/test/valid splits for randomly sampled image patches?')
97 106
98 if not rasterized: 107 if not rasterized:
99 # the data is provided as PCA-sphered, so rasterizing does not make sense 108 # the data is provided as PCA-sphered, so rasterizing does not make sense
100 # TODO: add a param to enable/disable 'PCA', and if disabled, then consider 109 # TODO: add a param to enable/disable 'PCA', and if disabled, then consider
101 # rasterizing or not 110 # rasterizing or not
106 115
107 op = TensorFnDataset(dtype, 116 op = TensorFnDataset(dtype,
108 bcast=(False,), 117 bcast=(False,),
109 fn=fn, 118 fn=fn,
110 single_shape=(105,)) 119 single_shape=(105,))
111 x = op(s_idx%N) 120 x = op(s_idx % N)
112 return x 121 return x