view test_dataset.py @ 212:9b57ea8c767f

previous commit was supposed to concern only one file, dataset.py, try to undo my other changes with this commit (nothing was broken though, just useless debugging prints)
author Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
date Wed, 21 May 2008 17:42:20 -0400
parents b9950ae5e54b
children 3c96d23b36ac d7250ee86f72
line wrap: on
line source

#!/bin/env python
from dataset import *
from math import *
import numpy

def have_raised(to_eval, **var):
    have_thrown = False
    try:
        eval(to_eval)
    except :
        have_thrown = True
    return have_thrown

def have_raised2(f, *args, **kwargs):
    have_thrown = False
    try:
        f(*args, **kwargs)
    except :
        have_thrown = True
    return have_thrown

def test1():
    print "test1"
    global a,ds
    a = numpy.random.rand(10,4)
    print a
    ds = ArrayDataSet(a,{'x':slice(3),'y':3,'z':[0,2]})
    print "len(ds)=",len(ds)
    assert(len(ds)==10)
    print "example 0 = ",ds[0]
#    assert
    print "x=",ds["x"]
    print "x|y"
    for x,y in ds("x","y"):
        print x,y
    minibatch_iterator = ds.minibatches(fieldnames=['z','y'],n_batches=1,minibatch_size=3,offset=4)
    minibatch = minibatch_iterator.__iter__().next()
    print "minibatch=",minibatch
    for var in minibatch:
        print "var=",var
    print "take a slice and look at field y",ds[1:6:2]["y"]

    del a,ds,x,y,minibatch_iterator,minibatch,var

def test_iterate_over_examples(array,ds):
#not in doc!!!
    i=0
    for example in range(len(ds)):
        assert (ds[example]['x']==array[example][:3]).all()
        assert ds[example]['y']==array[example][3]
        assert (ds[example]['z']==array[example][[0,2]]).all()
        i+=1
    assert i==len(ds)
    del example,i

#     - for example in dataset:
    i=0
    for example in ds:
        assert len(example)==3
        assert (example['x']==array[i][:3]).all()
        assert example['y']==array[i][3]
        assert (example['z']==array[i][0:3:2]).all()
        assert (numpy.append(example['x'],example['y'])==array[i]).all()
        i+=1
    assert i==len(ds)
    del example,i

#     - for val1,val2,... in dataset:
    i=0
    for x,y,z in ds:
        assert (x==array[i][:3]).all()
        assert y==array[i][3]
        assert (z==array[i][0:3:2]).all()
        assert (numpy.append(x,y)==array[i]).all()
        i+=1
    assert i==len(ds)
    del x,y,z,i

#     - for example in dataset(field1, field2,field3, ...):
    i=0
    for example in ds('x','y','z'):
        assert len(example)==3
        assert (example['x']==array[i][:3]).all()
        assert example['y']==array[i][3]
        assert (example['z']==array[i][0:3:2]).all()
        assert (numpy.append(example['x'],example['y'])==array[i]).all()
        i+=1
    assert i==len(ds)
    del example,i
    i=0
    for example in ds('y','x'):
        assert len(example)==2
        assert (example['x']==array[i][:3]).all()
        assert example['y']==array[i][3]
        assert (numpy.append(example['x'],example['y'])==array[i]).all()
        i+=1
    assert i==len(ds)
    del example,i

#     - for val1,val2,val3 in dataset(field1, field2,field3):
    i=0
    for x,y,z in ds('x','y','z'):
        assert (x==array[i][:3]).all()
        assert y==array[i][3]
        assert (z==array[i][0:3:2]).all()
        assert (numpy.append(x,y)==array[i]).all()
        i+=1
    assert i==len(ds)
    del x,y,z,i
    i=0
    for y,x in ds('y','x',):
        assert (x==array[i][:3]).all()
        assert y==array[i][3]
        assert (numpy.append(x,y)==array[i]).all()
        i+=1
    assert i==len(ds)
    del x,y,i

    def test_minibatch_size(minibatch,minibatch_size,len_ds,nb_field,nb_iter_finished):
        ##full minibatch or the last minibatch
        for idx in range(nb_field):
            test_minibatch_field_size(minibatch[idx],minibatch_size,len_ds,nb_iter_finished)
        del idx
    def test_minibatch_field_size(minibatch_field,minibatch_size,len_ds,nb_iter_finished):
        assert len(minibatch_field)==minibatch_size or ((nb_iter_finished*minibatch_size+len(minibatch_field))==len_ds and len(minibatch_field)<minibatch_size)

