comparison transformations/pipeline.py @ 45:f8a92292b299

merge de 4 fevrier
author SylvainPL <sylvain.pannetier.lebeuf@umontreal.ca>
date Thu, 04 Feb 2010 10:27:58 -0500
parents fdb0e0870fb4
children fabf910467b2
comparison
equal deleted inserted replaced
44:5deccb161307 45:f8a92292b299
1 #!/usr/bin/python
2 # coding: utf-8
3
1 from __future__ import with_statement 4 from __future__ import with_statement
2 5
3 import sys, os 6 import sys, os, getopt
4 import numpy 7 import numpy
5 import filetensor as ft 8 import filetensor as ft
6 import random 9 import random
7 10
11 # This is intended to be run as a GIMP script
12 from gimpfu import *
13
14 DEBUG = True
8 BATCH_SIZE = 100 15 BATCH_SIZE = 100
9 16 DEFAULT_NIST_PATH = '/data/lisa/data/nist/by_class/all/all_train_data.ft'
10 #import <modules> and stuff them in mods below 17 ARGS_FILE = os.environ['PIPELINE_ARGS_TMPFILE']
11 18
12 mods = [] 19 if DEBUG:
13 20 import pylab
14 # DANGER: HIGH VOLTAGE -- DO NOT EDIT BELOW THIS LINE 21 pylab.ion()
15 # ----------------------------------------------------------- 22
16 23 #from add_background_image import AddBackground
17 outf = sys.argv[1] 24 #from affine_transform import AffineTransformation
18 paramsf = sys.argv[2] 25 #from PoivreSel import PoivreSel
19 dataf = '/data/lisa/data/nist/by_class/all/all_train_data.ft' 26 from thick import Thick
20 if len(sys.argv) >= 4: 27 #from BruitGauss import BruitGauss
21 dataf = sys.argv[3] 28 #from gimp_script import GIMPTransformation
22 29 #from Rature import Rature
23 train_data = open(dataf, 'rb') 30 #from contrast Contrast
24 31 from local_elastic_distortions import LocalElasticDistorter
25 dim = tuple(ft._read_header(train_data)[3]) 32 from slant import Slant
26 33
27 res_data = numpy.empty(dim, dtype=numpy.int8) 34 MODULE_INSTANCES = [Thick(), LocalElasticDistorter(), Slant()]
28 35
29 all_settings = ['complexity'] 36 class Pipeline():
30 37 def __init__(self, modules, num_batches, batch_size, image_size=(32,32)):
31 for mod in mods: 38 self.modules = modules
32 all_settings += mod.get_settings_names() 39 self.num_batches = num_batches
33 40 self.batch_size = batch_size
34 params = numpy.empty(((dim[0]/BATCH_SIZE)+1, len(all_settings))) 41 self.num_params_stored = 0
35 42 self.image_size = image_size
36 for i in xrange(0, dim[0], BATCH_SIZE): 43
37 train_data.seek(0) 44 self.init_memory()
38 imgs = ft.read(train_data, slice(i, i+BATCH_SIZE)).astype(numpy.float32)/255 45
39 46 def init_num_params_stored(self):
40 complexity = random.random() 47 # just a dummy call to regenerate_parameters() to get the
41 p = i/BATCH_SIZE 48 # real number of params (only those which are stored)
42 j = 1 49 self.num_params_stored = 0
43 for mod in mods: 50 for m in self.modules:
44 par = mod.regenerate_parameters(complexity) 51 self.num_params_stored += len(m.regenerate_parameters(0.0))
45 params[p, j:j+len(par)] = par 52
46 j += len(par) 53 def init_memory(self):
47 54 self.init_num_params_stored()
48 for k in range(imgs.shape[0]): 55
49 c = imgs[k].reshape((32, 32)) 56 total = (self.num_batches + 1) * self.batch_size
50 for mod in mods: 57 num_px = self.image_size[0] * self.image_size[1]
51 c = mod.transform_image(c) 58
52 res_data[i+k] = c.reshape((1024,))*255 59 self.res_data = numpy.empty((total, num_px))
53 60 self.params = numpy.empty((total, self.num_params_stored))
54 with open(outf, 'wb') as f: 61
55 ft.write(f, res_data) 62 def run(self, batch_iterator, complexity_iterator):
56 63 img_size = self.image_size
57 numpy.save(paramsf, params) 64
65 for batch_no, batch in enumerate(batch_iterator):
66 complexity = complexity_iterator.next()
67
68 assert len(batch) == self.batch_size
69
70 for img_no, img in enumerate(batch):
71 sys.stdout.flush()
72 global_idx = batch_no*self.batch_size + img_no
73
74 img = img.reshape(img_size)
75
76 param_idx = 0
77 for mod in self.modules:
78 # This used to be done _per batch_,
79 # ie. out of the "for img" loop
80 p = mod.regenerate_parameters(complexity)
81 self.params[global_idx, param_idx:param_idx+len(p)] = p
82 param_idx += len(p)
83
84 img = mod.transform_image(img)
85
86 self.res_data[global_idx] = \
87 img.reshape((img_size[0] * img_size[1],))*255
88
89 pylab.imshow(img)
90 pylab.draw()
91
92 def write_output(self, output_file_path, params_output_file_path):
93 with open(output_file_path, 'wb') as f:
94 ft.write(f, self.res_data)
95
96 numpy.save(params_output_file_path, self.params)
97
98
99 ##############################################################################
100 # COMPLEXITY ITERATORS
101 # They're called once every batch, to get the complexity to use for that batch
102 # they must be infinite (should never throw StopIteration when calling next())
103
104 # probability of generating 0 complexity, otherwise
105 # uniform over 0.0-max_complexity
106 def range_complexity_iterator(probability_zero, max_complexity):
107 assert max_complexity <= 1.0
108 n = numpy.random.uniform(0.0, 1.0)
109 while True:
110 if n < probability_zero:
111 yield 0.0
112 else:
113 yield numpy.random.uniform(0.0, max_complexity)
114
115 ##############################################################################
116 # DATA ITERATORS
117 # They can be used to interleave different data sources etc.
118
119 class NistData():
120 def __init__(self, ):
121 nist_path = DEFAULT_NIST_PATH
122 self.train_data = open(nist_path, 'rb')
123 self.dim = tuple(ft._read_header(self.train_data)[3])
124
125 def just_nist_iterator(nist, batch_size, stop_after=None):
126 for i in xrange(0, nist.dim[0], batch_size):
127 nist.train_data.seek(0)
128 yield ft.read(nist.train_data, slice(i, i+batch_size)).astype(numpy.float32)/255
129
130 if not stop_after is None and i >= stop_after:
131 break
132
133 ##############################################################################
134 # MAIN
135
136 def usage():
137 print '''
138 Usage: run_pipeline.sh [-m ...] [-z ...] [-o ...] [-p ...]
139 -m, --max-complexity: max complexity to generate for a batch
140 -z, --probability-zero: probability of using complexity=0 for a batch
141 -o, --output-file: full path to file to use for output of images
142 -p, --params-output-file: path to file to output params to
143 '''
144
145 # See run_pipeline.py
146 def get_argv():
147 with open(ARGS_FILE) as f:
148 args = [l.rstrip() for l in f.readlines()]
149 return args
150
151 # Might be called locally or through dbidispatch. In all cases it should be
152 # passed to the GIMP executable to be able to use GIMP filters.
153 # Ex:
154 def main():
155 max_complexity = 0.5 # default
156 probability_zero = 0.1 # default
157 output_file_path = None
158 params_output_file_path = None
159 stop_after = None
160
161 try:
162 opts, args = getopt.getopt(get_argv(), "m:z:o:p:s:", ["max-complexity=", "probability-zero=", "output-file=", "params-output-file=", "stop-after="])
163 except getopt.GetoptError, err:
164 # print help information and exit:
165 print str(err) # will print something like "option -a not recognized"
166 usage()
167 sys.exit(2)
168 output = None
169 verbose = False
170 for o, a in opts:
171 if o in ('-m', '--max-complexity'):
172 max_complexity = float(a)
173 assert max_complexity >= 0.0 and max_complexity <= 1.0
174 elif o in ("-z", "--probability-zero"):
175 probability_zero = float(a)
176 assert probability_zero >= 0.0 and probability_zero <= 1.0
177 elif o in ("-o", "--output-file"):
178 output_file_path = a
179 elif o in ('-p', "--params-output-file"):
180 params_output_file_path = a
181 elif o in ('-s', "--stop-after"):
182 stop_after = int(a)
183 else:
184 assert False, "unhandled option"
185
186 if output_file_path == None or params_output_file_path == None:
187 print "Must specify both output files."
188 print
189 usage()
190 sys.exit(2)
191
192 nist = NistData()
193 num_batches = nist.dim[0]/BATCH_SIZE
194 if stop_after:
195 num_batches = stop_after
196 pl = Pipeline(modules=MODULE_INSTANCES, num_batches=num_batches, batch_size=BATCH_SIZE, image_size=(32,32))
197 cpx_it = range_complexity_iterator(probability_zero, max_complexity)
198 batch_it = just_nist_iterator(nist, BATCH_SIZE, stop_after)
199
200 pl.run(batch_it, cpx_it)
201 pl.write_output(output_file_path, params_output_file_path)
202
203 main()
204
205 pdb.gimp_quit(0)
206 pylab.ioff()
207 pylab.show()
208