view utils/seriestables/test_series.py @ 266:1e4e60ddadb1

Merge. Ah, et dans le dernier commit, j'avais oublié de mentionner que j'ai ajouté du code pour gérer l'isolation de différents clones pour rouler des expériences et modifier le code en même temps.
author fsavard
date Fri, 19 Mar 2010 10:56:16 -0400
parents bfe20d63f88c
children
line wrap: on
line source

import tempfile

import numpy
import numpy.random

from jobman import DD

import tables

from series import *
import series

#################################################
# Utils

def compare_floats(f1,f2):
    if f1-f2 < 1e-3:
        return True
    return False

def compare_lists(it1, it2, floats=False):
    if len(it1) != len(it2):
        return False

    for el1,  el2 in zip(it1, it2):
        if floats:
            if not compare_floats(el1,el2):
                return False
        elif el1 != el2:
            return False

    return True

#################################################
# Basic Series class tests

def test_Series_types():
    pass

#################################################
# ErrorSeries tests

def test_ErrorSeries_common_case(h5f=None):
    if not h5f:
        h5f_path = tempfile.NamedTemporaryFile().name
        h5f = tables.openFile(h5f_path, "w")

    validation_error = series.ErrorSeries(error_name="validation_error", table_name="validation_error",
                                hdf5_file=h5f, index_names=('epoch','minibatch'),
                                title="Validation error indexed by epoch and minibatch")

    # (1,1), (1,2) etc. are (epoch, minibatch) index
    validation_error.append((1,1), 32.0)
    validation_error.append((1,2), 30.0)
    validation_error.append((2,1), 28.0)
    validation_error.append((2,2), 26.0)

    h5f.close()

    h5f = tables.openFile(h5f_path, "r")
    
    table = h5f.getNode('/', 'validation_error')

    assert compare_lists(table.cols.epoch[:], [1,1,2,2])
    assert compare_lists(table.cols.minibatch[:], [1,2,1,2])
    assert compare_lists(table.cols.validation_error[:], [32.0, 30.0, 28.0, 26.0])

def test_ErrorSeries_no_index(h5f=None):
    if not h5f:
        h5f_path = tempfile.NamedTemporaryFile().name
        h5f = tables.openFile(h5f_path, "w")

    validation_error = series.ErrorSeries(error_name="validation_error",
                                table_name="validation_error",
                                hdf5_file=h5f, 
                                # empty tuple
                                index_names=tuple(),
                                title="Validation error with no index")

    # (1,1), (1,2) etc. are (epoch, minibatch) index
    validation_error.append(tuple(), 32.0)
    validation_error.append(tuple(), 30.0)
    validation_error.append(tuple(), 28.0)
    validation_error.append(tuple(), 26.0)

    h5f.close()

    h5f = tables.openFile(h5f_path, "r")
    
    table = h5f.getNode('/', 'validation_error')

    assert compare_lists(table.cols.validation_error[:], [32.0, 30.0, 28.0, 26.0])
    assert not ("epoch" in dir(table.cols))

def test_ErrorSeries_notimestamp(h5f=None):
    if not h5f:
        h5f_path = tempfile.NamedTemporaryFile().name
        h5f = tables.openFile(h5f_path, "w")

    validation_error = series.ErrorSeries(error_name="validation_error", table_name="validation_error",
                                hdf5_file=h5f, index_names=('epoch','minibatch'),
                                title="Validation error indexed by epoch and minibatch", 
                                store_timestamp=False)

    # (1,1), (1,2) etc. are (epoch, minibatch) index
    validation_error.append((1,1), 32.0)

    h5f.close()

    h5f = tables.openFile(h5f_path, "r")
    
    table = h5f.getNode('/', 'validation_error')

    assert compare_lists(table.cols.epoch[:], [1])
    assert not ("timestamp" in dir(table.cols))
    assert "cpuclock" in dir(table.cols)

