view data_generation/transformations/testmod.py @ 266:1e4e60ddadb1

Merge. Ah, et dans le dernier commit, j'avais oublié de mentionner que j'ai ajouté du code pour gérer l'isolation de différents clones pour rouler des expériences et modifier le code en même temps.
author fsavard
date Fri, 19 Mar 2010 10:56:16 -0400
parents 1f5937e9e530
children a9af079892ce
line wrap: on
line source

# This script is to test your modules to see if they conform to the module API
# defined on the wiki.
import random, numpy, gc, time, math, sys

# this is an example module that does stupid image value shifting

class DummyModule(object):
    def get_settings_names(self):
        return ['value']
    
    def regenerate_parameters(self, complexity):
        self._value = random.gauss(0, 0.5*complexity)
        return [self._value]

    def transform_image(self, image):
        return numpy.clip(image+self._value, 0, 1)
    
#import <your module>

# instanciate your class here (rather than DummyModule)
mod = DummyModule()

def error(msg):
    print "ERROR:", msg
    sys.exit(1)

def warn(msg):
    print "WARNING:", msg

def timeit(f, lbl):

    gc.disable()
    t = time.time()
    f()
    est = time.time() - t
    gc.enable()

    loops = max(1, int(10**math.floor(math.log(10/est, 10))))

    gc.disable()
    t = time.time()
    for _ in xrange(loops):
        f()

    print lbl, "(", loops, "loops ):", (time.time() - t)/loops, "s"
    gc.enable()

########################
# get_settings_names() #
########################

print "Testing get_settings_names()"

names = mod.get_settings_names()

if type(names) is not list:
    error("Must return a list")

if not all(type(e) is str for e in names):
    warn("The elements of the list should be strings")

###########################
# regenerate_parameters() #
###########################

print "Testing regenerate_parameters()"

params = mod.regenerate_parameters(0.2)

if type(params) is not list:
    error("Must return a list")

if len(params) != len(names):
    error("the returned parameter list must have the same length as the number of parameters")

params2 = mod.regenerate_parameters(0.2)
if len(names) != 0 and params == params2:
    error("the complexity parameter determines the distribution of the parameters, not their value")

mod.regenerate_parameters(0.0)
mod.regenerate_parameters(1.0)
    
mod.regenerate_parameters(0.5)

#####################
# transform_image() #
#####################

print "Testing transform_image()"

imgr = numpy.random.random_sample((32, 32)).astype(numpy.float32)
img1 = numpy.ones((32, 32), dtype=numpy.float32)
img0 = numpy.zeros((32, 32), dtype=numpy.float32)

resr = mod.transform_image(imgr)

if type(resr) is not numpy.ndarray:
    error("Must return an ndarray")

if resr.shape != (32, 32):
    error("Must return 32x32 array")

if resr.dtype != numpy.float32:
    error("Must return float32 array")

res1 = mod.transform_image(img1)
res0 = mod.transform_image(img0)

if res1.max() > 1.0 or res0.max() > 1.0:
    error("Must keep array values between 0 and 1")

if res1.min() < 0.0 or res0.min() < 0.0:
    error("Must keep array values between 0 and 1")

mod.regenerate_parameters(0.0)
mod.transform_image(imgr)
mod.regenerate_parameters(1.0)
mod.transform_image(imgr)

print "Bonus Stage: timings"

timeit(lambda: None, "empty")
timeit(lambda: mod.regenerate_parameters(0.5), "regenerate_parameters()")
timeit(lambda: mod.transform_image(imgr), "tranform_image()")

def f():
    mod.regenerate_parameters(0.2)
    mod.transform_image(imgr)

timeit(f, "regen and transform")