# HG changeset patch # User Yoshua Bengio # Date 1213727595 14400 # Node ID 9de4274ad5ba4643d70c8811d88995c711db3055 # Parent 2259f6fa4959ace15d3977b7c6a36d1ae4165e16# Parent 4efb503fd0da8fb11e51670228bb05f8f8f209aa Automated merge with ssh://projects@lgcm.iro.umontreal.ca/hg/pylearn diff -r 4efb503fd0da -r 9de4274ad5ba _test_dataset.py --- a/_test_dataset.py Tue Jun 17 14:32:54 2008 -0400 +++ b/_test_dataset.py Tue Jun 17 14:33:15 2008 -0400 @@ -134,12 +134,13 @@ # - for minibatch in dataset.minibatches([field1, field2, ...],minibatch_size=N): i=0 mi=0 - m=ds.minibatches(['x','z'], minibatch_size=3) - assert isinstance(m,DataSet.MinibatchWrapAroundIterator) + size=3 + m=ds.minibatches(['x','z'], minibatch_size=size) + assert hasattr(m,'__iter__') for minibatch in m: - assert isinstance(minibatch,DataSetFields) + assert isinstance(minibatch,LookupList) assert len(minibatch)==2 - test_minibatch_size(minibatch,m.minibatch_size,len(ds),2,mi) + test_minibatch_size(minibatch,size,len(ds),2,mi) if type(ds)==ArrayDataSet: assert (minibatch[0][:,::2]==minibatch[1]).all() else: @@ -147,92 +148,103 @@ (minibatch[0][j][::2]==minibatch[1][j]).all() mi+=1 i+=len(minibatch[0]) - assert i==len(ds) - assert mi==4 - del minibatch,i,m,mi + assert i==(len(ds)/size)*size + assert mi==(len(ds)/size) + del minibatch,i,m,mi,size i=0 mi=0 - m=ds.minibatches(['x','y'], minibatch_size=3) - assert isinstance(m,DataSet.MinibatchWrapAroundIterator) + size=3 + m=ds.minibatches(['x','y'], minibatch_size=size) + assert hasattr(m,'__iter__') for minibatch in m: + assert isinstance(minibatch,LookupList) assert len(minibatch)==2 - test_minibatch_size(minibatch,m.minibatch_size,len(ds),2,mi) + test_minibatch_size(minibatch,size,len(ds),2,mi) mi+=1 for id in range(len(minibatch[0])): assert (numpy.append(minibatch[0][id],minibatch[1][id])==array[i]).all() i+=1 - assert i==len(ds) - assert mi==4 - del minibatch,i,id,m,mi + assert i==(len(ds)/size)*size + assert mi==(len(ds)/size) + del minibatch,i,id,m,mi,size # - for mini1,mini2,mini3 in dataset.minibatches([field1, field2, field3], minibatch_size=N): i=0 mi=0 - m=ds.minibatches(['x','z'], minibatch_size=3) - assert isinstance(m,DataSet.MinibatchWrapAroundIterator) + size=3 + m=ds.minibatches(['x','z'], minibatch_size=size) + assert hasattr(m,'__iter__') for x,z in m: - test_minibatch_field_size(x,m.minibatch_size,len(ds),mi) - test_minibatch_field_size(z,m.minibatch_size,len(ds),mi) + test_minibatch_field_size(x,size,len(ds),mi) + test_minibatch_field_size(z,size,len(ds),mi) for id in range(len(x)): assert (x[id][::2]==z[id]).all() i+=1 mi+=1 - assert i==len(ds) - assert mi==4 - del x,z,i,m,mi + assert i==(len(ds)/size)*size + assert mi==(len(ds)/size) + del x,z,i,m,mi,size + i=0 mi=0 + size=3 m=ds.minibatches(['x','y'], minibatch_size=3) + assert hasattr(m,'__iter__') for x,y in m: - test_minibatch_field_size(x,m.minibatch_size,len(ds),mi) - test_minibatch_field_size(y,m.minibatch_size,len(ds),mi) + assert len(x)==size + assert len(y)==size + test_minibatch_field_size(x,size,len(ds),mi) + test_minibatch_field_size(y,size,len(ds),mi) mi+=1 for id in range(len(x)): assert (numpy.append(x[id],y[id])==array[i]).all() i+=1 - assert i==len(ds) - assert mi==4 - del x,y,i,id,m,mi + assert i==(len(ds)/size)*size + assert mi==(len(ds)/size) + del x,y,i,id,m,mi,size #not in doc i=0 - m=ds.minibatches(['x','y'],n_batches=1,minibatch_size=3,offset=4) - assert isinstance(m,DataSet.MinibatchWrapAroundIterator) + size=3 + m=ds.minibatches(['x','y'],n_batches=1,minibatch_size=size,offset=4) + assert hasattr(m,'__iter__') for x,y in m: - assert len(x)==m.minibatch_size - assert len(y)==m.minibatch_size - for id in range(m.minibatch_size): + assert len(x)==size + assert len(y)==size + for id in range(size): assert (numpy.append(x[id],y[id])==array[i+4]).all() i+=1 - assert i==m.n_batches*m.minibatch_size - del x,y,i,id,m + assert i==size + del x,y,i,id,m,size i=0 - m=ds.minibatches(['x','y'],n_batches=2,minibatch_size=3,offset=4) - assert isinstance(m,DataSet.MinibatchWrapAroundIterator) + size=3 + m=ds.minibatches(['x','y'],n_batches=2,minibatch_size=size,offset=4) + assert hasattr(m,'__iter__') for x,y in m: - assert len(x)==m.minibatch_size - assert len(y)==m.minibatch_size - for id in range(m.minibatch_size): + assert len(x)==size + assert len(y)==size + for id in range(size): assert (numpy.append(x[id],y[id])==array[i+4]).all() i+=1 - assert i==m.n_batches*m.minibatch_size - del x,y,i,id,m + assert i==2*size + del x,y,i,id,m,size i=0 - m=ds.minibatches(['x','y'],n_batches=20,minibatch_size=3,offset=4) - assert isinstance(m,DataSet.MinibatchWrapAroundIterator) + size=3 + m=ds.minibatches(['x','y'],n_batches=20,minibatch_size=size,offset=4) + assert hasattr(m,'__iter__') for x,y in m: - assert len(x)==m.minibatch_size - assert len(y)==m.minibatch_size - for id in range(m.minibatch_size): + assert len(x)==size + assert len(y)==size + for id in range(size): assert (numpy.append(x[id],y[id])==array[(i+4)%array.shape[0]]).all() i+=1 - assert i==m.n_batches*m.minibatch_size - del x,y,i,id + assert i==2*size # should not wrap + del x,y,i,id,size - assert not have_raised2(ds.minibatches,['x','y'],n_batches=1,minibatch_size=len(array)+1,offset=0) + assert have_raised2(ds.minibatches,['x','y'],n_batches=1,minibatch_size=len(array)+1,offset=0) assert not have_raised2(ds.minibatches,['x','y'],n_batches=1,minibatch_size=len(array),offset=0) def test_ds_iterator(array,iterator1,iterator2,iterator3): @@ -262,14 +274,17 @@ def test_getitem(array,ds): def test_ds(orig,ds,index): i=0 - assert len(ds)==len(index) - for x,z,y in ds('x','z','y'): - assert (orig[index[i]]['x']==array[index[i]][:3]).all() - assert (orig[index[i]]['x']==x).all() - assert orig[index[i]]['y']==array[index[i]][3] - assert (orig[index[i]]['y']==y).all() # why does it crash sometimes? - assert (orig[index[i]]['z']==array[index[i]][0:3:2]).all() - assert (orig[index[i]]['z']==z).all() + assert isinstance(ds,LookupList) + assert len(ds)==3 + assert len(ds[0])==len(index) +# for x,z,y in ds('x','z','y'): + for idx in index: + assert (orig[idx]['x']==array[idx][:3]).all() + assert (orig[idx]['x']==ds['x'][i]).all() + assert orig[idx]['y']==array[idx][3] + assert (orig[idx]['y']==ds['y'][i]).all() # why does it crash sometimes? + assert (orig[idx]['z']==array[idx][0:3:2]).all() + assert (orig[idx]['z']==ds['z'][i]).all() i+=1 del i ds[0] @@ -282,19 +297,22 @@ for x in ds: pass -#ds[:n] returns a dataset with the n first examples. +#ds[:n] returns a LookupList with the n first examples. ds2=ds[:3] - assert isinstance(ds2,LookupList) test_ds(ds,ds2,index=[0,1,2]) del ds2 -#ds[i1:i2:s]# returns a ds with the examples i1,i1+s,...i2-s. - ds2=ds.subset[1:7:2] - assert isinstance(ds2,DataSet) +#ds[i:j] returns a LookupList with examples i,i+1,...,j-1. + ds2=ds[1:3] + test_ds(ds,ds2,index=[1,2]) + del ds2 + +#ds[i1:i2:s] returns a LookupList with the examples i1,i1+s,...i2-s. + ds2=ds[1:7:2] test_ds(ds,ds2,[1,3,5]) del ds2 -#ds[i] +#ds[i] returns the (i+1)-th example of the dataset. ds2=ds[5] assert isinstance(ds2,Example) assert have_raised("var['ds']["+str(len(ds))+"]",ds=ds) # index not defined @@ -302,8 +320,8 @@ del ds2 #ds[[i1,i2,...in]]# returns a ds with examples i1,i2,...in. - ds2=ds.subset[[4,7,2,8]] - assert isinstance(ds2,DataSet) + ds2=ds[[4,7,2,8]] +# assert isinstance(ds2,DataSet) test_ds(ds,ds2,[4,7,2,8]) del ds2 diff -r 4efb503fd0da -r 9de4274ad5ba dataset.py --- a/dataset.py Tue Jun 17 14:32:54 2008 -0400 +++ b/dataset.py Tue Jun 17 14:33:15 2008 -0400 @@ -721,7 +721,12 @@ assert self.hasFields(*fieldnames) return self.src.minibatches_nowrap(fieldnames,minibatch_size,n_batches,offset) def __getitem__(self,i): - return FieldsSubsetDataSet(self.src[i],self.new_fieldnames) +# return FieldsSubsetDataSet(self.src[i],self.new_fieldnames) + complete_example = self.src[i] + return Example(self.new_fieldnames, + [complete_example[field] + for field in self.src_fieldnames]) + class DataSetFields(Example):