Mercurial > pylearn
diff _test_dataset.py @ 376:c9a89be5cb0a
Redesigning linear_regression
author | Yoshua Bengio <bengioy@iro.umontreal.ca> |
---|---|
date | Mon, 07 Jul 2008 10:08:35 -0400 |
parents | 18702ceb2096 |
children | 82da179d95b2 |
line wrap: on
line diff
--- a/_test_dataset.py Mon Jun 16 17:47:36 2008 -0400 +++ b/_test_dataset.py Mon Jul 07 10:08:35 2008 -0400 @@ -2,7 +2,7 @@ from dataset import * from math import * import numpy, unittest, sys -from misc import * +#from misc import * from lookup_list import LookupList def have_raised(to_eval, **var): @@ -134,12 +134,13 @@ # - for minibatch in dataset.minibatches([field1, field2, ...],minibatch_size=N): i=0 mi=0 - m=ds.minibatches(['x','z'], minibatch_size=3) - assert isinstance(m,DataSet.MinibatchWrapAroundIterator) + size=3 + m=ds.minibatches(['x','z'], minibatch_size=size) + assert hasattr(m,'__iter__') for minibatch in m: - assert isinstance(minibatch,DataSetFields) + assert isinstance(minibatch,LookupList) assert len(minibatch)==2 - test_minibatch_size(minibatch,m.minibatch_size,len(ds),2,mi) + test_minibatch_size(minibatch,size,len(ds),2,mi) if type(ds)==ArrayDataSet: assert (minibatch[0][:,::2]==minibatch[1]).all() else: @@ -147,92 +148,103 @@ (minibatch[0][j][::2]==minibatch[1][j]).all() mi+=1 i+=len(minibatch[0]) - assert i==len(ds) - assert mi==4 - del minibatch,i,m,mi + assert i==(len(ds)/size)*size + assert mi==(len(ds)/size) + del minibatch,i,m,mi,size i=0 mi=0 - m=ds.minibatches(['x','y'], minibatch_size=3) - assert isinstance(m,DataSet.MinibatchWrapAroundIterator) + size=3 + m=ds.minibatches(['x','y'], minibatch_size=size) + assert hasattr(m,'__iter__') for minibatch in m: + assert isinstance(minibatch,LookupList) assert len(minibatch)==2 - test_minibatch_size(minibatch,m.minibatch_size,len(ds),2,mi) + test_minibatch_size(minibatch,size,len(ds),2,mi) mi+=1 for id in range(len(minibatch[0])): assert (numpy.append(minibatch[0][id],minibatch[1][id])==array[i]).all() i+=1 - assert i==len(ds) - assert mi==4 - del minibatch,i,id,m,mi + assert i==(len(ds)/size)*size + assert mi==(len(ds)/size) + del minibatch,i,id,m,mi,size # - for mini1,mini2,mini3 in dataset.minibatches([field1, field2, field3], minibatch_size=N): i=0 mi=0 - m=ds.minibatches(['x','z'], minibatch_size=3) - assert isinstance(m,DataSet.MinibatchWrapAroundIterator) + size=3 + m=ds.minibatches(['x','z'], minibatch_size=size) + assert hasattr(m,'__iter__') for x,z in m: - test_minibatch_field_size(x,m.minibatch_size,len(ds),mi) - test_minibatch_field_size(z,m.minibatch_size,len(ds),mi) + test_minibatch_field_size(x,size,len(ds),mi) + test_minibatch_field_size(z,size,len(ds),mi) for id in range(len(x)): assert (x[id][::2]==z[id]).all() i+=1 mi+=1 - assert i==len(ds) - assert mi==4 - del x,z,i,m,mi + assert i==(len(ds)/size)*size + assert mi==(len(ds)/size) + del x,z,i,m,mi,size + i=0 mi=0 + size=3 m=ds.minibatches(['x','y'], minibatch_size=3) + assert hasattr(m,'__iter__') for x,y in m: - test_minibatch_field_size(x,m.minibatch_size,len(ds),mi) - test_minibatch_field_size(y,m.minibatch_size,len(ds),mi) + assert len(x)==size + assert len(y)==size + test_minibatch_field_size(x,size,len(ds),mi) + test_minibatch_field_size(y,size,len(ds),mi) mi+=1 for id in range(len(x)): assert (numpy.append(x[id],y[id])==array[i]).all() i+=1 - assert i==len(ds) - assert mi==4 - del x,y,i,id,m,mi + assert i==(len(ds)/size)*size + assert mi==(len(ds)/size) + del x,y,i,id,m,mi,size #not in doc i=0 - m=ds.minibatches(['x','y'],n_batches=1,minibatch_size=3,offset=4) - assert isinstance(m,DataSet.MinibatchWrapAroundIterator) + size=3 + m=ds.minibatches(['x','y'],n_batches=1,minibatch_size=size,offset=4) + assert hasattr(m,'__iter__') for x,y in m: - assert len(x)==m.minibatch_size - assert len(y)==m.minibatch_size - for id in range(m.minibatch_size): + assert len(x)==size + assert len(y)==size + for id in range(size): assert (numpy.append(x[id],y[id])==array[i+4]).all() i+=1 - assert i==m.n_batches*m.minibatch_size - del x,y,i,id,m + assert i==size + del x,y,i,id,m,size i=0 - m=ds.minibatches(['x','y'],n_batches=2,minibatch_size=3,offset=4) - assert isinstance(m,DataSet.MinibatchWrapAroundIterator) + size=3 + m=ds.minibatches(['x','y'],n_batches=2,minibatch_size=size,offset=4) + assert hasattr(m,'__iter__') for x,y in m: - assert len(x)==m.minibatch_size - assert len(y)==m.minibatch_size - for id in range(m.minibatch_size): + assert len(x)==size + assert len(y)==size + for id in range(size): assert (numpy.append(x[id],y[id])==array[i+4]).all() i+=1 - assert i==m.n_batches*m.minibatch_size - del x,y,i,id,m + assert i==2*size + del x,y,i,id,m,size i=0 - m=ds.minibatches(['x','y'],n_batches=20,minibatch_size=3,offset=4) - assert isinstance(m,DataSet.MinibatchWrapAroundIterator) + size=3 + m=ds.minibatches(['x','y'],n_batches=20,minibatch_size=size,offset=4) + assert hasattr(m,'__iter__') for x,y in m: - assert len(x)==m.minibatch_size - assert len(y)==m.minibatch_size - for id in range(m.minibatch_size): + assert len(x)==size + assert len(y)==size + for id in range(size): assert (numpy.append(x[id],y[id])==array[(i+4)%array.shape[0]]).all() i+=1 - assert i==m.n_batches*m.minibatch_size - del x,y,i,id + assert i==2*size # should not wrap + del x,y,i,id,size - assert not have_raised2(ds.minibatches,['x','y'],n_batches=1,minibatch_size=len(array)+1,offset=0) + assert have_raised2(ds.minibatches,['x','y'],n_batches=1,minibatch_size=len(array)+1,offset=0) assert not have_raised2(ds.minibatches,['x','y'],n_batches=1,minibatch_size=len(array),offset=0) def test_ds_iterator(array,iterator1,iterator2,iterator3): @@ -262,14 +274,17 @@ def test_getitem(array,ds): def test_ds(orig,ds,index): i=0 - assert len(ds)==len(index) - for x,z,y in ds('x','z','y'): - assert (orig[index[i]]['x']==array[index[i]][:3]).all() - assert (orig[index[i]]['x']==x).all() - assert orig[index[i]]['y']==array[index[i]][3] - assert (orig[index[i]]['y']==y).all() # why does it crash sometimes? - assert (orig[index[i]]['z']==array[index[i]][0:3:2]).all() - assert (orig[index[i]]['z']==z).all() + assert isinstance(ds,LookupList) + assert len(ds)==3 + assert len(ds[0])==len(index) +# for x,z,y in ds('x','z','y'): + for idx in index: + assert (orig[idx]['x']==array[idx][:3]).all() + assert (orig[idx]['x']==ds['x'][i]).all() + assert orig[idx]['y']==array[idx][3] + assert (orig[idx]['y']==ds['y'][i]).all() # why does it crash sometimes? + assert (orig[idx]['z']==array[idx][0:3:2]).all() + assert (orig[idx]['z']==ds['z'][i]).all() i+=1 del i ds[0] @@ -282,19 +297,22 @@ for x in ds: pass -#ds[:n] returns a dataset with the n first examples. +#ds[:n] returns a LookupList with the n first examples. ds2=ds[:3] - assert isinstance(ds2,LookupList) test_ds(ds,ds2,index=[0,1,2]) del ds2 -#ds[i1:i2:s]# returns a ds with the examples i1,i1+s,...i2-s. - ds2=ds.subset[1:7:2] - assert isinstance(ds2,DataSet) +#ds[i:j] returns a LookupList with examples i,i+1,...,j-1. + ds2=ds[1:3] + test_ds(ds,ds2,index=[1,2]) + del ds2 + +#ds[i1:i2:s] returns a LookupList with the examples i1,i1+s,...i2-s. + ds2=ds[1:7:2] test_ds(ds,ds2,[1,3,5]) del ds2 -#ds[i] +#ds[i] returns the (i+1)-th example of the dataset. ds2=ds[5] assert isinstance(ds2,Example) assert have_raised("var['ds']["+str(len(ds))+"]",ds=ds) # index not defined @@ -302,8 +320,8 @@ del ds2 #ds[[i1,i2,...in]]# returns a ds with examples i1,i2,...in. - ds2=ds.subset[[4,7,2,8]] - assert isinstance(ds2,DataSet) + ds2=ds[[4,7,2,8]] +# assert isinstance(ds2,DataSet) test_ds(ds,ds2,[4,7,2,8]) del ds2 @@ -326,6 +344,71 @@ # del i,example #* ds1 & ds2 & ds3 == ds.vstack([ds1,ds2,ds3])#???? +def test_subset(array,ds): + def test_ds(orig,ds,index): + i=0 + assert isinstance(ds2,DataSet) + assert len(ds)==len(index) + for x,z,y in ds('x','z','y'): + assert (orig[index[i]]['x']==array[index[i]][:3]).all() + assert (orig[index[i]]['x']==x).all() + assert orig[index[i]]['y']==array[index[i]][3] + assert orig[index[i]]['y']==y + assert (orig[index[i]]['z']==array[index[i]][0:3:2]).all() + assert (orig[index[i]]['z']==z).all() + i+=1 + del i + ds[0] + if len(ds)>2: + ds[:1] + ds[1:1] + ds[1:1:1] + if len(ds)>5: + ds[[1,2,3]] + for x in ds: + pass + +#ds[:n] returns a dataset with the n first examples. + ds2=ds.subset[:3] + test_ds(ds,ds2,index=[0,1,2]) +# del ds2 + +#ds[i1:i2:s]# returns a ds with the examples i1,i1+s,...i2-s. + ds2=ds.subset[1:7:2] + test_ds(ds,ds2,[1,3,5]) +# del ds2 + +# #ds[i] +# ds2=ds.subset[5] +# assert isinstance(ds2,Example) +# assert have_raised("var['ds']["+str(len(ds))+"]",ds=ds) # index not defined +# assert not have_raised("var['ds']["+str(len(ds)-1)+"]",ds=ds) +# del ds2 + +#ds[[i1,i2,...in]]# returns a ds with examples i1,i2,...in. + ds2=ds.subset[[4,7,2,8]] + test_ds(ds,ds2,[4,7,2,8]) +# del ds2 + +#ds.<property># returns the value of a property associated with + #the name <property>. The following properties should be supported: + # - 'description': a textual description or name for the ds + # - 'fieldtypes': a list of types (one per field) + +#* ds1 | ds2 | ds3 == ds.hstack([ds1,ds2,ds3])#???? + #assert hstack([ds('x','y'),ds('z')])==ds + #hstack([ds('z','y'),ds('x')])==ds + assert have_raised2(hstack,[ds('x'),ds('x')]) + assert have_raised2(hstack,[ds('y','x'),ds('x')]) + assert not have_raised2(hstack,[ds('x'),ds('y')]) + +# i=0 +# for example in hstack([ds('x'),ds('y'),ds('z')]): +# example==ds[i] +# i+=1 +# del i,example +#* ds1 & ds2 & ds3 == ds.vstack([ds1,ds2,ds3])#???? + def test_fields_fct(ds): #@todo, fill correctly assert len(ds.fields())==3 @@ -455,6 +538,7 @@ test_iterate_over_examples(array, ds) test_overrides(ds) test_getitem(array, ds) + test_subset(array, ds) test_ds_iterator(array,ds('x','y'),ds('y','z'),ds('x','y','z')) test_fields_fct(ds) @@ -515,6 +599,15 @@ del a, ds + def test_RenamedFieldsDataSet(self): + a = numpy.random.rand(10,4) + ds = ArrayDataSet(a,Example(['x1','y1','z1','w1'],[slice(3),3,[0,2],0])) + ds = RenamedFieldsDataSet(ds,['x1','y1','z1'],['x','y','z']) + + test_all(a,ds) + + del a, ds + def test_MinibatchDataSet(self): raise NotImplementedError() def test_HStackedDataSet(self): @@ -570,14 +663,17 @@ res = dsc[:] if __name__=='__main__': - if len(sys.argv)==2: - if sys.argv[1]=="--debug": + tests = [] + debug=False + if len(sys.argv)==1: + unittest.main() + else: + assert sys.argv[1]=="--debug" + for arg in sys.argv[2:]: + tests.append(arg) + if tests: + unittest.TestSuite(map(T_DataSet, tests)).debug() + else: module = __import__("_test_dataset") tests = unittest.TestLoader().loadTestsFromModule(module) tests.debug() - print "bad argument: only --debug is accepted" - elif len(sys.argv)==1: - unittest.main() - else: - print "bad argument: only --debug is accepted" -