Mercurial > pylearn
changeset 89:05dc4804357b
more test and refactoring
author | Frederic Bastien <bastienf@iro.umontreal.ca> |
---|---|
date | Mon, 05 May 2008 16:54:16 -0400 |
parents | 6749d18e11c8 |
children | a289b8bed64c |
files | test_dataset.py |
diffstat | 1 files changed, 89 insertions(+), 26 deletions(-) [+] |
line wrap: on
line diff
--- a/test_dataset.py Mon May 05 14:51:41 2008 -0400 +++ b/test_dataset.py Mon May 05 16:54:16 2008 -0400 @@ -44,23 +44,88 @@ assert len(ds)==10 #assert ds==a? should this work? - for i in range(len(ds)): - assert ds[i]['x'].all()==a[i][:2].all() - assert ds[i]['y']==a[i][3] - assert ds[i]['z'].all()==a[i][0:3:2].all() + +#not in doc!!! + for example in range(len(ds)): + assert ds[example]['x'].all()==a[example][:2].all() + assert ds[example]['y']==a[example][3] + assert ds[example]['z'].all()==a[example][0:3:2].all() +# - for example in dataset:: i=0 - for x in ds('x','y'): - assert numpy.append(x['x'],x['y']).all()==a[i].all() + for example in ds: + assert example['x'].all()==a[i][:2].all() + assert example['y']==a[i][3] + assert example['z'].all()==a[i][0:3:2].all() + assert numpy.append(example['x'],example['y']).all()==a[i].all() i+=1 + assert i==len(ds) + + def test_ds_iterator(iterator1,iterator2,iterator3): + i=0 + for x,y in iterator1: + assert x.all()==a[i][:2].all() + assert y==a[i][3] + assert numpy.append(x,y).all()==a[i].all() + i+=1 + assert i==len(ds) + i=0 + for y,z in iterator2: + assert y==a[i][3] + assert z.all()==a[i][0:3:2].all() + i+=1 + assert i==len(ds) + i=0 + for x,y,z in iterator3: + assert x.all()==a[i][:2].all() + assert y==a[i][3] + assert z.all()==a[i][0:3:2].all() + assert numpy.append(x,y).all()==a[i].all() + i+=1 + assert i==len(ds) +#not in doc!!! - for val1,val2,val3 in dataset(field1, field2,field3): + test_ds_iterator(ds('x','y'),ds('y','z'),ds('x','y','z')) + +#not in doc!!! - for val1,val2,val3 in dataset((field1, field2,field3)): + test_ds_iterator(ds(('x','y')),ds(('y','z')),ds(('x','y','z'))) + +# - for val1,val2,val3 in dataset([field1, field2,field3]): #was bugged + test_ds_iterator(ds(['x','y']),ds(['y','z']),ds(['x','y','z'])) + +# - for minibatch in dataset.minibatches([field1, field2, ...],minibatch_size=N): + for minibatch in ds.minibatches(['x','z'], minibatch_size=3): + assert len(minibatch)==2 + assert len(minibatch[0])==3 + assert len(minibatch[1])==3 + assert minibatch[0][:,0:3:2].all()==minibatch[1].all() i=0 - for x,y in ds('x','y'): - assert numpy.append(x,y).all()==a[i].all() - i+=1 - for minibatch in ds.minibatches(['x','z'], minibatch_size=3): - assert minibatch[0][:,0:3:2].all()==minibatch[1].all() + for minibatch in ds.minibatches(['x','y'], minibatch_size=3): + assert len(minibatch)==2 + assert len(minibatch[0])==3 + assert len(minibatch[1])==3 + for id in range(3): + assert numpy.append(minibatch[0][id],minibatch[1][id]).all()==a[i].all() + i+=1 + +# - for mini1,mini2,mini3 in dataset.minibatches([field1, field2, field3], minibatch_size=N): for x,z in ds.minibatches(['x','z'], minibatch_size=3): + assert len(x)==3 + assert len(z)==3 assert x[:,0:3:2].all()==z.all() + i=0 + for x,y in ds.minibatches(['x','y'], minibatch_size=3): + assert len(x)==3 + assert len(y)==3 + for id in range(3): + assert numpy.append(minibatch[0][id],minibatch[1][id]).all()==a[i].all() + i+=1 +# - for x,y,z in dataset: # fail x,y,z order not fixed as it is a dict. +# for x,y,z in ds: +# assert x.all()==a[i][:2].all() +# assert y==a[i][3] +# assert z.all()==a[i][0:3:2].all() +# assert numpy.append(x,y).all()==a[i].all() +# i+=1 # for minibatch in ds.minibatches(['z','y'], minibatch_size=3): # print minibatch @@ -99,6 +164,15 @@ assert orig[index[i]]['z'].all()==z.all() i+=1 del i + ds[0] + if len(ds)>2: + ds[:1] + ds[1:1] + ds[1:1:1] + if len(ds)>5: + ds[[1,2,3]] + for x in ds: + pass #ds[:n] returns a dataset with the n first examples. ds2=ds[:3] @@ -106,22 +180,13 @@ #ds[i1:i2:s]# returns a ds with the examples i1,i1+s,...i2-s. ds2=ds[1:7:2] - ds2[1] test_ds(ds,ds2,[1,3,5]) + #ds[[i1,i2,...in]]# returns a ds with examples i1,i2,...in. -# ds2=ds[[4,7,2,8]]# fail??? -# assert len(ds2)==4 -# i=0 -# index=[4,7,2,8] -# for x in ds2: -# assert ds[index[i]]['x'].all()==a[index[i]][:3].all() -# assert ds[index[i]]['x'].all()==x.all() -# assert ds[index[i]]['y']==a[index[i]][3] -# assert ds[index[i]]['y']==y -# assert ds[index[i]]['z'].all()==a[index[i]][0:3:2].all() -# assert ds[index[i]]['z'].all()==z.all() -# i+=1 + ds2=ds[[4,7,2,8]] + test_ds(ds,ds2,[4,7,2,8]) #ds[i1,i2,...]# should we accept???? + #ds[fieldname]# an iterable over the values of the field fieldname across #the ds (the iterable is obtained by default by calling valuesVStack #over the values for individual examples). @@ -183,5 +248,3 @@ test_LookupList() test_ArrayDataSet() - -