#     - for minibatch in dataset.minibatches([field1, field2, ...],minibatch_size=N):
    i=0
    mi=0
    m=ds.minibatches(['x','z'], minibatch_size=3)
    assert isinstance(m,DataSet.MinibatchWrapAroundIterator)
    for minibatch in m:
        assert isinstance(minibatch,DataSetFields)
        assert len(minibatch)==2
        test_minibatch_size(minibatch,m.minibatch_size,len(ds),2,mi)
        if type(ds)==ArrayDataSet:
            assert (minibatch[0][:,::2]==minibatch[1]).all()
        else:
            for j in xrange(len(minibatch[0])):
                (minibatch[0][j][::2]==minibatch[1][j]).all()
        mi+=1
        i+=len(minibatch[0])
    assert i==len(ds)
    assert mi==4
    del minibatch,i,m,mi

    i=0
    mi=0
    m=ds.minibatches(['x','y'], minibatch_size=3)
    assert isinstance(m,DataSet.MinibatchWrapAroundIterator)
    for minibatch in m:
        assert len(minibatch)==2
        test_minibatch_size(minibatch,m.minibatch_size,len(ds),2,mi)
        mi+=1
        for id in range(len(minibatch[0])):
            assert (numpy.append(minibatch[0][id],minibatch[1][id])==array[i]).all()
            i+=1
    assert i==len(ds)
    assert mi==4
    del minibatch,i,id,m,mi

#     - for mini1,mini2,mini3 in dataset.minibatches([field1, field2, field3], minibatch_size=N):
    i=0
    mi=0
    m=ds.minibatches(['x','z'], minibatch_size=3)
    assert isinstance(m,DataSet.MinibatchWrapAroundIterator)
    for x,z in m:
        test_minibatch_field_size(x,m.minibatch_size,len(ds),mi)
        test_minibatch_field_size(z,m.minibatch_size,len(ds),mi)
        for id in range(len(x)):
            assert (x[id][::2]==z[id]).all()
            i+=1
        mi+=1
    assert i==len(ds)
    assert mi==4
    del x,z,i,m,mi
    i=0
    mi=0
    m=ds.minibatches(['x','y'], minibatch_size=3)
    for x,y in m:
        test_minibatch_field_size(x,m.minibatch_size,len(ds),mi)
        test_minibatch_field_size(y,m.minibatch_size,len(ds),mi)
        mi+=1
        for id in range(len(x)):
            assert (numpy.append(x[id],y[id])==array[i]).all()
            i+=1
    assert i==len(ds)
    assert mi==4
    del x,y,i,id,m,mi

#not in doc
    i=0
    m=ds.minibatches(['x','y'],n_batches=1,minibatch_size=3,offset=4)
    assert isinstance(m,DataSet.MinibatchWrapAroundIterator)
    for x,y in m:
        assert len(x)==3
        assert len(y)==3
        for id in range(3):
            assert (numpy.append(x[id],y[id])==array[i+4]).all()
            i+=1
    assert i==3
    del x,y,i,id,m

    i=0
    m=ds.minibatches(['x','y'],n_batches=2,minibatch_size=3,offset=4)
    assert isinstance(m,DataSet.MinibatchWrapAroundIterator)
    for x,y in m:
        assert len(x)==3
        assert len(y)==3
        for id in range(3):
            assert (numpy.append(x[id],y[id])==array[i+4]).all()
            i+=1
    assert i==6
    del x,y,i,id,m

    i=0
    m=ds.minibatches(['x','y'],n_batches=20,minibatch_size=3,offset=4)
    assert isinstance(m,DataSet.MinibatchWrapAroundIterator)
    for x,y in m:
        assert len(x)==3
        assert len(y)==3
        for id in range(3):
            assert (numpy.append(x[id],y[id])==array[(i+4)%array.shape[0]]).all()
            i+=1
    assert i==m.n_batches*m.minibatch_size
    del x,y,i,id


