comparison dataset.py @ 38:d637ad8f7352

Finished first untested version of VStackedDataset
author bengioy@esprit.iro.umontreal.ca
date Thu, 24 Apr 2008 13:25:39 -0400
parents 73c4212ba5b3
children c682c6e9bf93
comparison
equal deleted inserted replaced
37:73c4212ba5b3 38:d637ad8f7352
147 """ 147 """
148 An iterator for minibatches that handles the case where we need to wrap around the 148 An iterator for minibatches that handles the case where we need to wrap around the
149 dataset because n_batches*minibatch_size > len(dataset). It is constructed from 149 dataset because n_batches*minibatch_size > len(dataset). It is constructed from
150 a dataset that provides a minibatch iterator that does not need to handle that problem. 150 a dataset that provides a minibatch iterator that does not need to handle that problem.
151 This class is a utility for dataset subclass writers, so that they do not have to handle 151 This class is a utility for dataset subclass writers, so that they do not have to handle
152 this issue multiple times. 152 this issue multiple times, nor check that fieldnames are valid, nor handle the
153 empty fieldnames (meaning 'use all the fields').
153 """ 154 """
154 def __init__(self,dataset,fieldnames,minibatch_size,n_batches,offset): 155 def __init__(self,dataset,fieldnames,minibatch_size,n_batches,offset):
155 self.dataset=dataset 156 self.dataset=dataset
156 self.fieldnames=fieldnames 157 self.fieldnames=fieldnames
157 self.minibatch_size=minibatch_size 158 self.minibatch_size=minibatch_size
161 self.L=len(dataset) 162 self.L=len(dataset)
162 assert offset+minibatch_size<=self.L 163 assert offset+minibatch_size<=self.L
163 ds_nbatches = (self.L-offset)/minibatch_size 164 ds_nbatches = (self.L-offset)/minibatch_size
164 if n_batches is not None: 165 if n_batches is not None:
165 ds_nbatches = max(n_batches,ds_nbatches) 166 ds_nbatches = max(n_batches,ds_nbatches)
167 if fieldnames:
168 assert dataset.hasFields(*fieldnames)
169 else:
170 fieldnames=dataset.fieldNames()
166 self.iterator = dataset.minibatches_nowrap(fieldnames,minibatch_size,ds_nbatches,offset) 171 self.iterator = dataset.minibatches_nowrap(fieldnames,minibatch_size,ds_nbatches,offset)
167 172
168 def __iter__(self): 173 def __iter__(self):
169 return self 174 return self
170 175
627 self.datasets=datasets 632 self.datasets=datasets
628 self.length=0 633 self.length=0
629 self.index2dataset={} 634 self.index2dataset={}
630 assert len(datasets)>0 635 assert len(datasets)>0
631 fieldnames = datasets[-1].fieldNames() 636 fieldnames = datasets[-1].fieldNames()
637 self.datasets_start_row=[]
632 # We use this map from row index to dataset index for constant-time random access of examples, 638 # We use this map from row index to dataset index for constant-time random access of examples,
633 # to avoid having to search for the appropriate dataset each time and slice is asked for. 639 # to avoid having to search for the appropriate dataset each time and slice is asked for.
634 for dataset,k in enumerate(datasets[0:-1]): 640 for dataset,k in enumerate(datasets[0:-1]):
635 L=len(dataset) 641 L=len(dataset)
636 assert L<DataSet.infinity 642 assert L<DataSet.infinity
637 for i in xrange(L): 643 for i in xrange(L):
638 self.index2dataset[self.length+i]=k 644 self.index2dataset[self.length+i]=k
645 self.datasets_start_row.append(self.length)
639 self.length+=L 646 self.length+=L
640 assert dataset.fieldNames()==fieldnames 647 assert dataset.fieldNames()==fieldnames
641 self.last_start=self.length 648 self.datasets_start_row.append(self.length)
642 self.length+=len(datasets[-1]) 649 self.length+=len(datasets[-1])
643 # If length is very large, we should use a more memory-efficient mechanism 650 # If length is very large, we should use a more memory-efficient mechanism
644 # that does not store all indices 651 # that does not store all indices
645 if self.length>1000000: 652 if self.length>1000000:
646 # 1 million entries would require about 60 meg for the index2dataset map 653 # 1 million entries would require about 60 meg for the index2dataset map
654 return self.datasets[0].fieldNames() 661 return self.datasets[0].fieldNames()
655 662
656 def hasFields(self,*fieldnames): 663 def hasFields(self,*fieldnames):
657 return self.datasets[0].hasFields(*fieldnames) 664 return self.datasets[0].hasFields(*fieldnames)
658 665
666 def locate_row(self,row):
667 """Return (dataset_index, row_within_dataset) for global row number"""
668 dataset_index = self.index2dataset[row]
669 row_within_dataset = self.datasets_start_row[dataset_index]
670 return dataset_index, row_within_dataset
671
659 def minibatches_nowrap(self, 672 def minibatches_nowrap(self,
660 fieldnames = minibatches_fieldnames, 673 fieldnames = minibatches_fieldnames,
661 minibatch_size = minibatches_minibatch_size, 674 minibatch_size = minibatches_minibatch_size,
662 n_batches = minibatches_n_batches, 675 n_batches = minibatches_n_batches,
663 offset = 0): 676 offset = 0):
677
664 class Iterator(object): 678 class Iterator(object):
665 def __init__(self,vsds): 679 def __init__(self,vsds):
666 self.vsds=vsds 680 self.vsds=vsds
667 self.next_row=offset 681 self.next_row=offset
668 self.next_dataset_index=0 682 self.next_dataset_index,self.next_dataset_row=self.vsds.locate_row(offset)
669 self.next_dataset_row=0
670 self.current_iterator,self.n_left_at_the_end_of_ds,self.n_left_in_mb= \ 683 self.current_iterator,self.n_left_at_the_end_of_ds,self.n_left_in_mb= \
671 self.next_iterator(vsds.datasets[0],offset,n_batches) 684 self.next_iterator(vsds.datasets[0],offset,n_batches)
672 685
673 def next_iterator(self,dataset,starting_offset,batches_left): 686 def next_iterator(self,dataset,starting_offset,batches_left):
674 L=len(dataset) 687 L=len(dataset)
676 if batches_left is not None: 689 if batches_left is not None:
677 ds_nbatches = max(batches_left,ds_nbatches) 690 ds_nbatches = max(batches_left,ds_nbatches)
678 if minibatch_size>L: 691 if minibatch_size>L:
679 ds_minibatch_size=L 692 ds_minibatch_size=L
680 n_left_in_mb=minibatch_size-L 693 n_left_in_mb=minibatch_size-L
681 else: n_left_in_mb=0 694 ds_nbatches=1
695 else:
696 n_left_in_mb=0
682 return dataset.minibatches(fieldnames,minibatch_size,ds_nbatches,starting_offset), \ 697 return dataset.minibatches(fieldnames,minibatch_size,ds_nbatches,starting_offset), \
683 L-(starting_offset+ds_nbatches*minibatch_size), n_left_in_mb 698 L-(starting_offset+ds_nbatches*minibatch_size), n_left_in_mb
684 699
685 def move_to_next_dataset(self): 700 def move_to_next_dataset(self):
686 self.next_dataset_index +=1 701 if self.n_left_at_the_end_of_ds>0:
687 if self.next_dataset_index==len(self.vsds.datasets): 702 self.current_iterator,self.n_left_at_the_end_of_ds,self.n_left_in_mb= \
688 self.next_dataset_index = 0 703 self.next_iterator(vsds.datasets[self.next_dataset_index],
689 self.current_iterator,self.n_left_at_the_end_of_ds,self.n_left_in_mb= \ 704 self.n_left_at_the_end_of_ds,1)
690 self.next_iterator(vsds.datasets[self.next_dataset_index],starting_offset,n_batches) 705 else:
706 self.next_dataset_index +=1
707 if self.next_dataset_index==len(self.vsds.datasets):
708 self.next_dataset_index = 0
709 self.current_iterator,self.n_left_at_the_end_of_ds,self.n_left_in_mb= \
710 self.next_iterator(vsds.datasets[self.next_dataset_index],starting_offset,n_batches)
691 711
692 def __iter__(self): 712 def __iter__(self):
693 return self 713 return self
694 714
695 def next(self): 715 def next(self):
696 dataset=self.vsds.datasets[self.next_dataset_index] 716 dataset=self.vsds.datasets[self.next_dataset_index]
697 mb = self.next_iterator.next() 717 mb = self.next_iterator.next()
698 if self.n_left_in_mb: 718 if self.n_left_in_mb:
699 names=self.vsds.fieldNames()
700 extra_mb = [] 719 extra_mb = []
701 while self.n_left_in_mb>0: 720 while self.n_left_in_mb>0:
702 self.move_to_next_dataset() 721 self.move_to_next_dataset()
703 extra_mb.append(self.next_iterator.next()) 722 extra_mb.append(self.next_iterator.next())
704 mb = Example(names, 723 mb = Example(names,
705 [dataset.valuesVStack(name,[mb[name]]+[b[name] for b in extra_mb]) 724 [dataset.valuesVStack(name,[mb[name]]+[b[name] for b in extra_mb])
706 for name in names]) 725 for name in fieldnames])
707 self.next_row+=minibatch_size 726 self.next_row+=minibatch_size
727 self.next_dataset_row+=minibatch_size
728 if self.next_row+minibatch_size>len(dataset):
729 self.move_to_next_dataset()
708 return mb 730 return mb
709 731
710 732
711 def supervised_learning_dataset(src_dataset,input_fields,target_fields,weight_field=None): 733 def supervised_learning_dataset(src_dataset,input_fields,target_fields,weight_field=None):
712 """ 734 """