comparison transformations/pipeline.py @ 48:fabf910467b2

Ajouté des hooks pour visualisation à différentes étapes. On peut dumper la grille d'images pour chaque image transformée ou visualiser live avec pylab.imshow() (pas encore essayé cette façon... j'ai un problème avec GIMP+python sur mon laptop).
author fsavard
date Thu, 04 Feb 2010 13:39:46 -0500
parents fdb0e0870fb4
children ff59670cd1f9
comparison
equal deleted inserted replaced
47:3bc75139654a 48:fabf910467b2
1 #!/usr/bin/python 1 #!/usr/bin/python
2 # coding: utf-8 2 # coding: utf-8
3 3
4 from __future__ import with_statement 4 from __future__ import with_statement
5
6 # This is intended to be run as a GIMP script
7 from gimpfu import *
5 8
6 import sys, os, getopt 9 import sys, os, getopt
7 import numpy 10 import numpy
8 import filetensor as ft 11 import filetensor as ft
9 import random 12 import random
10 13
11 # This is intended to be run as a GIMP script 14 # To debug locally, also call with -s 1 (to stop after 1 batch ~= 100)
12 from gimpfu import * 15 # (otherwise we allocate all needed memory, might be loonnng and/or crash
13 16 # if, lucky like me, you have an age-old laptop creaking from everywhere)
14 DEBUG = True 17 DEBUG = True
18 DEBUG_X = False # Debug under X (pylab.show())
19
20 DEBUG_IMAGES_PATH = None
21 if DEBUG:
22 # UNTESTED YET
23 # To avoid loading NIST if you don't have it handy
24 # (use with debug_images_iterator(), see main())
25 # To use NIST, leave as = None
26 DEBUG_IMAGES_PATH = None#'/home/francois/Desktop/debug_images'
27
28 # Directory where to dump images to visualize results
29 # (create it, otherwise it'll crash)
30 DEBUG_OUTPUT_DIR = 'debug_out'
31
15 BATCH_SIZE = 100 32 BATCH_SIZE = 100
16 DEFAULT_NIST_PATH = '/data/lisa/data/nist/by_class/all/all_train_data.ft' 33 DEFAULT_NIST_PATH = '/data/lisa/data/nist/by_class/all/all_train_data.ft'
17 ARGS_FILE = os.environ['PIPELINE_ARGS_TMPFILE'] 34 ARGS_FILE = os.environ['PIPELINE_ARGS_TMPFILE']
18 35
19 if DEBUG: 36 if DEBUG_X:
20 import pylab 37 import pylab
21 pylab.ion() 38 pylab.ion()
22 39
23 #from add_background_image import AddBackground 40 #from add_background_image import AddBackground
24 #from affine_transform import AffineTransformation 41 #from affine_transform import AffineTransformation
25 #from PoivreSel import PoivreSel 42 from PoivreSel import PoivreSel
26 from thick import Thick 43 from thick import Thick
27 #from BruitGauss import BruitGauss 44 #from BruitGauss import BruitGauss
28 #from gimp_script import GIMPTransformation 45 #from gimp_script import GIMPTransformation
29 #from Rature import Rature 46 #from Rature import Rature
30 #from contrast Contrast 47 from contrast import Contrast
31 from local_elastic_distortions import LocalElasticDistorter 48 from local_elastic_distortions import LocalElasticDistorter
32 from slant import Slant 49 from slant import Slant
33 50
34 MODULE_INSTANCES = [Thick(), LocalElasticDistorter(), Slant()] 51 if DEBUG:
52 from visualizer import Visualizer
53 # Either put the visualizer as in the MODULES_INSTANCES list
54 # after each module you want to visualize, or in the
55 # AFTER_EACH_MODULE_HOOK list (but not both, it's redundant)
56 VISUALIZER = Visualizer(to_dir=DEBUG_OUTPUT_DIR, on_screen=False)
57
58 MODULE_INSTANCES = [Thick(), LocalElasticDistorter(), PoivreSel(), Contrast()]
59
60 # These should have a "after_transform_callback(self, image)" method
61 # (called after each call to transform_image in a module)
62 AFTER_EACH_MODULE_HOOK = []
63 if DEBUG:
64 AFTER_EACH_MODULE_HOOK = [VISUALIZER]
65
66 # These should have a "end_transform_callback(self, final_image" method
67 # (called after all modules have been called)
68 END_TRANSFORM_HOOK = []
69 if DEBUG:
70 END_TRANSFORM_HOOK = [VISUALIZER]
35 71
36 class Pipeline(): 72 class Pipeline():
37 def __init__(self, modules, num_batches, batch_size, image_size=(32,32)): 73 def __init__(self, modules, num_batches, batch_size, image_size=(32,32)):
38 self.modules = modules 74 self.modules = modules
39 self.num_batches = num_batches 75 self.num_batches = num_batches
60 self.params = numpy.empty((total, self.num_params_stored)) 96 self.params = numpy.empty((total, self.num_params_stored))
61 97
62 def run(self, batch_iterator, complexity_iterator): 98 def run(self, batch_iterator, complexity_iterator):
63 img_size = self.image_size 99 img_size = self.image_size
64 100
101 should_hook_after_each = len(AFTER_EACH_MODULE_HOOK) != 0
102 should_hook_at_the_end = len(END_TRANSFORM_HOOK) != 0
103
65 for batch_no, batch in enumerate(batch_iterator): 104 for batch_no, batch in enumerate(batch_iterator):
66 complexity = complexity_iterator.next() 105 complexity = complexity_iterator.next()
106 if DEBUG:
107 print "Complexity:", complexity
67 108
68 assert len(batch) == self.batch_size 109 assert len(batch) == self.batch_size
69 110
70 for img_no, img in enumerate(batch): 111 for img_no, img in enumerate(batch):
71 sys.stdout.flush() 112 sys.stdout.flush()
81 self.params[global_idx, param_idx:param_idx+len(p)] = p 122 self.params[global_idx, param_idx:param_idx+len(p)] = p
82 param_idx += len(p) 123 param_idx += len(p)
83 124
84 img = mod.transform_image(img) 125 img = mod.transform_image(img)
85 126
127 if should_hook_after_each:
128 for hook in AFTER_EACH_MODULE_HOOK:
129 hook.after_transform_callback(img)
130
86 self.res_data[global_idx] = \ 131 self.res_data[global_idx] = \
87 img.reshape((img_size[0] * img_size[1],))*255 132 img.reshape((img_size[0] * img_size[1],))*255
88 133
89 pylab.imshow(img) 134
90 pylab.draw() 135 if should_hook_at_the_end:
136 for hook in END_TRANSFORM_HOOK:
137 hook.end_transform_callback(img)
91 138
92 def write_output(self, output_file_path, params_output_file_path): 139 def write_output(self, output_file_path, params_output_file_path):
93 with open(output_file_path, 'wb') as f: 140 with open(output_file_path, 'wb') as f:
94 ft.write(f, self.res_data) 141 ft.write(f, self.res_data)
95 142
114 161
115 ############################################################################## 162 ##############################################################################
116 # DATA ITERATORS 163 # DATA ITERATORS
117 # They can be used to interleave different data sources etc. 164 # They can be used to interleave different data sources etc.
118 165
166 '''
167 # Following code (DebugImages and iterator) is untested
168
169 def load_image(filepath):
170 _RGB_TO_GRAYSCALE = [0.3, 0.59, 0.11, 0.0]
171 img = Image.open(filepath)
172 img = numpy.asarray(img)
173 if len(img.shape) > 2:
174 img = (img * _RGB_TO_GRAYSCALE).sum(axis=2)
175 return (img / 255.0).astype('float')
176
177 class DebugImages():
178 def __init__(self, images_dir_path):
179 import glob, os.path
180 self.filelist = glob.glob(os.path.join(images_dir_path, "*.png"))
181
182 def debug_images_iterator(debug_images):
183 for path in debug_images.filelist:
184 yield load_image(path)
185 '''
186
119 class NistData(): 187 class NistData():
120 def __init__(self, ): 188 def __init__(self, ):
121 nist_path = DEFAULT_NIST_PATH 189 nist_path = DEFAULT_NIST_PATH
122 self.train_data = open(nist_path, 'rb') 190 self.train_data = open(nist_path, 'rb')
123 self.dim = tuple(ft._read_header(self.train_data)[3]) 191 self.dim = tuple(ft._read_header(self.train_data)[3])
149 return args 217 return args
150 218
151 # Might be called locally or through dbidispatch. In all cases it should be 219 # Might be called locally or through dbidispatch. In all cases it should be
152 # passed to the GIMP executable to be able to use GIMP filters. 220 # passed to the GIMP executable to be able to use GIMP filters.
153 # Ex: 221 # Ex:
154 def main(): 222 def _main():
155 max_complexity = 0.5 # default 223 max_complexity = 0.5 # default
156 probability_zero = 0.1 # default 224 probability_zero = 0.1 # default
157 output_file_path = None 225 output_file_path = None
158 params_output_file_path = None 226 params_output_file_path = None
159 stop_after = None 227 stop_after = None
228
229 import sys
230 print "python version: ", sys.version
160 231
161 try: 232 try:
162 opts, args = getopt.getopt(get_argv(), "m:z:o:p:s:", ["max-complexity=", "probability-zero=", "output-file=", "params-output-file=", "stop-after="]) 233 opts, args = getopt.getopt(get_argv(), "m:z:o:p:s:", ["max-complexity=", "probability-zero=", "output-file=", "params-output-file=", "stop-after="])
163 except getopt.GetoptError, err: 234 except getopt.GetoptError, err:
164 # print help information and exit: 235 # print help information and exit:
187 print "Must specify both output files." 258 print "Must specify both output files."
188 print 259 print
189 usage() 260 usage()
190 sys.exit(2) 261 sys.exit(2)
191 262
192 nist = NistData() 263 if DEBUG_IMAGES_PATH:
193 num_batches = nist.dim[0]/BATCH_SIZE 264 '''
194 if stop_after: 265 # This code is yet untested
195 num_batches = stop_after 266 debug_images = DebugImages(DEBUG_IMAGES_PATH)
196 pl = Pipeline(modules=MODULE_INSTANCES, num_batches=num_batches, batch_size=BATCH_SIZE, image_size=(32,32)) 267 num_batches = 1
268 batch_size = len(debug_images.filelist)
269 pl = Pipeline(modules=MODULE_INSTANCES, num_batches=num_batches, batch_size=BATCH_SIZE, image_size=(32,32))
270 batch_it = debug_images_iterator(debug_images)
271 '''
272 else:
273 nist = NistData()
274 num_batches = nist.dim[0]/BATCH_SIZE
275 if stop_after:
276 num_batches = stop_after
277 pl = Pipeline(modules=MODULE_INSTANCES, num_batches=num_batches, batch_size=BATCH_SIZE, image_size=(32,32))
278 batch_it = just_nist_iterator(nist, BATCH_SIZE, stop_after)
279
197 cpx_it = range_complexity_iterator(probability_zero, max_complexity) 280 cpx_it = range_complexity_iterator(probability_zero, max_complexity)
198 batch_it = just_nist_iterator(nist, BATCH_SIZE, stop_after)
199
200 pl.run(batch_it, cpx_it) 281 pl.run(batch_it, cpx_it)
201 pl.write_output(output_file_path, params_output_file_path) 282 pl.write_output(output_file_path, params_output_file_path)
202 283
203 main() 284 _main()
285
286 if DEBUG_X:
287 pylab.ioff()
288 pylab.show()
204 289
205 pdb.gimp_quit(0) 290 pdb.gimp_quit(0)
206 pylab.ioff() 291
207 pylab.show()
208