view test_dataset.py @ 99:a8da709eb6a9

in ArrayDataSet.__init__ if a columns is an index, we change it to be a list that containt only this index. This way, we remove the special case where the columns is an index for all subsequent call. This was possing trouble with numpy.vstack() called by MinibatchWrapAroundIterator.next
author Frederic Bastien <bastienf@iro.umontreal.ca>
date Tue, 06 May 2008 13:57:36 -0400
parents 352910e0dbf5
children 574f4db76022
line wrap: on
line source

#!/bin/env python
from dataset import *
from math import *
import numpy

def have_raised(to_eval):
    have_thrown = False
    try:
        eval(to_eval)
    except :
        have_thrown = True
    return have_thrown

def test1():
    print "test1"
    global a,ds
    a = numpy.random.rand(10,4)
    print a
    ds = ArrayDataSet(a,{'x':slice(3),'y':3,'z':[0,2]})
    print "len(ds)=",len(ds)
    assert(len(ds)==10)
    print "example 0 = ",ds[0]
#    assert
    print "x=",ds["x"]
    print "x|y"
    for x,y in ds("x","y"):
        print x,y
    minibatch_iterator = ds.minibatches(fieldnames=['z','y'],n_batches=1,minibatch_size=3,offset=4)
    minibatch = minibatch_iterator.__iter__().next()
    print "minibatch=",minibatch
    for var in minibatch:
        print "var=",var
    print "take a slice and look at field y",ds[1:6:2]["y"]

def test_ArrayDataSet():
    #don't test stream
    #tested only with float value
    #test with y too
    #test missing value

    def test_iterate_over_examples(array,ds):
#not in doc!!!
        i=0
        for example in range(len(ds)):
            assert (ds[example]['x']==a[example][:3]).all()
            assert ds[example]['y']==a[example][3]
            assert (ds[example]['z']==a[example][[0,2]]).all()
            i+=1
        assert i==len(ds)
        del example,i

#     - for example in dataset:
        i=0
        for example in ds:
            assert len(example)==3
            assert (example['x']==array[i][:3]).all()
            assert example['y']==array[i][3]
            assert (example['z']==array[i][0:3:2]).all()
            assert (numpy.append(example['x'],example['y'])==array[i]).all()
            i+=1
        assert i==len(ds)
        del example,i

#     - for val1,val2,... in dataset:
        i=0
        for x,y,z in ds:
            assert (x==array[i][:3]).all()
            assert y==array[i][3]
            assert (z==array[i][0:3:2]).all()
            assert (numpy.append(x,y)==array[i]).all()
            i+=1
        assert i==len(ds)
        del x,y,z,i

#     - for example in dataset(field1, field2,field3, ...):
        i=0
        for example in ds('x','y','z'):
            assert len(example)==3
            assert (example['x']==array[i][:3]).all()
            assert example['y']==array[i][3]
            assert (example['z']==array[i][0:3:2]).all()
            assert (numpy.append(example['x'],example['y'])==array[i]).all()
            i+=1
        assert i==len(ds)
        del example,i
        i=0
        for example in ds('y','x'):
            assert len(example)==2
            assert (example['x']==array[i][:3]).all()
            assert example['y']==array[i][3]
            assert (numpy.append(example['x'],example['y'])==array[i]).all()
            i+=1
        assert i==len(ds)
        del example,i

#     - for val1,val2,val3 in dataset(field1, field2,field3):
        i=0
        for x,y,z in ds('x','y','z'):
            assert (x==array[i][:3]).all()
            assert y==array[i][3]
            assert (z==array[i][0:3:2]).all()
            assert (numpy.append(x,y)==array[i]).all()
            i+=1
        assert i==len(ds)
        del x,y,z,i
        i=0
        for y,x in ds('y','x',):
            assert (x==array[i][:3]).all()
            assert y==array[i][3]
            assert (numpy.append(x,y)==array[i]).all()
            i+=1
        assert i==len(ds)
        del x,y,i


