comparison transformations/pipeline.py @ 61:cc4be6b25b8e

Data iterator alternating between NIST/font/captcha, removed the use of batches, keep track of labels (Not fully done yet)
author boulanni <nicolas_boulanger@hotmail.com>
date Mon, 08 Feb 2010 23:45:17 -0500
parents c89defea1e65
children 1afa95285b9c
comparison
equal deleted inserted replaced
60:d508f5a8acd0 61:cc4be6b25b8e
2 # coding: utf-8 2 # coding: utf-8
3 3
4 from __future__ import with_statement 4 from __future__ import with_statement
5 5
6 # This is intended to be run as a GIMP script 6 # This is intended to be run as a GIMP script
7 from gimpfu import * 7 #from gimpfu import *
8 8
9 import sys, os, getopt 9 import sys, os, getopt
10 import numpy 10 import numpy
11 import filetensor as ft 11 import filetensor as ft
12 import random 12 import random
13 13
14 # To debug locally, also call with -s 1 (to stop after 1 batch ~= 100) 14 # To debug locally, also call with -s 100 (to stop after ~100)
15 # (otherwise we allocate all needed memory, might be loonnng and/or crash 15 # (otherwise we allocate all needed memory, might be loonnng and/or crash
16 # if, lucky like me, you have an age-old laptop creaking from everywhere) 16 # if, lucky like me, you have an age-old laptop creaking from everywhere)
17 DEBUG = True 17 DEBUG = True
18 DEBUG_X = False 18 DEBUG_X = False
19 if DEBUG: 19 if DEBUG:
29 29
30 # Directory where to dump images to visualize results 30 # Directory where to dump images to visualize results
31 # (create it, otherwise it'll crash) 31 # (create it, otherwise it'll crash)
32 DEBUG_OUTPUT_DIR = 'debug_out' 32 DEBUG_OUTPUT_DIR = 'debug_out'
33 33
34 BATCH_SIZE = 100 34 DEFAULT_NIST_PATH = '/data/lisa/data/ift6266h10/train_data.ft'
35 DEFAULT_NIST_PATH = '/data/lisa/data/nist/by_class/all/all_train_data.ft' 35 DEFAULT_LABEL_PATH = '/data/lisa/data/ift6266h10/train_labels.ft'
36 ARGS_FILE = os.environ['PIPELINE_ARGS_TMPFILE'] 36 ARGS_FILE = os.environ['PIPELINE_ARGS_TMPFILE']
37 37
38 if DEBUG_X: 38 if DEBUG_X:
39 import pylab 39 import pylab
40 pylab.ion() 40 pylab.ion()
70 END_TRANSFORM_HOOK = [] 70 END_TRANSFORM_HOOK = []
71 if DEBUG: 71 if DEBUG:
72 END_TRANSFORM_HOOK = [VISUALIZER] 72 END_TRANSFORM_HOOK = [VISUALIZER]
73 73
74 class Pipeline(): 74 class Pipeline():
75 def __init__(self, modules, num_batches, batch_size, image_size=(32,32)): 75 def __init__(self, modules, num_img, image_size=(32,32)):
76 self.modules = modules 76 self.modules = modules
77 self.num_batches = num_batches 77 self.num_img = num_img
78 self.batch_size = batch_size
79 self.num_params_stored = 0 78 self.num_params_stored = 0
80 self.image_size = image_size 79 self.image_size = image_size
81 80
82 self.init_memory() 81 self.init_memory()
83 82
89 self.num_params_stored += len(m.regenerate_parameters(0.0)) 88 self.num_params_stored += len(m.regenerate_parameters(0.0))
90 89
91 def init_memory(self): 90 def init_memory(self):
92 self.init_num_params_stored() 91 self.init_num_params_stored()
93 92
94 total = self.num_batches * self.batch_size 93 total = self.num_img
95 num_px = self.image_size[0] * self.image_size[1] 94 num_px = self.image_size[0] * self.image_size[1]
96 95
97 self.res_data = numpy.empty((total, num_px)) 96 self.res_data = numpy.empty((total, num_px), dtype=numpy.uint8)
98 # +1 to store complexity 97 # +1 to store complexity
99 self.params = numpy.empty((total, self.num_params_stored+1)) 98 self.params = numpy.empty((total, self.num_params_stored+1))
100 99 self.res_labels = numpy.empty(total, dtype=numpy.int32)
101 def run(self, batch_iterator, complexity_iterator): 100
101 def run(self, img_iterator, complexity_iterator):
102 img_size = self.image_size 102 img_size = self.image_size
103 103
104 should_hook_after_each = len(AFTER_EACH_MODULE_HOOK) != 0 104 should_hook_after_each = len(AFTER_EACH_MODULE_HOOK) != 0
105 should_hook_at_the_end = len(END_TRANSFORM_HOOK) != 0 105 should_hook_at_the_end = len(END_TRANSFORM_HOOK) != 0
106 106
107 for batch_no, batch in enumerate(batch_iterator): 107 for img_no, (img, label) in enumerate(img_iterator):
108 sys.stdout.flush()
108 complexity = complexity_iterator.next() 109 complexity = complexity_iterator.next()
109 if DEBUG: 110
110 print "Complexity:", complexity 111 global_idx = img_no
111 112
112 assert len(batch) == self.batch_size 113 img = img.reshape(img_size)
113 114
114 for img_no, img in enumerate(batch): 115 param_idx = 1
115 sys.stdout.flush() 116 # store complexity along with other params
116 global_idx = batch_no*self.batch_size + img_no 117 self.params[global_idx, 0] = complexity
117 118 for mod in self.modules:
118 img = img.reshape(img_size) 119 # This used to be done _per batch_,
119 120 # ie. out of the "for img" loop
120 param_idx = 1 121 p = mod.regenerate_parameters(complexity)
121 # store complexity along with other params 122 self.params[global_idx, param_idx:param_idx+len(p)] = p
122 self.params[global_idx, 0] = complexity 123 param_idx += len(p)
123 for mod in self.modules: 124
124 # This used to be done _per batch_, 125 img = mod.transform_image(img)
125 # ie. out of the "for img" loop 126
126 p = mod.regenerate_parameters(complexity) 127 if should_hook_after_each:
127 self.params[global_idx, param_idx:param_idx+len(p)] = p 128 for hook in AFTER_EACH_MODULE_HOOK:
128 param_idx += len(p) 129 hook.after_transform_callback(img)
129 130
130 img = mod.transform_image(img) 131 self.res_data[global_idx] = \
131 132 img.reshape((img_size[0] * img_size[1],))*255
132 if should_hook_after_each: 133 self.res_labels[global_idx] = label
133 for hook in AFTER_EACH_MODULE_HOOK: 134
134 hook.after_transform_callback(img) 135 if should_hook_at_the_end:
135 136 for hook in END_TRANSFORM_HOOK:
136 self.res_data[global_idx] = \ 137 hook.end_transform_callback(img)
137 img.reshape((img_size[0] * img_size[1],))*255 138
138 139 def write_output(self, output_file_path, params_output_file_path, labels_output_file_path):
139
140 if should_hook_at_the_end:
141 for hook in END_TRANSFORM_HOOK:
142 hook.end_transform_callback(img)
143
144 def write_output(self, output_file_path, params_output_file_path):
145 with open(output_file_path, 'wb') as f: 140 with open(output_file_path, 'wb') as f:
146 ft.write(f, self.res_data) 141 ft.write(f, self.res_data)
147 142
148 numpy.save(params_output_file_path, self.params) 143 numpy.save(params_output_file_path, self.params)
144
145 with open(labels_output_file_path, 'wb') as f:
146 ft.write(f, self.res_labels)
149 147
150 148
151 ############################################################################## 149 ##############################################################################
152 # COMPLEXITY ITERATORS 150 # COMPLEXITY ITERATORS
153 # They're called once every batch, to get the complexity to use for that batch 151 # They're called once every img, to get the complexity to use for that img
154 # they must be infinite (should never throw StopIteration when calling next()) 152 # they must be infinite (should never throw StopIteration when calling next())
155 153
156 # probability of generating 0 complexity, otherwise 154 # probability of generating 0 complexity, otherwise
157 # uniform over 0.0-max_complexity 155 # uniform over 0.0-max_complexity
158 def range_complexity_iterator(probability_zero, max_complexity): 156 def range_complexity_iterator(probability_zero, max_complexity):
188 for path in debug_images.filelist: 186 for path in debug_images.filelist:
189 yield load_image(path) 187 yield load_image(path)
190 ''' 188 '''
191 189
192 class NistData(): 190 class NistData():
193 def __init__(self, ): 191 def __init__(self, nist_path, label_path):
194 nist_path = DEFAULT_NIST_PATH
195 self.train_data = open(nist_path, 'rb') 192 self.train_data = open(nist_path, 'rb')
193 self.train_labels = open(label_path, 'rb')
196 self.dim = tuple(ft._read_header(self.train_data)[3]) 194 self.dim = tuple(ft._read_header(self.train_data)[3])
197 195
198 def just_nist_iterator(nist, batch_size, stop_after=None): 196 def nist_supp_iterator(nist, prob_font, prob_captcha, num_img):
199 for i in xrange(0, nist.dim[0], batch_size): 197 subtensor = slice(0, num_img)
200 if not stop_after is None and i >= stop_after: 198 img = ft.read(nist.train_data, subtensor).astype(numpy.float32)/255
201 break 199 labels = ft.read(nist.train_labels, subtensor)
202 200
203 nist.train_data.seek(0) 201 for i in xrange(num_img):
204 yield ft.read(nist.train_data, slice(i, i+batch_size)).astype(numpy.float32)/255 202 r = numpy.random.rand()
205 203 if r<= prob_font:
204 pass #get font
205 elif r<= prob_font + prob_captcha:
206 pass #get captcha
207 else:
208 j = numpy.random.randint(num_img)
209 yield img[j], labels[j]
206 210
207 211
208 # Mostly for debugging, for the moment, just to see if we can 212 # Mostly for debugging, for the moment, just to see if we can
209 # reload the images and parameters. 213 # reload the images and parameters.
210 def reload(output_file_path, params_output_file_path): 214 def reload(output_file_path, params_output_file_path):
223 # MAIN 227 # MAIN
224 228
225 def usage(): 229 def usage():
226 print ''' 230 print '''
227 Usage: run_pipeline.sh [-m ...] [-z ...] [-o ...] [-p ...] 231 Usage: run_pipeline.sh [-m ...] [-z ...] [-o ...] [-p ...]
228 -m, --max-complexity: max complexity to generate for a batch 232 -m, --max-complexity: max complexity to generate for an image
229 -z, --probability-zero: probability of using complexity=0 for a batch 233 -z, --probability-zero: probability of using complexity=0 for an image
230 -o, --output-file: full path to file to use for output of images 234 -o, --output-file: full path to file to use for output of images
231 -p, --params-output-file: path to file to output params to 235 -p, --params-output-file: path to file to output params to
236 -r, --labels-output-file: path to file to output labels to
237 -f, --data-file: path to filetensor (.ft) data file (NIST)
238 -l, --label-file: path to filetensor (.ft) labels file (NIST labels)
239 -a, --prob-font: probability of using a raw font image
240 -b, --prob-captcha: probability of using a captcha image
232 ''' 241 '''
233 242
234 # See run_pipeline.py 243 # See run_pipeline.py
235 def get_argv(): 244 def get_argv():
236 with open(ARGS_FILE) as f: 245 with open(ARGS_FILE) as f:
243 def _main(): 252 def _main():
244 max_complexity = 0.5 # default 253 max_complexity = 0.5 # default
245 probability_zero = 0.1 # default 254 probability_zero = 0.1 # default
246 output_file_path = None 255 output_file_path = None
247 params_output_file_path = None 256 params_output_file_path = None
257 labels_output_file_path = None
258 nist_path = DEFAULT_NIST_PATH
259 label_path = DEFAULT_LABEL_PATH
260 prob_font = 0.0
261 prob_captcha = 0.0
248 stop_after = None 262 stop_after = None
249 reload_mode = False 263 reload_mode = False
250 264
251 try: 265 try:
252 opts, args = getopt.getopt(get_argv(), "rm:z:o:p:s:", ["reload","max-complexity=", "probability-zero=", "output-file=", "params-output-file=", "stop-after="]) 266 opts, args = getopt.getopt(get_argv(), "rm:z:o:p:r:s:f:l:a:b:", ["reload","max-complexity=", "probability-zero=", "output-file=", "params-output-file=", "labels-output-file=", "stop-after=", "data-file=", "label-file=", "prob-font=", "prob-captcha="])
253 except getopt.GetoptError, err: 267 except getopt.GetoptError, err:
254 # print help information and exit: 268 # print help information and exit:
255 print str(err) # will print something like "option -a not recognized" 269 print str(err) # will print something like "option -a not recognized"
256 usage() 270 usage()
257 sys.exit(2) 271 sys.exit(2)
267 assert probability_zero >= 0.0 and probability_zero <= 1.0 281 assert probability_zero >= 0.0 and probability_zero <= 1.0
268 elif o in ("-o", "--output-file"): 282 elif o in ("-o", "--output-file"):
269 output_file_path = a 283 output_file_path = a
270 elif o in ('-p', "--params-output-file"): 284 elif o in ('-p', "--params-output-file"):
271 params_output_file_path = a 285 params_output_file_path = a
286 elif o in ('-r', "--labels-output-file"):
287 labels_output_file_path = a
272 elif o in ('-s', "--stop-after"): 288 elif o in ('-s', "--stop-after"):
273 stop_after = int(a) 289 stop_after = int(a)
290 elif o in ('-f', "--data-file"):
291 nist_path = a
292 elif o in ('-l', "--label-file"):
293 label_path = a
294 elif o in ('-a', "--prob-font"):
295 prob_font = float(a)
296 elif o in ('-b', "--prob-captcha"):
297 prob_captcha = float(a)
274 else: 298 else:
275 assert False, "unhandled option" 299 assert False, "unhandled option"
276 300
277 if output_file_path == None or params_output_file_path == None: 301 if output_file_path == None or params_output_file_path == None or labels_output_file_path == None:
278 print "Must specify both output files." 302 print "Must specify the three output files."
279 print 303 print
280 usage() 304 usage()
281 sys.exit(2) 305 sys.exit(2)
282 306
283 if reload_mode: 307 if reload_mode:
285 else: 309 else:
286 if DEBUG_IMAGES_PATH: 310 if DEBUG_IMAGES_PATH:
287 ''' 311 '''
288 # This code is yet untested 312 # This code is yet untested
289 debug_images = DebugImages(DEBUG_IMAGES_PATH) 313 debug_images = DebugImages(DEBUG_IMAGES_PATH)
290 num_batches = 1 314 num_img = len(debug_images.filelist)
291 batch_size = len(debug_images.filelist) 315 pl = Pipeline(modules=MODULE_INSTANCES, num_img=num_img, image_size=(32,32))
292 pl = Pipeline(modules=MODULE_INSTANCES, num_batches=num_batches, batch_size=BATCH_SIZE, image_size=(32,32)) 316 img_it = debug_images_iterator(debug_images)
293 batch_it = debug_images_iterator(debug_images)
294 ''' 317 '''
295 else: 318 else:
296 nist = NistData() 319 nist = NistData(nist_path, label_path)
297 num_batches = nist.dim[0]/BATCH_SIZE 320 num_img = nist.dim[0]
298 if stop_after: 321 if stop_after:
299 num_batches = stop_after 322 num_img = stop_after
300 pl = Pipeline(modules=MODULE_INSTANCES, num_batches=num_batches, batch_size=BATCH_SIZE, image_size=(32,32)) 323 pl = Pipeline(modules=MODULE_INSTANCES, num_img=num_img, image_size=(32,32))
301 batch_it = just_nist_iterator(nist, BATCH_SIZE, stop_after) 324 img_it = nist_supp_iterator(nist, prob_font, prob_captcha, num_img)
302 325
303 cpx_it = range_complexity_iterator(probability_zero, max_complexity) 326 cpx_it = range_complexity_iterator(probability_zero, max_complexity)
304 pl.run(batch_it, cpx_it) 327 pl.run(img_it, cpx_it)
305 pl.write_output(output_file_path, params_output_file_path) 328 pl.write_output(output_file_path, params_output_file_path, labels_output_file_path)
306 329
307 _main() 330 _main()
308 331
309 if DEBUG_X: 332 if DEBUG_X:
310 pylab.ioff() 333 pylab.ioff()