comparison data_generation/transformations/testmod.py @ 167:1f5937e9e530

More moves - transformations into data_generation, added "deep" folder
author Dumitru Erhan <dumitru.erhan@gmail.com>
date Fri, 26 Feb 2010 14:15:38 -0500
parents transformations/testmod.py@af2f9252dd14
children a9af079892ce
comparison
equal deleted inserted replaced
166:17ae5a1a4dd1 167:1f5937e9e530
1 # This script is to test your modules to see if they conform to the module API
2 # defined on the wiki.
3 import random, numpy, gc, time, math, sys
4
5 # this is an example module that does stupid image value shifting
6
7 class DummyModule(object):
8 def get_settings_names(self):
9 return ['value']
10
11 def regenerate_parameters(self, complexity):
12 self._value = random.gauss(0, 0.5*complexity)
13 return [self._value]
14
15 def transform_image(self, image):
16 return numpy.clip(image+self._value, 0, 1)
17
18 #import <your module>
19
20 # instanciate your class here (rather than DummyModule)
21 mod = DummyModule()
22
23 def error(msg):
24 print "ERROR:", msg
25 sys.exit(1)
26
27 def warn(msg):
28 print "WARNING:", msg
29
30 def timeit(f, lbl):
31
32 gc.disable()
33 t = time.time()
34 f()
35 est = time.time() - t
36 gc.enable()
37
38 loops = max(1, int(10**math.floor(math.log(10/est, 10))))
39
40 gc.disable()
41 t = time.time()
42 for _ in xrange(loops):
43 f()
44
45 print lbl, "(", loops, "loops ):", (time.time() - t)/loops, "s"
46 gc.enable()
47
48 ########################
49 # get_settings_names() #
50 ########################
51
52 print "Testing get_settings_names()"
53
54 names = mod.get_settings_names()
55
56 if type(names) is not list:
57 error("Must return a list")
58
59 if not all(type(e) is str for e in names):
60 warn("The elements of the list should be strings")
61
62 ###########################
63 # regenerate_parameters() #
64 ###########################
65
66 print "Testing regenerate_parameters()"
67
68 params = mod.regenerate_parameters(0.2)
69
70 if type(params) is not list:
71 error("Must return a list")
72
73 if len(params) != len(names):
74 error("the returned parameter list must have the same length as the number of parameters")
75
76 params2 = mod.regenerate_parameters(0.2)
77 if len(names) != 0 and params == params2:
78 error("the complexity parameter determines the distribution of the parameters, not their value")
79
80 mod.regenerate_parameters(0.0)
81 mod.regenerate_parameters(1.0)
82
83 mod.regenerate_parameters(0.5)
84
85 #####################
86 # transform_image() #
87 #####################
88
89 print "Testing transform_image()"
90
91 imgr = numpy.random.random_sample((32, 32)).astype(numpy.float32)
92 img1 = numpy.ones((32, 32), dtype=numpy.float32)
93 img0 = numpy.zeros((32, 32), dtype=numpy.float32)
94
95 resr = mod.transform_image(imgr)
96
97 if type(resr) is not numpy.ndarray:
98 error("Must return an ndarray")
99
100 if resr.shape != (32, 32):
101 error("Must return 32x32 array")
102
103 if resr.dtype != numpy.float32:
104 error("Must return float32 array")
105
106 res1 = mod.transform_image(img1)
107 res0 = mod.transform_image(img0)
108
109 if res1.max() > 1.0 or res0.max() > 1.0:
110 error("Must keep array values between 0 and 1")
111
112 if res1.min() < 0.0 or res0.min() < 0.0:
113 error("Must keep array values between 0 and 1")
114
115 mod.regenerate_parameters(0.0)
116 mod.transform_image(imgr)
117 mod.regenerate_parameters(1.0)
118 mod.transform_image(imgr)
119
120 print "Bonus Stage: timings"
121
122 timeit(lambda: None, "empty")
123 timeit(lambda: mod.regenerate_parameters(0.5), "regenerate_parameters()")
124 timeit(lambda: mod.transform_image(imgr), "tranform_image()")
125
126 def f():
127 mod.regenerate_parameters(0.2)
128 mod.transform_image(imgr)
129
130 timeit(f, "regen and transform")