diff data_generation/transformations/testmod.py @ 167:1f5937e9e530

More moves - transformations into data_generation, added "deep" folder
author Dumitru Erhan <dumitru.erhan@gmail.com>
date Fri, 26 Feb 2010 14:15:38 -0500
parents transformations/testmod.py@af2f9252dd14
children a9af079892ce
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/data_generation/transformations/testmod.py	Fri Feb 26 14:15:38 2010 -0500
@@ -0,0 +1,130 @@
+# This script is to test your modules to see if they conform to the module API
+# defined on the wiki.
+import random, numpy, gc, time, math, sys
+
+# this is an example module that does stupid image value shifting
+
+class DummyModule(object):
+    def get_settings_names(self):
+        return ['value']
+    
+    def regenerate_parameters(self, complexity):
+        self._value = random.gauss(0, 0.5*complexity)
+        return [self._value]
+
+    def transform_image(self, image):
+        return numpy.clip(image+self._value, 0, 1)
+    
+#import <your module>
+
+# instanciate your class here (rather than DummyModule)
+mod = DummyModule()
+
+def error(msg):
+    print "ERROR:", msg
+    sys.exit(1)
+
+def warn(msg):
+    print "WARNING:", msg
+
+def timeit(f, lbl):
+
+    gc.disable()
+    t = time.time()
+    f()
+    est = time.time() - t
+    gc.enable()
+
+    loops = max(1, int(10**math.floor(math.log(10/est, 10))))
+
+    gc.disable()
+    t = time.time()
+    for _ in xrange(loops):
+        f()
+
+    print lbl, "(", loops, "loops ):", (time.time() - t)/loops, "s"
+    gc.enable()
+
+########################
+# get_settings_names() #
+########################
+
+print "Testing get_settings_names()"
+
+names = mod.get_settings_names()
+
+if type(names) is not list:
+    error("Must return a list")
+
+if not all(type(e) is str for e in names):
+    warn("The elements of the list should be strings")
+
+###########################
+# regenerate_parameters() #
+###########################
+
+print "Testing regenerate_parameters()"
+
+params = mod.regenerate_parameters(0.2)
+
+if type(params) is not list:
+    error("Must return a list")
+
+if len(params) != len(names):
+    error("the returned parameter list must have the same length as the number of parameters")
+
+params2 = mod.regenerate_parameters(0.2)
+if len(names) != 0 and params == params2:
+    error("the complexity parameter determines the distribution of the parameters, not their value")
+
+mod.regenerate_parameters(0.0)
+mod.regenerate_parameters(1.0)
+    
+mod.regenerate_parameters(0.5)
+
+#####################
+# transform_image() #
+#####################
+
+print "Testing transform_image()"
+
+imgr = numpy.random.random_sample((32, 32)).astype(numpy.float32)
+img1 = numpy.ones((32, 32), dtype=numpy.float32)
+img0 = numpy.zeros((32, 32), dtype=numpy.float32)
+
+resr = mod.transform_image(imgr)
+
+if type(resr) is not numpy.ndarray:
+    error("Must return an ndarray")
+
+if resr.shape != (32, 32):
+    error("Must return 32x32 array")
+
+if resr.dtype != numpy.float32:
+    error("Must return float32 array")
+
+res1 = mod.transform_image(img1)
+res0 = mod.transform_image(img0)
+
+if res1.max() > 1.0 or res0.max() > 1.0:
+    error("Must keep array values between 0 and 1")
+
+if res1.min() < 0.0 or res0.min() < 0.0:
+    error("Must keep array values between 0 and 1")
+
+mod.regenerate_parameters(0.0)
+mod.transform_image(imgr)
+mod.regenerate_parameters(1.0)
+mod.transform_image(imgr)
+
+print "Bonus Stage: timings"
+
+timeit(lambda: None, "empty")
+timeit(lambda: mod.regenerate_parameters(0.5), "regenerate_parameters()")
+timeit(lambda: mod.transform_image(imgr), "tranform_image()")
+
+def f():
+    mod.regenerate_parameters(0.2)
+    mod.transform_image(imgr)
+
+timeit(f, "regen and transform")