Mercurial > ift6266
comparison data_generation/transformations/testmod.py @ 167:1f5937e9e530
More moves - transformations into data_generation, added "deep" folder
author | Dumitru Erhan <dumitru.erhan@gmail.com> |
---|---|
date | Fri, 26 Feb 2010 14:15:38 -0500 |
parents | transformations/testmod.py@af2f9252dd14 |
children | a9af079892ce |
comparison
equal
deleted
inserted
replaced
166:17ae5a1a4dd1 | 167:1f5937e9e530 |
---|---|
1 # This script is to test your modules to see if they conform to the module API | |
2 # defined on the wiki. | |
3 import random, numpy, gc, time, math, sys | |
4 | |
5 # this is an example module that does stupid image value shifting | |
6 | |
7 class DummyModule(object): | |
8 def get_settings_names(self): | |
9 return ['value'] | |
10 | |
11 def regenerate_parameters(self, complexity): | |
12 self._value = random.gauss(0, 0.5*complexity) | |
13 return [self._value] | |
14 | |
15 def transform_image(self, image): | |
16 return numpy.clip(image+self._value, 0, 1) | |
17 | |
18 #import <your module> | |
19 | |
20 # instanciate your class here (rather than DummyModule) | |
21 mod = DummyModule() | |
22 | |
23 def error(msg): | |
24 print "ERROR:", msg | |
25 sys.exit(1) | |
26 | |
27 def warn(msg): | |
28 print "WARNING:", msg | |
29 | |
30 def timeit(f, lbl): | |
31 | |
32 gc.disable() | |
33 t = time.time() | |
34 f() | |
35 est = time.time() - t | |
36 gc.enable() | |
37 | |
38 loops = max(1, int(10**math.floor(math.log(10/est, 10)))) | |
39 | |
40 gc.disable() | |
41 t = time.time() | |
42 for _ in xrange(loops): | |
43 f() | |
44 | |
45 print lbl, "(", loops, "loops ):", (time.time() - t)/loops, "s" | |
46 gc.enable() | |
47 | |
48 ######################## | |
49 # get_settings_names() # | |
50 ######################## | |
51 | |
52 print "Testing get_settings_names()" | |
53 | |
54 names = mod.get_settings_names() | |
55 | |
56 if type(names) is not list: | |
57 error("Must return a list") | |
58 | |
59 if not all(type(e) is str for e in names): | |
60 warn("The elements of the list should be strings") | |
61 | |
62 ########################### | |
63 # regenerate_parameters() # | |
64 ########################### | |
65 | |
66 print "Testing regenerate_parameters()" | |
67 | |
68 params = mod.regenerate_parameters(0.2) | |
69 | |
70 if type(params) is not list: | |
71 error("Must return a list") | |
72 | |
73 if len(params) != len(names): | |
74 error("the returned parameter list must have the same length as the number of parameters") | |
75 | |
76 params2 = mod.regenerate_parameters(0.2) | |
77 if len(names) != 0 and params == params2: | |
78 error("the complexity parameter determines the distribution of the parameters, not their value") | |
79 | |
80 mod.regenerate_parameters(0.0) | |
81 mod.regenerate_parameters(1.0) | |
82 | |
83 mod.regenerate_parameters(0.5) | |
84 | |
85 ##################### | |
86 # transform_image() # | |
87 ##################### | |
88 | |
89 print "Testing transform_image()" | |
90 | |
91 imgr = numpy.random.random_sample((32, 32)).astype(numpy.float32) | |
92 img1 = numpy.ones((32, 32), dtype=numpy.float32) | |
93 img0 = numpy.zeros((32, 32), dtype=numpy.float32) | |
94 | |
95 resr = mod.transform_image(imgr) | |
96 | |
97 if type(resr) is not numpy.ndarray: | |
98 error("Must return an ndarray") | |
99 | |
100 if resr.shape != (32, 32): | |
101 error("Must return 32x32 array") | |
102 | |
103 if resr.dtype != numpy.float32: | |
104 error("Must return float32 array") | |
105 | |
106 res1 = mod.transform_image(img1) | |
107 res0 = mod.transform_image(img0) | |
108 | |
109 if res1.max() > 1.0 or res0.max() > 1.0: | |
110 error("Must keep array values between 0 and 1") | |
111 | |
112 if res1.min() < 0.0 or res0.min() < 0.0: | |
113 error("Must keep array values between 0 and 1") | |
114 | |
115 mod.regenerate_parameters(0.0) | |
116 mod.transform_image(imgr) | |
117 mod.regenerate_parameters(1.0) | |
118 mod.transform_image(imgr) | |
119 | |
120 print "Bonus Stage: timings" | |
121 | |
122 timeit(lambda: None, "empty") | |
123 timeit(lambda: mod.regenerate_parameters(0.5), "regenerate_parameters()") | |
124 timeit(lambda: mod.transform_image(imgr), "tranform_image()") | |
125 | |
126 def f(): | |
127 mod.regenerate_parameters(0.2) | |
128 mod.transform_image(imgr) | |
129 | |
130 timeit(f, "regen and transform") |