# HG changeset patch # User Frederic Bastien # Date 1213728375 14400 # Node ID 3aa9e5a5802abb003774084cc5c135eb2f307558 # Parent 9de4274ad5ba4643d70c8811d88995c711db3055# Parent a22ea54a19edda679d7e65aea7df65f6b27bed0e Automated merge with ssh://projects@lgcm.iro.umontreal.ca/hg/pylearn diff -r 9de4274ad5ba -r 3aa9e5a5802a _test_dataset.py --- a/_test_dataset.py Tue Jun 17 14:33:15 2008 -0400 +++ b/_test_dataset.py Tue Jun 17 14:46:15 2008 -0400 @@ -344,6 +344,71 @@ # del i,example #* ds1 & ds2 & ds3 == ds.vstack([ds1,ds2,ds3])#???? +def test_subset(array,ds): + def test_ds(orig,ds,index): + i=0 + assert isinstance(ds2,DataSet) + assert len(ds)==len(index) + for x,z,y in ds('x','z','y'): + assert (orig[index[i]]['x']==array[index[i]][:3]).all() + assert (orig[index[i]]['x']==x).all() + assert orig[index[i]]['y']==array[index[i]][3] + assert orig[index[i]]['y']==y + assert (orig[index[i]]['z']==array[index[i]][0:3:2]).all() + assert (orig[index[i]]['z']==z).all() + i+=1 + del i + ds[0] + if len(ds)>2: + ds[:1] + ds[1:1] + ds[1:1:1] + if len(ds)>5: + ds[[1,2,3]] + for x in ds: + pass + +#ds[:n] returns a dataset with the n first examples. + ds2=ds.subset[:3] + test_ds(ds,ds2,index=[0,1,2]) +# del ds2 + +# #ds[i1:i2:s]# returns a ds with the examples i1,i1+s,...i2-s. +# ds2=ds.subset[1:7:2] +# test_ds(ds,ds2,[1,3,5]) +# del ds2 + +# #ds[i] +# ds2=ds.subset[5] +# assert isinstance(ds2,Example) +# assert have_raised("var['ds']["+str(len(ds))+"]",ds=ds) # index not defined +# assert not have_raised("var['ds']["+str(len(ds)-1)+"]",ds=ds) +# del ds2 + +# #ds[[i1,i2,...in]]# returns a ds with examples i1,i2,...in. +# ds2=ds.subset[[4,7,2,8]] +# test_ds(ds,ds2,[4,7,2,8]) +# del ds2 + +# #ds.# returns the value of a property associated with +# #the name . The following properties should be supported: +# # - 'description': a textual description or name for the ds +# # - 'fieldtypes': a list of types (one per field) + +# #* ds1 | ds2 | ds3 == ds.hstack([ds1,ds2,ds3])#???? +# #assert hstack([ds('x','y'),ds('z')])==ds +# #hstack([ds('z','y'),ds('x')])==ds +# assert have_raised2(hstack,[ds('x'),ds('x')]) +# assert have_raised2(hstack,[ds('y','x'),ds('x')]) +# assert not have_raised2(hstack,[ds('x'),ds('y')]) + +# # i=0 +# # for example in hstack([ds('x'),ds('y'),ds('z')]): +# # example==ds[i] +# # i+=1 +# # del i,example +# #* ds1 & ds2 & ds3 == ds.vstack([ds1,ds2,ds3])#???? + def test_fields_fct(ds): #@todo, fill correctly assert len(ds.fields())==3 @@ -473,6 +538,7 @@ test_iterate_over_examples(array, ds) test_overrides(ds) test_getitem(array, ds) + test_subset(array, ds) test_ds_iterator(array,ds('x','y'),ds('y','z'),ds('x','y','z')) test_fields_fct(ds) diff -r 9de4274ad5ba -r 3aa9e5a5802a dataset.py --- a/dataset.py Tue Jun 17 14:33:15 2008 -0400 +++ b/dataset.py Tue Jun 17 14:46:15 2008 -0400 @@ -864,7 +864,8 @@ return self def next(self): upper = self.next_example+minibatch_size - assert upper<=self.ds.length + if upper>self.ds.length: + raise StopIteration #minibatch = Example(self.ds._fields.keys(), # [field[self.next_example:upper] # for field in self.ds._fields])