Mercurial > ift6266
comparison transformations/pipeline.py @ 50:ff59670cd1f9
Ajouté l'enregistrement de la complexité, et un strict minimum pour reloader les fichiers d'images et de paramètres
author | fsavard |
---|---|
date | Thu, 04 Feb 2010 14:13:57 -0500 |
parents | fabf910467b2 |
children | c89defea1e65 |
comparison
equal
deleted
inserted
replaced
49:8ce089f30463 | 50:ff59670cd1f9 |
---|---|
12 import random | 12 import random |
13 | 13 |
14 # To debug locally, also call with -s 1 (to stop after 1 batch ~= 100) | 14 # To debug locally, also call with -s 1 (to stop after 1 batch ~= 100) |
15 # (otherwise we allocate all needed memory, might be loonnng and/or crash | 15 # (otherwise we allocate all needed memory, might be loonnng and/or crash |
16 # if, lucky like me, you have an age-old laptop creaking from everywhere) | 16 # if, lucky like me, you have an age-old laptop creaking from everywhere) |
17 DEBUG = True | 17 DEBUG = False |
18 DEBUG_X = False # Debug under X (pylab.show()) | 18 DEBUG_X = False |
19 if DEBUG: | |
20 DEBUG_X = False # Debug under X (pylab.show()) | |
19 | 21 |
20 DEBUG_IMAGES_PATH = None | 22 DEBUG_IMAGES_PATH = None |
21 if DEBUG: | 23 if DEBUG: |
22 # UNTESTED YET | 24 # UNTESTED YET |
23 # To avoid loading NIST if you don't have it handy | 25 # To avoid loading NIST if you don't have it handy |
87 self.num_params_stored += len(m.regenerate_parameters(0.0)) | 89 self.num_params_stored += len(m.regenerate_parameters(0.0)) |
88 | 90 |
89 def init_memory(self): | 91 def init_memory(self): |
90 self.init_num_params_stored() | 92 self.init_num_params_stored() |
91 | 93 |
92 total = (self.num_batches + 1) * self.batch_size | 94 total = self.num_batches * self.batch_size |
93 num_px = self.image_size[0] * self.image_size[1] | 95 num_px = self.image_size[0] * self.image_size[1] |
94 | 96 |
95 self.res_data = numpy.empty((total, num_px)) | 97 self.res_data = numpy.empty((total, num_px)) |
96 self.params = numpy.empty((total, self.num_params_stored)) | 98 # +1 to store complexity |
99 self.params = numpy.empty((total, self.num_params_stored+1)) | |
97 | 100 |
98 def run(self, batch_iterator, complexity_iterator): | 101 def run(self, batch_iterator, complexity_iterator): |
99 img_size = self.image_size | 102 img_size = self.image_size |
100 | 103 |
101 should_hook_after_each = len(AFTER_EACH_MODULE_HOOK) != 0 | 104 should_hook_after_each = len(AFTER_EACH_MODULE_HOOK) != 0 |
112 sys.stdout.flush() | 115 sys.stdout.flush() |
113 global_idx = batch_no*self.batch_size + img_no | 116 global_idx = batch_no*self.batch_size + img_no |
114 | 117 |
115 img = img.reshape(img_size) | 118 img = img.reshape(img_size) |
116 | 119 |
117 param_idx = 0 | 120 param_idx = 1 |
121 # store complexity along with other params | |
122 self.params[global_idx, 0] = complexity | |
118 for mod in self.modules: | 123 for mod in self.modules: |
119 # This used to be done _per batch_, | 124 # This used to be done _per batch_, |
120 # ie. out of the "for img" loop | 125 # ie. out of the "for img" loop |
121 p = mod.regenerate_parameters(complexity) | 126 p = mod.regenerate_parameters(complexity) |
122 self.params[global_idx, param_idx:param_idx+len(p)] = p | 127 self.params[global_idx, param_idx:param_idx+len(p)] = p |
190 self.train_data = open(nist_path, 'rb') | 195 self.train_data = open(nist_path, 'rb') |
191 self.dim = tuple(ft._read_header(self.train_data)[3]) | 196 self.dim = tuple(ft._read_header(self.train_data)[3]) |
192 | 197 |
193 def just_nist_iterator(nist, batch_size, stop_after=None): | 198 def just_nist_iterator(nist, batch_size, stop_after=None): |
194 for i in xrange(0, nist.dim[0], batch_size): | 199 for i in xrange(0, nist.dim[0], batch_size): |
200 if not stop_after is None and i >= stop_after: | |
201 break | |
202 | |
195 nist.train_data.seek(0) | 203 nist.train_data.seek(0) |
196 yield ft.read(nist.train_data, slice(i, i+batch_size)).astype(numpy.float32)/255 | 204 yield ft.read(nist.train_data, slice(i, i+batch_size)).astype(numpy.float32)/255 |
197 | 205 |
198 if not stop_after is None and i >= stop_after: | 206 |
199 break | 207 |
208 # Mostly for debugging, for the moment, just to see if we can | |
209 # reload the images and parameters. | |
210 def reload(output_file_path, params_output_file_path): | |
211 images_ft = open(output_file_path, 'rb') | |
212 images_ft_dim = tuple(ft._read_header(images_ft)[3]) | |
213 | |
214 print "Images dimensions: ", images_ft_dim | |
215 | |
216 params = numpy.load(params_output_file_path) | |
217 | |
218 print "Params dimensions: ", params.shape | |
219 print params | |
220 | |
200 | 221 |
201 ############################################################################## | 222 ############################################################################## |
202 # MAIN | 223 # MAIN |
203 | 224 |
204 def usage(): | 225 def usage(): |
223 max_complexity = 0.5 # default | 244 max_complexity = 0.5 # default |
224 probability_zero = 0.1 # default | 245 probability_zero = 0.1 # default |
225 output_file_path = None | 246 output_file_path = None |
226 params_output_file_path = None | 247 params_output_file_path = None |
227 stop_after = None | 248 stop_after = None |
228 | 249 reload_mode = False |
229 import sys | |
230 print "python version: ", sys.version | |
231 | 250 |
232 try: | 251 try: |
233 opts, args = getopt.getopt(get_argv(), "m:z:o:p:s:", ["max-complexity=", "probability-zero=", "output-file=", "params-output-file=", "stop-after="]) | 252 opts, args = getopt.getopt(get_argv(), "rm:z:o:p:s:", ["reload","max-complexity=", "probability-zero=", "output-file=", "params-output-file=", "stop-after="]) |
234 except getopt.GetoptError, err: | 253 except getopt.GetoptError, err: |
235 # print help information and exit: | 254 # print help information and exit: |
236 print str(err) # will print something like "option -a not recognized" | 255 print str(err) # will print something like "option -a not recognized" |
237 usage() | 256 usage() |
238 sys.exit(2) | 257 sys.exit(2) |
239 output = None | 258 |
240 verbose = False | |
241 for o, a in opts: | 259 for o, a in opts: |
242 if o in ('-m', '--max-complexity'): | 260 if o in ('-m', '--max-complexity'): |
243 max_complexity = float(a) | 261 max_complexity = float(a) |
244 assert max_complexity >= 0.0 and max_complexity <= 1.0 | 262 assert max_complexity >= 0.0 and max_complexity <= 1.0 |
263 elif o in ('-r', '--reload'): | |
264 reload_mode = True | |
245 elif o in ("-z", "--probability-zero"): | 265 elif o in ("-z", "--probability-zero"): |
246 probability_zero = float(a) | 266 probability_zero = float(a) |
247 assert probability_zero >= 0.0 and probability_zero <= 1.0 | 267 assert probability_zero >= 0.0 and probability_zero <= 1.0 |
248 elif o in ("-o", "--output-file"): | 268 elif o in ("-o", "--output-file"): |
249 output_file_path = a | 269 output_file_path = a |
258 print "Must specify both output files." | 278 print "Must specify both output files." |
259 print | 279 print |
260 usage() | 280 usage() |
261 sys.exit(2) | 281 sys.exit(2) |
262 | 282 |
263 if DEBUG_IMAGES_PATH: | 283 if reload_mode: |
264 ''' | 284 reload(output_file_path, params_output_file_path) |
265 # This code is yet untested | |
266 debug_images = DebugImages(DEBUG_IMAGES_PATH) | |
267 num_batches = 1 | |
268 batch_size = len(debug_images.filelist) | |
269 pl = Pipeline(modules=MODULE_INSTANCES, num_batches=num_batches, batch_size=BATCH_SIZE, image_size=(32,32)) | |
270 batch_it = debug_images_iterator(debug_images) | |
271 ''' | |
272 else: | 285 else: |
273 nist = NistData() | 286 if DEBUG_IMAGES_PATH: |
274 num_batches = nist.dim[0]/BATCH_SIZE | 287 ''' |
275 if stop_after: | 288 # This code is yet untested |
276 num_batches = stop_after | 289 debug_images = DebugImages(DEBUG_IMAGES_PATH) |
277 pl = Pipeline(modules=MODULE_INSTANCES, num_batches=num_batches, batch_size=BATCH_SIZE, image_size=(32,32)) | 290 num_batches = 1 |
278 batch_it = just_nist_iterator(nist, BATCH_SIZE, stop_after) | 291 batch_size = len(debug_images.filelist) |
279 | 292 pl = Pipeline(modules=MODULE_INSTANCES, num_batches=num_batches, batch_size=BATCH_SIZE, image_size=(32,32)) |
280 cpx_it = range_complexity_iterator(probability_zero, max_complexity) | 293 batch_it = debug_images_iterator(debug_images) |
281 pl.run(batch_it, cpx_it) | 294 ''' |
282 pl.write_output(output_file_path, params_output_file_path) | 295 else: |
296 nist = NistData() | |
297 num_batches = nist.dim[0]/BATCH_SIZE | |
298 if stop_after: | |
299 num_batches = stop_after | |
300 pl = Pipeline(modules=MODULE_INSTANCES, num_batches=num_batches, batch_size=BATCH_SIZE, image_size=(32,32)) | |
301 batch_it = just_nist_iterator(nist, BATCH_SIZE, stop_after) | |
302 | |
303 cpx_it = range_complexity_iterator(probability_zero, max_complexity) | |
304 pl.run(batch_it, cpx_it) | |
305 pl.write_output(output_file_path, params_output_file_path) | |
283 | 306 |
284 _main() | 307 _main() |
285 | 308 |
286 if DEBUG_X: | 309 if DEBUG_X: |
287 pylab.ioff() | 310 pylab.ioff() |