#     - for minibatch in dataset.minibatches([field1, field2, ...],minibatch_size=N):
        i=0
        for minibatch in ds.minibatches(['x','z'], minibatch_size=3):
            assert len(minibatch)==2
            assert len(minibatch[0])==3
            assert len(minibatch[1])==3
            assert (minibatch[0][:,0:3:2]==minibatch[1]).all()
            i+=1
        #assert i==#??? What shoud be the value?
        print i
        del minibatch,i
        i=0
        for minibatch in ds.minibatches(['x','y'], minibatch_size=3):
            assert len(minibatch)==2
            assert len(minibatch[0])==3
            assert len(minibatch[1])==3
            for id in range(3):
                assert (numpy.append(minibatch[0][id],minibatch[1][id])==a[i]).all()
                i+=1
        #assert i==#??? What shoud be the value?
        print i
        del minibatch,i,id

#     - for mini1,mini2,mini3 in dataset.minibatches([field1, field2, field3], minibatch_size=N):
        i=0
        for x,z in ds.minibatches(['x','z'], minibatch_size=3):
            assert len(x)==3
            assert len(z)==3
            assert (x[:,0:3:2]==z).all()
            i+=1
        #assert i==#??? What shoud be the value?
        print i 
        del x,z,i
        i=0
        for x,y in ds.minibatches(['x','y'], minibatch_size=3):
            assert len(x)==3
            assert len(y)==3
            for id in range(3):
                assert (numpy.append(x[id],y[id])==a[i]).all()
                i+=1
        #assert i==#??? What shoud be the value?
        print i 
        del x,y,i,id

#not in doc
        i=0
        for x,y in ds.minibatches(['x','y'],n_batches=1,minibatch_size=3,offset=4):
            assert len(x)==3
            assert len(y)==3
            for id in range(3):
                assert (numpy.append(x[id],y[id])==a[i+4]).all()
                i+=1
        assert i==3
        del x,y,i,id

        i=0
        for x,y in ds.minibatches(['x','y'],n_batches=2,minibatch_size=3,offset=4):
            assert len(x)==3
            assert len(y)==3
            for id in range(3):
                assert (numpy.append(x[id],y[id])==a[i+4]).all()
                i+=1
        assert i==6
        del x,y,i,id

        i=0
        for x,y in ds.minibatches(['x','y'],n_batches=10,minibatch_size=3,offset=4):
            assert len(x)==3
            assert len(y)==3
            for id in range(3):
                assert (numpy.append(x[id],y[id])==a[i+4]).all()
                i+=1
        assert i==6
        del x,y,i,id


    def test_ds_iterator(array,iterator1,iterator2,iterator3):
        i=0
        for x,y in iterator1:
            assert (x==array[i][:3]).all()
            assert y==array[i][3]
            assert (numpy.append(x,y)==array[i]).all()
            i+=1
        assert i==len(ds)
        i=0
        for y,z in iterator2:
            assert y==array[i][3]
            assert (z==array[i][0:3:2]).all()
            i+=1
        assert i==len(ds)
        i=0
        for x,y,z in iterator3:
            assert (x==array[i][:3]).all()
            assert y==array[i][3]
            assert (z==array[i][0:3:2]).all()
            assert (numpy.append(x,y)==array[i]).all()
            i+=1
        assert i==len(ds)

    print "test_ArrayDataSet"
    a = numpy.random.rand(10,4)
    ds = ArrayDataSet(a,{'x':slice(3),'y':3,'z':[0,2]})###???tuple not tested
    ds = ArrayDataSet(a,LookupList(['x','y','z'],[slice(3),3,[0,2]]))###???tuple not tested
    assert len(ds)==10
    #assert ds==a? should this work?
    
    test_iterate_over_examples(a, ds)