def test_ds_iterator(array,iterator1,iterator2,iterator3):
    l=len(iterator1)
    i=0
    for x,y in iterator1:
        assert (x==array[i][:3]).all()
        assert y==array[i][3]
        assert (numpy.append(x,y)==array[i]).all()
        i+=1
    assert i==l
    i=0
    for y,z in iterator2:
        assert y==array[i][3]
        assert (z==array[i][0:3:2]).all()
        i+=1
    assert i==l
    i=0
    for x,y,z in iterator3:
        assert (x==array[i][:3]).all()
        assert y==array[i][3]
        assert (z==array[i][0:3:2]).all()
        assert (numpy.append(x,y)==array[i]).all()
        i+=1
    assert i==l

def test_getitem(array,ds):
    def test_ds(orig,ds,index):
        i=0
        assert len(ds)==len(index)
        for x,z,y in ds('x','z','y'):
            assert (orig[index[i]]['x']==array[index[i]][:3]).all()
            assert (orig[index[i]]['x']==x).all()
            assert orig[index[i]]['y']==array[index[i]][3]
            assert orig[index[i]]['y']==y
            assert (orig[index[i]]['z']==array[index[i]][0:3:2]).all()
            assert (orig[index[i]]['z']==z).all()
            i+=1
        del i
        ds[0]
        if len(ds)>2:
            ds[:1]
            ds[1:1]
            ds[1:1:1]
        if len(ds)>5:
            ds[[1,2,3]]
        for x in ds:
            pass

#ds[:n] returns a dataset with the n first examples.
    ds2=ds[:3]
    assert isinstance(ds2,DataSet)
    test_ds(ds,ds2,index=[0,1,2])
    del ds2

#ds[i1:i2:s]# returns a ds with the examples i1,i1+s,...i2-s.
    ds2=ds[1:7:2]
    assert isinstance(ds2,DataSet)
    test_ds(ds,ds2,[1,3,5])
    del ds2

#ds[i]
    ds2=ds[5]
    assert isinstance(ds2,Example)
    assert have_raised("var['ds']["+str(len(ds))+"]",ds=ds)  # index not defined
    assert not have_raised("var['ds']["+str(len(ds)-1)+"]",ds=ds)
    del ds2

#ds[[i1,i2,...in]]# returns a ds with examples i1,i2,...in.
    ds2=ds[[4,7,2,8]]
    assert isinstance(ds2,DataSet)
    test_ds(ds,ds2,[4,7,2,8])
    del ds2

#ds[fieldname]# an iterable over the values of the field fieldname across
  #the ds (the iterable is obtained by default by calling valuesVStack
  #over the values for individual examples).
    assert have_raised("ds['h']")  # h is not defined...
    assert have_raised("ds[['x']]")  # bad syntax
    assert not have_raised("var['ds']['x']",ds=ds)
    isinstance(ds['x'],DataSetFields)
    ds2=ds['x']
    assert len(ds['x'])==10
    assert len(ds['y'])==10
    assert len(ds['z'])==10
    i=0
    for example in ds['x']:
        assert (example==array[i][:3]).all()
        i+=1
    assert i==len(ds)
    i=0
    for example in ds['y']:
        assert (example==array[i][3]).all()
        i+=1
    assert i==len(ds)
    i=0
    for example in ds['z']:
        assert (example==array[i,0:3:2]).all()
        i+=1
    assert i==len(ds)
    del ds2,i

#ds.<property># returns the value of a property associated with
  #the name <property>. The following properties should be supported:
  #    - 'description': a textual description or name for the ds
  #    - 'fieldtypes': a list of types (one per field)

#* ds1 | ds2 | ds3 == ds.hstack([ds1,ds2,ds3])#????
    #assert hstack([ds('x','y'),ds('z')])==ds
    #hstack([ds('z','y'),ds('x')])==ds
    assert have_raised2(hstack,[ds('x'),ds('x')])
    assert have_raised2(hstack,[ds('y','x'),ds('x')])
    assert not have_raised2(hstack,[ds('x'),ds('y')])
    
#        i=0
#        for example in hstack([ds('x'),ds('y'),ds('z')]):
#            example==ds[i]
#            i+=1 
#        del i,example
#* ds1 & ds2 & ds3 == ds.vstack([ds1,ds2,ds3])#????

