comparison transformations/local_elastic_distortions.py @ 24:010e826b41e8

Modifications to elastic distortions: fixed an important bug with distortions themselves (now result is much nicer visually), made interface to conform to Transformation standard, and added ability to save a certain amount of distortion fields to reuse them if complexity doesn't change
author fsavard <francois.savard@polymtl.ca>
date Fri, 29 Jan 2010 13:37:52 -0500
parents 8d1c37190122
children b67d729ebfe3
comparison
equal deleted inserted replaced
21:afdd41db8152 24:010e826b41e8
21 import math 21 import math
22 import numpy 22 import numpy
23 import numpy.random 23 import numpy.random
24 import scipy.signal # convolve2d 24 import scipy.signal # convolve2d
25 25
26 def raw_zeros(size): 26 _TEST_DIR = "/home/francois/Desktop/dist_tests/"
27
28 def _raw_zeros(size):
27 return [[0 for i in range(size[1])] for j in range(size[0])] 29 return [[0 for i in range(size[1])] for j in range(size[0])]
28 30
31 class ElasticDistortionParams():
32 def __init__(self, image_size, alpha=0.0, sigma=0.0):
33 self.image_size = image_size
34 self.alpha = alpha
35 self.sigma = sigma
36
37 h,w = self.image_size
38
39 self.matrix_tl_corners_rows = _raw_zeros((h,w))
40 self.matrix_tl_corners_cols = _raw_zeros((h,w))
41
42 self.matrix_tr_corners_rows = _raw_zeros((h,w))
43 self.matrix_tr_corners_cols = _raw_zeros((h,w))
44
45 self.matrix_bl_corners_rows = _raw_zeros((h,w))
46 self.matrix_bl_corners_cols = _raw_zeros((h,w))
47
48 self.matrix_br_corners_rows = _raw_zeros((h,w))
49 self.matrix_br_corners_cols = _raw_zeros((h,w))
50
51 # those will hold the precomputed ratios for
52 # bilinear interpolation
53 self.matrix_tl_multiply = numpy.zeros((h,w))
54 self.matrix_tr_multiply = numpy.zeros((h,w))
55 self.matrix_bl_multiply = numpy.zeros((h,w))
56 self.matrix_br_multiply = numpy.zeros((h,w))
57
58 def alpha_sigma(self):
59 return [self.alpha, self.sigma]
60
29 class LocalElasticDistorter(): 61 class LocalElasticDistorter():
30 def __init__(self, image_size, kernel_size, sigma, alpha): 62 def __init__(self, image_size):
31 self.image_size = image_size 63 self.image_size = image_size
32 self.kernel_size = kernel_size 64
33 self.sigma = sigma 65 self.current_complexity = 0.0
34 self.alpha = alpha 66
35 self.c_alpha = int(math.ceil(alpha)) 67 # number of precomputed fields
36 68 # (principle: as complexity doesn't change often, we can
37 self.kernel = self.gen_gaussian_kernel() 69 # precompute a certain number of fields for a given complexity,
38 self.fields = None 70 # each with its own parameters. That way, we have good
39 self.regenerate_fields() 71 # randomization, but we're much faster).
72 self.to_precompute = 50
73
74 # Both use ElasticDistortionParams
75 self.current_params = None
76 self.precomputed_params = []
77
78 #
79 self.kernel_size = None
80 self.kernel = None
81
82 # set some defaults
83 self.regenerate_parameters(0.0)
84
85 def get_settings_names(self):
86 return ['alpha', 'sigma']
87
88 def regenerate_parameters(self, complexity):
89 if abs(complexity - self.current_complexity) > 1e-4:
90 self.current_complexity = complexity
91
92 # complexity changed, fields must be regenerated
93 self.precomputed_params = []
94
95 if len(self.precomputed_params) <= self.to_precompute:
96 # not yet enough params generated, produce one more
97 # and append to list
98 new_params = self._initialize_new_params()
99 new_params = self._generate_fields(new_params)
100 self.current_params = new_params
101 self.precomputed_params.append(new_params)
102 else:
103 # if we have enough precomputed fields, just select one
104 # at random and set parameters to match what they were
105 # when the field was generated
106 idx = numpy.random.randint(0, len(self.precomputed_params))
107 self.current_params = self.precomputed_params[idx]
108
109 return self.current_params.alpha_sigma()
40 110
41 # adapted from http://blenderartists.org/forum/showthread.php?t=163361 111 # adapted from http://blenderartists.org/forum/showthread.php?t=163361
42 def gen_gaussian_kernel(self): 112 def _gen_gaussian_kernel(self, sigma):
113 # the kernel size can change DRAMATICALLY the time
114 # for the blur operation... so even though results are better
115 # with a bigger kernel, we need to compromise here
116 # 1*s is very different from 2*s, but there's not much difference
117 # between 2*s and 4*s
118 ks = self.kernel_size
119 s = sigma
120 target_ks = (1.5*s, 1.5*s)
121 if not ks is None and ks[0] == target_ks[0] and ks[1] == target_ks[1]:
122 # kernel size is good, ok, no need to regenerate
123 return
124 self.kernel_size = target_ks
43 h,w = self.kernel_size 125 h,w = self.kernel_size
44 a,b = h/2.0, w/2.0 126 a,b = h/2.0, w/2.0
45 y,x = numpy.ogrid[0:w, 0:h] 127 y,x = numpy.ogrid[0:w, 0:h]
46 s = self.sigma
47 gauss = numpy.exp(-numpy.square((x-a)/s))*numpy.exp(-numpy.square((y-b)/s)) 128 gauss = numpy.exp(-numpy.square((x-a)/s))*numpy.exp(-numpy.square((y-b)/s))
48 # Normalize so we don't reduce image intensity 129 # Normalize so we don't reduce image intensity
49 return gauss/gauss.sum() 130 self.kernel = gauss/gauss.sum()
50 131
51 def gen_distortion_field(self): 132 def _gen_distortion_field(self, params):
52 field = numpy.random.uniform(-1.0, 1.0, self.image_size) 133 self._gen_gaussian_kernel(params.sigma)
53 return scipy.signal.convolve2d(field, self.kernel, mode='same') 134
54 135 # we add kernel_size on all four sides so blurring
55 def regenerate_fields(self): 136 # with the kernel produces a smoother result on borders
137 ks0 = self.kernel_size[0]
138 ks1 = self.kernel_size[1]
139 sz0 = self.image_size[1] + ks0
140 sz1 = self.image_size[0] + ks1
141 field = numpy.random.uniform(-1.0, 1.0, (sz0, sz1))
142 field = scipy.signal.convolve2d(field, self.kernel, mode='same')
143
144 # crop only image_size in the middle
145 field = field[ks0:ks0+self.image_size[0], ks1:ks1+self.image_size[1]]
146
147 return params.alpha * field
148
149
150 def _initialize_new_params(self):
151 params = ElasticDistortionParams(self.image_size)
152
153 cpx = self.current_complexity
154 # pour faire progresser la complexité un peu plus vite
155 # tout en gardant les extrêmes de 0.0 et 1.0
156 cpx = cpx ** (1./3.)
157
158 # the smaller the alpha, the closest the pixels are fetched
159 # a max of 10 is reasonable
160 params.alpha = cpx * 10.0
161
162 # the bigger the sigma, the smoother is the distortion
163 # max of 1 is "reasonable", but produces VERY noisy results
164 # And the bigger the sigma, the bigger the blur kernel, and the
165 # slower the field generation, btw.
166 params.sigma = 10.0 - (7.0 * cpx)
167
168 return params
169
170 def _generate_fields(self, params):
56 ''' 171 '''
57 Here's how the code works: 172 Here's how the code works:
58 - We first generate "distortion fields" for x and y with these steps: 173 - We first generate "distortion fields" for x and y with these steps:
59 - Uniform noise over [-1, 1] in a matrix of size (h,w) 174 - Uniform noise over [-1, 1] in a matrix of size (h,w)
60 - Blur with a Gaussian kernel of spread sigma 175 - Blur with a Gaussian kernel of spread sigma
72 pixels for each new pixel. 187 pixels for each new pixel.
73 - Then I multiply those extracted nearby points by precomputed 188 - Then I multiply those extracted nearby points by precomputed
74 ratios for the bilinear interpolation. 189 ratios for the bilinear interpolation.
75 ''' 190 '''
76 191
77 self.fields = [None, None] 192 p = params
78 self.fields[0] = self.alpha*self.gen_distortion_field() 193
79 self.fields[1] = self.alpha*self.gen_distortion_field() 194 dist_fields = [None, None]
80 195 dist_fields[0] = self._gen_distortion_field(params)
81 #import pylab 196 dist_fields[1] = self._gen_distortion_field(params)
82 #pylab.imshow(self.fields[0]) 197
198 #pylab.imshow(dist_fields[0])
83 #pylab.show() 199 #pylab.show()
84 200
85 # regenerate distortion index matrices 201 # regenerate distortion index matrices
86 # "_rows" are row indices 202 # "_rows" are row indices
87 # "_cols" are column indices 203 # "_cols" are column indices
88 # (separated due to the way fancy indexing works in numpy) 204 # (separated due to the way fancy indexing works in numpy)
89 h,w = self.image_size 205 h,w = p.image_size
90
91 self.matrix_tl_corners_rows = raw_zeros((h,w))
92 self.matrix_tl_corners_cols = raw_zeros((h,w))
93
94 self.matrix_tr_corners_rows = raw_zeros((h,w))
95 self.matrix_tr_corners_cols = raw_zeros((h,w))
96
97 self.matrix_bl_corners_rows = raw_zeros((h,w))
98 self.matrix_bl_corners_cols = raw_zeros((h,w))
99
100 self.matrix_br_corners_rows = raw_zeros((h,w))
101 self.matrix_br_corners_cols = raw_zeros((h,w))
102
103 # those will hold the precomputed ratios for
104 # bilinear interpolation
105 self.matrix_tl_multiply = numpy.zeros((h,w))
106 self.matrix_tr_multiply = numpy.zeros((h,w))
107 self.matrix_bl_multiply = numpy.zeros((h,w))
108 self.matrix_br_multiply = numpy.zeros((h,w))
109 206
110 for y in range(h): 207 for y in range(h):
111 for x in range(w): 208 for x in range(w):
112 distort_x = self.fields[0][y,x] 209 distort_x = dist_fields[0][y,x]
113 distort_y = self.fields[1][y,x] 210 distort_y = dist_fields[1][y,x]
114 f_dy = int(math.floor(distort_y)) 211
115 f_dx = int(math.floor(distort_x)) 212 # the "target" is the coordinate we fetch color data from
116 y0 = y+f_dy 213 # (in the original image)
117 x0 = x+f_dx 214 # target_left and _top are the rounded coordinate on the
118 index_tl = [y0, x0] 215 # left/top of this target (float) coordinate
119 index_tr = [y0, x0+1] 216 target_pixel = (y+distort_y, x+distort_x)
120 index_bl = [y0+1, x0] 217
121 index_br = [y0+1, x0+1] 218 target_left = int(math.floor(x + distort_x))
122 x_ratio = abs(distort_x-f_dx) # ratio of left vs right (for bilinear) 219 target_top = int(math.floor(y + distort_y))
123 y_ratio = abs(distort_y-f_dy) # ratio of top vs bottom 220
221 index_tl = [target_top, target_left]
222 index_tr = [target_top, target_left+1]
223 index_bl = [target_top+1, target_left]
224 index_br = [target_top+1, target_left+1]
225
226 # x_ratio is the ratio of importance of left pixels
227 # y_ratio is the """" of top pixels
228 # (in bilinear combination)
229 y_ratio = 1.0 - (target_pixel[0] - target_top)
230 x_ratio = 1.0 - (target_pixel[1] - target_left)
124 231
125 # We use a default background color of 0 for displacements 232 # We use a default background color of 0 for displacements
126 # outside of boundaries of the image. 233 # outside of boundaries of the image.
127 234
128 # if top left outside bounds 235 # if top left outside bounds
129 if index_tl[0] < 0 or index_tl[0] >= h or index_tl[1] < 0 or index_tl[1] >= w: 236 if index_tl[0] < 0 or index_tl[0] >= h or index_tl[1] < 0 or index_tl[1] >= w:
130 self.matrix_tl_corners_rows[y][x] = 0 237 p.matrix_tl_corners_rows[y][x] = 0
131 self.matrix_tl_corners_cols[y][x] = 0 238 p.matrix_tl_corners_cols[y][x] = 0
132 self.matrix_tl_multiply[y,x] = 0 239 p.matrix_tl_multiply[y,x] = 0
133 else: 240 else:
134 self.matrix_tl_corners_rows[y][x] = index_tl[0] 241 p.matrix_tl_corners_rows[y][x] = index_tl[0]
135 self.matrix_tl_corners_cols[y][x] = index_tl[1] 242 p.matrix_tl_corners_cols[y][x] = index_tl[1]
136 self.matrix_tl_multiply[y,x] = x_ratio*y_ratio 243 p.matrix_tl_multiply[y,x] = x_ratio*y_ratio
137
138 244
139 # if top right outside bounds 245 # if top right outside bounds
140 if index_tr[0] < 0 or index_tr[0] >= h or index_tr[1] < 0 or index_tr[1] >= w: 246 if index_tr[0] < 0 or index_tr[0] >= h or index_tr[1] < 0 or index_tr[1] >= w:
141 self.matrix_tr_corners_rows[y][x] = 0 247 p.matrix_tr_corners_rows[y][x] = 0
142 self.matrix_tr_corners_cols[y][x] = 0 248 p.matrix_tr_corners_cols[y][x] = 0
143 self.matrix_tr_multiply[y,x] = 0 249 p.matrix_tr_multiply[y,x] = 0
144 else: 250 else:
145 self.matrix_tr_corners_rows[y][x] = index_tr[0] 251 p.matrix_tr_corners_rows[y][x] = index_tr[0]
146 self.matrix_tr_corners_cols[y][x] = index_tr[1] 252 p.matrix_tr_corners_cols[y][x] = index_tr[1]
147 self.matrix_tr_multiply[y,x] = (1.0-x_ratio)*y_ratio 253 p.matrix_tr_multiply[y,x] = (1.0-x_ratio)*y_ratio
148 254
149 # if bottom left outside bounds 255 # if bottom left outside bounds
150 if index_bl[0] < 0 or index_bl[0] >= h or index_bl[1] < 0 or index_bl[1] >= w: 256 if index_bl[0] < 0 or index_bl[0] >= h or index_bl[1] < 0 or index_bl[1] >= w:
151 self.matrix_bl_corners_rows[y][x] = 0 257 p.matrix_bl_corners_rows[y][x] = 0
152 self.matrix_bl_corners_cols[y][x] = 0 258 p.matrix_bl_corners_cols[y][x] = 0
153 self.matrix_bl_multiply[y,x] = 0 259 p.matrix_bl_multiply[y,x] = 0
154 else: 260 else:
155 self.matrix_bl_corners_rows[y][x] = index_bl[0] 261 p.matrix_bl_corners_rows[y][x] = index_bl[0]
156 self.matrix_bl_corners_cols[y][x] = index_bl[1] 262 p.matrix_bl_corners_cols[y][x] = index_bl[1]
157 self.matrix_bl_multiply[y,x] = x_ratio*(1.0-y_ratio) 263 p.matrix_bl_multiply[y,x] = x_ratio*(1.0-y_ratio)
158 264
159 # if bottom right outside bounds 265 # if bottom right outside bounds
160 if index_br[0] < 0 or index_br[0] >= h or index_br[1] < 0 or index_br[1] >= w: 266 if index_br[0] < 0 or index_br[0] >= h or index_br[1] < 0 or index_br[1] >= w:
161 self.matrix_br_corners_rows[y][x] = 0 267 p.matrix_br_corners_rows[y][x] = 0
162 self.matrix_br_corners_cols[y][x] = 0 268 p.matrix_br_corners_cols[y][x] = 0
163 self.matrix_br_multiply[y,x] = 0 269 p.matrix_br_multiply[y,x] = 0
164 else: 270 else:
165 self.matrix_br_corners_rows[y][x] = index_br[0] 271 p.matrix_br_corners_rows[y][x] = index_br[0]
166 self.matrix_br_corners_cols[y][x] = index_br[1] 272 p.matrix_br_corners_cols[y][x] = index_br[1]
167 self.matrix_br_multiply[y,x] = (1.0-x_ratio)*(1.0-y_ratio) 273 p.matrix_br_multiply[y,x] = (1.0-x_ratio)*(1.0-y_ratio)
168 274
169 def distort_image(self, image): 275 # not really necessary, but anyway
276 return p
277
278 def transform_image(self, image):
279 p = self.current_params
280
170 # index pixels to get the 4 corners for bilinear combination 281 # index pixels to get the 4 corners for bilinear combination
171 tl_pixels = image[self.matrix_tl_corners_rows, self.matrix_tl_corners_cols] 282 tl_pixels = image[p.matrix_tl_corners_rows, p.matrix_tl_corners_cols]
172 tr_pixels = image[self.matrix_tr_corners_rows, self.matrix_tr_corners_cols] 283 tr_pixels = image[p.matrix_tr_corners_rows, p.matrix_tr_corners_cols]
173 bl_pixels = image[self.matrix_bl_corners_rows, self.matrix_bl_corners_cols] 284 bl_pixels = image[p.matrix_bl_corners_rows, p.matrix_bl_corners_cols]
174 br_pixels = image[self.matrix_br_corners_rows, self.matrix_br_corners_cols] 285 br_pixels = image[p.matrix_br_corners_rows, p.matrix_br_corners_cols]
175 286
176 # bilinear ratios, elemwise multiply 287 # bilinear ratios, elemwise multiply
177 tl_pixels = numpy.multiply(tl_pixels, self.matrix_tl_multiply) 288 tl_pixels = numpy.multiply(tl_pixels, p.matrix_tl_multiply)
178 tr_pixels = numpy.multiply(tr_pixels, self.matrix_tr_multiply) 289 tr_pixels = numpy.multiply(tr_pixels, p.matrix_tr_multiply)
179 bl_pixels = numpy.multiply(bl_pixels, self.matrix_bl_multiply) 290 bl_pixels = numpy.multiply(bl_pixels, p.matrix_bl_multiply)
180 br_pixels = numpy.multiply(br_pixels, self.matrix_br_multiply) 291 br_pixels = numpy.multiply(br_pixels, p.matrix_br_multiply)
181 292
182 # sum to finish bilinear combination 293 # sum to finish bilinear combination
183 return numpy.sum([tl_pixels,tr_pixels,bl_pixels,br_pixels], axis=0) 294 return numpy.sum([tl_pixels,tr_pixels,bl_pixels,br_pixels], axis=0)
184 295
185 # TESTS ---------------------------------------------------------------------- 296 # TESTS ----------------------------------------------------------------------
191 if len(img.shape) > 2: 302 if len(img.shape) > 2:
192 img = (img * _RGB_TO_GRAYSCALE).sum(axis=2) 303 img = (img * _RGB_TO_GRAYSCALE).sum(axis=2)
193 return (img / 255.0).astype('float') 304 return (img / 255.0).astype('float')
194 305
195 def _specific_test(): 306 def _specific_test():
196 img = _load_image("tests/d.png") 307 imgpath = os.path.join(_TEST_DIR, "d.png")
197 dist = LocalElasticDistorter((32,32), (15,15), 9.0, 5.0) 308 img = _load_image(imgpath)
198 dist.distort_image(img) 309 dist = LocalElasticDistorter((32,32))
199 310 print dist.regenerate_parameters(0.5)
311 img = dist.distort_image(img)
312 pylab.imshow(img)
313 pylab.show()
314
315 def _complexity_tests():
316 imgpath = os.path.join(_TEST_DIR, "d.png")
317 dist = LocalElasticDistorter((32,32))
318 orig_img = _load_image(imgpath)
319 html_content = '''<html><body>Original:<br/><img src='d.png'>'''
320 for complexity in numpy.arange(0.0, 1.1, 0.1):
321 html_content += '<br/>Complexity: ' + str(complexity) + '<br/>'
322 for i in range(10):
323 t1 = time.time()
324 dist.regenerate_parameters(complexity)
325 t2 = time.time()
326 print "diff", t2-t1
327 img = dist.transform_image(orig_img)
328 filename = "complexity_" + str(complexity) + "_" + str(i) + ".png"
329 new_path = os.path.join(_TEST_DIR, filename)
330 _save_image(img, new_path)
331 html_content += '<img src="' + filename + '">'
332 html_content += "</body></html>"
333 html_file = open(os.path.join(_TEST_DIR, "complexity.html"), "w")
334 html_file.write(html_content)
335 html_file.close()
336
337 def _complexity_benchmark():
338 imgpath = os.path.join(_TEST_DIR, "d.png")
339 dist = LocalElasticDistorter((32,32))
340 orig_img = _load_image(imgpath)
341
342 # time the first 10
343 t1 = time.time()
344 for i in range(10):
345 dist.regenerate_parameters(0.2)
346 img = dist.transform_image(orig_img)
347 t2 = time.time()
348
349 print "first 10, total = ", t2-t1, ", avg=", (t2-t1)/10
350
351 # time the next 40
352 t1 = time.time()
353 for i in range(40):
354 dist.regenerate_parameters(0.2)
355 img = dist.transform_image(orig_img)
356 t2 = time.time()
357
358 print "next 40, total = ", t2-t1, ", avg=", (t2-t1)/40
359
360 # time the next 50
361 t1 = time.time()
362 for i in range(50):
363 dist.regenerate_parameters(0.2)
364 img = dist.transform_image(orig_img)
365 t2 = time.time()
366
367 print "next 50, total = ", t2-t1, ", avg=", (t2-t1)/50
368
369 # time the next 1000
370 t1 = time.time()
371 for i in range(1000):
372 dist.regenerate_parameters(0.2)
373 img = dist.transform_image(orig_img)
374 t2 = time.time()
375
376 print "next 1000, total = ", t2-t1, ", avg=", (t2-t1)/1000
377
378
379
380 def _save_image(img, path):
381 img2 = Image.fromarray((img * 255).astype('uint8'), "L")
382 img2.save(path)
383
384 # TODO: reformat to follow new class... it function of complexity now
385 '''
200 def _distorter_tests(): 386 def _distorter_tests():
201 #import pylab 387 #import pylab
202 #pylab.imshow(img) 388 #pylab.imshow(img)
203 #pylab.show() 389 #pylab.show()
204 390
205 for letter in ("d", "a", "n", "o"): 391 for letter in ("d", "a", "n", "o"):
206 img = _load_image("tests/" + letter + ".png") 392 img = _load_image("tests/" + letter + ".png")
207 for alpha in (1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0): 393 for alpha in (1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0):
208 for sigma in (1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0): 394 for sigma in (1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0):
209 id = LocalElasticDistorter((32,32), (15,15), sigma, alpha) 395 id = LocalElasticDistorter((32,32))
210 img2 = id.distort_image(img) 396 img2 = id.distort_image(img)
211 img2 = Image.fromarray((img2 * 255).astype('uint8'), "L") 397 img2 = Image.fromarray((img2 * 255).astype('uint8'), "L")
212 img2.save("tests/"+letter+"_alpha"+str(alpha)+"_sigma"+str(sigma)+".png") 398 img2.save("tests/"+letter+"_alpha"+str(alpha)+"_sigma"+str(sigma)+".png")
399 '''
213 400
214 def _benchmark(): 401 def _benchmark():
215 img = _load_image("tests/d.png") 402 img = _load_image("tests/d.png")
216 dist = LocalElasticDistorter((32,32), (10,10), 5.0, 5.0) 403 dist = LocalElasticDistorter((32,32))
404 dist.regenerate_parameters(0.0)
217 import time 405 import time
218 t1 = time.time() 406 t1 = time.time()
219 for i in range(10000): 407 for i in range(10000):
220 if i % 1000 == 0: 408 if i % 1000 == 0:
221 print "-" 409 print "-"
223 t2 = time.time() 411 t2 = time.time()
224 print "t2-t1", t2-t1 412 print "t2-t1", t2-t1
225 print "avg", 10000/(t2-t1) 413 print "avg", 10000/(t2-t1)
226 414
227 if __name__ == '__main__': 415 if __name__ == '__main__':
416 import time
417 import pylab
228 import Image 418 import Image
229 _distorter_tests() 419 import os.path
420 #_distorter_tests()
230 #_benchmark() 421 #_benchmark()
231 #_specific_test() 422 #_specific_test()
232 423 #_complexity_tests()
233 424 _complexity_benchmark()
425
426