#     - for val1,val2,val3 in dataset(field1, field2,field3):
    test_ds_iterator(a,ds('x','y'),ds('y','z'),ds('x','y','z'))

    assert have_raised("ds['h']")  # h is not defined...
    assert have_raised("ds[['h']]")  # h is not defined...

    assert len(ds.fields())==3
    for field in ds.fields():
        for field_value in field: # iterate over the values associated to that field for all the ds examples
            pass
    for field in ds('x','z').fields():
        pass
    for field in ds.fields('x','y'):
        pass
    for field_examples in ds.fields():
        for example_value in field_examples:
            pass

    assert ds == ds.fields().examples()

    def test_ds(orig,ds,index):
        i=0
        assert len(ds)==len(index)
        for x,z,y in ds('x','z','y'):
            assert (orig[index[i]]['x']==a[index[i]][:3]).all()
            assert (orig[index[i]]['x']==x).all()
            assert orig[index[i]]['y']==a[index[i]][3]
            assert orig[index[i]]['y']==y
            assert (orig[index[i]]['z']==a[index[i]][0:3:2]).all()
            assert (orig[index[i]]['z']==z).all()
            i+=1
        del i
        ds[0]
        if len(ds)>2:
            ds[:1]
            ds[1:1]
            ds[1:1:1]
        if len(ds)>5:
            ds[[1,2,3]]
        for x in ds:
            pass

    #ds[:n] returns a dataset with the n first examples.
    ds2=ds[:3]
    test_ds(ds,ds2,index=[0,1,2])

    #ds[i1:i2:s]# returns a ds with the examples i1,i1+s,...i2-s.
    ds2=ds[1:7:2]
    test_ds(ds,ds2,[1,3,5])

    #ds[[i1,i2,...in]]# returns a ds with examples i1,i2,...in.
    ds2=ds[[4,7,2,8]]
    test_ds(ds,ds2,[4,7,2,8])
    #ds[i1,i2,...]# should we accept????

    #ds[fieldname]# an iterable over the values of the field fieldname across
      #the ds (the iterable is obtained by default by calling valuesVStack
      #over the values for individual examples).

    #ds.<property># returns the value of a property associated with
      #the name <property>. The following properties should be supported:
      #    - 'description': a textual description or name for the ds
      #    - 'fieldtypes': a list of types (one per field)
    #* ds1 | ds2 | ds3 == ds.hstack([ds1,ds2,ds3])
    #* ds1 & ds2 & ds3 == ds.vstack([ds1,ds2,ds3])

#    for (x,y) in (ds('x','y'),a): #???don't work # haven't found a variant that work.
#        assert numpy.append(x,y)==z

def test_LookupList():
    #test only the example in the doc???
    print "test_LookupList"
    example = LookupList(['x','y','z'],[1,2,3])
    example['x'] = [1, 2, 3] # set or change a field
    x, y, z = example
    x = example[0]
    x = example["x"]
    assert example.keys()==['x','y','z']
    assert example.values()==[[1,2,3],2,3]
    assert example.items()==[('x',[1,2,3]),('y',2),('z',3)]
    example.append_keyval('u',0) # adds item with name 'u' and value 0
    assert len(example)==4 # number of items = 4 here
    example2 = LookupList(['v','w'], ['a','b'])
    example3 = LookupList(['x','y','z','u','v','w'], [[1, 2, 3],2,3,0,'a','b'])
    assert example+example2==example3
    assert have_raised("example+example")

def test_ApplyFunctionDataSet():
    print "test_ApplyFunctionDataSet"
    raise NotImplementedError()
def test_CacheDataSet():
    print "test_CacheDataSet"
    raise NotImplementedError()
def test_FieldsSubsetDataSet():
    print "test_FieldsSubsetDataSet"
    raise NotImplementedError()
def test_DataSetFields():
    print "test_DataSetFields"
    raise NotImplementedError()
def test_MinibatchDataSet():
    print "test_MinibatchDataSet"
    raise NotImplementedError()
def test_HStackedDataSet():
    print "test_HStackedDataSet"
    raise NotImplementedError()
def test_VStackedDataSet():
    print "test_VStackedDataSet"
    raise NotImplementedError()
def test_ArrayFieldsDataSet():
    print "test_ArrayFieldsDataSet"
    raise NotImplementedError()

test1()
test_LookupList()
test_ArrayDataSet()