comparison transformations/pipeline.py @ 50:ff59670cd1f9

Ajouté l'enregistrement de la complexité, et un strict minimum pour reloader les fichiers d'images et de paramètres
author fsavard
date Thu, 04 Feb 2010 14:13:57 -0500
parents fabf910467b2
children c89defea1e65
comparison
equal deleted inserted replaced
49:8ce089f30463 50:ff59670cd1f9
12 import random 12 import random
13 13
14 # To debug locally, also call with -s 1 (to stop after 1 batch ~= 100) 14 # To debug locally, also call with -s 1 (to stop after 1 batch ~= 100)
15 # (otherwise we allocate all needed memory, might be loonnng and/or crash 15 # (otherwise we allocate all needed memory, might be loonnng and/or crash
16 # if, lucky like me, you have an age-old laptop creaking from everywhere) 16 # if, lucky like me, you have an age-old laptop creaking from everywhere)
17 DEBUG = True 17 DEBUG = False
18 DEBUG_X = False # Debug under X (pylab.show()) 18 DEBUG_X = False
19 if DEBUG:
20 DEBUG_X = False # Debug under X (pylab.show())
19 21
20 DEBUG_IMAGES_PATH = None 22 DEBUG_IMAGES_PATH = None
21 if DEBUG: 23 if DEBUG:
22 # UNTESTED YET 24 # UNTESTED YET
23 # To avoid loading NIST if you don't have it handy 25 # To avoid loading NIST if you don't have it handy
87 self.num_params_stored += len(m.regenerate_parameters(0.0)) 89 self.num_params_stored += len(m.regenerate_parameters(0.0))
88 90
89 def init_memory(self): 91 def init_memory(self):
90 self.init_num_params_stored() 92 self.init_num_params_stored()
91 93
92 total = (self.num_batches + 1) * self.batch_size 94 total = self.num_batches * self.batch_size
93 num_px = self.image_size[0] * self.image_size[1] 95 num_px = self.image_size[0] * self.image_size[1]
94 96
95 self.res_data = numpy.empty((total, num_px)) 97 self.res_data = numpy.empty((total, num_px))
96 self.params = numpy.empty((total, self.num_params_stored)) 98 # +1 to store complexity
99 self.params = numpy.empty((total, self.num_params_stored+1))
97 100
98 def run(self, batch_iterator, complexity_iterator): 101 def run(self, batch_iterator, complexity_iterator):
99 img_size = self.image_size 102 img_size = self.image_size
100 103
101 should_hook_after_each = len(AFTER_EACH_MODULE_HOOK) != 0 104 should_hook_after_each = len(AFTER_EACH_MODULE_HOOK) != 0
112 sys.stdout.flush() 115 sys.stdout.flush()
113 global_idx = batch_no*self.batch_size + img_no 116 global_idx = batch_no*self.batch_size + img_no
114 117
115 img = img.reshape(img_size) 118 img = img.reshape(img_size)
116 119
117 param_idx = 0 120 param_idx = 1
121 # store complexity along with other params
122 self.params[global_idx, 0] = complexity
118 for mod in self.modules: 123 for mod in self.modules:
119 # This used to be done _per batch_, 124 # This used to be done _per batch_,
120 # ie. out of the "for img" loop 125 # ie. out of the "for img" loop
121 p = mod.regenerate_parameters(complexity) 126 p = mod.regenerate_parameters(complexity)
122 self.params[global_idx, param_idx:param_idx+len(p)] = p 127 self.params[global_idx, param_idx:param_idx+len(p)] = p
190 self.train_data = open(nist_path, 'rb') 195 self.train_data = open(nist_path, 'rb')
191 self.dim = tuple(ft._read_header(self.train_data)[3]) 196 self.dim = tuple(ft._read_header(self.train_data)[3])
192 197
193 def just_nist_iterator(nist, batch_size, stop_after=None): 198 def just_nist_iterator(nist, batch_size, stop_after=None):
194 for i in xrange(0, nist.dim[0], batch_size): 199 for i in xrange(0, nist.dim[0], batch_size):
200 if not stop_after is None and i >= stop_after:
201 break
202
195 nist.train_data.seek(0) 203 nist.train_data.seek(0)
196 yield ft.read(nist.train_data, slice(i, i+batch_size)).astype(numpy.float32)/255 204 yield ft.read(nist.train_data, slice(i, i+batch_size)).astype(numpy.float32)/255
197 205
198 if not stop_after is None and i >= stop_after: 206
199 break 207
208 # Mostly for debugging, for the moment, just to see if we can
209 # reload the images and parameters.
210 def reload(output_file_path, params_output_file_path):
211 images_ft = open(output_file_path, 'rb')
212 images_ft_dim = tuple(ft._read_header(images_ft)[3])
213
214 print "Images dimensions: ", images_ft_dim
215
216 params = numpy.load(params_output_file_path)
217
218 print "Params dimensions: ", params.shape
219 print params
220
200 221
201 ############################################################################## 222 ##############################################################################
202 # MAIN 223 # MAIN
203 224
204 def usage(): 225 def usage():
223 max_complexity = 0.5 # default 244 max_complexity = 0.5 # default
224 probability_zero = 0.1 # default 245 probability_zero = 0.1 # default
225 output_file_path = None 246 output_file_path = None
226 params_output_file_path = None 247 params_output_file_path = None
227 stop_after = None 248 stop_after = None
228 249 reload_mode = False
229 import sys
230 print "python version: ", sys.version
231 250
232 try: 251 try:
233 opts, args = getopt.getopt(get_argv(), "m:z:o:p:s:", ["max-complexity=", "probability-zero=", "output-file=", "params-output-file=", "stop-after="]) 252 opts, args = getopt.getopt(get_argv(), "rm:z:o:p:s:", ["reload","max-complexity=", "probability-zero=", "output-file=", "params-output-file=", "stop-after="])
234 except getopt.GetoptError, err: 253 except getopt.GetoptError, err:
235 # print help information and exit: 254 # print help information and exit:
236 print str(err) # will print something like "option -a not recognized" 255 print str(err) # will print something like "option -a not recognized"
237 usage() 256 usage()
238 sys.exit(2) 257 sys.exit(2)
239 output = None 258
240 verbose = False
241 for o, a in opts: 259 for o, a in opts:
242 if o in ('-m', '--max-complexity'): 260 if o in ('-m', '--max-complexity'):
243 max_complexity = float(a) 261 max_complexity = float(a)
244 assert max_complexity >= 0.0 and max_complexity <= 1.0 262 assert max_complexity >= 0.0 and max_complexity <= 1.0
263 elif o in ('-r', '--reload'):
264 reload_mode = True
245 elif o in ("-z", "--probability-zero"): 265 elif o in ("-z", "--probability-zero"):
246 probability_zero = float(a) 266 probability_zero = float(a)
247 assert probability_zero >= 0.0 and probability_zero <= 1.0 267 assert probability_zero >= 0.0 and probability_zero <= 1.0
248 elif o in ("-o", "--output-file"): 268 elif o in ("-o", "--output-file"):
249 output_file_path = a 269 output_file_path = a
258 print "Must specify both output files." 278 print "Must specify both output files."
259 print 279 print
260 usage() 280 usage()
261 sys.exit(2) 281 sys.exit(2)
262 282
263 if DEBUG_IMAGES_PATH: 283 if reload_mode:
264 ''' 284 reload(output_file_path, params_output_file_path)
265 # This code is yet untested
266 debug_images = DebugImages(DEBUG_IMAGES_PATH)
267 num_batches = 1
268 batch_size = len(debug_images.filelist)
269 pl = Pipeline(modules=MODULE_INSTANCES, num_batches=num_batches, batch_size=BATCH_SIZE, image_size=(32,32))
270 batch_it = debug_images_iterator(debug_images)
271 '''
272 else: 285 else:
273 nist = NistData() 286 if DEBUG_IMAGES_PATH:
274 num_batches = nist.dim[0]/BATCH_SIZE 287 '''
275 if stop_after: 288 # This code is yet untested
276 num_batches = stop_after 289 debug_images = DebugImages(DEBUG_IMAGES_PATH)
277 pl = Pipeline(modules=MODULE_INSTANCES, num_batches=num_batches, batch_size=BATCH_SIZE, image_size=(32,32)) 290 num_batches = 1
278 batch_it = just_nist_iterator(nist, BATCH_SIZE, stop_after) 291 batch_size = len(debug_images.filelist)
279 292 pl = Pipeline(modules=MODULE_INSTANCES, num_batches=num_batches, batch_size=BATCH_SIZE, image_size=(32,32))
280 cpx_it = range_complexity_iterator(probability_zero, max_complexity) 293 batch_it = debug_images_iterator(debug_images)
281 pl.run(batch_it, cpx_it) 294 '''
282 pl.write_output(output_file_path, params_output_file_path) 295 else:
296 nist = NistData()
297 num_batches = nist.dim[0]/BATCH_SIZE
298 if stop_after:
299 num_batches = stop_after
300 pl = Pipeline(modules=MODULE_INSTANCES, num_batches=num_batches, batch_size=BATCH_SIZE, image_size=(32,32))
301 batch_it = just_nist_iterator(nist, BATCH_SIZE, stop_after)
302
303 cpx_it = range_complexity_iterator(probability_zero, max_complexity)
304 pl.run(batch_it, cpx_it)
305 pl.write_output(output_file_path, params_output_file_path)
283 306
284 _main() 307 _main()
285 308
286 if DEBUG_X: 309 if DEBUG_X:
287 pylab.ioff() 310 pylab.ioff()