Mercurial > pylearn
diff test_dataset.py @ 94:9c8f3c9c247b
corrected the use of .all()
author | Frederic Bastien <bastienf@iro.umontreal.ca> |
---|---|
date | Mon, 05 May 2008 17:44:49 -0400 |
parents | eee739fefdff |
children | 352910e0dbf5 |
line wrap: on
line diff
--- a/test_dataset.py Mon May 05 17:13:53 2008 -0400 +++ b/test_dataset.py Mon May 05 17:44:49 2008 -0400 @@ -47,35 +47,35 @@ #not in doc!!! for example in range(len(ds)): - assert ds[example]['x'].all()==a[example][:2].all() + assert (ds[example]['x']==a[example][:3]).all() assert ds[example]['y']==a[example][3] - assert ds[example]['z'].all()==a[example][0:3:2].all() + assert (ds[example]['z']==a[example][[0,2]]).all() # - for example in dataset: i=0 for example in ds: - assert example['x'].all()==a[i][:2].all() + assert (example['x']==a[i][:3]).all() assert example['y']==a[i][3] - assert example['z'].all()==a[i][0:3:2].all() - assert numpy.append(example['x'],example['y']).all()==a[i].all() + assert (example['z']==a[i][0:3:2]).all() + assert (numpy.append(example['x'],example['y'])==a[i]).all() i+=1 assert i==len(ds) # - for val1,val2,... in dataset: i=0 for x,y,z in ds: - assert x.all()==a[i][:2].all() + assert (x==a[i][:3]).all() assert y==a[i][3] - assert z.all()==a[i][0:3:2].all() - assert numpy.append(example['x'],example['y']).all()==a[i].all() + assert (z==a[i][0:3:2]).all() + assert (numpy.append(x,y)==a[i]).all() i+=1 assert i==len(ds) # - for example in dataset(field1, field2,field3, ...): i=0 for example in ds('x','y','z'): - assert example['x'].all()==a[i][:2].all() + assert (example['x']==a[i][:3]).all() assert example['y']==a[i][3] - assert example['z'].all()==a[i][0:3:2].all() - assert numpy.append(example['x'],example['y']).all()==a[i].all() + assert (example['z']==a[i][0:3:2]).all() + assert (numpy.append(example['x'],example['y'])==a[i]).all() i+=1 assert i==len(ds) @@ -86,23 +86,23 @@ def test_ds_iterator(iterator1,iterator2,iterator3): i=0 for x,y in iterator1: - assert x.all()==a[i][:2].all() + assert (x==a[i][:3]).all() assert y==a[i][3] - assert numpy.append(x,y).all()==a[i].all() + assert (numpy.append(x,y)==a[i]).all() i+=1 assert i==len(ds) i=0 for y,z in iterator2: assert y==a[i][3] - assert z.all()==a[i][0:3:2].all() + assert (z==a[i][0:3:2]).all() i+=1 assert i==len(ds) i=0 for x,y,z in iterator3: - assert x.all()==a[i][:2].all() + assert (x==a[i][:3]).all() assert y==a[i][3] - assert z.all()==a[i][0:3:2].all() - assert numpy.append(x,y).all()==a[i].all() + assert (z==a[i][0:3:2]).all() + assert (numpy.append(x,y)==a[i]).all() i+=1 assert i==len(ds) @@ -114,34 +114,34 @@ assert len(minibatch)==2 assert len(minibatch[0])==3 assert len(minibatch[1])==3 - assert minibatch[0][:,0:3:2].all()==minibatch[1].all() + assert (minibatch[0][:,0:3:2]==minibatch[1]).all() i=0 for minibatch in ds.minibatches(['x','y'], minibatch_size=3): assert len(minibatch)==2 assert len(minibatch[0])==3 assert len(minibatch[1])==3 for id in range(3): - assert numpy.append(minibatch[0][id],minibatch[1][id]).all()==a[i].all() + assert (numpy.append(minibatch[0][id],minibatch[1][id])==a[i]).all() i+=1 # - for mini1,mini2,mini3 in dataset.minibatches([field1, field2, field3], minibatch_size=N): for x,z in ds.minibatches(['x','z'], minibatch_size=3): assert len(x)==3 assert len(z)==3 - assert x[:,0:3:2].all()==z.all() + assert (x[:,0:3:2]==z).all() i=0 for x,y in ds.minibatches(['x','y'], minibatch_size=3): assert len(x)==3 assert len(y)==3 for id in range(3): - assert numpy.append(minibatch[0][id],minibatch[1][id]).all()==a[i].all() + assert (numpy.append(x[id],y[id])==a[i]).all() i+=1 # - for x,y,z in dataset: # fail x,y,z order not fixed as it is a dict. # for x,y,z in ds: -# assert x.all()==a[i][:2].all() +# assert (x==a[i][:2]).all() # assert y==a[i][3] -# assert z.all()==a[i][0:3:2].all() -# assert numpy.append(x,y).all()==a[i].all() +# assert (z==a[i][0:3:2]).all() +# assert (numpy.append(x,y)==a[i]).all() # i+=1 # for minibatch in ds.minibatches(['z','y'], minibatch_size=3): @@ -173,12 +173,12 @@ i=0 assert len(ds)==len(index) for x,z,y in ds('x','z','y'): - assert orig[index[i]]['x'].all()==a[index[i]][:3].all() - assert orig[index[i]]['x'].all()==x.all() + assert (orig[index[i]]['x']==a[index[i]][:3]).all() + assert (orig[index[i]]['x']==x).all() assert orig[index[i]]['y']==a[index[i]][3] assert orig[index[i]]['y']==y - assert orig[index[i]]['z'].all()==a[index[i]][0:3:2].all() - assert orig[index[i]]['z'].all()==z.all() + assert (orig[index[i]]['z']==a[index[i]][0:3:2]).all() + assert (orig[index[i]]['z']==z).all() i+=1 del i ds[0]