Mercurial > ift6266
comparison transformations/pipeline.py @ 108:a7cd8dd3221c
pipeline.py: placé les modules dans le bon ordre + store NIST en bytes plutôt qu'en float32 dans la RAM
author | boulanni <nicolas_boulanger@hotmail.com> |
---|---|
date | Mon, 15 Feb 2010 16:17:48 -0500 |
parents | 95c491bb5662 |
children | 9c45e0071b52 |
comparison
equal
deleted
inserted
replaced
107:a9b87b68101d | 108:a7cd8dd3221c |
---|---|
39 | 39 |
40 if DEBUG_X: | 40 if DEBUG_X: |
41 import pylab | 41 import pylab |
42 pylab.ion() | 42 pylab.ion() |
43 | 43 |
44 #from add_background_image import AddBackground | |
45 from affine_transform import AffineTransformation | |
46 from PoivreSel import PoivreSel | 44 from PoivreSel import PoivreSel |
47 from thick import Thick | 45 from thick import Thick |
48 from BruitGauss import BruitGauss | 46 from BruitGauss import BruitGauss |
47 from DistorsionGauss import DistorsionGauss | |
48 from PermutPixel import PermutPixel | |
49 from gimp_script import GIMP1 | 49 from gimp_script import GIMP1 |
50 from Rature import Rature | 50 from Rature import Rature |
51 from contrast import Contrast | 51 from contrast import Contrast |
52 from Occlusion import Occlusion | |
53 from local_elastic_distortions import LocalElasticDistorter | 52 from local_elastic_distortions import LocalElasticDistorter |
54 from slant import Slant | 53 from slant import Slant |
54 from Occlusion import Occlusion | |
55 from add_background_image import AddBackground | |
56 from affine_transform import AffineTransformation | |
55 | 57 |
56 if DEBUG: | 58 if DEBUG: |
57 from visualizer import Visualizer | 59 from visualizer import Visualizer |
58 # Either put the visualizer as in the MODULES_INSTANCES list | 60 # Either put the visualizer as in the MODULES_INSTANCES list |
59 # after each module you want to visualize, or in the | 61 # after each module you want to visualize, or in the |
60 # AFTER_EACH_MODULE_HOOK list (but not both, it's redundant) | 62 # AFTER_EACH_MODULE_HOOK list (but not both, it's redundant) |
61 VISUALIZER = Visualizer(to_dir=DEBUG_OUTPUT_DIR, on_screen=False) | 63 VISUALIZER = Visualizer(to_dir=DEBUG_OUTPUT_DIR, on_screen=False) |
62 | 64 |
63 MODULE_INSTANCES = [Thick(),Slant(),GIMP1(),AffineTransformation(),LocalElasticDistorter(),Occlusion(),Rature(),Contrast()] | 65 ###---------------------order of transformation module |
66 MODULE_INSTANCES = [Slant(),Thick(),AffineTransformation(),LocalElasticDistorter(),GIMP1(),Rature(),Occlusion(), PermutPixel(),DistorsionGauss(),AddBackground(), PoivreSel(), BruitGauss(), Contrast()] | |
64 | 67 |
65 # These should have a "after_transform_callback(self, image)" method | 68 # These should have a "after_transform_callback(self, image)" method |
66 # (called after each call to transform_image in a module) | 69 # (called after each call to transform_image in a module) |
67 AFTER_EACH_MODULE_HOOK = [] | 70 AFTER_EACH_MODULE_HOOK = [] |
68 if DEBUG: | 71 if DEBUG: |
199 self.train_data.close() | 202 self.train_data.close() |
200 self.train_data = open(nist_path, 'rb') | 203 self.train_data = open(nist_path, 'rb') |
201 self.ocr_data = open(ocr_path, 'rb') | 204 self.ocr_data = open(ocr_path, 'rb') |
202 self.ocr_labels = open(ocrlabel_path, 'rb') | 205 self.ocr_labels = open(ocrlabel_path, 'rb') |
203 | 206 |
207 # cet iterator load tout en ram | |
204 def nist_supp_iterator(nist, prob_font, prob_captcha, prob_ocr, num_img): | 208 def nist_supp_iterator(nist, prob_font, prob_captcha, prob_ocr, num_img): |
205 img = ft.read(nist.train_data).astype(numpy.float32)/255 | 209 img = ft.read(nist.train_data) |
206 labels = ft.read(nist.train_labels) | 210 labels = ft.read(nist.train_labels) |
207 if prob_ocr: | 211 if prob_ocr: |
208 ocr_img = ft.read(nist.ocr_data).astype(numpy.float32)/255 | 212 ocr_img = ft.read(nist.ocr_data) |
209 ocr_labels = ft.read(nist.ocr_labels) | 213 ocr_labels = ft.read(nist.ocr_labels) |
210 | 214 |
211 for i in xrange(num_img): | 215 for i in xrange(num_img): |
212 r = numpy.random.rand() | 216 r = numpy.random.rand() |
213 if r <= prob_font: | 217 if r <= prob_font: |
214 pass #get font | 218 pass #get font |
215 elif r <= prob_font + prob_captcha: | 219 elif r <= prob_font + prob_captcha: |
216 pass #get captcha | 220 pass #get captcha |
217 elif r <= prob_font + prob_captcha + prob_ocr: | 221 elif r <= prob_font + prob_captcha + prob_ocr: |
218 j = numpy.random.randint(len(ocr_labels)) | 222 j = numpy.random.randint(len(ocr_labels)) |
219 yield ocr_img[j], ocr_labels[j] | 223 yield ocr_img[j].astype(numpy.float32)/255, ocr_labels[j] |
220 else: | 224 else: |
221 j = numpy.random.randint(len(labels)) | 225 j = numpy.random.randint(len(labels)) |
222 yield img[j], labels[j] | 226 yield img[j].astype(numpy.float32)/255, labels[j] |
223 | 227 |
224 | 228 |
225 # Mostly for debugging, for the moment, just to see if we can | 229 # Mostly for debugging, for the moment, just to see if we can |
226 # reload the images and parameters. | 230 # reload the images and parameters. |
227 def reload(output_file_path, params_output_file_path): | 231 def reload(output_file_path, params_output_file_path): |