Mercurial > pylearn
comparison dataset.py @ 37:73c4212ba5b3
Factored the minibatch-writing code into an iterator class inside DataSet
author | bengioy@esprit.iro.umontreal.ca |
---|---|
date | Thu, 24 Apr 2008 12:03:06 -0400 |
parents | 438440ba0627 |
children | d637ad8f7352 |
comparison
equal
deleted
inserted
replaced
36:438440ba0627 | 37:73c4212ba5b3 |
---|---|
96 | 96 |
97 A DataSet sub-class should always redefine the following methods: | 97 A DataSet sub-class should always redefine the following methods: |
98 * __len__ if it is not a stream | 98 * __len__ if it is not a stream |
99 * __getitem__ may not be feasible with some streams | 99 * __getitem__ may not be feasible with some streams |
100 * fieldNames | 100 * fieldNames |
101 * minibatches | 101 * minibatches_nowrap (called by DataSet.minibatches()) |
102 * valuesHStack | 102 * valuesHStack |
103 * valuesVStack | 103 * valuesVStack |
104 For efficiency of implementation, a sub-class might also want to redefine | 104 For efficiency of implementation, a sub-class might also want to redefine |
105 * hasFields | 105 * hasFields |
106 """ | 106 """ |
140 | 140 |
141 The default implementation calls the minibatches iterator and extracts the first example of each field. | 141 The default implementation calls the minibatches iterator and extracts the first example of each field. |
142 """ | 142 """ |
143 return DataSet.MinibatchToSingleExampleIterator(self.minibatches(None, minibatch_size = 1)) | 143 return DataSet.MinibatchToSingleExampleIterator(self.minibatches(None, minibatch_size = 1)) |
144 | 144 |
145 | |
146 class MinibatchWrapAroundIterator(object): | |
147 """ | |
148 An iterator for minibatches that handles the case where we need to wrap around the | |
149 dataset because n_batches*minibatch_size > len(dataset). It is constructed from | |
150 a dataset that provides a minibatch iterator that does not need to handle that problem. | |
151 This class is a utility for dataset subclass writers, so that they do not have to handle | |
152 this issue multiple times. | |
153 """ | |
154 def __init__(self,dataset,fieldnames,minibatch_size,n_batches,offset): | |
155 self.dataset=dataset | |
156 self.fieldnames=fieldnames | |
157 self.minibatch_size=minibatch_size | |
158 self.n_batches=n_batches | |
159 self.n_batches_done=0 | |
160 self.next_row=offset | |
161 self.L=len(dataset) | |
162 assert offset+minibatch_size<=self.L | |
163 ds_nbatches = (self.L-offset)/minibatch_size | |
164 if n_batches is not None: | |
165 ds_nbatches = max(n_batches,ds_nbatches) | |
166 self.iterator = dataset.minibatches_nowrap(fieldnames,minibatch_size,ds_nbatches,offset) | |
167 | |
168 def __iter__(self): | |
169 return self | |
170 | |
171 def next_index(self): | |
172 return self.next_row | |
173 | |
174 def next(self): | |
175 if self.n_batches and self.n_batches_done==self.n_batches: | |
176 raise StopIteration | |
177 upper = self.next_row+minibatch_size | |
178 if upper <=self.L: | |
179 minibatch = self.minibatch_iterator.next() | |
180 else: | |
181 if not self.n_batches: | |
182 raise StopIteration | |
183 # we must concatenate (vstack) the bottom and top parts of our minibatch | |
184 # first get the beginning of our minibatch (top of dataset) | |
185 first_part = self.dataset.minibatches_nowrap(fieldnames,self.L-self.next_row,1,self.next_row).next() | |
186 second_part = self.dataset.minibatches_nowrap(fieldnames,upper-self.L,1,0).next() | |
187 minibatch = Example(self.fieldnames, | |
188 [self.dataset.valuesVStack(name,[first_part[name],second_part[name]]) | |
189 for name in self.fieldnames]) | |
190 self.next_row=upper | |
191 self.n_batches_done+=1 | |
192 if upper >= L: | |
193 self.next_row -= L | |
194 return minibatch | |
195 | |
196 | |
145 minibatches_fieldnames = None | 197 minibatches_fieldnames = None |
146 minibatches_minibatch_size = 1 | 198 minibatches_minibatch_size = 1 |
147 minibatches_n_batches = None | 199 minibatches_n_batches = None |
148 def minibatches(self, | 200 def minibatches(self, |
149 fieldnames = minibatches_fieldnames, | 201 fieldnames = minibatches_fieldnames, |
150 minibatch_size = minibatches_minibatch_size, | 202 minibatch_size = minibatches_minibatch_size, |
151 n_batches = minibatches_n_batches): | 203 n_batches = minibatches_n_batches, |
204 offset = 0): | |
152 """ | 205 """ |
153 Return an iterator that supports three forms of syntax: | 206 Return an iterator that supports three forms of syntax: |
154 | 207 |
155 for i in dataset.minibatches(None,**kwargs): ... | 208 for i in dataset.minibatches(None,**kwargs): ... |
156 | 209 |
191 - n_batches (integer, default None) | 244 - n_batches (integer, default None) |
192 The iterator will loop exactly this many times, and then stop. If None, | 245 The iterator will loop exactly this many times, and then stop. If None, |
193 the derived class can choose a default. If (-1), then the returned | 246 the derived class can choose a default. If (-1), then the returned |
194 iterator should support looping indefinitely. | 247 iterator should support looping indefinitely. |
195 | 248 |
249 - offset (integer, default 0) | |
250 The iterator will start at example 'offset' in the dataset, rather than the default. | |
251 | |
196 Note: A list-like container is something like a tuple, list, numpy.ndarray or | 252 Note: A list-like container is something like a tuple, list, numpy.ndarray or |
197 any other object that supports integer indexing and slicing. | 253 any other object that supports integer indexing and slicing. |
198 | 254 |
199 """ | 255 """ |
256 return MinibatchWrapAroundIterator(self,fieldnames,minibatch_size,n_batches,offset) | |
257 | |
258 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): | |
259 """ | |
260 This is the minibatches iterator generator that sub-classes must define. | |
261 It does not need to worry about wrapping around multiple times across the dataset, | |
262 as this is handled by MinibatchWrapAroundIterator when DataSet.minibatches() is called. | |
263 The next() method of the returned iterator does not even need to worry about | |
264 the termination condition (as StopIteration will be raised by DataSet.minibatches | |
265 before an improper call to minibatches_nowrap's next() is made). | |
266 That next() method can assert that its next row will always be within [0,len(dataset)). | |
267 The iterator returned by minibatches_nowrap does not need to implement | |
268 a next_index() method either, as this will be provided by MinibatchWrapAroundIterator. | |
269 """ | |
200 raise AbstractFunction() | 270 raise AbstractFunction() |
201 | |
202 | 271 |
203 def __len__(self): | 272 def __len__(self): |
204 """ | 273 """ |
205 len(dataset) returns the number of examples in the dataset. | 274 len(dataset) returns the number of examples in the dataset. |
206 By default, a DataSet is a 'stream', i.e. it has an unbounded (infinite) length. | 275 By default, a DataSet is a 'stream', i.e. it has an unbounded (infinite) length. |
356 dataset2 = dataset1.fields().examples() | 425 dataset2 = dataset1.fields().examples() |
357 and dataset2 should behave exactly like dataset1 (in fact by default dataset2==dataset1). | 426 and dataset2 should behave exactly like dataset1 (in fact by default dataset2==dataset1). |
358 """ | 427 """ |
359 def __init__(self,dataset,*fieldnames): | 428 def __init__(self,dataset,*fieldnames): |
360 self.dataset=dataset | 429 self.dataset=dataset |
361 assert dataset.hasField(*fieldnames) | 430 assert dataset.hasFields(*fieldnames) |
362 LookupList.__init__(self,dataset.fieldNames(), | 431 LookupList.__init__(self,dataset.fieldNames(), |
363 dataset.minibatches(fieldnames if len(fieldnames)>0 else self.fieldNames(),minibatch_size=len(dataset)).next() | 432 dataset.minibatches(fieldnames if len(fieldnames)>0 else self.fieldNames(), |
433 minibatch_size=len(dataset)).next() | |
364 def examples(self): | 434 def examples(self): |
365 return self.dataset | 435 return self.dataset |
366 | 436 |
367 def __or__(self,other): | 437 def __or__(self,other): |
368 """ | 438 """ |
376 fields1 + fields2 is a DataSetFields that whose list of fields is the concatenation | 446 fields1 + fields2 is a DataSetFields that whose list of fields is the concatenation |
377 of the fields of DataSetFields fields1 and fields2. | 447 of the fields of DataSetFields fields1 and fields2. |
378 """ | 448 """ |
379 return (self.examples() | other.examples()).fields() | 449 return (self.examples() | other.examples()).fields() |
380 | 450 |
451 | |
381 class MinibatchDataSet(DataSet): | 452 class MinibatchDataSet(DataSet): |
382 """ | 453 """ |
383 Turn a LookupList of same-length fields into an example-iterable dataset. | 454 Turn a LookupList of same-length fields into an example-iterable dataset. |
384 Each element of the lookup-list should be an iterable and sliceable, all of the same length. | 455 Each element of the lookup-list should be an iterable and sliceable, all of the same length. |
385 """ | 456 """ |
405 return Example(self.fields.keys(),[field[i] for field in self.fields]) | 476 return Example(self.fields.keys(),[field[i] for field in self.fields]) |
406 | 477 |
407 def fieldNames(self): | 478 def fieldNames(self): |
408 return self.fields.keys() | 479 return self.fields.keys() |
409 | 480 |
410 def hasField(self,*fieldnames): | 481 def hasFields(self,*fieldnames): |
411 for fieldname in fieldnames: | 482 for fieldname in fieldnames: |
412 if fieldname not in self.fields: | 483 if fieldname not in self.fields: |
413 return False | 484 return False |
414 return True | 485 return True |
415 | 486 |
416 def minibatches(self, | 487 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): |
417 fieldnames = minibatches_fieldnames, | |
418 minibatch_size = minibatches_minibatch_size, | |
419 n_batches = minibatches_n_batches): | |
420 class Iterator(object): | 488 class Iterator(object): |
421 def __init__(self,ds): | 489 def __init__(self,ds): |
422 self.ds=ds | 490 self.ds=ds |
423 self.next_example=0 | 491 self.next_example=offset |
424 self.n_batches_done=0 | |
425 assert minibatch_size > 0 | 492 assert minibatch_size > 0 |
426 if minibatch_size > ds.length | 493 if offset+minibatch_size > ds.length |
427 raise NotImplementedError() | 494 raise NotImplementedError() |
428 def __iter__(self): | 495 def __iter__(self): |
429 return self | 496 return self |
430 def next_index(self): | |
431 return self.next_example | |
432 def next(self): | 497 def next(self): |
433 upper = next_example+minibatch_size | 498 upper = next_example+minibatch_size |
434 if upper<=self.ds.length: | 499 assert upper<=self.ds.length |
435 minibatch = Example(self.ds.fields.keys(), | 500 minibatch = Example(self.ds.fields.keys(), |
436 [field[next_example:upper] | 501 [field[next_example:upper] |
437 for field in self.ds.fields]) | 502 for field in self.ds.fields]) |
438 else: # we must concatenate (vstack) the bottom and top parts of our minibatch | |
439 minibatch = Example(self.ds.fields.keys(), | |
440 [self.ds.valuesVStack(name,[value[next_example:], | |
441 value[0:upper-self.ds.length]]) | |
442 for name,value in self.ds.fields.items()]) | |
443 self.next_example+=minibatch_size | 503 self.next_example+=minibatch_size |
444 self.n_batches_done+=1 | |
445 if n_batches: | |
446 if self.n_batches_done==n_batches: | |
447 raise StopIteration | |
448 if self.next_example>=self.ds.length: | |
449 self.next_example-=self.ds.length | |
450 else: | |
451 if self.next_example>=self.ds.length: | |
452 raise StopIteration | |
453 return DataSetFields(MinibatchDataSet(minibatch),fieldnames) | 504 return DataSetFields(MinibatchDataSet(minibatch),fieldnames) |
454 | 505 |
455 return Iterator(self) | 506 return MinibatchWrapAroundIterator(self,fieldnames,minibatch_size,n_batches,offset) |
456 | 507 |
457 def valuesVStack(self,fieldname,fieldvalues): | 508 def valuesVStack(self,fieldname,fieldvalues): |
458 return self.values_vstack(fieldname,fieldvalues) | 509 return self.values_vstack(fieldname,fieldvalues) |
459 | 510 |
460 def valuesHStack(self,fieldnames,fieldvalues): | 511 def valuesHStack(self,fieldnames,fieldvalues): |
502 self.fieldname2dataset[fieldname]=i | 553 self.fieldname2dataset[fieldname]=i |
503 for fieldname,i in names_to_change: | 554 for fieldname,i in names_to_change: |
504 del self.fieldname2dataset[fieldname] | 555 del self.fieldname2dataset[fieldname] |
505 self.fieldname2dataset[rename_field(fieldname,self.datasets[i],i)]=i | 556 self.fieldname2dataset[rename_field(fieldname,self.datasets[i],i)]=i |
506 | 557 |
507 def hasField(self,*fieldnames): | 558 def hasFields(self,*fieldnames): |
508 for fieldname in fieldnames: | 559 for fieldname in fieldnames: |
509 if not fieldname in self.fieldname2dataset: | 560 if not fieldname in self.fieldname2dataset: |
510 return False | 561 return False |
511 return True | 562 return True |
512 | 563 |
513 def fieldNames(self): | 564 def fieldNames(self): |
514 return self.fieldname2dataset.keys() | 565 return self.fieldname2dataset.keys() |
515 | 566 |
516 def minibatches(self, | 567 def minibatches_nowrap(self, |
517 fieldnames = minibatches_fieldnames, | 568 fieldnames = minibatches_fieldnames, |
518 minibatch_size = minibatches_minibatch_size, | 569 minibatch_size = minibatches_minibatch_size, |
519 n_batches = minibatches_n_batches): | 570 n_batches = minibatches_n_batches, |
571 offset = 0): | |
520 | 572 |
521 class Iterator(object): | 573 class Iterator(object): |
522 def __init__(self,hsds,iterators): | 574 def __init__(self,hsds,iterators): |
523 self.hsds=hsds | 575 self.hsds=hsds |
524 self.iterators=iterators | 576 self.iterators=iterators |
525 def __iter__(self): | 577 def __iter__(self): |
526 return self | 578 return self |
527 def next_index(self): | |
528 return self.iterators[0].next_index() | |
529 def next(self): | 579 def next(self): |
530 # concatenate all the fields of the minibatches | 580 # concatenate all the fields of the minibatches |
531 minibatch = reduce(LookupList.__add__,[iterator.next() for iterator in self.iterators]) | 581 minibatch = reduce(LookupList.__add__,[iterator.next() for iterator in self.iterators]) |
532 # and return a DataSetFields whose dataset is the transpose (=examples()) of this minibatch | 582 # and return a DataSetFields whose dataset is the transpose (=examples()) of this minibatch |
533 return DataSetFields(MinibatchDataSet(minibatch,self.hsds.valuesVStack, | 583 return DataSetFields(MinibatchDataSet(minibatch,self.hsds.valuesVStack, |
543 for fieldname in fieldnames: | 593 for fieldname in fieldnames: |
544 dataset=self.datasets[self.fieldnames2dataset[fieldname]] | 594 dataset=self.datasets[self.fieldnames2dataset[fieldname]] |
545 datasets.add(dataset) | 595 datasets.add(dataset) |
546 fields_in_dataset[dataset].append(fieldname) | 596 fields_in_dataset[dataset].append(fieldname) |
547 datasets=list(datasets) | 597 datasets=list(datasets) |
548 iterators=[dataset.minibatches(fields_in_dataset[dataset],minibatch_size,n_batches) | 598 iterators=[dataset.minibatches(fields_in_dataset[dataset],minibatch_size,n_batches,offset) |
549 for dataset in datasets] | 599 for dataset in datasets] |
550 else: | 600 else: |
551 datasets=self.datasets | 601 datasets=self.datasets |
552 iterators=[dataset.minibatches(None,minibatch_size,n_batches) for dataset in datasets] | 602 iterators=[dataset.minibatches(None,minibatch_size,n_batches,offset) for dataset in datasets] |
553 return Iterator(self,iterators) | 603 return Iterator(self,iterators) |
554 | 604 |
555 | 605 |
556 def valuesVStack(self,fieldname,fieldvalues): | 606 def valuesVStack(self,fieldname,fieldvalues): |
557 return self.datasets[self.fieldname2dataset[fieldname]].valuesVStack(fieldname,fieldvalues) | 607 return self.datasets[self.fieldname2dataset[fieldname]].valuesVStack(fieldname,fieldvalues) |
575 """ | 625 """ |
576 def __init__(self,datasets): | 626 def __init__(self,datasets): |
577 self.datasets=datasets | 627 self.datasets=datasets |
578 self.length=0 | 628 self.length=0 |
579 self.index2dataset={} | 629 self.index2dataset={} |
580 # we use this map from row index to dataset index for constant-time random access of examples, | 630 assert len(datasets)>0 |
581 # to avoid having to search for the appropriate dataset each time and slice is asked for | 631 fieldnames = datasets[-1].fieldNames() |
632 # We use this map from row index to dataset index for constant-time random access of examples, | |
633 # to avoid having to search for the appropriate dataset each time and slice is asked for. | |
582 for dataset,k in enumerate(datasets[0:-1]): | 634 for dataset,k in enumerate(datasets[0:-1]): |
583 L=len(dataset) | 635 L=len(dataset) |
584 assert L<DataSet.infinity | 636 assert L<DataSet.infinity |
585 for i in xrange(L): | 637 for i in xrange(L): |
586 self.index2dataset[self.length+i]=k | 638 self.index2dataset[self.length+i]=k |
587 self.length+=L | 639 self.length+=L |
640 assert dataset.fieldNames()==fieldnames | |
588 self.last_start=self.length | 641 self.last_start=self.length |
589 self.length+=len(datasets[-1]) | 642 self.length+=len(datasets[-1]) |
590 | 643 # If length is very large, we should use a more memory-efficient mechanism |
591 | 644 # that does not store all indices |
645 if self.length>1000000: | |
646 # 1 million entries would require about 60 meg for the index2dataset map | |
647 # TODO | |
648 print "A more efficient mechanism for index2dataset should be implemented" | |
649 | |
650 def __len__(self): | |
651 return self.length | |
652 | |
653 def fieldNames(self): | |
654 return self.datasets[0].fieldNames() | |
655 | |
656 def hasFields(self,*fieldnames): | |
657 return self.datasets[0].hasFields(*fieldnames) | |
658 | |
659 def minibatches_nowrap(self, | |
660 fieldnames = minibatches_fieldnames, | |
661 minibatch_size = minibatches_minibatch_size, | |
662 n_batches = minibatches_n_batches, | |
663 offset = 0): | |
664 class Iterator(object): | |
665 def __init__(self,vsds): | |
666 self.vsds=vsds | |
667 self.next_row=offset | |
668 self.next_dataset_index=0 | |
669 self.next_dataset_row=0 | |
670 self.current_iterator,self.n_left_at_the_end_of_ds,self.n_left_in_mb= \ | |
671 self.next_iterator(vsds.datasets[0],offset,n_batches) | |
672 | |
673 def next_iterator(self,dataset,starting_offset,batches_left): | |
674 L=len(dataset) | |
675 ds_nbatches = (L-starting_offset)/minibatch_size | |
676 if batches_left is not None: | |
677 ds_nbatches = max(batches_left,ds_nbatches) | |
678 if minibatch_size>L: | |
679 ds_minibatch_size=L | |
680 n_left_in_mb=minibatch_size-L | |
681 else: n_left_in_mb=0 | |
682 return dataset.minibatches(fieldnames,minibatch_size,ds_nbatches,starting_offset), \ | |
683 L-(starting_offset+ds_nbatches*minibatch_size), n_left_in_mb | |
684 | |
685 def move_to_next_dataset(self): | |
686 self.next_dataset_index +=1 | |
687 if self.next_dataset_index==len(self.vsds.datasets): | |
688 self.next_dataset_index = 0 | |
689 self.current_iterator,self.n_left_at_the_end_of_ds,self.n_left_in_mb= \ | |
690 self.next_iterator(vsds.datasets[self.next_dataset_index],starting_offset,n_batches) | |
691 | |
692 def __iter__(self): | |
693 return self | |
694 | |
695 def next(self): | |
696 dataset=self.vsds.datasets[self.next_dataset_index] | |
697 mb = self.next_iterator.next() | |
698 if self.n_left_in_mb: | |
699 names=self.vsds.fieldNames() | |
700 extra_mb = [] | |
701 while self.n_left_in_mb>0: | |
702 self.move_to_next_dataset() | |
703 extra_mb.append(self.next_iterator.next()) | |
704 mb = Example(names, | |
705 [dataset.valuesVStack(name,[mb[name]]+[b[name] for b in extra_mb]) | |
706 for name in names]) | |
707 self.next_row+=minibatch_size | |
708 return mb | |
709 | |
710 | |
592 def supervised_learning_dataset(src_dataset,input_fields,target_fields,weight_field=None): | 711 def supervised_learning_dataset(src_dataset,input_fields,target_fields,weight_field=None): |
593 """ | 712 """ |
594 Wraps an arbitrary DataSet into one for supervised learning tasks by forcing the | 713 Wraps an arbitrary DataSet into one for supervised learning tasks by forcing the |
595 user to define a set of fields as the 'input' field and a set of fields | 714 user to define a set of fields as the 'input' field and a set of fields |
596 as the 'target' field. Optionally, a single weight_field can also be defined. | 715 as the 'target' field. Optionally, a single weight_field can also be defined. |