comparison transformations/pipeline.py @ 108:a7cd8dd3221c

pipeline.py: placé les modules dans le bon ordre + store NIST en bytes plutôt qu'en float32 dans la RAM
author boulanni <nicolas_boulanger@hotmail.com>
date Mon, 15 Feb 2010 16:17:48 -0500
parents 95c491bb5662
children 9c45e0071b52
comparison
equal deleted inserted replaced
107:a9b87b68101d 108:a7cd8dd3221c
39 39
40 if DEBUG_X: 40 if DEBUG_X:
41 import pylab 41 import pylab
42 pylab.ion() 42 pylab.ion()
43 43
44 #from add_background_image import AddBackground
45 from affine_transform import AffineTransformation
46 from PoivreSel import PoivreSel 44 from PoivreSel import PoivreSel
47 from thick import Thick 45 from thick import Thick
48 from BruitGauss import BruitGauss 46 from BruitGauss import BruitGauss
47 from DistorsionGauss import DistorsionGauss
48 from PermutPixel import PermutPixel
49 from gimp_script import GIMP1 49 from gimp_script import GIMP1
50 from Rature import Rature 50 from Rature import Rature
51 from contrast import Contrast 51 from contrast import Contrast
52 from Occlusion import Occlusion
53 from local_elastic_distortions import LocalElasticDistorter 52 from local_elastic_distortions import LocalElasticDistorter
54 from slant import Slant 53 from slant import Slant
54 from Occlusion import Occlusion
55 from add_background_image import AddBackground
56 from affine_transform import AffineTransformation
55 57
56 if DEBUG: 58 if DEBUG:
57 from visualizer import Visualizer 59 from visualizer import Visualizer
58 # Either put the visualizer as in the MODULES_INSTANCES list 60 # Either put the visualizer as in the MODULES_INSTANCES list
59 # after each module you want to visualize, or in the 61 # after each module you want to visualize, or in the
60 # AFTER_EACH_MODULE_HOOK list (but not both, it's redundant) 62 # AFTER_EACH_MODULE_HOOK list (but not both, it's redundant)
61 VISUALIZER = Visualizer(to_dir=DEBUG_OUTPUT_DIR, on_screen=False) 63 VISUALIZER = Visualizer(to_dir=DEBUG_OUTPUT_DIR, on_screen=False)
62 64
63 MODULE_INSTANCES = [Thick(),Slant(),GIMP1(),AffineTransformation(),LocalElasticDistorter(),Occlusion(),Rature(),Contrast()] 65 ###---------------------order of transformation module
66 MODULE_INSTANCES = [Slant(),Thick(),AffineTransformation(),LocalElasticDistorter(),GIMP1(),Rature(),Occlusion(), PermutPixel(),DistorsionGauss(),AddBackground(), PoivreSel(), BruitGauss(), Contrast()]
64 67
65 # These should have a "after_transform_callback(self, image)" method 68 # These should have a "after_transform_callback(self, image)" method
66 # (called after each call to transform_image in a module) 69 # (called after each call to transform_image in a module)
67 AFTER_EACH_MODULE_HOOK = [] 70 AFTER_EACH_MODULE_HOOK = []
68 if DEBUG: 71 if DEBUG:
199 self.train_data.close() 202 self.train_data.close()
200 self.train_data = open(nist_path, 'rb') 203 self.train_data = open(nist_path, 'rb')
201 self.ocr_data = open(ocr_path, 'rb') 204 self.ocr_data = open(ocr_path, 'rb')
202 self.ocr_labels = open(ocrlabel_path, 'rb') 205 self.ocr_labels = open(ocrlabel_path, 'rb')
203 206
207 # cet iterator load tout en ram
204 def nist_supp_iterator(nist, prob_font, prob_captcha, prob_ocr, num_img): 208 def nist_supp_iterator(nist, prob_font, prob_captcha, prob_ocr, num_img):
205 img = ft.read(nist.train_data).astype(numpy.float32)/255 209 img = ft.read(nist.train_data)
206 labels = ft.read(nist.train_labels) 210 labels = ft.read(nist.train_labels)
207 if prob_ocr: 211 if prob_ocr:
208 ocr_img = ft.read(nist.ocr_data).astype(numpy.float32)/255 212 ocr_img = ft.read(nist.ocr_data)
209 ocr_labels = ft.read(nist.ocr_labels) 213 ocr_labels = ft.read(nist.ocr_labels)
210 214
211 for i in xrange(num_img): 215 for i in xrange(num_img):
212 r = numpy.random.rand() 216 r = numpy.random.rand()
213 if r <= prob_font: 217 if r <= prob_font:
214 pass #get font 218 pass #get font
215 elif r <= prob_font + prob_captcha: 219 elif r <= prob_font + prob_captcha:
216 pass #get captcha 220 pass #get captcha
217 elif r <= prob_font + prob_captcha + prob_ocr: 221 elif r <= prob_font + prob_captcha + prob_ocr:
218 j = numpy.random.randint(len(ocr_labels)) 222 j = numpy.random.randint(len(ocr_labels))
219 yield ocr_img[j], ocr_labels[j] 223 yield ocr_img[j].astype(numpy.float32)/255, ocr_labels[j]
220 else: 224 else:
221 j = numpy.random.randint(len(labels)) 225 j = numpy.random.randint(len(labels))
222 yield img[j], labels[j] 226 yield img[j].astype(numpy.float32)/255, labels[j]
223 227
224 228
225 # Mostly for debugging, for the moment, just to see if we can 229 # Mostly for debugging, for the moment, just to see if we can
226 # reload the images and parameters. 230 # reload the images and parameters.
227 def reload(output_file_path, params_output_file_path): 231 def reload(output_file_path, params_output_file_path):