Mercurial > pylearn
view test_dataset.py @ 99:a8da709eb6a9
in ArrayDataSet.__init__ if a columns is an index, we change it to be a list that containt only this index. This way, we remove the special case where the columns is an index for all subsequent call.
This was possing trouble with numpy.vstack() called by MinibatchWrapAroundIterator.next
author | Frederic Bastien <bastienf@iro.umontreal.ca> |
---|---|
date | Tue, 06 May 2008 13:57:36 -0400 |
parents | 352910e0dbf5 |
children | 574f4db76022 |
line wrap: on
line source
#!/bin/env python from dataset import * from math import * import numpy def have_raised(to_eval): have_thrown = False try: eval(to_eval) except : have_thrown = True return have_thrown def test1(): print "test1" global a,ds a = numpy.random.rand(10,4) print a ds = ArrayDataSet(a,{'x':slice(3),'y':3,'z':[0,2]}) print "len(ds)=",len(ds) assert(len(ds)==10) print "example 0 = ",ds[0] # assert print "x=",ds["x"] print "x|y" for x,y in ds("x","y"): print x,y minibatch_iterator = ds.minibatches(fieldnames=['z','y'],n_batches=1,minibatch_size=3,offset=4) minibatch = minibatch_iterator.__iter__().next() print "minibatch=",minibatch for var in minibatch: print "var=",var print "take a slice and look at field y",ds[1:6:2]["y"] def test_ArrayDataSet(): #don't test stream #tested only with float value #test with y too #test missing value def test_iterate_over_examples(array,ds): #not in doc!!! i=0 for example in range(len(ds)): assert (ds[example]['x']==a[example][:3]).all() assert ds[example]['y']==a[example][3] assert (ds[example]['z']==a[example][[0,2]]).all() i+=1 assert i==len(ds) del example,i # - for example in dataset: i=0 for example in ds: assert len(example)==3 assert (example['x']==array[i][:3]).all() assert example['y']==array[i][3] assert (example['z']==array[i][0:3:2]).all() assert (numpy.append(example['x'],example['y'])==array[i]).all() i+=1 assert i==len(ds) del example,i # - for val1,val2,... in dataset: i=0 for x,y,z in ds: assert (x==array[i][:3]).all() assert y==array[i][3] assert (z==array[i][0:3:2]).all() assert (numpy.append(x,y)==array[i]).all() i+=1 assert i==len(ds) del x,y,z,i # - for example in dataset(field1, field2,field3, ...): i=0 for example in ds('x','y','z'): assert len(example)==3 assert (example['x']==array[i][:3]).all() assert example['y']==array[i][3] assert (example['z']==array[i][0:3:2]).all() assert (numpy.append(example['x'],example['y'])==array[i]).all() i+=1 assert i==len(ds) del example,i i=0 for example in ds('y','x'): assert len(example)==2 assert (example['x']==array[i][:3]).all() assert example['y']==array[i][3] assert (numpy.append(example['x'],example['y'])==array[i]).all() i+=1 assert i==len(ds) del example,i # - for val1,val2,val3 in dataset(field1, field2,field3): i=0 for x,y,z in ds('x','y','z'): assert (x==array[i][:3]).all() assert y==array[i][3] assert (z==array[i][0:3:2]).all() assert (numpy.append(x,y)==array[i]).all() i+=1 assert i==len(ds) del x,y,z,i i=0 for y,x in ds('y','x',): assert (x==array[i][:3]).all() assert y==array[i][3] assert (numpy.append(x,y)==array[i]).all() i+=1 assert i==len(ds) del x,y,i # - for minibatch in dataset.minibatches([field1, field2, ...],minibatch_size=N): i=0 for minibatch in ds.minibatches(['x','z'], minibatch_size=3): assert len(minibatch)==2 assert len(minibatch[0])==3 assert len(minibatch[1])==3 assert (minibatch[0][:,0:3:2]==minibatch[1]).all() i+=1 #assert i==#??? What shoud be the value? print i del minibatch,i i=0 for minibatch in ds.minibatches(['x','y'], minibatch_size=3): assert len(minibatch)==2 assert len(minibatch[0])==3 assert len(minibatch[1])==3 for id in range(3): assert (numpy.append(minibatch[0][id],minibatch[1][id])==a[i]).all() i+=1 #assert i==#??? What shoud be the value? print i del minibatch,i,id # - for mini1,mini2,mini3 in dataset.minibatches([field1, field2, field3], minibatch_size=N): i=0 for x,z in ds.minibatches(['x','z'], minibatch_size=3): assert len(x)==3 assert len(z)==3 assert (x[:,0:3:2]==z).all() i+=1 #assert i==#??? What shoud be the value? print i del x,z,i i=0 for x,y in ds.minibatches(['x','y'], minibatch_size=3): assert len(x)==3 assert len(y)==3 for id in range(3): assert (numpy.append(x[id],y[id])==a[i]).all() i+=1 #assert i==#??? What shoud be the value? print i del x,y,i,id #not in doc i=0 for x,y in ds.minibatches(['x','y'],n_batches=1,minibatch_size=3,offset=4): assert len(x)==3 assert len(y)==3 for id in range(3): assert (numpy.append(x[id],y[id])==a[i+4]).all() i+=1 assert i==3 del x,y,i,id i=0 for x,y in ds.minibatches(['x','y'],n_batches=2,minibatch_size=3,offset=4): assert len(x)==3 assert len(y)==3 for id in range(3): assert (numpy.append(x[id],y[id])==a[i+4]).all() i+=1 assert i==6 del x,y,i,id i=0 for x,y in ds.minibatches(['x','y'],n_batches=10,minibatch_size=3,offset=4): assert len(x)==3 assert len(y)==3 for id in range(3): assert (numpy.append(x[id],y[id])==a[i+4]).all() i+=1 assert i==6 del x,y,i,id def test_ds_iterator(array,iterator1,iterator2,iterator3): i=0 for x,y in iterator1: assert (x==array[i][:3]).all() assert y==array[i][3] assert (numpy.append(x,y)==array[i]).all() i+=1 assert i==len(ds) i=0 for y,z in iterator2: assert y==array[i][3] assert (z==array[i][0:3:2]).all() i+=1 assert i==len(ds) i=0 for x,y,z in iterator3: assert (x==array[i][:3]).all() assert y==array[i][3] assert (z==array[i][0:3:2]).all() assert (numpy.append(x,y)==array[i]).all() i+=1 assert i==len(ds) print "test_ArrayDataSet" a = numpy.random.rand(10,4) ds = ArrayDataSet(a,{'x':slice(3),'y':3,'z':[0,2]})###???tuple not tested ds = ArrayDataSet(a,LookupList(['x','y','z'],[slice(3),3,[0,2]]))###???tuple not tested assert len(ds)==10 #assert ds==a? should this work? test_iterate_over_examples(a, ds) # - for val1,val2,val3 in dataset(field1, field2,field3): test_ds_iterator(a,ds('x','y'),ds('y','z'),ds('x','y','z')) assert have_raised("ds['h']") # h is not defined... assert have_raised("ds[['h']]") # h is not defined... assert len(ds.fields())==3 for field in ds.fields(): for field_value in field: # iterate over the values associated to that field for all the ds examples pass for field in ds('x','z').fields(): pass for field in ds.fields('x','y'): pass for field_examples in ds.fields(): for example_value in field_examples: pass assert ds == ds.fields().examples() def test_ds(orig,ds,index): i=0 assert len(ds)==len(index) for x,z,y in ds('x','z','y'): assert (orig[index[i]]['x']==a[index[i]][:3]).all() assert (orig[index[i]]['x']==x).all() assert orig[index[i]]['y']==a[index[i]][3] assert orig[index[i]]['y']==y assert (orig[index[i]]['z']==a[index[i]][0:3:2]).all() assert (orig[index[i]]['z']==z).all() i+=1 del i ds[0] if len(ds)>2: ds[:1] ds[1:1] ds[1:1:1] if len(ds)>5: ds[[1,2,3]] for x in ds: pass #ds[:n] returns a dataset with the n first examples. ds2=ds[:3] test_ds(ds,ds2,index=[0,1,2]) #ds[i1:i2:s]# returns a ds with the examples i1,i1+s,...i2-s. ds2=ds[1:7:2] test_ds(ds,ds2,[1,3,5]) #ds[[i1,i2,...in]]# returns a ds with examples i1,i2,...in. ds2=ds[[4,7,2,8]] test_ds(ds,ds2,[4,7,2,8]) #ds[i1,i2,...]# should we accept???? #ds[fieldname]# an iterable over the values of the field fieldname across #the ds (the iterable is obtained by default by calling valuesVStack #over the values for individual examples). #ds.<property># returns the value of a property associated with #the name <property>. The following properties should be supported: # - 'description': a textual description or name for the ds # - 'fieldtypes': a list of types (one per field) #* ds1 | ds2 | ds3 == ds.hstack([ds1,ds2,ds3]) #* ds1 & ds2 & ds3 == ds.vstack([ds1,ds2,ds3]) # for (x,y) in (ds('x','y'),a): #???don't work # haven't found a variant that work. # assert numpy.append(x,y)==z def test_LookupList(): #test only the example in the doc??? print "test_LookupList" example = LookupList(['x','y','z'],[1,2,3]) example['x'] = [1, 2, 3] # set or change a field x, y, z = example x = example[0] x = example["x"] assert example.keys()==['x','y','z'] assert example.values()==[[1,2,3],2,3] assert example.items()==[('x',[1,2,3]),('y',2),('z',3)] example.append_keyval('u',0) # adds item with name 'u' and value 0 assert len(example)==4 # number of items = 4 here example2 = LookupList(['v','w'], ['a','b']) example3 = LookupList(['x','y','z','u','v','w'], [[1, 2, 3],2,3,0,'a','b']) assert example+example2==example3 assert have_raised("example+example") def test_ApplyFunctionDataSet(): print "test_ApplyFunctionDataSet" raise NotImplementedError() def test_CacheDataSet(): print "test_CacheDataSet" raise NotImplementedError() def test_FieldsSubsetDataSet(): print "test_FieldsSubsetDataSet" raise NotImplementedError() def test_DataSetFields(): print "test_DataSetFields" raise NotImplementedError() def test_MinibatchDataSet(): print "test_MinibatchDataSet" raise NotImplementedError() def test_HStackedDataSet(): print "test_HStackedDataSet" raise NotImplementedError() def test_VStackedDataSet(): print "test_VStackedDataSet" raise NotImplementedError() def test_ArrayFieldsDataSet(): print "test_ArrayFieldsDataSet" raise NotImplementedError() test1() test_LookupList() test_ArrayDataSet()