Mercurial > pylearn
changeset 38:d637ad8f7352
Finished first untested version of VStackedDataset
author | bengioy@esprit.iro.umontreal.ca |
---|---|
date | Thu, 24 Apr 2008 13:25:39 -0400 |
parents | 73c4212ba5b3 |
children | c682c6e9bf93 |
files | dataset.py |
diffstat | 1 files changed, 34 insertions(+), 12 deletions(-) [+] |
line wrap: on
line diff
--- a/dataset.py Thu Apr 24 12:03:06 2008 -0400 +++ b/dataset.py Thu Apr 24 13:25:39 2008 -0400 @@ -149,7 +149,8 @@ dataset because n_batches*minibatch_size > len(dataset). It is constructed from a dataset that provides a minibatch iterator that does not need to handle that problem. This class is a utility for dataset subclass writers, so that they do not have to handle - this issue multiple times. + this issue multiple times, nor check that fieldnames are valid, nor handle the + empty fieldnames (meaning 'use all the fields'). """ def __init__(self,dataset,fieldnames,minibatch_size,n_batches,offset): self.dataset=dataset @@ -163,6 +164,10 @@ ds_nbatches = (self.L-offset)/minibatch_size if n_batches is not None: ds_nbatches = max(n_batches,ds_nbatches) + if fieldnames: + assert dataset.hasFields(*fieldnames) + else: + fieldnames=dataset.fieldNames() self.iterator = dataset.minibatches_nowrap(fieldnames,minibatch_size,ds_nbatches,offset) def __iter__(self): @@ -629,6 +634,7 @@ self.index2dataset={} assert len(datasets)>0 fieldnames = datasets[-1].fieldNames() + self.datasets_start_row=[] # We use this map from row index to dataset index for constant-time random access of examples, # to avoid having to search for the appropriate dataset each time and slice is asked for. for dataset,k in enumerate(datasets[0:-1]): @@ -636,9 +642,10 @@ assert L<DataSet.infinity for i in xrange(L): self.index2dataset[self.length+i]=k + self.datasets_start_row.append(self.length) self.length+=L assert dataset.fieldNames()==fieldnames - self.last_start=self.length + self.datasets_start_row.append(self.length) self.length+=len(datasets[-1]) # If length is very large, we should use a more memory-efficient mechanism # that does not store all indices @@ -656,17 +663,23 @@ def hasFields(self,*fieldnames): return self.datasets[0].hasFields(*fieldnames) + def locate_row(self,row): + """Return (dataset_index, row_within_dataset) for global row number""" + dataset_index = self.index2dataset[row] + row_within_dataset = self.datasets_start_row[dataset_index] + return dataset_index, row_within_dataset + def minibatches_nowrap(self, fieldnames = minibatches_fieldnames, minibatch_size = minibatches_minibatch_size, n_batches = minibatches_n_batches, offset = 0): + class Iterator(object): def __init__(self,vsds): self.vsds=vsds self.next_row=offset - self.next_dataset_index=0 - self.next_dataset_row=0 + self.next_dataset_index,self.next_dataset_row=self.vsds.locate_row(offset) self.current_iterator,self.n_left_at_the_end_of_ds,self.n_left_in_mb= \ self.next_iterator(vsds.datasets[0],offset,n_batches) @@ -678,16 +691,23 @@ if minibatch_size>L: ds_minibatch_size=L n_left_in_mb=minibatch_size-L - else: n_left_in_mb=0 + ds_nbatches=1 + else: + n_left_in_mb=0 return dataset.minibatches(fieldnames,minibatch_size,ds_nbatches,starting_offset), \ L-(starting_offset+ds_nbatches*minibatch_size), n_left_in_mb def move_to_next_dataset(self): - self.next_dataset_index +=1 - if self.next_dataset_index==len(self.vsds.datasets): - self.next_dataset_index = 0 - self.current_iterator,self.n_left_at_the_end_of_ds,self.n_left_in_mb= \ - self.next_iterator(vsds.datasets[self.next_dataset_index],starting_offset,n_batches) + if self.n_left_at_the_end_of_ds>0: + self.current_iterator,self.n_left_at_the_end_of_ds,self.n_left_in_mb= \ + self.next_iterator(vsds.datasets[self.next_dataset_index], + self.n_left_at_the_end_of_ds,1) + else: + self.next_dataset_index +=1 + if self.next_dataset_index==len(self.vsds.datasets): + self.next_dataset_index = 0 + self.current_iterator,self.n_left_at_the_end_of_ds,self.n_left_in_mb= \ + self.next_iterator(vsds.datasets[self.next_dataset_index],starting_offset,n_batches) def __iter__(self): return self @@ -696,15 +716,17 @@ dataset=self.vsds.datasets[self.next_dataset_index] mb = self.next_iterator.next() if self.n_left_in_mb: - names=self.vsds.fieldNames() extra_mb = [] while self.n_left_in_mb>0: self.move_to_next_dataset() extra_mb.append(self.next_iterator.next()) mb = Example(names, [dataset.valuesVStack(name,[mb[name]]+[b[name] for b in extra_mb]) - for name in names]) + for name in fieldnames]) self.next_row+=minibatch_size + self.next_dataset_row+=minibatch_size + if self.next_row+minibatch_size>len(dataset): + self.move_to_next_dataset() return mb