Mercurial > pylearn
diff dataset.py @ 37:73c4212ba5b3
Factored the minibatch-writing code into an iterator class inside DataSet
author | bengioy@esprit.iro.umontreal.ca |
---|---|
date | Thu, 24 Apr 2008 12:03:06 -0400 |
parents | 438440ba0627 |
children | d637ad8f7352 |
line wrap: on
line diff
--- a/dataset.py Tue Apr 22 18:03:11 2008 -0400 +++ b/dataset.py Thu Apr 24 12:03:06 2008 -0400 @@ -98,7 +98,7 @@ * __len__ if it is not a stream * __getitem__ may not be feasible with some streams * fieldNames - * minibatches + * minibatches_nowrap (called by DataSet.minibatches()) * valuesHStack * valuesVStack For efficiency of implementation, a sub-class might also want to redefine @@ -142,13 +142,66 @@ """ return DataSet.MinibatchToSingleExampleIterator(self.minibatches(None, minibatch_size = 1)) + + class MinibatchWrapAroundIterator(object): + """ + An iterator for minibatches that handles the case where we need to wrap around the + dataset because n_batches*minibatch_size > len(dataset). It is constructed from + a dataset that provides a minibatch iterator that does not need to handle that problem. + This class is a utility for dataset subclass writers, so that they do not have to handle + this issue multiple times. + """ + def __init__(self,dataset,fieldnames,minibatch_size,n_batches,offset): + self.dataset=dataset + self.fieldnames=fieldnames + self.minibatch_size=minibatch_size + self.n_batches=n_batches + self.n_batches_done=0 + self.next_row=offset + self.L=len(dataset) + assert offset+minibatch_size<=self.L + ds_nbatches = (self.L-offset)/minibatch_size + if n_batches is not None: + ds_nbatches = max(n_batches,ds_nbatches) + self.iterator = dataset.minibatches_nowrap(fieldnames,minibatch_size,ds_nbatches,offset) + + def __iter__(self): + return self + + def next_index(self): + return self.next_row + + def next(self): + if self.n_batches and self.n_batches_done==self.n_batches: + raise StopIteration + upper = self.next_row+minibatch_size + if upper <=self.L: + minibatch = self.minibatch_iterator.next() + else: + if not self.n_batches: + raise StopIteration + # we must concatenate (vstack) the bottom and top parts of our minibatch + # first get the beginning of our minibatch (top of dataset) + first_part = self.dataset.minibatches_nowrap(fieldnames,self.L-self.next_row,1,self.next_row).next() + second_part = self.dataset.minibatches_nowrap(fieldnames,upper-self.L,1,0).next() + minibatch = Example(self.fieldnames, + [self.dataset.valuesVStack(name,[first_part[name],second_part[name]]) + for name in self.fieldnames]) + self.next_row=upper + self.n_batches_done+=1 + if upper >= L: + self.next_row -= L + return minibatch + + minibatches_fieldnames = None minibatches_minibatch_size = 1 minibatches_n_batches = None def minibatches(self, - fieldnames = minibatches_fieldnames, - minibatch_size = minibatches_minibatch_size, - n_batches = minibatches_n_batches): + fieldnames = minibatches_fieldnames, + minibatch_size = minibatches_minibatch_size, + n_batches = minibatches_n_batches, + offset = 0): """ Return an iterator that supports three forms of syntax: @@ -193,13 +246,29 @@ the derived class can choose a default. If (-1), then the returned iterator should support looping indefinitely. + - offset (integer, default 0) + The iterator will start at example 'offset' in the dataset, rather than the default. + Note: A list-like container is something like a tuple, list, numpy.ndarray or any other object that supports integer indexing and slicing. """ + return MinibatchWrapAroundIterator(self,fieldnames,minibatch_size,n_batches,offset) + + def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): + """ + This is the minibatches iterator generator that sub-classes must define. + It does not need to worry about wrapping around multiple times across the dataset, + as this is handled by MinibatchWrapAroundIterator when DataSet.minibatches() is called. + The next() method of the returned iterator does not even need to worry about + the termination condition (as StopIteration will be raised by DataSet.minibatches + before an improper call to minibatches_nowrap's next() is made). + That next() method can assert that its next row will always be within [0,len(dataset)). + The iterator returned by minibatches_nowrap does not need to implement + a next_index() method either, as this will be provided by MinibatchWrapAroundIterator. + """ raise AbstractFunction() - def __len__(self): """ len(dataset) returns the number of examples in the dataset. @@ -358,9 +427,10 @@ """ def __init__(self,dataset,*fieldnames): self.dataset=dataset - assert dataset.hasField(*fieldnames) + assert dataset.hasFields(*fieldnames) LookupList.__init__(self,dataset.fieldNames(), - dataset.minibatches(fieldnames if len(fieldnames)>0 else self.fieldNames(),minibatch_size=len(dataset)).next() + dataset.minibatches(fieldnames if len(fieldnames)>0 else self.fieldNames(), + minibatch_size=len(dataset)).next() def examples(self): return self.dataset @@ -378,6 +448,7 @@ """ return (self.examples() | other.examples()).fields() + class MinibatchDataSet(DataSet): """ Turn a LookupList of same-length fields into an example-iterable dataset. @@ -407,52 +478,32 @@ def fieldNames(self): return self.fields.keys() - def hasField(self,*fieldnames): + def hasFields(self,*fieldnames): for fieldname in fieldnames: if fieldname not in self.fields: return False return True - def minibatches(self, - fieldnames = minibatches_fieldnames, - minibatch_size = minibatches_minibatch_size, - n_batches = minibatches_n_batches): + def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): class Iterator(object): def __init__(self,ds): self.ds=ds - self.next_example=0 - self.n_batches_done=0 + self.next_example=offset assert minibatch_size > 0 - if minibatch_size > ds.length + if offset+minibatch_size > ds.length raise NotImplementedError() def __iter__(self): return self - def next_index(self): - return self.next_example def next(self): upper = next_example+minibatch_size - if upper<=self.ds.length: - minibatch = Example(self.ds.fields.keys(), - [field[next_example:upper] - for field in self.ds.fields]) - else: # we must concatenate (vstack) the bottom and top parts of our minibatch - minibatch = Example(self.ds.fields.keys(), - [self.ds.valuesVStack(name,[value[next_example:], - value[0:upper-self.ds.length]]) - for name,value in self.ds.fields.items()]) + assert upper<=self.ds.length + minibatch = Example(self.ds.fields.keys(), + [field[next_example:upper] + for field in self.ds.fields]) self.next_example+=minibatch_size - self.n_batches_done+=1 - if n_batches: - if self.n_batches_done==n_batches: - raise StopIteration - if self.next_example>=self.ds.length: - self.next_example-=self.ds.length - else: - if self.next_example>=self.ds.length: - raise StopIteration return DataSetFields(MinibatchDataSet(minibatch),fieldnames) - return Iterator(self) + return MinibatchWrapAroundIterator(self,fieldnames,minibatch_size,n_batches,offset) def valuesVStack(self,fieldname,fieldvalues): return self.values_vstack(fieldname,fieldvalues) @@ -504,7 +555,7 @@ del self.fieldname2dataset[fieldname] self.fieldname2dataset[rename_field(fieldname,self.datasets[i],i)]=i - def hasField(self,*fieldnames): + def hasFields(self,*fieldnames): for fieldname in fieldnames: if not fieldname in self.fieldname2dataset: return False @@ -513,10 +564,11 @@ def fieldNames(self): return self.fieldname2dataset.keys() - def minibatches(self, - fieldnames = minibatches_fieldnames, - minibatch_size = minibatches_minibatch_size, - n_batches = minibatches_n_batches): + def minibatches_nowrap(self, + fieldnames = minibatches_fieldnames, + minibatch_size = minibatches_minibatch_size, + n_batches = minibatches_n_batches, + offset = 0): class Iterator(object): def __init__(self,hsds,iterators): @@ -524,8 +576,6 @@ self.iterators=iterators def __iter__(self): return self - def next_index(self): - return self.iterators[0].next_index() def next(self): # concatenate all the fields of the minibatches minibatch = reduce(LookupList.__add__,[iterator.next() for iterator in self.iterators]) @@ -545,11 +595,11 @@ datasets.add(dataset) fields_in_dataset[dataset].append(fieldname) datasets=list(datasets) - iterators=[dataset.minibatches(fields_in_dataset[dataset],minibatch_size,n_batches) + iterators=[dataset.minibatches(fields_in_dataset[dataset],minibatch_size,n_batches,offset) for dataset in datasets] else: datasets=self.datasets - iterators=[dataset.minibatches(None,minibatch_size,n_batches) for dataset in datasets] + iterators=[dataset.minibatches(None,minibatch_size,n_batches,offset) for dataset in datasets] return Iterator(self,iterators) @@ -577,18 +627,87 @@ self.datasets=datasets self.length=0 self.index2dataset={} - # we use this map from row index to dataset index for constant-time random access of examples, - # to avoid having to search for the appropriate dataset each time and slice is asked for + assert len(datasets)>0 + fieldnames = datasets[-1].fieldNames() + # We use this map from row index to dataset index for constant-time random access of examples, + # to avoid having to search for the appropriate dataset each time and slice is asked for. for dataset,k in enumerate(datasets[0:-1]): L=len(dataset) assert L<DataSet.infinity for i in xrange(L): self.index2dataset[self.length+i]=k self.length+=L + assert dataset.fieldNames()==fieldnames self.last_start=self.length self.length+=len(datasets[-1]) - - + # If length is very large, we should use a more memory-efficient mechanism + # that does not store all indices + if self.length>1000000: + # 1 million entries would require about 60 meg for the index2dataset map + # TODO + print "A more efficient mechanism for index2dataset should be implemented" + + def __len__(self): + return self.length + + def fieldNames(self): + return self.datasets[0].fieldNames() + + def hasFields(self,*fieldnames): + return self.datasets[0].hasFields(*fieldnames) + + def minibatches_nowrap(self, + fieldnames = minibatches_fieldnames, + minibatch_size = minibatches_minibatch_size, + n_batches = minibatches_n_batches, + offset = 0): + class Iterator(object): + def __init__(self,vsds): + self.vsds=vsds + self.next_row=offset + self.next_dataset_index=0 + self.next_dataset_row=0 + self.current_iterator,self.n_left_at_the_end_of_ds,self.n_left_in_mb= \ + self.next_iterator(vsds.datasets[0],offset,n_batches) + + def next_iterator(self,dataset,starting_offset,batches_left): + L=len(dataset) + ds_nbatches = (L-starting_offset)/minibatch_size + if batches_left is not None: + ds_nbatches = max(batches_left,ds_nbatches) + if minibatch_size>L: + ds_minibatch_size=L + n_left_in_mb=minibatch_size-L + else: n_left_in_mb=0 + return dataset.minibatches(fieldnames,minibatch_size,ds_nbatches,starting_offset), \ + L-(starting_offset+ds_nbatches*minibatch_size), n_left_in_mb + + def move_to_next_dataset(self): + self.next_dataset_index +=1 + if self.next_dataset_index==len(self.vsds.datasets): + self.next_dataset_index = 0 + self.current_iterator,self.n_left_at_the_end_of_ds,self.n_left_in_mb= \ + self.next_iterator(vsds.datasets[self.next_dataset_index],starting_offset,n_batches) + + def __iter__(self): + return self + + def next(self): + dataset=self.vsds.datasets[self.next_dataset_index] + mb = self.next_iterator.next() + if self.n_left_in_mb: + names=self.vsds.fieldNames() + extra_mb = [] + while self.n_left_in_mb>0: + self.move_to_next_dataset() + extra_mb.append(self.next_iterator.next()) + mb = Example(names, + [dataset.valuesVStack(name,[mb[name]]+[b[name] for b in extra_mb]) + for name in names]) + self.next_row+=minibatch_size + return mb + + def supervised_learning_dataset(src_dataset,input_fields,target_fields,weight_field=None): """ Wraps an arbitrary DataSet into one for supervised learning tasks by forcing the