view data_generation/transformations/testmod.py @ 204:e1f5f66dd7dd

Changé le coût de reconstruction pour stabilité numérique, en ajoutant une petite constante dans le log.
author fsavard
date Thu, 04 Mar 2010 08:18:42 -0500
parents 1f5937e9e530
children a9af079892ce
line wrap: on
line source

# This script is to test your modules to see if they conform to the module API
# defined on the wiki.
import random, numpy, gc, time, math, sys

# this is an example module that does stupid image value shifting

class DummyModule(object):
    def get_settings_names(self):
        return ['value']
    
    def regenerate_parameters(self, complexity):
        self._value = random.gauss(0, 0.5*complexity)
        return [self._value]

    def transform_image(self, image):
        return numpy.clip(image+self._value, 0, 1)
    
#import <your module>

# instanciate your class here (rather than DummyModule)
mod = DummyModule()

def error(msg):
    print "ERROR:", msg
    sys.exit(1)

def warn(msg):
    print "WARNING:", msg

def timeit(f, lbl):

    gc.disable()
    t = time.time()
    f()
    est = time.time() - t
    gc.enable()

    loops = max(1, int(10**math.floor(math.log(10/est, 10))))

    gc.disable()
    t = time.time()
    for _ in xrange(loops):
        f()

    print lbl, "(", loops, "loops ):", (time.time() - t)/loops, "s"
    gc.enable()

########################
# get_settings_names() #
########################

print "Testing get_settings_names()"

names = mod.get_settings_names()

if type(names) is not list:
    error("Must return a list")

if not all(type(e) is str for e in names):
    warn("The elements of the list should be strings")

###########################
# regenerate_parameters() #
###########################

print "Testing regenerate_parameters()"

params = mod.regenerate_parameters(0.2)

if type(params) is not list:
    error("Must return a list")

if len(params) != len(names):
    error("the returned parameter list must have the same length as the number of parameters")

params2 = mod.regenerate_parameters(0.2)
if len(names) != 0 and params == params2:
    error("the complexity parameter determines the distribution of the parameters, not their value")

mod.regenerate_parameters(0.0)
mod.regenerate_parameters(1.0)
    
mod.regenerate_parameters(0.5)

#####################
# transform_image() #
#####################

print "Testing transform_image()"

imgr = numpy.random.random_sample((32, 32)).astype(numpy.float32)
img1 = numpy.ones((32, 32), dtype=numpy.float32)
img0 = numpy.zeros((32, 32), dtype=numpy.float32)

resr = mod.transform_image(imgr)

if type(resr) is not numpy.ndarray:
    error("Must return an ndarray")

if resr.shape != (32, 32):
    error("Must return 32x32 array")

if resr.dtype != numpy.float32:
    error("Must return float32 array")

res1 = mod.transform_image(img1)
res0 = mod.transform_image(img0)

if res1.max() > 1.0 or res0.max() > 1.0:
    error("Must keep array values between 0 and 1")

if res1.min() < 0.0 or res0.min() < 0.0:
    error("Must keep array values between 0 and 1")

mod.regenerate_parameters(0.0)
mod.transform_image(imgr)
mod.regenerate_parameters(1.0)
mod.transform_image(imgr)

print "Bonus Stage: timings"

timeit(lambda: None, "empty")
timeit(lambda: mod.regenerate_parameters(0.5), "regenerate_parameters()")
timeit(lambda: mod.transform_image(imgr), "tranform_image()")

def f():
    mod.regenerate_parameters(0.2)
    mod.transform_image(imgr)

timeit(f, "regen and transform")