# HG changeset patch # User Yoshua Bengio # Date 1210118129 14400 # Node ID 695b729027d4da956dbeeb9a4979c72fa996ecd7 # Parent c4916445e0258bd897c7d7d736b73f7e24f2b0bc# Parent cf9bdb1d96561de4b07e0163d4ba64ca1d9db9a2 Merges? diff -r c4916445e025 -r 695b729027d4 dataset.py --- a/dataset.py Tue May 06 19:54:43 2008 -0400 +++ b/dataset.py Tue May 06 19:55:29 2008 -0400 @@ -209,16 +209,18 @@ self.n_batches=n_batches self.n_batches_done=0 self.next_row=offset + self.offset=offset self.L=len(dataset) assert offset+minibatch_size<=self.L - ds_nbatches = (self.L-offset)/minibatch_size + ds_nbatches = (self.L-self.next_row)/self.minibatch_size if n_batches is not None: - ds_nbatches = max(n_batches,ds_nbatches) + ds_nbatches = min(n_batches,ds_nbatches) if fieldnames: assert dataset.hasFields(*fieldnames) else: - fieldnames=dataset.fieldNames() - self.iterator = dataset.minibatches_nowrap(fieldnames,minibatch_size,ds_nbatches,offset) + self.fieldnames=dataset.fieldNames() + self.iterator = self.dataset.minibatches_nowrap(self.fieldnames,self.minibatch_size, + ds_nbatches,self.next_row) def __iter__(self): return self @@ -229,23 +231,35 @@ def next(self): if self.n_batches and self.n_batches_done==self.n_batches: raise StopIteration + elif not self.n_batches and self.next_row ==self.L: + raise StopIteration upper = self.next_row+self.minibatch_size if upper <=self.L: minibatch = self.iterator.next() else: if not self.n_batches: - raise StopIteration - # we must concatenate (vstack) the bottom and top parts of our minibatch - # first get the beginning of our minibatch (top of dataset) - first_part = self.dataset.minibatches_nowrap(fieldnames,self.L-self.next_row,1,self.next_row).next() - second_part = self.dataset.minibatches_nowrap(fieldnames,upper-self.L,1,0).next() - minibatch = Example(self.fieldnames, - [self.dataset.valuesVStack(name,[first_part[name],second_part[name]]) - for name in self.fieldnames]) + upper=min(upper, self.L) + # if their is not a fixed number of batch, we continue to the end of the dataset. + # this can create a minibatch that is smaller then the minibatch_size + assert (self.L-self.next_row)<=self.minibatch_size + minibatch = self.dataset.minibatches_nowrap(self.fieldnames,self.L-self.next_row,1,self.next_row).next() + else: + # we must concatenate (vstack) the bottom and top parts of our minibatch + # first get the beginning of our minibatch (top of dataset) + first_part = self.dataset.minibatches_nowrap(self.fieldnames,self.L-self.next_row,1,self.next_row).next() + second_part = self.dataset.minibatches_nowrap(self.fieldnames,upper-self.L,1,0).next() + minibatch = Example(self.fieldnames, + [self.dataset.valuesVStack(name,[first_part[name],second_part[name]]) + for name in self.fieldnames]) self.next_row=upper self.n_batches_done+=1 if upper >= self.L and self.n_batches: self.next_row -= self.L + ds_nbatches = (self.L-self.next_row)/self.minibatch_size + if self.n_batches is not None: + ds_nbatches = min(self.n_batches,ds_nbatches) + self.iterator = self.dataset.minibatches_nowrap(self.fieldnames,self.minibatch_size, + ds_nbatches,self.next_row) return DataSetFields(MinibatchDataSet(minibatch,self.dataset.valuesVStack, self.dataset.valuesHStack), minibatch.keys()) @@ -919,6 +933,7 @@ for fieldname, fieldcolumns in self.fields_columns.items(): if type(fieldcolumns) is int: assert fieldcolumns>=0 and fieldcolumns2: + ds[:1] + ds[1:1] + ds[1:1:1] + if len(ds)>5: + ds[[1,2,3]] + for x in ds: + pass + + #ds[:n] returns a dataset with the n first examples. + ds2=ds[:3] + assert isinstance(ds2,DataSet) + test_ds(ds,ds2,index=[0,1,2]) + del ds2 + + #ds[i1:i2:s]# returns a ds with the examples i1,i1+s,...i2-s. + ds2=ds[1:7:2] + assert isinstance(ds2,DataSet) + test_ds(ds,ds2,[1,3,5]) + del ds2 + + #ds[i] + ds2=ds[5] + assert isinstance(ds2,Example) + assert have_raised("ds["+str(len(ds))+"]") # index not defined + assert not have_raised("ds["+str(len(ds)-1)+"]") + del ds2 + + #ds[[i1,i2,...in]]# returns a ds with examples i1,i2,...in. + ds2=ds[[4,7,2,8]] + assert isinstance(ds2,DataSet) + test_ds(ds,ds2,[4,7,2,8]) + del ds2 + + #ds[fieldname]# an iterable over the values of the field fieldname across + #the ds (the iterable is obtained by default by calling valuesVStack + #over the values for individual examples). + assert have_raised("ds['h']") # h is not defined... + assert have_raised("ds[['x']]") # bad syntax + assert not have_raised("ds['x']") + isinstance(ds['x'],DataSetFields) + ds2=ds['x'] + assert len(ds['x'])==10 + assert len(ds['y'])==10 + assert len(ds['z'])==10 + i=0 + for example in ds['x']: + assert (example==a[i][:3]).all() + i+=1 + i=0 + for example in ds['y']: + assert (example==a[i][3]).all() + i+=1 + i=0 + for example in ds['z']: + assert (example==a[i,0:3:2]).all() + i+=1 + del ds2,i + + #ds.# returns the value of a property associated with + #the name . The following properties should be supported: + # - 'description': a textual description or name for the ds + # - 'fieldtypes': a list of types (one per field) + + #* ds1 | ds2 | ds3 == ds.hstack([ds1,ds2,ds3])#???? + #* ds1 & ds2 & ds3 == ds.vstack([ds1,ds2,ds3])#???? + + print "test_ArrayDataSet" a = numpy.random.rand(10,4) ds = ArrayDataSet(a,{'x':slice(3),'y':3,'z':[0,2]})###???tuple not tested ds = ArrayDataSet(a,LookupList(['x','y','z'],[slice(3),3,[0,2]]))###???tuple not tested assert len(ds)==10 #assert ds==a? should this work? - + test_iterate_over_examples(a, ds) - + test_getitem(a, ds) # - for val1,val2,val3 in dataset(field1, field2,field3): test_ds_iterator(a,ds('x','y'),ds('y','z'),ds('x','y','z')) - assert have_raised("ds['h']") # h is not defined... - assert have_raised("ds[['h']]") # h is not defined... assert len(ds.fields())==3 for field in ds.fields(): @@ -241,54 +350,7 @@ pass assert ds == ds.fields().examples() - - def test_ds(orig,ds,index): - i=0 - assert len(ds)==len(index) - for x,z,y in ds('x','z','y'): - assert (orig[index[i]]['x']==a[index[i]][:3]).all() - assert (orig[index[i]]['x']==x).all() - assert orig[index[i]]['y']==a[index[i]][3] - assert orig[index[i]]['y']==y - assert (orig[index[i]]['z']==a[index[i]][0:3:2]).all() - assert (orig[index[i]]['z']==z).all() - i+=1 - del i - ds[0] - if len(ds)>2: - ds[:1] - ds[1:1] - ds[1:1:1] - if len(ds)>5: - ds[[1,2,3]] - for x in ds: - pass - - #ds[:n] returns a dataset with the n first examples. - ds2=ds[:3] - test_ds(ds,ds2,index=[0,1,2]) - - #ds[i1:i2:s]# returns a ds with the examples i1,i1+s,...i2-s. - ds2=ds[1:7:2] - test_ds(ds,ds2,[1,3,5]) - - #ds[[i1,i2,...in]]# returns a ds with examples i1,i2,...in. - ds2=ds[[4,7,2,8]] - test_ds(ds,ds2,[4,7,2,8]) - #ds[i1,i2,...]# should we accept???? - - #ds[fieldname]# an iterable over the values of the field fieldname across - #the ds (the iterable is obtained by default by calling valuesVStack - #over the values for individual examples). - - #ds.# returns the value of a property associated with - #the name . The following properties should be supported: - # - 'description': a textual description or name for the ds - # - 'fieldtypes': a list of types (one per field) - #* ds1 | ds2 | ds3 == ds.hstack([ds1,ds2,ds3]) - #* ds1 & ds2 & ds3 == ds.vstack([ds1,ds2,ds3]) - -# for (x,y) in (ds('x','y'),a): #???don't work # haven't found a variant that work. +# for ((x,y),a_v) in (ds('x','y'),a): #???don't work # haven't found a variant that work.# will not work # assert numpy.append(x,y)==z def test_LookupList():