Mercurial > pylearn
changeset 108:695b729027d4
Merges?
author | Yoshua Bengio <bengioy@iro.umontreal.ca> |
---|---|
date | Tue, 06 May 2008 19:55:29 -0400 |
parents | c4916445e025 (current diff) cf9bdb1d9656 (diff) |
children | d97f6fe6bdf9 |
files | |
diffstat | 2 files changed, 180 insertions(+), 103 deletions(-) [+] |
line wrap: on
line diff
--- a/dataset.py Tue May 06 19:54:43 2008 -0400 +++ b/dataset.py Tue May 06 19:55:29 2008 -0400 @@ -209,16 +209,18 @@ self.n_batches=n_batches self.n_batches_done=0 self.next_row=offset + self.offset=offset self.L=len(dataset) assert offset+minibatch_size<=self.L - ds_nbatches = (self.L-offset)/minibatch_size + ds_nbatches = (self.L-self.next_row)/self.minibatch_size if n_batches is not None: - ds_nbatches = max(n_batches,ds_nbatches) + ds_nbatches = min(n_batches,ds_nbatches) if fieldnames: assert dataset.hasFields(*fieldnames) else: - fieldnames=dataset.fieldNames() - self.iterator = dataset.minibatches_nowrap(fieldnames,minibatch_size,ds_nbatches,offset) + self.fieldnames=dataset.fieldNames() + self.iterator = self.dataset.minibatches_nowrap(self.fieldnames,self.minibatch_size, + ds_nbatches,self.next_row) def __iter__(self): return self @@ -229,23 +231,35 @@ def next(self): if self.n_batches and self.n_batches_done==self.n_batches: raise StopIteration + elif not self.n_batches and self.next_row ==self.L: + raise StopIteration upper = self.next_row+self.minibatch_size if upper <=self.L: minibatch = self.iterator.next() else: if not self.n_batches: - raise StopIteration - # we must concatenate (vstack) the bottom and top parts of our minibatch - # first get the beginning of our minibatch (top of dataset) - first_part = self.dataset.minibatches_nowrap(fieldnames,self.L-self.next_row,1,self.next_row).next() - second_part = self.dataset.minibatches_nowrap(fieldnames,upper-self.L,1,0).next() - minibatch = Example(self.fieldnames, - [self.dataset.valuesVStack(name,[first_part[name],second_part[name]]) - for name in self.fieldnames]) + upper=min(upper, self.L) + # if their is not a fixed number of batch, we continue to the end of the dataset. + # this can create a minibatch that is smaller then the minibatch_size + assert (self.L-self.next_row)<=self.minibatch_size + minibatch = self.dataset.minibatches_nowrap(self.fieldnames,self.L-self.next_row,1,self.next_row).next() + else: + # we must concatenate (vstack) the bottom and top parts of our minibatch + # first get the beginning of our minibatch (top of dataset) + first_part = self.dataset.minibatches_nowrap(self.fieldnames,self.L-self.next_row,1,self.next_row).next() + second_part = self.dataset.minibatches_nowrap(self.fieldnames,upper-self.L,1,0).next() + minibatch = Example(self.fieldnames, + [self.dataset.valuesVStack(name,[first_part[name],second_part[name]]) + for name in self.fieldnames]) self.next_row=upper self.n_batches_done+=1 if upper >= self.L and self.n_batches: self.next_row -= self.L + ds_nbatches = (self.L-self.next_row)/self.minibatch_size + if self.n_batches is not None: + ds_nbatches = min(self.n_batches,ds_nbatches) + self.iterator = self.dataset.minibatches_nowrap(self.fieldnames,self.minibatch_size, + ds_nbatches,self.next_row) return DataSetFields(MinibatchDataSet(minibatch,self.dataset.valuesVStack, self.dataset.valuesHStack), minibatch.keys()) @@ -919,6 +933,7 @@ for fieldname, fieldcolumns in self.fields_columns.items(): if type(fieldcolumns) is int: assert fieldcolumns>=0 and fieldcolumns<data_array.shape[1] + self.fields_columns[fieldname]=[fieldcolumns] elif type(fieldcolumns) is slice: start,step=None,None if not fieldcolumns.start: @@ -962,7 +977,7 @@ # else check for a fieldname if self.hasFields(key): - return self.data[self.fields_columns[key],:] + return self.data[:,self.fields_columns[key]] # else we are trying to access a property of the dataset assert key in self.__dict__ # else it means we are trying to access a non-existing property return self.__dict__[key]
--- a/test_dataset.py Tue May 06 19:54:43 2008 -0400 +++ b/test_dataset.py Tue May 06 19:55:29 2008 -0400 @@ -35,9 +35,10 @@ def test_ArrayDataSet(): #don't test stream #tested only with float value - #test with y too - #test missing value - + #don't always test with y + #don't test missing value + #don't test with tuple + #don't test proterties def test_iterate_over_examples(array,ds): #not in doc!!! i=0 @@ -112,80 +113,107 @@ assert i==len(ds) del x,y,i + def test_minibatch_size(minibatch,minibatch_size,len_ds,nb_field,nb_iter_finished): + ##full minibatch or the last minibatch + for idx in range(nb_field): + test_minibatch_field_size(minibatch[idx],minibatch_size,len_ds,nb_iter_finished) + del idx + def test_minibatch_field_size(minibatch_field,minibatch_size,len_ds,nb_iter_finished): + assert len(minibatch_field)==minibatch_size or ((nb_iter_finished*minibatch_size+len(minibatch_field))==len_ds and len(minibatch_field)<minibatch_size) # - for minibatch in dataset.minibatches([field1, field2, ...],minibatch_size=N): i=0 - for minibatch in ds.minibatches(['x','z'], minibatch_size=3): + mi=0 + m=ds.minibatches(['x','z'], minibatch_size=3) + assert isinstance(m,DataSet.MinibatchWrapAroundIterator) + for minibatch in m: assert len(minibatch)==2 - assert len(minibatch[0])==3 - assert len(minibatch[1])==3 + test_minibatch_size(minibatch,m.minibatch_size,len(ds),2,mi) assert (minibatch[0][:,0:3:2]==minibatch[1]).all() - i+=1 - #assert i==#??? What shoud be the value? - print i - del minibatch,i + mi+=1 + i+=len(minibatch[0]) + assert i==len(ds) + assert mi==4 + del minibatch,i,m,mi + i=0 - for minibatch in ds.minibatches(['x','y'], minibatch_size=3): + mi=0 + m=ds.minibatches(['x','y'], minibatch_size=3) + assert isinstance(m,DataSet.MinibatchWrapAroundIterator) + for minibatch in m: assert len(minibatch)==2 - assert len(minibatch[0])==3 - assert len(minibatch[1])==3 - for id in range(3): + test_minibatch_size(minibatch,m.minibatch_size,len(ds),2,mi) + mi+=1 + for id in range(len(minibatch[0])): assert (numpy.append(minibatch[0][id],minibatch[1][id])==a[i]).all() i+=1 - #assert i==#??? What shoud be the value? - print i - del minibatch,i,id + assert i==len(ds) + assert mi==4 + del minibatch,i,id,m,mi # - for mini1,mini2,mini3 in dataset.minibatches([field1, field2, field3], minibatch_size=N): i=0 - for x,z in ds.minibatches(['x','z'], minibatch_size=3): - assert len(x)==3 - assert len(z)==3 + mi=0 + m=ds.minibatches(['x','z'], minibatch_size=3) + assert isinstance(m,DataSet.MinibatchWrapAroundIterator) + for x,z in m: + test_minibatch_field_size(x,m.minibatch_size,len(ds),mi) + test_minibatch_field_size(z,m.minibatch_size,len(ds),mi) assert (x[:,0:3:2]==z).all() - i+=1 - #assert i==#??? What shoud be the value? - print i - del x,z,i + i+=len(x) + mi+=1 + assert i==len(ds) + assert mi==4 + del x,z,i,m,mi i=0 - for x,y in ds.minibatches(['x','y'], minibatch_size=3): - assert len(x)==3 - assert len(y)==3 - for id in range(3): + mi=0 + m=ds.minibatches(['x','y'], minibatch_size=3) + for x,y in m: + test_minibatch_field_size(x,m.minibatch_size,len(ds),mi) + test_minibatch_field_size(y,m.minibatch_size,len(ds),mi) + mi+=1 + for id in range(len(x)): assert (numpy.append(x[id],y[id])==a[i]).all() i+=1 - #assert i==#??? What shoud be the value? - print i - del x,y,i,id + assert i==len(ds) + assert mi==4 + del x,y,i,id,m,mi #not in doc i=0 - for x,y in ds.minibatches(['x','y'],n_batches=1,minibatch_size=3,offset=4): + m=ds.minibatches(['x','y'],n_batches=1,minibatch_size=3,offset=4) + assert isinstance(m,DataSet.MinibatchWrapAroundIterator) + for x,y in m: assert len(x)==3 assert len(y)==3 for id in range(3): assert (numpy.append(x[id],y[id])==a[i+4]).all() i+=1 assert i==3 - del x,y,i,id + del x,y,i,id,m i=0 - for x,y in ds.minibatches(['x','y'],n_batches=2,minibatch_size=3,offset=4): + m=ds.minibatches(['x','y'],n_batches=2,minibatch_size=3,offset=4) + assert isinstance(m,DataSet.MinibatchWrapAroundIterator) + for x,y in m: assert len(x)==3 assert len(y)==3 for id in range(3): assert (numpy.append(x[id],y[id])==a[i+4]).all() i+=1 assert i==6 - del x,y,i,id + del x,y,i,id,m i=0 - for x,y in ds.minibatches(['x','y'],n_batches=10,minibatch_size=3,offset=4): + m=ds.minibatches(['x','y'],n_batches=20,minibatch_size=3,offset=4) + assert isinstance(m,DataSet.MinibatchWrapAroundIterator) + for x,y in m: assert len(x)==3 assert len(y)==3 for id in range(3): - assert (numpy.append(x[id],y[id])==a[i+4]).all() + assert (numpy.append(x[id],y[id])==a[(i+4)%a.shape[0]]).all() i+=1 - assert i==6 + assert i==m.n_batches*m.minibatch_size del x,y,i,id @@ -212,21 +240,102 @@ i+=1 assert i==len(ds) + def test_getitem(array,ds): + + def test_ds(orig,ds,index): + i=0 + assert len(ds)==len(index) + for x,z,y in ds('x','z','y'): + assert (orig[index[i]]['x']==array[index[i]][:3]).all() + assert (orig[index[i]]['x']==x).all() + assert orig[index[i]]['y']==array[index[i]][3] + assert orig[index[i]]['y']==y + assert (orig[index[i]]['z']==array[index[i]][0:3:2]).all() + assert (orig[index[i]]['z']==z).all() + i+=1 + del i + ds[0] + if len(ds)>2: + ds[:1] + ds[1:1] + ds[1:1:1] + if len(ds)>5: + ds[[1,2,3]] + for x in ds: + pass + + #ds[:n] returns a dataset with the n first examples. + ds2=ds[:3] + assert isinstance(ds2,DataSet) + test_ds(ds,ds2,index=[0,1,2]) + del ds2 + + #ds[i1:i2:s]# returns a ds with the examples i1,i1+s,...i2-s. + ds2=ds[1:7:2] + assert isinstance(ds2,DataSet) + test_ds(ds,ds2,[1,3,5]) + del ds2 + + #ds[i] + ds2=ds[5] + assert isinstance(ds2,Example) + assert have_raised("ds["+str(len(ds))+"]") # index not defined + assert not have_raised("ds["+str(len(ds)-1)+"]") + del ds2 + + #ds[[i1,i2,...in]]# returns a ds with examples i1,i2,...in. + ds2=ds[[4,7,2,8]] + assert isinstance(ds2,DataSet) + test_ds(ds,ds2,[4,7,2,8]) + del ds2 + + #ds[fieldname]# an iterable over the values of the field fieldname across + #the ds (the iterable is obtained by default by calling valuesVStack + #over the values for individual examples). + assert have_raised("ds['h']") # h is not defined... + assert have_raised("ds[['x']]") # bad syntax + assert not have_raised("ds['x']") + isinstance(ds['x'],DataSetFields) + ds2=ds['x'] + assert len(ds['x'])==10 + assert len(ds['y'])==10 + assert len(ds['z'])==10 + i=0 + for example in ds['x']: + assert (example==a[i][:3]).all() + i+=1 + i=0 + for example in ds['y']: + assert (example==a[i][3]).all() + i+=1 + i=0 + for example in ds['z']: + assert (example==a[i,0:3:2]).all() + i+=1 + del ds2,i + + #ds.<property># returns the value of a property associated with + #the name <property>. The following properties should be supported: + # - 'description': a textual description or name for the ds + # - 'fieldtypes': a list of types (one per field) + + #* ds1 | ds2 | ds3 == ds.hstack([ds1,ds2,ds3])#???? + #* ds1 & ds2 & ds3 == ds.vstack([ds1,ds2,ds3])#???? + + print "test_ArrayDataSet" a = numpy.random.rand(10,4) ds = ArrayDataSet(a,{'x':slice(3),'y':3,'z':[0,2]})###???tuple not tested ds = ArrayDataSet(a,LookupList(['x','y','z'],[slice(3),3,[0,2]]))###???tuple not tested assert len(ds)==10 #assert ds==a? should this work? - + test_iterate_over_examples(a, ds) - + test_getitem(a, ds) # - for val1,val2,val3 in dataset(field1, field2,field3): test_ds_iterator(a,ds('x','y'),ds('y','z'),ds('x','y','z')) - assert have_raised("ds['h']") # h is not defined... - assert have_raised("ds[['h']]") # h is not defined... assert len(ds.fields())==3 for field in ds.fields(): @@ -241,54 +350,7 @@ pass assert ds == ds.fields().examples() - - def test_ds(orig,ds,index): - i=0 - assert len(ds)==len(index) - for x,z,y in ds('x','z','y'): - assert (orig[index[i]]['x']==a[index[i]][:3]).all() - assert (orig[index[i]]['x']==x).all() - assert orig[index[i]]['y']==a[index[i]][3] - assert orig[index[i]]['y']==y - assert (orig[index[i]]['z']==a[index[i]][0:3:2]).all() - assert (orig[index[i]]['z']==z).all() - i+=1 - del i - ds[0] - if len(ds)>2: - ds[:1] - ds[1:1] - ds[1:1:1] - if len(ds)>5: - ds[[1,2,3]] - for x in ds: - pass - - #ds[:n] returns a dataset with the n first examples. - ds2=ds[:3] - test_ds(ds,ds2,index=[0,1,2]) - - #ds[i1:i2:s]# returns a ds with the examples i1,i1+s,...i2-s. - ds2=ds[1:7:2] - test_ds(ds,ds2,[1,3,5]) - - #ds[[i1,i2,...in]]# returns a ds with examples i1,i2,...in. - ds2=ds[[4,7,2,8]] - test_ds(ds,ds2,[4,7,2,8]) - #ds[i1,i2,...]# should we accept???? - - #ds[fieldname]# an iterable over the values of the field fieldname across - #the ds (the iterable is obtained by default by calling valuesVStack - #over the values for individual examples). - - #ds.<property># returns the value of a property associated with - #the name <property>. The following properties should be supported: - # - 'description': a textual description or name for the ds - # - 'fieldtypes': a list of types (one per field) - #* ds1 | ds2 | ds3 == ds.hstack([ds1,ds2,ds3]) - #* ds1 & ds2 & ds3 == ds.vstack([ds1,ds2,ds3]) - -# for (x,y) in (ds('x','y'),a): #???don't work # haven't found a variant that work. +# for ((x,y),a_v) in (ds('x','y'),a): #???don't work # haven't found a variant that work.# will not work # assert numpy.append(x,y)==z def test_LookupList():