Mercurial > ift6266
view data_generation/transformations/testmod.py @ 266:1e4e60ddadb1
Merge. Ah, et dans le dernier commit, j'avais oublié de mentionner que j'ai ajouté du code pour gérer l'isolation de différents clones pour rouler des expériences et modifier le code en même temps.
author | fsavard |
---|---|
date | Fri, 19 Mar 2010 10:56:16 -0400 |
parents | 1f5937e9e530 |
children | a9af079892ce |
line wrap: on
line source
# This script is to test your modules to see if they conform to the module API # defined on the wiki. import random, numpy, gc, time, math, sys # this is an example module that does stupid image value shifting class DummyModule(object): def get_settings_names(self): return ['value'] def regenerate_parameters(self, complexity): self._value = random.gauss(0, 0.5*complexity) return [self._value] def transform_image(self, image): return numpy.clip(image+self._value, 0, 1) #import <your module> # instanciate your class here (rather than DummyModule) mod = DummyModule() def error(msg): print "ERROR:", msg sys.exit(1) def warn(msg): print "WARNING:", msg def timeit(f, lbl): gc.disable() t = time.time() f() est = time.time() - t gc.enable() loops = max(1, int(10**math.floor(math.log(10/est, 10)))) gc.disable() t = time.time() for _ in xrange(loops): f() print lbl, "(", loops, "loops ):", (time.time() - t)/loops, "s" gc.enable() ######################## # get_settings_names() # ######################## print "Testing get_settings_names()" names = mod.get_settings_names() if type(names) is not list: error("Must return a list") if not all(type(e) is str for e in names): warn("The elements of the list should be strings") ########################### # regenerate_parameters() # ########################### print "Testing regenerate_parameters()" params = mod.regenerate_parameters(0.2) if type(params) is not list: error("Must return a list") if len(params) != len(names): error("the returned parameter list must have the same length as the number of parameters") params2 = mod.regenerate_parameters(0.2) if len(names) != 0 and params == params2: error("the complexity parameter determines the distribution of the parameters, not their value") mod.regenerate_parameters(0.0) mod.regenerate_parameters(1.0) mod.regenerate_parameters(0.5) ##################### # transform_image() # ##################### print "Testing transform_image()" imgr = numpy.random.random_sample((32, 32)).astype(numpy.float32) img1 = numpy.ones((32, 32), dtype=numpy.float32) img0 = numpy.zeros((32, 32), dtype=numpy.float32) resr = mod.transform_image(imgr) if type(resr) is not numpy.ndarray: error("Must return an ndarray") if resr.shape != (32, 32): error("Must return 32x32 array") if resr.dtype != numpy.float32: error("Must return float32 array") res1 = mod.transform_image(img1) res0 = mod.transform_image(img0) if res1.max() > 1.0 or res0.max() > 1.0: error("Must keep array values between 0 and 1") if res1.min() < 0.0 or res0.min() < 0.0: error("Must keep array values between 0 and 1") mod.regenerate_parameters(0.0) mod.transform_image(imgr) mod.regenerate_parameters(1.0) mod.transform_image(imgr) print "Bonus Stage: timings" timeit(lambda: None, "empty") timeit(lambda: mod.regenerate_parameters(0.5), "regenerate_parameters()") timeit(lambda: mod.transform_image(imgr), "tranform_image()") def f(): mod.regenerate_parameters(0.2) mod.transform_image(imgr) timeit(f, "regen and transform")