def test_ErrorSeries_nocpuclock(h5f=None):
    if not h5f:
        h5f_path = tempfile.NamedTemporaryFile().name
        h5f = tables.openFile(h5f_path, "w")

    validation_error = series.ErrorSeries(error_name="validation_error", table_name="validation_error",
                                hdf5_file=h5f, index_names=('epoch','minibatch'),
                                title="Validation error indexed by epoch and minibatch", 
                                store_cpuclock=False)

    # (1,1), (1,2) etc. are (epoch, minibatch) index
    validation_error.append((1,1), 32.0)

    h5f.close()

    h5f = tables.openFile(h5f_path, "r")
    
    table = h5f.getNode('/', 'validation_error')

    assert compare_lists(table.cols.epoch[:], [1])
    assert not ("cpuclock" in dir(table.cols))
    assert "timestamp" in dir(table.cols)

def test_AccumulatorSeriesWrapper_common_case(h5f=None):
    if not h5f:
        h5f_path = tempfile.NamedTemporaryFile().name
        h5f = tables.openFile(h5f_path, "w")

    validation_error = ErrorSeries(error_name="accumulated_validation_error",
                                table_name="accumulated_validation_error",
                                hdf5_file=h5f,
                                index_names=('epoch','minibatch'),
                                title="Validation error, summed every 3 minibatches, indexed by epoch and minibatch")

    accumulator = AccumulatorSeriesWrapper(base_series=validation_error,
                                    reduce_every=3, reduce_function=numpy.sum)

    # (1,1), (1,2) etc. are (epoch, minibatch) index
    accumulator.append((1,1), 32.0)
    accumulator.append((1,2), 30.0)
    accumulator.append((2,1), 28.0)
    accumulator.append((2,2), 26.0)
    accumulator.append((3,1), 24.0)
    accumulator.append((3,2), 22.0)

    h5f.close()

    h5f = tables.openFile(h5f_path, "r")
    
    table = h5f.getNode('/', 'accumulated_validation_error')

    assert compare_lists(table.cols.epoch[:], [2,3])
    assert compare_lists(table.cols.minibatch[:], [1,2])
    assert compare_lists(table.cols.accumulated_validation_error[:], [90.0,72.0], floats=True)

def test_BasicStatisticsSeries_common_case(h5f=None):
    if not h5f:
        h5f_path = tempfile.NamedTemporaryFile().name
        h5f = tables.openFile(h5f_path, "w")

    stats_series = BasicStatisticsSeries(table_name="b_vector_statistics",
                                hdf5_file=h5f, index_names=('epoch','minibatch'),
                                title="Basic statistics for b vector indexed by epoch and minibatch")

    # (1,1), (1,2) etc. are (epoch, minibatch) index
    stats_series.append((1,1), [0.15, 0.20, 0.30])
    stats_series.append((1,2), [-0.18, 0.30, 0.58])
    stats_series.append((2,1), [0.18, -0.38, -0.68])
    stats_series.append((2,2), [0.15, 0.02, 1.9])

    h5f.close()

    h5f = tables.openFile(h5f_path, "r")
    
    table = h5f.getNode('/', 'b_vector_statistics')

    assert compare_lists(table.cols.epoch[:], [1,1,2,2])
    assert compare_lists(table.cols.minibatch[:], [1,2,1,2])
    assert compare_lists(table.cols.mean[:], [0.21666667,  0.23333333, -0.29333332,  0.69], floats=True)
    assert compare_lists(table.cols.min[:], [0.15000001, -0.18000001, -0.68000001,  0.02], floats=True)
    assert compare_lists(table.cols.max[:], [0.30, 0.58, 0.18, 1.9], floats=True)
    assert compare_lists(table.cols.std[:], [0.06236095, 0.31382939,  0.35640177, 0.85724366], floats=True)

