Mercurial > pylearn
changeset 6:d5738b79089a
Removed MinibatchIterator and instead made minibatch_size a field of all DataSets,
so that they can all iterate over minibatches, optionally.
author | bengioy@bengiomac.local |
---|---|
date | Mon, 24 Mar 2008 09:04:06 -0400 |
parents | 8039918516fe |
children | 6f8f338686db |
files | _test_dataset.py dataset.py |
diffstat | 2 files changed, 57 insertions(+), 52 deletions(-) [+] |
line wrap: on
line diff
--- a/_test_dataset.py Sun Mar 23 22:44:43 2008 -0400 +++ b/_test_dataset.py Mon Mar 24 09:04:06 2008 -0400 @@ -13,12 +13,16 @@ numpy.random.seed(123456) def test0(self): - a=ArrayDataSet(data=numpy.random.rand(8,3),fields={"x":slice(2),"y":slice(1,3)}) + a=ArrayDataSet(data=numpy.random.rand(8,3),fields={"x":slice(2),"y":slice(1,3)},minibatch_size=1) s=0 for example in a: s+=_sum_all(example.x) print s self.failUnless(abs(s-11.4674133)<1e-6) + a.minibatch_size=2 + for mb in a: + print mb + if __name__ == '__main__': unittest.main()
--- a/dataset.py Sun Mar 23 22:44:43 2008 -0400 +++ b/dataset.py Mon Mar 24 09:04:06 2008 -0400 @@ -6,20 +6,31 @@ A dataset is basically an iterator over examples. It does not necessarily have a fixed length (this is useful for 'streams' which feed on-line learning). Datasets with fixed and known length are FiniteDataSet, a subclass of DataSet. - Examples and datasets have named fields. + Examples and datasets optionally have named fields. One can obtain a sub-dataset by taking dataset.field or dataset(field1,field2,field3,...). Fields are not mutually exclusive, i.e. two fields can overlap in their actual content. - The content of a field can be of any type, but often will be a numpy tensor. + The content of a field can be of any type, but often will be a numpy array. + The minibatch_size field, if different than 1, means that the iterator (next() method) + returns not a single example but an array of length minibatch_size, i.e., an indexable + object. """ - def __init__(self): - pass + def __init__(self,minibatch_size=1): + assert minibatch_size>0 + self.minibatch_size=minibatch_size def __iter__(self): return self def next(self): - """Return the next example in the dataset.""" + """ + Return the next example or the next minibatch in the dataset. + A minibatch (of length > 1) should be something one can iterate on again in order + to obtain the individual examples. If the dataset has fields, + then the example or the minibatch must have the same fields + (typically this is implemented by returning another (small) dataset, when + there are fields). + """ raise NotImplementedError def __getattr__(self,fieldname): @@ -41,8 +52,8 @@ and a subdataset can be obtained by slicing. """ - def __init__(self): - pass + def __init__(self,minibatch_size): + DataSet.__init__(self,minibatch_size) def __len__(self): """len(dataset) returns the number of examples in the dataset.""" @@ -56,49 +67,27 @@ """dataset[i:j] returns the subdataset with examples i,i+1,...,j-1.""" raise NotImplementedError - def minibatches(self,minibatch_size): - """Return an iterator for the dataset that goes through minibatches of the given size.""" - return MinibatchIterator(self,minibatch_size) - -class MinibatchIterator(object): - """ - Iterator class for FiniteDataSet that can iterate by minibatches - (sub-dataset of consecutive examples). - """ - def __init__(self,dataset,minibatch_size): - assert minibatch_size>0 and minibatch_size<len(dataset) - self.dataset=dataset - self.minibatch_size=minibatch_size - self.current=-minibatch_size - def __iter__(self): - return self - def next(self): - self.current+=self.minibatch_size - if self.current>=len(self.dataset): - self.current=-self.minibatchsize - raise StopIteration - return self.dataset[self.current:self.current+self.minibatchsize] - # we may want ArrayDataSet defined in another python file import numpy class ArrayDataSet(FiniteDataSet): """ - A fixed-length and fixed-width dataset in which each element is a numpy.array - or a number, hence the whole dataset corresponds to a numpy.array. Fields + A fixed-length and fixed-width dataset in which each element is a numpy array + or a number, hence the whole dataset corresponds to a numpy array. Fields must correspond to a slice of columns. If the dataset has fields, - each 'example' is just a one-row ArrayDataSet, otherwise it is a numpy.array. - Any dataset can also be converted to a numpy.array (losing the notion of fields) + each 'example' is just a one-row ArrayDataSet, otherwise it is a numpy array. + Any dataset can also be converted to a numpy array (losing the notion of fields) by the asarray(dataset) call. """ - def __init__(self,dataset=None,data=None,fields={}): + def __init__(self,dataset=None,data=None,fields={},minibatch_size=1): """ Construct an ArrayDataSet, either from a DataSet, or from - a numpy.array plus an optional specification of fields (by + a numpy array plus an optional specification of fields (by a dictionary of column slices indexed by field names). """ + FiniteDataSet.__init__(self,minibatch_size) self.current_row=-1 # used for view of this dataset as an iterator if dataset!=None: assert data==None and fields=={} @@ -122,28 +111,40 @@ fieldslice = slice(start,fieldslice.stop,step) # and coherent with the data array assert fieldslice.start>=0 and fieldslice.stop<=self.width + assert minibatch_size<=len(self.data) def next(self): """ - Return the next example in the dataset. If the dataset has fields, - the 'example' is just a one-row ArrayDataSet, otherwise it is a numpy.array. + Return the next example(s) in the dataset. If self.minibatch_size>1 return that + many examples. If the dataset has fields, the example or the minibatch of examples + is just a minibatch_size-rows ArrayDataSet (so that the fields can be accessed), + but that resulting mini-dataset has a minibatch_size of 1, so that one can iterate + example-wise on it. On the other hand, if the dataset has no fields (e.g. because + it is already the field of a bigger dataset), then the returned example or minibatch + is a numpy array. Following the array semantics of indexing and slicing, + if the minibatch_size is 1 (and there are no fields), then the result is an array + with one less dimension (e.g., a vector, if the dataset is a matrix), corresponding + to a row. Again, if the minibatch_size is >1, one can iterate on the result to + obtain individual examples (as rows). """ if self.fields: - self.current_row+=1 - if self.current_row==len(self.data): - self.current_row=-1 + self.current_row+=self.minibatch_size + if self.current_row>=len(self.data): + self.current_row=-self.minibatch_size raise StopIteration - return self[self.current_row] + if self.minibatch_size==1: + return self[self.current_row] + else: + return self[self.current_row:self.current_row+self.minibatch_size] else: - return self.data[self.current_row] + if self.minibatch_size==1: + return self.data[self.current_row] + else: + return self.data[self.current_row:self.current_row+self.minibatch_size] def __getattr__(self,fieldname): - """Return a sub-dataset containing only the given fieldname as field.""" - data=self.data[self.fields[fieldname]] - if len(data)==1: - return data - else: - return ArrayDataSet(data=data) + """Return a numpy array with the content associated with the given field name.""" + return self.data[self.fields[fieldname]] def __call__(self,*fieldnames): """Return a sub-dataset containing only the given fieldnames as fields.""" @@ -155,7 +156,7 @@ new_fields={} for field in self.fields: new_fields[field[0]]=slice(field[1].start-min_col,field[1].stop-min_col,field[1].step) - return ArrayDataSet(data=self.data[:,min_col:max_col],fields=new_fields) + return ArrayDataSet(data=self.data[:,min_col:max_col],fields=new_fields,minibatch_size=self.minibatch_size) def fieldNames(self): """Return the list of field names that are supported by getattr and getFields.""" @@ -179,7 +180,7 @@ def __getslice__(self,*slice_args): """dataset[i:j] returns the subdataset with examples i,i+1,...,j-1.""" - return ArrayDataSet(data=self.data[slice(slice_args)],fields=self.fields) + return ArrayDataSet(data=self.data[apply(slice,slice_args)],fields=self.fields) def asarray(self): if self.fields: