comparison transformations/pipeline.py @ 41:fdb0e0870fb4

Beaucoup de modifications à pipeline.py pour généraliser et un début de visualisation, et créé un wrapper (run_pipeline.py) pour appeler avec GIMP. - Modifications à pipeline.py - Wrappé la boucle du pipeline dans une classe - Isolé le problème de itérer sur les batches et les complexités dans des itérateurs - Permet d'avoir des ordres compliqués de batch (plusieurs sources), de complexités - Maintenant regenerate_parameters() est appelé pour chaque image. - Command line arguments avec getopt(). On pourra rajouter des options ainsi. - run_pipeline.py - Le but est de permettre de passer des arguments. Pas facile (pas trouvé comment de façon simple) avec la command line pour appeler GIMP en mode batch. C'est un hack ici. - Le but ultime est de permettre de lancer les jobs sur les clusters avec dbidispatch en précisant les options (diff. pour chaque job) sur la ligne de commande.
author fsavard
date Wed, 03 Feb 2010 17:08:27 -0500
parents f6b6c74bb82f
children fabf910467b2
comparison
equal deleted inserted replaced
40:0f1337994716 41:fdb0e0870fb4
1 #!/usr/bin/python
2 # coding: utf-8
3
1 from __future__ import with_statement 4 from __future__ import with_statement
2 5
3 import sys, os 6 import sys, os, getopt
4 import numpy 7 import numpy
5 import filetensor as ft 8 import filetensor as ft
6 import random 9 import random
7 10
11 # This is intended to be run as a GIMP script
12 from gimpfu import *
13
14 DEBUG = True
8 BATCH_SIZE = 100 15 BATCH_SIZE = 100
9 16 DEFAULT_NIST_PATH = '/data/lisa/data/nist/by_class/all/all_train_data.ft'
10 #import <modules> and stuff them in mods below 17 ARGS_FILE = os.environ['PIPELINE_ARGS_TMPFILE']
11 18
12 mods = [] 19 if DEBUG:
13 20 import pylab
14 # DANGER: HIGH VOLTAGE -- DO NOT EDIT BELOW THIS LINE 21 pylab.ion()
15 # ----------------------------------------------------------- 22
16 23 #from add_background_image import AddBackground
17 outf = sys.argv[1] 24 #from affine_transform import AffineTransformation
18 paramsf = sys.argv[2] 25 #from PoivreSel import PoivreSel
19 dataf = '/data/lisa/data/nist/by_class/all/all_train_data.ft' 26 from thick import Thick
20 if len(sys.argv) >= 4: 27 #from BruitGauss import BruitGauss
21 dataf = sys.argv[3] 28 #from gimp_script import GIMPTransformation
22 29 #from Rature import Rature
23 train_data = open(dataf, 'rb') 30 #from contrast Contrast
24 31 from local_elastic_distortions import LocalElasticDistorter
25 dim = tuple(ft._read_header(train_data)[3]) 32 from slant import Slant
26 33
27 res_data = numpy.empty(dim, dtype=numpy.int8) 34 MODULE_INSTANCES = [Thick(), LocalElasticDistorter(), Slant()]
28 35
29 all_settings = ['complexity'] 36 class Pipeline():
30 37 def __init__(self, modules, num_batches, batch_size, image_size=(32,32)):
31 for mod in mods: 38 self.modules = modules
32 all_settings += mod.get_settings_names() 39 self.num_batches = num_batches
33 40 self.batch_size = batch_size
34 params = numpy.empty(((dim[0]/BATCH_SIZE)+1, len(all_settings))) 41 self.num_params_stored = 0
35 42 self.image_size = image_size
36 for i in xrange(0, dim[0], BATCH_SIZE): 43
37 train_data.seek(0) 44 self.init_memory()
38 imgs = ft.read(train_data, slice(i, i+BATCH_SIZE)).astype(numpy.float32)/255 45
39 46 def init_num_params_stored(self):
40 complexity = random.random() 47 # just a dummy call to regenerate_parameters() to get the
41 p = i/BATCH_SIZE 48 # real number of params (only those which are stored)
42 j = 1 49 self.num_params_stored = 0
43 for mod in mods: 50 for m in self.modules:
44 par = mod.regenerate_parameters(complexity) 51 self.num_params_stored += len(m.regenerate_parameters(0.0))
45 params[p, j:j+len(par)] = par 52
46 j += len(par) 53 def init_memory(self):
47 54 self.init_num_params_stored()
48 for k in range(imgs.shape[0]): 55
49 c = imgs[k].reshape((32, 32)) 56 total = (self.num_batches + 1) * self.batch_size
50 for mod in mods: 57 num_px = self.image_size[0] * self.image_size[1]
51 c = mod.transform_image(c) 58
52 res_data[i+k] = c.reshape((1024,))*255 59 self.res_data = numpy.empty((total, num_px))
53 60 self.params = numpy.empty((total, self.num_params_stored))
54 with open(outf, 'wb') as f: 61
55 ft.write(f, res_data) 62 def run(self, batch_iterator, complexity_iterator):
56 63 img_size = self.image_size
57 numpy.save(paramsf, params) 64
65 for batch_no, batch in enumerate(batch_iterator):
66 complexity = complexity_iterator.next()
67
68 assert len(batch) == self.batch_size
69
70 for img_no, img in enumerate(batch):
71 sys.stdout.flush()
72 global_idx = batch_no*self.batch_size + img_no
73
74 img = img.reshape(img_size)
75
76 param_idx = 0
77 for mod in self.modules:
78 # This used to be done _per batch_,
79 # ie. out of the "for img" loop
80 p = mod.regenerate_parameters(complexity)
81 self.params[global_idx, param_idx:param_idx+len(p)] = p
82 param_idx += len(p)
83
84 img = mod.transform_image(img)
85
86 self.res_data[global_idx] = \
87 img.reshape((img_size[0] * img_size[1],))*255
88
89 pylab.imshow(img)
90 pylab.draw()
91
92 def write_output(self, output_file_path, params_output_file_path):
93 with open(output_file_path, 'wb') as f:
94 ft.write(f, self.res_data)
95
96 numpy.save(params_output_file_path, self.params)
97
98
99 ##############################################################################
100 # COMPLEXITY ITERATORS
101 # They're called once every batch, to get the complexity to use for that batch
102 # they must be infinite (should never throw StopIteration when calling next())
103
104 # probability of generating 0 complexity, otherwise
105 # uniform over 0.0-max_complexity
106 def range_complexity_iterator(probability_zero, max_complexity):
107 assert max_complexity <= 1.0
108 n = numpy.random.uniform(0.0, 1.0)
109 while True:
110 if n < probability_zero:
111 yield 0.0
112 else:
113 yield numpy.random.uniform(0.0, max_complexity)
114
115 ##############################################################################
116 # DATA ITERATORS
117 # They can be used to interleave different data sources etc.
118
119 class NistData():
120 def __init__(self, ):
121 nist_path = DEFAULT_NIST_PATH
122 self.train_data = open(nist_path, 'rb')
123 self.dim = tuple(ft._read_header(self.train_data)[3])
124
125 def just_nist_iterator(nist, batch_size, stop_after=None):
126 for i in xrange(0, nist.dim[0], batch_size):
127 nist.train_data.seek(0)
128 yield ft.read(nist.train_data, slice(i, i+batch_size)).astype(numpy.float32)/255
129
130 if not stop_after is None and i >= stop_after:
131 break
132
133 ##############################################################################
134 # MAIN
135
136 def usage():
137 print '''
138 Usage: run_pipeline.sh [-m ...] [-z ...] [-o ...] [-p ...]
139 -m, --max-complexity: max complexity to generate for a batch
140 -z, --probability-zero: probability of using complexity=0 for a batch
141 -o, --output-file: full path to file to use for output of images
142 -p, --params-output-file: path to file to output params to
143 '''
144
145 # See run_pipeline.py
146 def get_argv():
147 with open(ARGS_FILE) as f:
148 args = [l.rstrip() for l in f.readlines()]
149 return args
150
151 # Might be called locally or through dbidispatch. In all cases it should be
152 # passed to the GIMP executable to be able to use GIMP filters.
153 # Ex:
154 def main():
155 max_complexity = 0.5 # default
156 probability_zero = 0.1 # default
157 output_file_path = None
158 params_output_file_path = None
159 stop_after = None
160
161 try:
162 opts, args = getopt.getopt(get_argv(), "m:z:o:p:s:", ["max-complexity=", "probability-zero=", "output-file=", "params-output-file=", "stop-after="])
163 except getopt.GetoptError, err:
164 # print help information and exit:
165 print str(err) # will print something like "option -a not recognized"
166 usage()
167 sys.exit(2)
168 output = None
169 verbose = False
170 for o, a in opts:
171 if o in ('-m', '--max-complexity'):
172 max_complexity = float(a)
173 assert max_complexity >= 0.0 and max_complexity <= 1.0
174 elif o in ("-z", "--probability-zero"):
175 probability_zero = float(a)
176 assert probability_zero >= 0.0 and probability_zero <= 1.0
177 elif o in ("-o", "--output-file"):
178 output_file_path = a
179 elif o in ('-p', "--params-output-file"):
180 params_output_file_path = a
181 elif o in ('-s', "--stop-after"):
182 stop_after = int(a)
183 else:
184 assert False, "unhandled option"
185
186 if output_file_path == None or params_output_file_path == None:
187 print "Must specify both output files."
188 print
189 usage()
190 sys.exit(2)
191
192 nist = NistData()
193 num_batches = nist.dim[0]/BATCH_SIZE
194 if stop_after:
195 num_batches = stop_after
196 pl = Pipeline(modules=MODULE_INSTANCES, num_batches=num_batches, batch_size=BATCH_SIZE, image_size=(32,32))
197 cpx_it = range_complexity_iterator(probability_zero, max_complexity)
198 batch_it = just_nist_iterator(nist, BATCH_SIZE, stop_after)
199
200 pl.run(batch_it, cpx_it)
201 pl.write_output(output_file_path, params_output_file_path)
202
203 main()
204
205 pdb.gimp_quit(0)
206 pylab.ioff()
207 pylab.show()
208