def test_fields_fct(ds):
    #@todo, fill correctly
    assert len(ds.fields())==3
    i=0
    v=0
    for field in ds.fields():
        for field_value in field: # iterate over the values associated to that field for all the ds examples
            v+=1
        i+=1
    assert i==3
    assert v==3*10
    del i,v
    
    i=0
    v=0
    for field in ds('x','z').fields():
        i+=1
        for val in field:
            v+=1
    assert i==2
    assert v==2*10
    del i,v
    
    i=0
    v=0
    for field in ds.fields('x','y'):
        i+=1
        for val in field:
            v+=1
    assert i==2
    assert v==2*10
    del i,v
    
    i=0
    v=0
    for field_examples in ds.fields():
        for example_value in field_examples:
            v+=1
        i+=1
    assert i==3
    assert v==3*10
    del i,v
    
    assert ds == ds.fields().examples()
    assert len(ds('x','y').fields()) == 2
    assert len(ds('x','z').fields()) == 2
    assert len(ds('y').fields()) == 1

    del field
def test_all(array,ds):
    assert len(ds)==10

    test_iterate_over_examples(array, ds)
    test_getitem(array, ds)
    test_ds_iterator(array,ds('x','y'),ds('y','z'),ds('x','y','z'))
    test_fields_fct(ds)

def test_ArrayDataSet():
    #don't test stream
    #tested only with float value
    #don't always test with y
    #don't test missing value
    #don't test with tuple
    #don't test proterties
    print "test_ArrayDataSet"
    a2 = numpy.random.rand(10,4)
    ds = ArrayDataSet(a2,{'x':slice(3),'y':3,'z':[0,2]})###???tuple not tested
    ds = ArrayDataSet(a2,LookupList(['x','y','z'],[slice(3),3,[0,2]]))###???tuple not tested
    #assert ds==a? should this work?

    test_all(a2,ds)

    del a2, ds

def test_LookupList():
    #test only the example in the doc???
    print "test_LookupList"
    example = LookupList(['x','y','z'],[1,2,3])
    example['x'] = [1, 2, 3] # set or change a field
    x, y, z = example
    x = example[0]
    x = example["x"]
    assert example.keys()==['x','y','z']
    assert example.values()==[[1,2,3],2,3]
    assert example.items()==[('x',[1,2,3]),('y',2),('z',3)]
    example.append_keyval('u',0) # adds item with name 'u' and value 0
    assert len(example)==4 # number of items = 4 here
    example2 = LookupList(['v','w'], ['a','b'])
    example3 = LookupList(['x','y','z','u','v','w'], [[1, 2, 3],2,3,0,'a','b'])
    assert example+example2==example3
    assert have_raised("var['x']+var['x']",x=example)

    del example, example2, example3, x, y ,z

def test_CachedDataSet():
    print "test_CacheDataSet"
    a = numpy.random.rand(10,4)
    ds1 = ArrayDataSet(a,LookupList(['x','y','z'],[slice(3),3,[0,2]]))###???tuple not tested
    ds2 = CachedDataSet(ds1)
    ds3 = CachedDataSet(ds1,cache_all_upon_construction=True)

    test_all(a,ds2)
    test_all(a,ds3)

    del a,ds1,ds2,ds3


def test_DataSetFields():
    print "test_DataSetFields"
    raise NotImplementedError()

def test_ApplyFunctionDataSet():
    print "test_ApplyFunctionDataSet"
    a = numpy.random.rand(10,4)
    a2 = a+1
    ds1 = ArrayDataSet(a,LookupList(['x','y','z'],[slice(3),3,[0,2]]))###???tuple not tested

    ds2 = ApplyFunctionDataSet(ds1,lambda x,y,z: (x+1,y+1,z+1), ['x','y','z'],minibatch_mode=False)
    ds3 = ApplyFunctionDataSet(ds1,lambda x,y,z: (numpy.array(x)+1,numpy.array(y)+1,numpy.array(z)+1), ['x','y','z'],minibatch_mode=True)

    test_all(a2,ds2)
    test_all(a2,ds3)

    del a,ds1,ds2,ds3

def test_FieldsSubsetDataSet():
    print "test_FieldsSubsetDataSet"
    raise NotImplementedError()
def test_MinibatchDataSet():
    print "test_MinibatchDataSet"
    raise NotImplementedError()
def test_HStackedDataSet():
    print "test_HStackedDataSet"
    raise NotImplementedError()
def test_VStackedDataSet():
    print "test_VStackedDataSet"
    raise NotImplementedError()
def test_ArrayFieldsDataSet():
    print "test_ArrayFieldsDataSet"
    raise NotImplementedError()
if __name__=='__main__':
    test1()
    test_LookupList()
    test_ArrayDataSet()
    test_CachedDataSet()
    test_ApplyFunctionDataSet()
#test pmat.py