def test_SharedParamsStatisticsWrapper_commoncase(h5f=None):
    import numpy.random

    if not h5f:
        h5f_path = tempfile.NamedTemporaryFile().name
        h5f = tables.openFile(h5f_path, "w")

    stats = SharedParamsStatisticsWrapper(new_group_name="params", base_group="/",
                                arrays_names=('b1','b2','b3'), hdf5_file=h5f,
                                index_names=('epoch','minibatch'))

    b1 = DD({'value':numpy.random.rand(5)})
    b2 = DD({'value':numpy.random.rand(5)})
    b3 = DD({'value':numpy.random.rand(5)})
    stats.append((1,1), [b1,b2,b3])

    h5f.close()

    h5f = tables.openFile(h5f_path, "r")

    b1_table = h5f.getNode('/params', 'b1')
    b3_table = h5f.getNode('/params', 'b3')

    assert b1_table.cols.mean[0] - numpy.mean(b1.value) < 1e-3
    assert b3_table.cols.mean[0] - numpy.mean(b3.value) < 1e-3
    assert b1_table.cols.min[0] - numpy.min(b1.value) < 1e-3
    assert b3_table.cols.min[0] - numpy.min(b3.value) < 1e-3

def test_SharedParamsStatisticsWrapper_notimestamp(h5f=None):
    import numpy.random

    if not h5f:
        h5f_path = tempfile.NamedTemporaryFile().name
        h5f = tables.openFile(h5f_path, "w")

    stats = SharedParamsStatisticsWrapper(new_group_name="params", base_group="/",
                                arrays_names=('b1','b2','b3'), hdf5_file=h5f,
                                index_names=('epoch','minibatch'),
                                store_timestamp=False)

    b1 = DD({'value':numpy.random.rand(5)})
    b2 = DD({'value':numpy.random.rand(5)})
    b3 = DD({'value':numpy.random.rand(5)})
    stats.append((1,1), [b1,b2,b3])

    h5f.close()

    h5f = tables.openFile(h5f_path, "r")

    b1_table = h5f.getNode('/params', 'b1')
    b3_table = h5f.getNode('/params', 'b3')

    assert b1_table.cols.mean[0] - numpy.mean(b1.value) < 1e-3
    assert b3_table.cols.mean[0] - numpy.mean(b3.value) < 1e-3
    assert b1_table.cols.min[0] - numpy.min(b1.value) < 1e-3
    assert b3_table.cols.min[0] - numpy.min(b3.value) < 1e-3

    assert not ('timestamp' in dir(b1_table.cols))

def test_get_desc():
    h5f_path = tempfile.NamedTemporaryFile().name
    h5f = tables.openFile(h5f_path, "w")

    desc = series._get_description_with_n_ints_n_floats(("col1","col2"), ("col3","col4"))

    mytable = h5f.createTable('/', 'mytable', desc)

    # just make sure the columns are there... otherwise this will throw an exception
    mytable.cols.col1
    mytable.cols.col2
    mytable.cols.col3
    mytable.cols.col4

    try:
        # this should fail... LocalDescription must be local to get_desc_etc
        test = LocalDescription
        assert False
    except:
        assert True

    assert True

def test_index_to_tuple_floaterror():
    try:
        series._index_to_tuple(5.1)
        assert False
    except TypeError:
        assert True

def test_index_to_tuple_arrayok():
    tpl = series._index_to_tuple([1,2,3])
    assert type(tpl) == tuple and tpl[1] == 2 and tpl[2] == 3

def test_index_to_tuple_intbecomestuple():
    tpl = series._index_to_tuple(32)

    assert type(tpl) == tuple and tpl == (32,)

def test_index_to_tuple_longbecomestuple():
    tpl = series._index_to_tuple(928374928374928L)

    assert type(tpl) == tuple and tpl == (928374928374928L,)

if __name__ == '__main__':
    import tempfile
    test_get_desc()
    test_ErrorSeries_common_case()
    test_BasicStatisticsSeries_common_case()
    test_AccumulatorSeriesWrapper_common_case()
    test_SharedParamsStatisticsWrapper_commoncase()