Mercurial > pylearn
diff test_dataset.py @ 91:eee739fefdff
corrected test from discution about the syntax with Yoshua
author | Frederic Bastien <bastienf@iro.umontreal.ca> |
---|---|
date | Mon, 05 May 2008 17:13:53 -0400 |
parents | 05dc4804357b |
children | 9c8f3c9c247b |
line wrap: on
line diff
--- a/test_dataset.py Mon May 05 17:13:07 2008 -0400 +++ b/test_dataset.py Mon May 05 17:13:53 2008 -0400 @@ -41,7 +41,7 @@ print "test_ArrayDataSet" a = numpy.random.rand(10,4) ds = ArrayDataSet(a,{'x':slice(3),'y':3,'z':[0,2]})###???tuple not tested - + ds = ArrayDataSet(a,LookupList(['x','y','z'],[slice(3),3,[0,2]]))###???tuple not tested assert len(ds)==10 #assert ds==a? should this work? @@ -50,7 +50,8 @@ assert ds[example]['x'].all()==a[example][:2].all() assert ds[example]['y']==a[example][3] assert ds[example]['z'].all()==a[example][0:3:2].all() -# - for example in dataset:: + +# - for example in dataset: i=0 for example in ds: assert example['x'].all()==a[i][:2].all() @@ -59,6 +60,28 @@ assert numpy.append(example['x'],example['y']).all()==a[i].all() i+=1 assert i==len(ds) +# - for val1,val2,... in dataset: + i=0 + for x,y,z in ds: + assert x.all()==a[i][:2].all() + assert y==a[i][3] + assert z.all()==a[i][0:3:2].all() + assert numpy.append(example['x'],example['y']).all()==a[i].all() + i+=1 + assert i==len(ds) +# - for example in dataset(field1, field2,field3, ...): + i=0 + for example in ds('x','y','z'): + assert example['x'].all()==a[i][:2].all() + assert example['y']==a[i][3] + assert example['z'].all()==a[i][0:3:2].all() + assert numpy.append(example['x'],example['y']).all()==a[i].all() + i+=1 + assert i==len(ds) + +# - for val1,val2,val3 in dataset(field1, field2,field3): + +# - for example in dataset(field1, field2,field3, ...): def test_ds_iterator(iterator1,iterator2,iterator3): i=0 @@ -86,12 +109,6 @@ #not in doc!!! - for val1,val2,val3 in dataset(field1, field2,field3): test_ds_iterator(ds('x','y'),ds('y','z'),ds('x','y','z')) -#not in doc!!! - for val1,val2,val3 in dataset((field1, field2,field3)): - test_ds_iterator(ds(('x','y')),ds(('y','z')),ds(('x','y','z'))) - -# - for val1,val2,val3 in dataset([field1, field2,field3]): #was bugged - test_ds_iterator(ds(['x','y']),ds(['y','z']),ds(['x','y','z'])) - # - for minibatch in dataset.minibatches([field1, field2, ...],minibatch_size=N): for minibatch in ds.minibatches(['x','z'], minibatch_size=3): assert len(minibatch)==2