Mercurial > pylearn
comparison dataset.py @ 38:d637ad8f7352
Finished first untested version of VStackedDataset
author | bengioy@esprit.iro.umontreal.ca |
---|---|
date | Thu, 24 Apr 2008 13:25:39 -0400 |
parents | 73c4212ba5b3 |
children | c682c6e9bf93 |
comparison
equal
deleted
inserted
replaced
37:73c4212ba5b3 | 38:d637ad8f7352 |
---|---|
147 """ | 147 """ |
148 An iterator for minibatches that handles the case where we need to wrap around the | 148 An iterator for minibatches that handles the case where we need to wrap around the |
149 dataset because n_batches*minibatch_size > len(dataset). It is constructed from | 149 dataset because n_batches*minibatch_size > len(dataset). It is constructed from |
150 a dataset that provides a minibatch iterator that does not need to handle that problem. | 150 a dataset that provides a minibatch iterator that does not need to handle that problem. |
151 This class is a utility for dataset subclass writers, so that they do not have to handle | 151 This class is a utility for dataset subclass writers, so that they do not have to handle |
152 this issue multiple times. | 152 this issue multiple times, nor check that fieldnames are valid, nor handle the |
153 empty fieldnames (meaning 'use all the fields'). | |
153 """ | 154 """ |
154 def __init__(self,dataset,fieldnames,minibatch_size,n_batches,offset): | 155 def __init__(self,dataset,fieldnames,minibatch_size,n_batches,offset): |
155 self.dataset=dataset | 156 self.dataset=dataset |
156 self.fieldnames=fieldnames | 157 self.fieldnames=fieldnames |
157 self.minibatch_size=minibatch_size | 158 self.minibatch_size=minibatch_size |
161 self.L=len(dataset) | 162 self.L=len(dataset) |
162 assert offset+minibatch_size<=self.L | 163 assert offset+minibatch_size<=self.L |
163 ds_nbatches = (self.L-offset)/minibatch_size | 164 ds_nbatches = (self.L-offset)/minibatch_size |
164 if n_batches is not None: | 165 if n_batches is not None: |
165 ds_nbatches = max(n_batches,ds_nbatches) | 166 ds_nbatches = max(n_batches,ds_nbatches) |
167 if fieldnames: | |
168 assert dataset.hasFields(*fieldnames) | |
169 else: | |
170 fieldnames=dataset.fieldNames() | |
166 self.iterator = dataset.minibatches_nowrap(fieldnames,minibatch_size,ds_nbatches,offset) | 171 self.iterator = dataset.minibatches_nowrap(fieldnames,minibatch_size,ds_nbatches,offset) |
167 | 172 |
168 def __iter__(self): | 173 def __iter__(self): |
169 return self | 174 return self |
170 | 175 |
627 self.datasets=datasets | 632 self.datasets=datasets |
628 self.length=0 | 633 self.length=0 |
629 self.index2dataset={} | 634 self.index2dataset={} |
630 assert len(datasets)>0 | 635 assert len(datasets)>0 |
631 fieldnames = datasets[-1].fieldNames() | 636 fieldnames = datasets[-1].fieldNames() |
637 self.datasets_start_row=[] | |
632 # We use this map from row index to dataset index for constant-time random access of examples, | 638 # We use this map from row index to dataset index for constant-time random access of examples, |
633 # to avoid having to search for the appropriate dataset each time and slice is asked for. | 639 # to avoid having to search for the appropriate dataset each time and slice is asked for. |
634 for dataset,k in enumerate(datasets[0:-1]): | 640 for dataset,k in enumerate(datasets[0:-1]): |
635 L=len(dataset) | 641 L=len(dataset) |
636 assert L<DataSet.infinity | 642 assert L<DataSet.infinity |
637 for i in xrange(L): | 643 for i in xrange(L): |
638 self.index2dataset[self.length+i]=k | 644 self.index2dataset[self.length+i]=k |
645 self.datasets_start_row.append(self.length) | |
639 self.length+=L | 646 self.length+=L |
640 assert dataset.fieldNames()==fieldnames | 647 assert dataset.fieldNames()==fieldnames |
641 self.last_start=self.length | 648 self.datasets_start_row.append(self.length) |
642 self.length+=len(datasets[-1]) | 649 self.length+=len(datasets[-1]) |
643 # If length is very large, we should use a more memory-efficient mechanism | 650 # If length is very large, we should use a more memory-efficient mechanism |
644 # that does not store all indices | 651 # that does not store all indices |
645 if self.length>1000000: | 652 if self.length>1000000: |
646 # 1 million entries would require about 60 meg for the index2dataset map | 653 # 1 million entries would require about 60 meg for the index2dataset map |
654 return self.datasets[0].fieldNames() | 661 return self.datasets[0].fieldNames() |
655 | 662 |
656 def hasFields(self,*fieldnames): | 663 def hasFields(self,*fieldnames): |
657 return self.datasets[0].hasFields(*fieldnames) | 664 return self.datasets[0].hasFields(*fieldnames) |
658 | 665 |
666 def locate_row(self,row): | |
667 """Return (dataset_index, row_within_dataset) for global row number""" | |
668 dataset_index = self.index2dataset[row] | |
669 row_within_dataset = self.datasets_start_row[dataset_index] | |
670 return dataset_index, row_within_dataset | |
671 | |
659 def minibatches_nowrap(self, | 672 def minibatches_nowrap(self, |
660 fieldnames = minibatches_fieldnames, | 673 fieldnames = minibatches_fieldnames, |
661 minibatch_size = minibatches_minibatch_size, | 674 minibatch_size = minibatches_minibatch_size, |
662 n_batches = minibatches_n_batches, | 675 n_batches = minibatches_n_batches, |
663 offset = 0): | 676 offset = 0): |
677 | |
664 class Iterator(object): | 678 class Iterator(object): |
665 def __init__(self,vsds): | 679 def __init__(self,vsds): |
666 self.vsds=vsds | 680 self.vsds=vsds |
667 self.next_row=offset | 681 self.next_row=offset |
668 self.next_dataset_index=0 | 682 self.next_dataset_index,self.next_dataset_row=self.vsds.locate_row(offset) |
669 self.next_dataset_row=0 | |
670 self.current_iterator,self.n_left_at_the_end_of_ds,self.n_left_in_mb= \ | 683 self.current_iterator,self.n_left_at_the_end_of_ds,self.n_left_in_mb= \ |
671 self.next_iterator(vsds.datasets[0],offset,n_batches) | 684 self.next_iterator(vsds.datasets[0],offset,n_batches) |
672 | 685 |
673 def next_iterator(self,dataset,starting_offset,batches_left): | 686 def next_iterator(self,dataset,starting_offset,batches_left): |
674 L=len(dataset) | 687 L=len(dataset) |
676 if batches_left is not None: | 689 if batches_left is not None: |
677 ds_nbatches = max(batches_left,ds_nbatches) | 690 ds_nbatches = max(batches_left,ds_nbatches) |
678 if minibatch_size>L: | 691 if minibatch_size>L: |
679 ds_minibatch_size=L | 692 ds_minibatch_size=L |
680 n_left_in_mb=minibatch_size-L | 693 n_left_in_mb=minibatch_size-L |
681 else: n_left_in_mb=0 | 694 ds_nbatches=1 |
695 else: | |
696 n_left_in_mb=0 | |
682 return dataset.minibatches(fieldnames,minibatch_size,ds_nbatches,starting_offset), \ | 697 return dataset.minibatches(fieldnames,minibatch_size,ds_nbatches,starting_offset), \ |
683 L-(starting_offset+ds_nbatches*minibatch_size), n_left_in_mb | 698 L-(starting_offset+ds_nbatches*minibatch_size), n_left_in_mb |
684 | 699 |
685 def move_to_next_dataset(self): | 700 def move_to_next_dataset(self): |
686 self.next_dataset_index +=1 | 701 if self.n_left_at_the_end_of_ds>0: |
687 if self.next_dataset_index==len(self.vsds.datasets): | 702 self.current_iterator,self.n_left_at_the_end_of_ds,self.n_left_in_mb= \ |
688 self.next_dataset_index = 0 | 703 self.next_iterator(vsds.datasets[self.next_dataset_index], |
689 self.current_iterator,self.n_left_at_the_end_of_ds,self.n_left_in_mb= \ | 704 self.n_left_at_the_end_of_ds,1) |
690 self.next_iterator(vsds.datasets[self.next_dataset_index],starting_offset,n_batches) | 705 else: |
706 self.next_dataset_index +=1 | |
707 if self.next_dataset_index==len(self.vsds.datasets): | |
708 self.next_dataset_index = 0 | |
709 self.current_iterator,self.n_left_at_the_end_of_ds,self.n_left_in_mb= \ | |
710 self.next_iterator(vsds.datasets[self.next_dataset_index],starting_offset,n_batches) | |
691 | 711 |
692 def __iter__(self): | 712 def __iter__(self): |
693 return self | 713 return self |
694 | 714 |
695 def next(self): | 715 def next(self): |
696 dataset=self.vsds.datasets[self.next_dataset_index] | 716 dataset=self.vsds.datasets[self.next_dataset_index] |
697 mb = self.next_iterator.next() | 717 mb = self.next_iterator.next() |
698 if self.n_left_in_mb: | 718 if self.n_left_in_mb: |
699 names=self.vsds.fieldNames() | |
700 extra_mb = [] | 719 extra_mb = [] |
701 while self.n_left_in_mb>0: | 720 while self.n_left_in_mb>0: |
702 self.move_to_next_dataset() | 721 self.move_to_next_dataset() |
703 extra_mb.append(self.next_iterator.next()) | 722 extra_mb.append(self.next_iterator.next()) |
704 mb = Example(names, | 723 mb = Example(names, |
705 [dataset.valuesVStack(name,[mb[name]]+[b[name] for b in extra_mb]) | 724 [dataset.valuesVStack(name,[mb[name]]+[b[name] for b in extra_mb]) |
706 for name in names]) | 725 for name in fieldnames]) |
707 self.next_row+=minibatch_size | 726 self.next_row+=minibatch_size |
727 self.next_dataset_row+=minibatch_size | |
728 if self.next_row+minibatch_size>len(dataset): | |
729 self.move_to_next_dataset() | |
708 return mb | 730 return mb |
709 | 731 |
710 | 732 |
711 def supervised_learning_dataset(src_dataset,input_fields,target_fields,weight_field=None): | 733 def supervised_learning_dataset(src_dataset,input_fields,target_fields,weight_field=None): |
712 """ | 734 """ |