comparison dataset.py @ 37:73c4212ba5b3

Factored the minibatch-writing code into an iterator class inside DataSet
author bengioy@esprit.iro.umontreal.ca
date Thu, 24 Apr 2008 12:03:06 -0400
parents 438440ba0627
children d637ad8f7352
comparison
equal deleted inserted replaced
36:438440ba0627 37:73c4212ba5b3
96 96
97 A DataSet sub-class should always redefine the following methods: 97 A DataSet sub-class should always redefine the following methods:
98 * __len__ if it is not a stream 98 * __len__ if it is not a stream
99 * __getitem__ may not be feasible with some streams 99 * __getitem__ may not be feasible with some streams
100 * fieldNames 100 * fieldNames
101 * minibatches 101 * minibatches_nowrap (called by DataSet.minibatches())
102 * valuesHStack 102 * valuesHStack
103 * valuesVStack 103 * valuesVStack
104 For efficiency of implementation, a sub-class might also want to redefine 104 For efficiency of implementation, a sub-class might also want to redefine
105 * hasFields 105 * hasFields
106 """ 106 """
140 140
141 The default implementation calls the minibatches iterator and extracts the first example of each field. 141 The default implementation calls the minibatches iterator and extracts the first example of each field.
142 """ 142 """
143 return DataSet.MinibatchToSingleExampleIterator(self.minibatches(None, minibatch_size = 1)) 143 return DataSet.MinibatchToSingleExampleIterator(self.minibatches(None, minibatch_size = 1))
144 144
145
146 class MinibatchWrapAroundIterator(object):
147 """
148 An iterator for minibatches that handles the case where we need to wrap around the
149 dataset because n_batches*minibatch_size > len(dataset). It is constructed from
150 a dataset that provides a minibatch iterator that does not need to handle that problem.
151 This class is a utility for dataset subclass writers, so that they do not have to handle
152 this issue multiple times.
153 """
154 def __init__(self,dataset,fieldnames,minibatch_size,n_batches,offset):
155 self.dataset=dataset
156 self.fieldnames=fieldnames
157 self.minibatch_size=minibatch_size
158 self.n_batches=n_batches
159 self.n_batches_done=0
160 self.next_row=offset
161 self.L=len(dataset)
162 assert offset+minibatch_size<=self.L
163 ds_nbatches = (self.L-offset)/minibatch_size
164 if n_batches is not None:
165 ds_nbatches = max(n_batches,ds_nbatches)
166 self.iterator = dataset.minibatches_nowrap(fieldnames,minibatch_size,ds_nbatches,offset)
167
168 def __iter__(self):
169 return self
170
171 def next_index(self):
172 return self.next_row
173
174 def next(self):
175 if self.n_batches and self.n_batches_done==self.n_batches:
176 raise StopIteration
177 upper = self.next_row+minibatch_size
178 if upper <=self.L:
179 minibatch = self.minibatch_iterator.next()
180 else:
181 if not self.n_batches:
182 raise StopIteration
183 # we must concatenate (vstack) the bottom and top parts of our minibatch
184 # first get the beginning of our minibatch (top of dataset)
185 first_part = self.dataset.minibatches_nowrap(fieldnames,self.L-self.next_row,1,self.next_row).next()
186 second_part = self.dataset.minibatches_nowrap(fieldnames,upper-self.L,1,0).next()
187 minibatch = Example(self.fieldnames,
188 [self.dataset.valuesVStack(name,[first_part[name],second_part[name]])
189 for name in self.fieldnames])
190 self.next_row=upper
191 self.n_batches_done+=1
192 if upper >= L:
193 self.next_row -= L
194 return minibatch
195
196
145 minibatches_fieldnames = None 197 minibatches_fieldnames = None
146 minibatches_minibatch_size = 1 198 minibatches_minibatch_size = 1
147 minibatches_n_batches = None 199 minibatches_n_batches = None
148 def minibatches(self, 200 def minibatches(self,
149 fieldnames = minibatches_fieldnames, 201 fieldnames = minibatches_fieldnames,
150 minibatch_size = minibatches_minibatch_size, 202 minibatch_size = minibatches_minibatch_size,
151 n_batches = minibatches_n_batches): 203 n_batches = minibatches_n_batches,
204 offset = 0):
152 """ 205 """
153 Return an iterator that supports three forms of syntax: 206 Return an iterator that supports three forms of syntax:
154 207
155 for i in dataset.minibatches(None,**kwargs): ... 208 for i in dataset.minibatches(None,**kwargs): ...
156 209
191 - n_batches (integer, default None) 244 - n_batches (integer, default None)
192 The iterator will loop exactly this many times, and then stop. If None, 245 The iterator will loop exactly this many times, and then stop. If None,
193 the derived class can choose a default. If (-1), then the returned 246 the derived class can choose a default. If (-1), then the returned
194 iterator should support looping indefinitely. 247 iterator should support looping indefinitely.
195 248
249 - offset (integer, default 0)
250 The iterator will start at example 'offset' in the dataset, rather than the default.
251
196 Note: A list-like container is something like a tuple, list, numpy.ndarray or 252 Note: A list-like container is something like a tuple, list, numpy.ndarray or
197 any other object that supports integer indexing and slicing. 253 any other object that supports integer indexing and slicing.
198 254
199 """ 255 """
256 return MinibatchWrapAroundIterator(self,fieldnames,minibatch_size,n_batches,offset)
257
258 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset):
259 """
260 This is the minibatches iterator generator that sub-classes must define.
261 It does not need to worry about wrapping around multiple times across the dataset,
262 as this is handled by MinibatchWrapAroundIterator when DataSet.minibatches() is called.
263 The next() method of the returned iterator does not even need to worry about
264 the termination condition (as StopIteration will be raised by DataSet.minibatches
265 before an improper call to minibatches_nowrap's next() is made).
266 That next() method can assert that its next row will always be within [0,len(dataset)).
267 The iterator returned by minibatches_nowrap does not need to implement
268 a next_index() method either, as this will be provided by MinibatchWrapAroundIterator.
269 """
200 raise AbstractFunction() 270 raise AbstractFunction()
201
202 271
203 def __len__(self): 272 def __len__(self):
204 """ 273 """
205 len(dataset) returns the number of examples in the dataset. 274 len(dataset) returns the number of examples in the dataset.
206 By default, a DataSet is a 'stream', i.e. it has an unbounded (infinite) length. 275 By default, a DataSet is a 'stream', i.e. it has an unbounded (infinite) length.
356 dataset2 = dataset1.fields().examples() 425 dataset2 = dataset1.fields().examples()
357 and dataset2 should behave exactly like dataset1 (in fact by default dataset2==dataset1). 426 and dataset2 should behave exactly like dataset1 (in fact by default dataset2==dataset1).
358 """ 427 """
359 def __init__(self,dataset,*fieldnames): 428 def __init__(self,dataset,*fieldnames):
360 self.dataset=dataset 429 self.dataset=dataset
361 assert dataset.hasField(*fieldnames) 430 assert dataset.hasFields(*fieldnames)
362 LookupList.__init__(self,dataset.fieldNames(), 431 LookupList.__init__(self,dataset.fieldNames(),
363 dataset.minibatches(fieldnames if len(fieldnames)>0 else self.fieldNames(),minibatch_size=len(dataset)).next() 432 dataset.minibatches(fieldnames if len(fieldnames)>0 else self.fieldNames(),
433 minibatch_size=len(dataset)).next()
364 def examples(self): 434 def examples(self):
365 return self.dataset 435 return self.dataset
366 436
367 def __or__(self,other): 437 def __or__(self,other):
368 """ 438 """
376 fields1 + fields2 is a DataSetFields that whose list of fields is the concatenation 446 fields1 + fields2 is a DataSetFields that whose list of fields is the concatenation
377 of the fields of DataSetFields fields1 and fields2. 447 of the fields of DataSetFields fields1 and fields2.
378 """ 448 """
379 return (self.examples() | other.examples()).fields() 449 return (self.examples() | other.examples()).fields()
380 450
451
381 class MinibatchDataSet(DataSet): 452 class MinibatchDataSet(DataSet):
382 """ 453 """
383 Turn a LookupList of same-length fields into an example-iterable dataset. 454 Turn a LookupList of same-length fields into an example-iterable dataset.
384 Each element of the lookup-list should be an iterable and sliceable, all of the same length. 455 Each element of the lookup-list should be an iterable and sliceable, all of the same length.
385 """ 456 """
405 return Example(self.fields.keys(),[field[i] for field in self.fields]) 476 return Example(self.fields.keys(),[field[i] for field in self.fields])
406 477
407 def fieldNames(self): 478 def fieldNames(self):
408 return self.fields.keys() 479 return self.fields.keys()
409 480
410 def hasField(self,*fieldnames): 481 def hasFields(self,*fieldnames):
411 for fieldname in fieldnames: 482 for fieldname in fieldnames:
412 if fieldname not in self.fields: 483 if fieldname not in self.fields:
413 return False 484 return False
414 return True 485 return True
415 486
416 def minibatches(self, 487 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset):
417 fieldnames = minibatches_fieldnames,
418 minibatch_size = minibatches_minibatch_size,
419 n_batches = minibatches_n_batches):
420 class Iterator(object): 488 class Iterator(object):
421 def __init__(self,ds): 489 def __init__(self,ds):
422 self.ds=ds 490 self.ds=ds
423 self.next_example=0 491 self.next_example=offset
424 self.n_batches_done=0
425 assert minibatch_size > 0 492 assert minibatch_size > 0
426 if minibatch_size > ds.length 493 if offset+minibatch_size > ds.length
427 raise NotImplementedError() 494 raise NotImplementedError()
428 def __iter__(self): 495 def __iter__(self):
429 return self 496 return self
430 def next_index(self):
431 return self.next_example
432 def next(self): 497 def next(self):
433 upper = next_example+minibatch_size 498 upper = next_example+minibatch_size
434 if upper<=self.ds.length: 499 assert upper<=self.ds.length
435 minibatch = Example(self.ds.fields.keys(), 500 minibatch = Example(self.ds.fields.keys(),
436 [field[next_example:upper] 501 [field[next_example:upper]
437 for field in self.ds.fields]) 502 for field in self.ds.fields])
438 else: # we must concatenate (vstack) the bottom and top parts of our minibatch
439 minibatch = Example(self.ds.fields.keys(),
440 [self.ds.valuesVStack(name,[value[next_example:],
441 value[0:upper-self.ds.length]])
442 for name,value in self.ds.fields.items()])
443 self.next_example+=minibatch_size 503 self.next_example+=minibatch_size
444 self.n_batches_done+=1
445 if n_batches:
446 if self.n_batches_done==n_batches:
447 raise StopIteration
448 if self.next_example>=self.ds.length:
449 self.next_example-=self.ds.length
450 else:
451 if self.next_example>=self.ds.length:
452 raise StopIteration
453 return DataSetFields(MinibatchDataSet(minibatch),fieldnames) 504 return DataSetFields(MinibatchDataSet(minibatch),fieldnames)
454 505
455 return Iterator(self) 506 return MinibatchWrapAroundIterator(self,fieldnames,minibatch_size,n_batches,offset)
456 507
457 def valuesVStack(self,fieldname,fieldvalues): 508 def valuesVStack(self,fieldname,fieldvalues):
458 return self.values_vstack(fieldname,fieldvalues) 509 return self.values_vstack(fieldname,fieldvalues)
459 510
460 def valuesHStack(self,fieldnames,fieldvalues): 511 def valuesHStack(self,fieldnames,fieldvalues):
502 self.fieldname2dataset[fieldname]=i 553 self.fieldname2dataset[fieldname]=i
503 for fieldname,i in names_to_change: 554 for fieldname,i in names_to_change:
504 del self.fieldname2dataset[fieldname] 555 del self.fieldname2dataset[fieldname]
505 self.fieldname2dataset[rename_field(fieldname,self.datasets[i],i)]=i 556 self.fieldname2dataset[rename_field(fieldname,self.datasets[i],i)]=i
506 557
507 def hasField(self,*fieldnames): 558 def hasFields(self,*fieldnames):
508 for fieldname in fieldnames: 559 for fieldname in fieldnames:
509 if not fieldname in self.fieldname2dataset: 560 if not fieldname in self.fieldname2dataset:
510 return False 561 return False
511 return True 562 return True
512 563
513 def fieldNames(self): 564 def fieldNames(self):
514 return self.fieldname2dataset.keys() 565 return self.fieldname2dataset.keys()
515 566
516 def minibatches(self, 567 def minibatches_nowrap(self,
517 fieldnames = minibatches_fieldnames, 568 fieldnames = minibatches_fieldnames,
518 minibatch_size = minibatches_minibatch_size, 569 minibatch_size = minibatches_minibatch_size,
519 n_batches = minibatches_n_batches): 570 n_batches = minibatches_n_batches,
571 offset = 0):
520 572
521 class Iterator(object): 573 class Iterator(object):
522 def __init__(self,hsds,iterators): 574 def __init__(self,hsds,iterators):
523 self.hsds=hsds 575 self.hsds=hsds
524 self.iterators=iterators 576 self.iterators=iterators
525 def __iter__(self): 577 def __iter__(self):
526 return self 578 return self
527 def next_index(self):
528 return self.iterators[0].next_index()
529 def next(self): 579 def next(self):
530 # concatenate all the fields of the minibatches 580 # concatenate all the fields of the minibatches
531 minibatch = reduce(LookupList.__add__,[iterator.next() for iterator in self.iterators]) 581 minibatch = reduce(LookupList.__add__,[iterator.next() for iterator in self.iterators])
532 # and return a DataSetFields whose dataset is the transpose (=examples()) of this minibatch 582 # and return a DataSetFields whose dataset is the transpose (=examples()) of this minibatch
533 return DataSetFields(MinibatchDataSet(minibatch,self.hsds.valuesVStack, 583 return DataSetFields(MinibatchDataSet(minibatch,self.hsds.valuesVStack,
543 for fieldname in fieldnames: 593 for fieldname in fieldnames:
544 dataset=self.datasets[self.fieldnames2dataset[fieldname]] 594 dataset=self.datasets[self.fieldnames2dataset[fieldname]]
545 datasets.add(dataset) 595 datasets.add(dataset)
546 fields_in_dataset[dataset].append(fieldname) 596 fields_in_dataset[dataset].append(fieldname)
547 datasets=list(datasets) 597 datasets=list(datasets)
548 iterators=[dataset.minibatches(fields_in_dataset[dataset],minibatch_size,n_batches) 598 iterators=[dataset.minibatches(fields_in_dataset[dataset],minibatch_size,n_batches,offset)
549 for dataset in datasets] 599 for dataset in datasets]
550 else: 600 else:
551 datasets=self.datasets 601 datasets=self.datasets
552 iterators=[dataset.minibatches(None,minibatch_size,n_batches) for dataset in datasets] 602 iterators=[dataset.minibatches(None,minibatch_size,n_batches,offset) for dataset in datasets]
553 return Iterator(self,iterators) 603 return Iterator(self,iterators)
554 604
555 605
556 def valuesVStack(self,fieldname,fieldvalues): 606 def valuesVStack(self,fieldname,fieldvalues):
557 return self.datasets[self.fieldname2dataset[fieldname]].valuesVStack(fieldname,fieldvalues) 607 return self.datasets[self.fieldname2dataset[fieldname]].valuesVStack(fieldname,fieldvalues)
575 """ 625 """
576 def __init__(self,datasets): 626 def __init__(self,datasets):
577 self.datasets=datasets 627 self.datasets=datasets
578 self.length=0 628 self.length=0
579 self.index2dataset={} 629 self.index2dataset={}
580 # we use this map from row index to dataset index for constant-time random access of examples, 630 assert len(datasets)>0
581 # to avoid having to search for the appropriate dataset each time and slice is asked for 631 fieldnames = datasets[-1].fieldNames()
632 # We use this map from row index to dataset index for constant-time random access of examples,
633 # to avoid having to search for the appropriate dataset each time and slice is asked for.
582 for dataset,k in enumerate(datasets[0:-1]): 634 for dataset,k in enumerate(datasets[0:-1]):
583 L=len(dataset) 635 L=len(dataset)
584 assert L<DataSet.infinity 636 assert L<DataSet.infinity
585 for i in xrange(L): 637 for i in xrange(L):
586 self.index2dataset[self.length+i]=k 638 self.index2dataset[self.length+i]=k
587 self.length+=L 639 self.length+=L
640 assert dataset.fieldNames()==fieldnames
588 self.last_start=self.length 641 self.last_start=self.length
589 self.length+=len(datasets[-1]) 642 self.length+=len(datasets[-1])
590 643 # If length is very large, we should use a more memory-efficient mechanism
591 644 # that does not store all indices
645 if self.length>1000000:
646 # 1 million entries would require about 60 meg for the index2dataset map
647 # TODO
648 print "A more efficient mechanism for index2dataset should be implemented"
649
650 def __len__(self):
651 return self.length
652
653 def fieldNames(self):
654 return self.datasets[0].fieldNames()
655
656 def hasFields(self,*fieldnames):
657 return self.datasets[0].hasFields(*fieldnames)
658
659 def minibatches_nowrap(self,
660 fieldnames = minibatches_fieldnames,
661 minibatch_size = minibatches_minibatch_size,
662 n_batches = minibatches_n_batches,
663 offset = 0):
664 class Iterator(object):
665 def __init__(self,vsds):
666 self.vsds=vsds
667 self.next_row=offset
668 self.next_dataset_index=0
669 self.next_dataset_row=0
670 self.current_iterator,self.n_left_at_the_end_of_ds,self.n_left_in_mb= \
671 self.next_iterator(vsds.datasets[0],offset,n_batches)
672
673 def next_iterator(self,dataset,starting_offset,batches_left):
674 L=len(dataset)
675 ds_nbatches = (L-starting_offset)/minibatch_size
676 if batches_left is not None:
677 ds_nbatches = max(batches_left,ds_nbatches)
678 if minibatch_size>L:
679 ds_minibatch_size=L
680 n_left_in_mb=minibatch_size-L
681 else: n_left_in_mb=0
682 return dataset.minibatches(fieldnames,minibatch_size,ds_nbatches,starting_offset), \
683 L-(starting_offset+ds_nbatches*minibatch_size), n_left_in_mb
684
685 def move_to_next_dataset(self):
686 self.next_dataset_index +=1
687 if self.next_dataset_index==len(self.vsds.datasets):
688 self.next_dataset_index = 0
689 self.current_iterator,self.n_left_at_the_end_of_ds,self.n_left_in_mb= \
690 self.next_iterator(vsds.datasets[self.next_dataset_index],starting_offset,n_batches)
691
692 def __iter__(self):
693 return self
694
695 def next(self):
696 dataset=self.vsds.datasets[self.next_dataset_index]
697 mb = self.next_iterator.next()
698 if self.n_left_in_mb:
699 names=self.vsds.fieldNames()
700 extra_mb = []
701 while self.n_left_in_mb>0:
702 self.move_to_next_dataset()
703 extra_mb.append(self.next_iterator.next())
704 mb = Example(names,
705 [dataset.valuesVStack(name,[mb[name]]+[b[name] for b in extra_mb])
706 for name in names])
707 self.next_row+=minibatch_size
708 return mb
709
710
592 def supervised_learning_dataset(src_dataset,input_fields,target_fields,weight_field=None): 711 def supervised_learning_dataset(src_dataset,input_fields,target_fields,weight_field=None):
593 """ 712 """
594 Wraps an arbitrary DataSet into one for supervised learning tasks by forcing the 713 Wraps an arbitrary DataSet into one for supervised learning tasks by forcing the
595 user to define a set of fields as the 'input' field and a set of fields 714 user to define a set of fields as the 'input' field and a set of fields
596 as the 'target' field. Optionally, a single weight_field can also be defined. 715 as the 'target' field. Optionally, a single weight_field can also be defined.