Mercurial > ift6266
comparison transformations/pipeline.py @ 61:cc4be6b25b8e
Data iterator alternating between NIST/font/captcha, removed the use of batches, keep track of labels (Not fully done yet)
author | boulanni <nicolas_boulanger@hotmail.com> |
---|---|
date | Mon, 08 Feb 2010 23:45:17 -0500 |
parents | c89defea1e65 |
children | 1afa95285b9c |
comparison
equal
deleted
inserted
replaced
60:d508f5a8acd0 | 61:cc4be6b25b8e |
---|---|
2 # coding: utf-8 | 2 # coding: utf-8 |
3 | 3 |
4 from __future__ import with_statement | 4 from __future__ import with_statement |
5 | 5 |
6 # This is intended to be run as a GIMP script | 6 # This is intended to be run as a GIMP script |
7 from gimpfu import * | 7 #from gimpfu import * |
8 | 8 |
9 import sys, os, getopt | 9 import sys, os, getopt |
10 import numpy | 10 import numpy |
11 import filetensor as ft | 11 import filetensor as ft |
12 import random | 12 import random |
13 | 13 |
14 # To debug locally, also call with -s 1 (to stop after 1 batch ~= 100) | 14 # To debug locally, also call with -s 100 (to stop after ~100) |
15 # (otherwise we allocate all needed memory, might be loonnng and/or crash | 15 # (otherwise we allocate all needed memory, might be loonnng and/or crash |
16 # if, lucky like me, you have an age-old laptop creaking from everywhere) | 16 # if, lucky like me, you have an age-old laptop creaking from everywhere) |
17 DEBUG = True | 17 DEBUG = True |
18 DEBUG_X = False | 18 DEBUG_X = False |
19 if DEBUG: | 19 if DEBUG: |
29 | 29 |
30 # Directory where to dump images to visualize results | 30 # Directory where to dump images to visualize results |
31 # (create it, otherwise it'll crash) | 31 # (create it, otherwise it'll crash) |
32 DEBUG_OUTPUT_DIR = 'debug_out' | 32 DEBUG_OUTPUT_DIR = 'debug_out' |
33 | 33 |
34 BATCH_SIZE = 100 | 34 DEFAULT_NIST_PATH = '/data/lisa/data/ift6266h10/train_data.ft' |
35 DEFAULT_NIST_PATH = '/data/lisa/data/nist/by_class/all/all_train_data.ft' | 35 DEFAULT_LABEL_PATH = '/data/lisa/data/ift6266h10/train_labels.ft' |
36 ARGS_FILE = os.environ['PIPELINE_ARGS_TMPFILE'] | 36 ARGS_FILE = os.environ['PIPELINE_ARGS_TMPFILE'] |
37 | 37 |
38 if DEBUG_X: | 38 if DEBUG_X: |
39 import pylab | 39 import pylab |
40 pylab.ion() | 40 pylab.ion() |
70 END_TRANSFORM_HOOK = [] | 70 END_TRANSFORM_HOOK = [] |
71 if DEBUG: | 71 if DEBUG: |
72 END_TRANSFORM_HOOK = [VISUALIZER] | 72 END_TRANSFORM_HOOK = [VISUALIZER] |
73 | 73 |
74 class Pipeline(): | 74 class Pipeline(): |
75 def __init__(self, modules, num_batches, batch_size, image_size=(32,32)): | 75 def __init__(self, modules, num_img, image_size=(32,32)): |
76 self.modules = modules | 76 self.modules = modules |
77 self.num_batches = num_batches | 77 self.num_img = num_img |
78 self.batch_size = batch_size | |
79 self.num_params_stored = 0 | 78 self.num_params_stored = 0 |
80 self.image_size = image_size | 79 self.image_size = image_size |
81 | 80 |
82 self.init_memory() | 81 self.init_memory() |
83 | 82 |
89 self.num_params_stored += len(m.regenerate_parameters(0.0)) | 88 self.num_params_stored += len(m.regenerate_parameters(0.0)) |
90 | 89 |
91 def init_memory(self): | 90 def init_memory(self): |
92 self.init_num_params_stored() | 91 self.init_num_params_stored() |
93 | 92 |
94 total = self.num_batches * self.batch_size | 93 total = self.num_img |
95 num_px = self.image_size[0] * self.image_size[1] | 94 num_px = self.image_size[0] * self.image_size[1] |
96 | 95 |
97 self.res_data = numpy.empty((total, num_px)) | 96 self.res_data = numpy.empty((total, num_px), dtype=numpy.uint8) |
98 # +1 to store complexity | 97 # +1 to store complexity |
99 self.params = numpy.empty((total, self.num_params_stored+1)) | 98 self.params = numpy.empty((total, self.num_params_stored+1)) |
100 | 99 self.res_labels = numpy.empty(total, dtype=numpy.int32) |
101 def run(self, batch_iterator, complexity_iterator): | 100 |
101 def run(self, img_iterator, complexity_iterator): | |
102 img_size = self.image_size | 102 img_size = self.image_size |
103 | 103 |
104 should_hook_after_each = len(AFTER_EACH_MODULE_HOOK) != 0 | 104 should_hook_after_each = len(AFTER_EACH_MODULE_HOOK) != 0 |
105 should_hook_at_the_end = len(END_TRANSFORM_HOOK) != 0 | 105 should_hook_at_the_end = len(END_TRANSFORM_HOOK) != 0 |
106 | 106 |
107 for batch_no, batch in enumerate(batch_iterator): | 107 for img_no, (img, label) in enumerate(img_iterator): |
108 sys.stdout.flush() | |
108 complexity = complexity_iterator.next() | 109 complexity = complexity_iterator.next() |
109 if DEBUG: | 110 |
110 print "Complexity:", complexity | 111 global_idx = img_no |
111 | 112 |
112 assert len(batch) == self.batch_size | 113 img = img.reshape(img_size) |
113 | 114 |
114 for img_no, img in enumerate(batch): | 115 param_idx = 1 |
115 sys.stdout.flush() | 116 # store complexity along with other params |
116 global_idx = batch_no*self.batch_size + img_no | 117 self.params[global_idx, 0] = complexity |
117 | 118 for mod in self.modules: |
118 img = img.reshape(img_size) | 119 # This used to be done _per batch_, |
119 | 120 # ie. out of the "for img" loop |
120 param_idx = 1 | 121 p = mod.regenerate_parameters(complexity) |
121 # store complexity along with other params | 122 self.params[global_idx, param_idx:param_idx+len(p)] = p |
122 self.params[global_idx, 0] = complexity | 123 param_idx += len(p) |
123 for mod in self.modules: | 124 |
124 # This used to be done _per batch_, | 125 img = mod.transform_image(img) |
125 # ie. out of the "for img" loop | 126 |
126 p = mod.regenerate_parameters(complexity) | 127 if should_hook_after_each: |
127 self.params[global_idx, param_idx:param_idx+len(p)] = p | 128 for hook in AFTER_EACH_MODULE_HOOK: |
128 param_idx += len(p) | 129 hook.after_transform_callback(img) |
129 | 130 |
130 img = mod.transform_image(img) | 131 self.res_data[global_idx] = \ |
131 | 132 img.reshape((img_size[0] * img_size[1],))*255 |
132 if should_hook_after_each: | 133 self.res_labels[global_idx] = label |
133 for hook in AFTER_EACH_MODULE_HOOK: | 134 |
134 hook.after_transform_callback(img) | 135 if should_hook_at_the_end: |
135 | 136 for hook in END_TRANSFORM_HOOK: |
136 self.res_data[global_idx] = \ | 137 hook.end_transform_callback(img) |
137 img.reshape((img_size[0] * img_size[1],))*255 | 138 |
138 | 139 def write_output(self, output_file_path, params_output_file_path, labels_output_file_path): |
139 | |
140 if should_hook_at_the_end: | |
141 for hook in END_TRANSFORM_HOOK: | |
142 hook.end_transform_callback(img) | |
143 | |
144 def write_output(self, output_file_path, params_output_file_path): | |
145 with open(output_file_path, 'wb') as f: | 140 with open(output_file_path, 'wb') as f: |
146 ft.write(f, self.res_data) | 141 ft.write(f, self.res_data) |
147 | 142 |
148 numpy.save(params_output_file_path, self.params) | 143 numpy.save(params_output_file_path, self.params) |
144 | |
145 with open(labels_output_file_path, 'wb') as f: | |
146 ft.write(f, self.res_labels) | |
149 | 147 |
150 | 148 |
151 ############################################################################## | 149 ############################################################################## |
152 # COMPLEXITY ITERATORS | 150 # COMPLEXITY ITERATORS |
153 # They're called once every batch, to get the complexity to use for that batch | 151 # They're called once every img, to get the complexity to use for that img |
154 # they must be infinite (should never throw StopIteration when calling next()) | 152 # they must be infinite (should never throw StopIteration when calling next()) |
155 | 153 |
156 # probability of generating 0 complexity, otherwise | 154 # probability of generating 0 complexity, otherwise |
157 # uniform over 0.0-max_complexity | 155 # uniform over 0.0-max_complexity |
158 def range_complexity_iterator(probability_zero, max_complexity): | 156 def range_complexity_iterator(probability_zero, max_complexity): |
188 for path in debug_images.filelist: | 186 for path in debug_images.filelist: |
189 yield load_image(path) | 187 yield load_image(path) |
190 ''' | 188 ''' |
191 | 189 |
192 class NistData(): | 190 class NistData(): |
193 def __init__(self, ): | 191 def __init__(self, nist_path, label_path): |
194 nist_path = DEFAULT_NIST_PATH | |
195 self.train_data = open(nist_path, 'rb') | 192 self.train_data = open(nist_path, 'rb') |
193 self.train_labels = open(label_path, 'rb') | |
196 self.dim = tuple(ft._read_header(self.train_data)[3]) | 194 self.dim = tuple(ft._read_header(self.train_data)[3]) |
197 | 195 |
198 def just_nist_iterator(nist, batch_size, stop_after=None): | 196 def nist_supp_iterator(nist, prob_font, prob_captcha, num_img): |
199 for i in xrange(0, nist.dim[0], batch_size): | 197 subtensor = slice(0, num_img) |
200 if not stop_after is None and i >= stop_after: | 198 img = ft.read(nist.train_data, subtensor).astype(numpy.float32)/255 |
201 break | 199 labels = ft.read(nist.train_labels, subtensor) |
202 | 200 |
203 nist.train_data.seek(0) | 201 for i in xrange(num_img): |
204 yield ft.read(nist.train_data, slice(i, i+batch_size)).astype(numpy.float32)/255 | 202 r = numpy.random.rand() |
205 | 203 if r<= prob_font: |
204 pass #get font | |
205 elif r<= prob_font + prob_captcha: | |
206 pass #get captcha | |
207 else: | |
208 j = numpy.random.randint(num_img) | |
209 yield img[j], labels[j] | |
206 | 210 |
207 | 211 |
208 # Mostly for debugging, for the moment, just to see if we can | 212 # Mostly for debugging, for the moment, just to see if we can |
209 # reload the images and parameters. | 213 # reload the images and parameters. |
210 def reload(output_file_path, params_output_file_path): | 214 def reload(output_file_path, params_output_file_path): |
223 # MAIN | 227 # MAIN |
224 | 228 |
225 def usage(): | 229 def usage(): |
226 print ''' | 230 print ''' |
227 Usage: run_pipeline.sh [-m ...] [-z ...] [-o ...] [-p ...] | 231 Usage: run_pipeline.sh [-m ...] [-z ...] [-o ...] [-p ...] |
228 -m, --max-complexity: max complexity to generate for a batch | 232 -m, --max-complexity: max complexity to generate for an image |
229 -z, --probability-zero: probability of using complexity=0 for a batch | 233 -z, --probability-zero: probability of using complexity=0 for an image |
230 -o, --output-file: full path to file to use for output of images | 234 -o, --output-file: full path to file to use for output of images |
231 -p, --params-output-file: path to file to output params to | 235 -p, --params-output-file: path to file to output params to |
236 -r, --labels-output-file: path to file to output labels to | |
237 -f, --data-file: path to filetensor (.ft) data file (NIST) | |
238 -l, --label-file: path to filetensor (.ft) labels file (NIST labels) | |
239 -a, --prob-font: probability of using a raw font image | |
240 -b, --prob-captcha: probability of using a captcha image | |
232 ''' | 241 ''' |
233 | 242 |
234 # See run_pipeline.py | 243 # See run_pipeline.py |
235 def get_argv(): | 244 def get_argv(): |
236 with open(ARGS_FILE) as f: | 245 with open(ARGS_FILE) as f: |
243 def _main(): | 252 def _main(): |
244 max_complexity = 0.5 # default | 253 max_complexity = 0.5 # default |
245 probability_zero = 0.1 # default | 254 probability_zero = 0.1 # default |
246 output_file_path = None | 255 output_file_path = None |
247 params_output_file_path = None | 256 params_output_file_path = None |
257 labels_output_file_path = None | |
258 nist_path = DEFAULT_NIST_PATH | |
259 label_path = DEFAULT_LABEL_PATH | |
260 prob_font = 0.0 | |
261 prob_captcha = 0.0 | |
248 stop_after = None | 262 stop_after = None |
249 reload_mode = False | 263 reload_mode = False |
250 | 264 |
251 try: | 265 try: |
252 opts, args = getopt.getopt(get_argv(), "rm:z:o:p:s:", ["reload","max-complexity=", "probability-zero=", "output-file=", "params-output-file=", "stop-after="]) | 266 opts, args = getopt.getopt(get_argv(), "rm:z:o:p:r:s:f:l:a:b:", ["reload","max-complexity=", "probability-zero=", "output-file=", "params-output-file=", "labels-output-file=", "stop-after=", "data-file=", "label-file=", "prob-font=", "prob-captcha="]) |
253 except getopt.GetoptError, err: | 267 except getopt.GetoptError, err: |
254 # print help information and exit: | 268 # print help information and exit: |
255 print str(err) # will print something like "option -a not recognized" | 269 print str(err) # will print something like "option -a not recognized" |
256 usage() | 270 usage() |
257 sys.exit(2) | 271 sys.exit(2) |
267 assert probability_zero >= 0.0 and probability_zero <= 1.0 | 281 assert probability_zero >= 0.0 and probability_zero <= 1.0 |
268 elif o in ("-o", "--output-file"): | 282 elif o in ("-o", "--output-file"): |
269 output_file_path = a | 283 output_file_path = a |
270 elif o in ('-p', "--params-output-file"): | 284 elif o in ('-p', "--params-output-file"): |
271 params_output_file_path = a | 285 params_output_file_path = a |
286 elif o in ('-r', "--labels-output-file"): | |
287 labels_output_file_path = a | |
272 elif o in ('-s', "--stop-after"): | 288 elif o in ('-s', "--stop-after"): |
273 stop_after = int(a) | 289 stop_after = int(a) |
290 elif o in ('-f', "--data-file"): | |
291 nist_path = a | |
292 elif o in ('-l', "--label-file"): | |
293 label_path = a | |
294 elif o in ('-a', "--prob-font"): | |
295 prob_font = float(a) | |
296 elif o in ('-b', "--prob-captcha"): | |
297 prob_captcha = float(a) | |
274 else: | 298 else: |
275 assert False, "unhandled option" | 299 assert False, "unhandled option" |
276 | 300 |
277 if output_file_path == None or params_output_file_path == None: | 301 if output_file_path == None or params_output_file_path == None or labels_output_file_path == None: |
278 print "Must specify both output files." | 302 print "Must specify the three output files." |
279 print | 303 print |
280 usage() | 304 usage() |
281 sys.exit(2) | 305 sys.exit(2) |
282 | 306 |
283 if reload_mode: | 307 if reload_mode: |
285 else: | 309 else: |
286 if DEBUG_IMAGES_PATH: | 310 if DEBUG_IMAGES_PATH: |
287 ''' | 311 ''' |
288 # This code is yet untested | 312 # This code is yet untested |
289 debug_images = DebugImages(DEBUG_IMAGES_PATH) | 313 debug_images = DebugImages(DEBUG_IMAGES_PATH) |
290 num_batches = 1 | 314 num_img = len(debug_images.filelist) |
291 batch_size = len(debug_images.filelist) | 315 pl = Pipeline(modules=MODULE_INSTANCES, num_img=num_img, image_size=(32,32)) |
292 pl = Pipeline(modules=MODULE_INSTANCES, num_batches=num_batches, batch_size=BATCH_SIZE, image_size=(32,32)) | 316 img_it = debug_images_iterator(debug_images) |
293 batch_it = debug_images_iterator(debug_images) | |
294 ''' | 317 ''' |
295 else: | 318 else: |
296 nist = NistData() | 319 nist = NistData(nist_path, label_path) |
297 num_batches = nist.dim[0]/BATCH_SIZE | 320 num_img = nist.dim[0] |
298 if stop_after: | 321 if stop_after: |
299 num_batches = stop_after | 322 num_img = stop_after |
300 pl = Pipeline(modules=MODULE_INSTANCES, num_batches=num_batches, batch_size=BATCH_SIZE, image_size=(32,32)) | 323 pl = Pipeline(modules=MODULE_INSTANCES, num_img=num_img, image_size=(32,32)) |
301 batch_it = just_nist_iterator(nist, BATCH_SIZE, stop_after) | 324 img_it = nist_supp_iterator(nist, prob_font, prob_captcha, num_img) |
302 | 325 |
303 cpx_it = range_complexity_iterator(probability_zero, max_complexity) | 326 cpx_it = range_complexity_iterator(probability_zero, max_complexity) |
304 pl.run(batch_it, cpx_it) | 327 pl.run(img_it, cpx_it) |
305 pl.write_output(output_file_path, params_output_file_path) | 328 pl.write_output(output_file_path, params_output_file_path, labels_output_file_path) |
306 | 329 |
307 _main() | 330 _main() |
308 | 331 |
309 if DEBUG_X: | 332 if DEBUG_X: |
310 pylab.ioff() | 333 pylab.ioff() |