# HG changeset patch # User bengioy@esprit.iro.umontreal.ca # Date 1209057939 14400 # Node ID d637ad8f735261312cb82963feb8344a8bb2bc2f # Parent 73c4212ba5b3b3b1d48f64ba8a1fb75a383b82ca Finished first untested version of VStackedDataset diff -r 73c4212ba5b3 -r d637ad8f7352 dataset.py --- a/dataset.py Thu Apr 24 12:03:06 2008 -0400 +++ b/dataset.py Thu Apr 24 13:25:39 2008 -0400 @@ -149,7 +149,8 @@ dataset because n_batches*minibatch_size > len(dataset). It is constructed from a dataset that provides a minibatch iterator that does not need to handle that problem. This class is a utility for dataset subclass writers, so that they do not have to handle - this issue multiple times. + this issue multiple times, nor check that fieldnames are valid, nor handle the + empty fieldnames (meaning 'use all the fields'). """ def __init__(self,dataset,fieldnames,minibatch_size,n_batches,offset): self.dataset=dataset @@ -163,6 +164,10 @@ ds_nbatches = (self.L-offset)/minibatch_size if n_batches is not None: ds_nbatches = max(n_batches,ds_nbatches) + if fieldnames: + assert dataset.hasFields(*fieldnames) + else: + fieldnames=dataset.fieldNames() self.iterator = dataset.minibatches_nowrap(fieldnames,minibatch_size,ds_nbatches,offset) def __iter__(self): @@ -629,6 +634,7 @@ self.index2dataset={} assert len(datasets)>0 fieldnames = datasets[-1].fieldNames() + self.datasets_start_row=[] # We use this map from row index to dataset index for constant-time random access of examples, # to avoid having to search for the appropriate dataset each time and slice is asked for. for dataset,k in enumerate(datasets[0:-1]): @@ -636,9 +642,10 @@ assert LL: ds_minibatch_size=L n_left_in_mb=minibatch_size-L - else: n_left_in_mb=0 + ds_nbatches=1 + else: + n_left_in_mb=0 return dataset.minibatches(fieldnames,minibatch_size,ds_nbatches,starting_offset), \ L-(starting_offset+ds_nbatches*minibatch_size), n_left_in_mb def move_to_next_dataset(self): - self.next_dataset_index +=1 - if self.next_dataset_index==len(self.vsds.datasets): - self.next_dataset_index = 0 - self.current_iterator,self.n_left_at_the_end_of_ds,self.n_left_in_mb= \ - self.next_iterator(vsds.datasets[self.next_dataset_index],starting_offset,n_batches) + if self.n_left_at_the_end_of_ds>0: + self.current_iterator,self.n_left_at_the_end_of_ds,self.n_left_in_mb= \ + self.next_iterator(vsds.datasets[self.next_dataset_index], + self.n_left_at_the_end_of_ds,1) + else: + self.next_dataset_index +=1 + if self.next_dataset_index==len(self.vsds.datasets): + self.next_dataset_index = 0 + self.current_iterator,self.n_left_at_the_end_of_ds,self.n_left_in_mb= \ + self.next_iterator(vsds.datasets[self.next_dataset_index],starting_offset,n_batches) def __iter__(self): return self @@ -696,15 +716,17 @@ dataset=self.vsds.datasets[self.next_dataset_index] mb = self.next_iterator.next() if self.n_left_in_mb: - names=self.vsds.fieldNames() extra_mb = [] while self.n_left_in_mb>0: self.move_to_next_dataset() extra_mb.append(self.next_iterator.next()) mb = Example(names, [dataset.valuesVStack(name,[mb[name]]+[b[name] for b in extra_mb]) - for name in names]) + for name in fieldnames]) self.next_row+=minibatch_size + self.next_dataset_row+=minibatch_size + if self.next_row+minibatch_size>len(dataset): + self.move_to_next_dataset() return mb