comparison dataset.py @ 290:9b533cc7874a

trying to get default implemenations to work
author James Bergstra <bergstrj@iro.umontreal.ca>
date Thu, 05 Jun 2008 18:38:42 -0400
parents ed70580f2324
children 174374d59405
comparison
equal deleted inserted replaced
276:271a16d42072 290:9b533cc7874a
1 1
2 from lookup_list import LookupList 2 from lookup_list import LookupList as Example
3 Example = LookupList
4 from misc import unique_elements_list_intersection 3 from misc import unique_elements_list_intersection
5 from string import join 4 from string import join
6 from sys import maxint 5 from sys import maxint
7 import numpy, copy 6 import numpy, copy
8 7
35 attribute_names = self.attributeNames() 34 attribute_names = self.attributeNames()
36 if return_copy: 35 if return_copy:
37 return [copy.copy(self.__getattribute__(name)) for name in attribute_names] 36 return [copy.copy(self.__getattribute__(name)) for name in attribute_names]
38 else: 37 else:
39 return [self.__getattribute__(name) for name in attribute_names] 38 return [self.__getattribute__(name) for name in attribute_names]
40
41 39
42 class DataSet(AttributesHolder): 40 class DataSet(AttributesHolder):
43 """A virtual base class for datasets. 41 """A virtual base class for datasets.
44 42
45 A DataSet can be seen as a generalization of a matrix, meant to be used in conjunction 43 A DataSet can be seen as a generalization of a matrix, meant to be used in conjunction
232 self.fieldnames=fieldnames 230 self.fieldnames=fieldnames
233 self.minibatch_size=minibatch_size 231 self.minibatch_size=minibatch_size
234 self.n_batches=n_batches 232 self.n_batches=n_batches
235 self.n_batches_done=0 233 self.n_batches_done=0
236 self.next_row=offset 234 self.next_row=offset
237 self.offset=offset
238 self.L=len(dataset) 235 self.L=len(dataset)
239 assert offset+minibatch_size<=self.L 236 self.offset=offset % self.L
240 ds_nbatches = (self.L-self.next_row)/self.minibatch_size 237 ds_nbatches = (self.L-self.next_row)/self.minibatch_size
241 if n_batches is not None: 238 if n_batches is not None:
242 ds_nbatches = min(n_batches,ds_nbatches) 239 ds_nbatches = min(n_batches,ds_nbatches)
243 if fieldnames: 240 if fieldnames:
244 assert dataset.hasFields(*fieldnames) 241 assert dataset.hasFields(*fieldnames)
245 else: 242 else:
246 self.fieldnames=dataset.fieldNames() 243 self.fieldnames=dataset.fieldNames()
247 self.iterator = self.dataset.minibatches_nowrap(self.fieldnames,self.minibatch_size, 244 self.iterator = self.dataset.minibatches_nowrap(self.fieldnames,self.minibatch_size, ds_nbatches,self.next_row)
248 ds_nbatches,self.next_row)
249 245
250 def __iter__(self): 246 def __iter__(self):
251 return self 247 return self
252 248
253 def next_index(self): 249 def next_index(self):
316 Using the first syntax, all the fields will be returned in "i". 312 Using the first syntax, all the fields will be returned in "i".
317 Using the third syntax, i1, i2, i3 will be list-like containers of the 313 Using the third syntax, i1, i2, i3 will be list-like containers of the
318 f1, f2, and f3 fields of a batch of examples on each loop iteration. 314 f1, f2, and f3 fields of a batch of examples on each loop iteration.
319 315
320 The minibatches iterator is expected to return upon each call to next() 316 The minibatches iterator is expected to return upon each call to next()
321 a DataSetFields object, which is a LookupList (indexed by the field names) whose 317 a DataSetFields object, which is a Example (indexed by the field names) whose
322 elements are iterable and indexable over the minibatch examples, and which keeps a pointer to 318 elements are iterable and indexable over the minibatch examples, and which keeps a pointer to
323 a sub-dataset that can be used to iterate over the individual examples 319 a sub-dataset that can be used to iterate over the individual examples
324 in the minibatch. Hence a minibatch can be converted back to a regular 320 in the minibatch. Hence a minibatch can be converted back to a regular
325 dataset or its fields can be looked at individually (and possibly iterated over). 321 dataset or its fields can be looked at individually (and possibly iterated over).
326 322
422 assert i in self.__dict__ # else it means we are trying to access a non-existing property 418 assert i in self.__dict__ # else it means we are trying to access a non-existing property
423 return self.__dict__[i] 419 return self.__dict__[i]
424 420
425 def __getitem__(self,i): 421 def __getitem__(self,i):
426 """ 422 """
427 dataset[i] returns the (i+1)-th example of the dataset. 423 @rtype: Example
428 dataset[i:j] returns the subdataset with examples i,i+1,...,j-1. 424 @returns: single or multiple examples
429 dataset[i:j:s] returns the subdataset with examples i,i+2,i+4...,j-2. 425
430 dataset[[i1,i2,..,in]] returns the subdataset with examples i1,i2,...,in. 426 @type i: integer or slice or <iterable> of integers
431 dataset['key'] returns a property associated with the given 'key' string. 427 @param i:
432 If 'key' is a fieldname, then the VStacked field values (iterable over 428 dataset[i] returns the (i+1)-th example of the dataset.
433 field values) for that field is returned. Other keys may be supported 429 dataset[i:j] returns the subdataset with examples i,i+1,...,j-1.
434 by different dataset subclasses. The following key names are encouraged: 430 dataset[i:j:s] returns the subdataset with examples i,i+2,i+4...,j-2.
435 - 'description': a textual description or name for the dataset 431 dataset[[i1,i2,..,in]] returns the subdataset with examples i1,i2,...,in.
436 - '<fieldname>.type': a type name or value for a given <fieldname> 432
437 433 @note:
438 Note that some stream datasets may be unable to implement random access, i.e. 434 Some stream datasets may be unable to implement random access, i.e.
439 arbitrary slicing/indexing 435 arbitrary slicing/indexing because they can only iterate through
440 because they can only iterate through examples one or a minibatch at a time 436 examples one or a minibatch at a time and do not actually store or keep
441 and do not actually store or keep past (or future) examples. 437 past (or future) examples.
442 438
443 The default implementation of getitem uses the minibatches iterator 439 The default implementation of getitem uses the minibatches iterator
444 to obtain one example, one slice, or a list of examples. It may not 440 to obtain one example, one slice, or a list of examples. It may not
445 always be the most efficient way to obtain the result, especially if 441 always be the most efficient way to obtain the result, especially if
446 the data are actually stored in a memory array. 442 the data are actually stored in a memory array.
447 """ 443 """
448 # check for an index 444
449 if type(i) is int: 445 if type(i) is int:
450 return DataSet.MinibatchToSingleExampleIterator( 446 #TODO: consider asserting that i >= 0
451 self.minibatches(minibatch_size=1,n_batches=1,offset=i)).next() 447 i_batch = self.minibatches_nowrap(self.fieldNames(),
452 rows=None 448 minibatch_size=1, n_batches=1, offset=i % len(self))
453 # or a slice 449 return DataSet.MinibatchToSingleExampleIterator(i_batch).next()
450
451 #if i is a contiguous slice
452 if type(i) is slice and (i.step in (None, 1)):
453 offset = 0 if i.start is None else i.start
454 upper_bound = len(self) if i.stop is None else i.stop
455 return MinibatchDataSet(self.minibatches_nowrap(self.fieldNames(),
456 minibatch_size=upper_bound - offset,
457 n_batches=1,
458 offset=offset).next())
459
460 # if slice has a step param, convert it to list and handle it with the
461 # list code
454 if type(i) is slice: 462 if type(i) is slice:
455 #print 'i=',i 463 offset = 0 if i.start is None else i.start
456 if not i.start: i=slice(0,i.stop,i.step) 464 upper_bound = len(self) if i.stop is None else i.stop
457 if not i.stop: i=slice(i.start,len(self),i.step) 465 i = list(range(offset, upper_bound, i.step))
458 if not i.step: i=slice(i.start,i.stop,1) 466
459 if i.step is 1: 467 # handle tuples, arrays, lists
460 return self.minibatches(minibatch_size=i.stop-i.start,n_batches=1,offset=i.start).next().examples() 468 if hasattr(i, '__getitem__'):
461 rows = range(i.start,i.stop,i.step) 469 for idx in i:
462 # or a list of indices 470 #dis-allow nested slices
463 elif type(i) is list: 471 if not isinstance(idx, int):
464 rows = i 472 raise TypeError(idx)
465 if rows is not None: 473 # call back into self.__getitem__
466 examples = [self[row] for row in rows] 474 examples = [self.minibatches_nowrap(self.fieldNames(),
467 fields_values = zip(*examples) 475 minibatch_size=1, n_batches=1, offset=ii%len(self)).next()
468 return MinibatchDataSet( 476 for ii in i]
469 Example(self.fieldNames(),[ self.valuesVStack(fieldname,field_values) 477 # re-index the fields in each example by field instead of by example
470 for fieldname,field_values 478 field_values = [[] for blah in self.fieldNames()]
471 in zip(self.fieldNames(),fields_values)]), 479 for e in examples:
472 self.valuesVStack,self.valuesHStack) 480 for f,v in zip(field_values, e):
481 f.append(v)
482 #build them into a LookupList (a.ka. Example)
483 zz = zip(self.fieldNames(),field_values)
484 vst = [self.valuesVStack(fieldname,field_values) for fieldname,field_values in zz]
485 example = Example(self.fieldNames(), vst)
486 return MinibatchDataSet(example, self.valuesVStack, self.valuesHStack)
473 raise TypeError(i, type(i)) 487 raise TypeError(i, type(i))
474 488
475 def valuesHStack(self,fieldnames,fieldvalues): 489 def valuesHStack(self,fieldnames,fieldvalues):
476 """ 490 """
477 Return a value that corresponds to concatenating (horizontally) several field values. 491 Return a value that corresponds to concatenating (horizontally) several field values.
491 if all_numpy: 505 if all_numpy:
492 return numpy.hstack(fieldvalues) 506 return numpy.hstack(fieldvalues)
493 # the default implementation of horizontal stacking is to put values in a list 507 # the default implementation of horizontal stacking is to put values in a list
494 return fieldvalues 508 return fieldvalues
495 509
496
497 def valuesVStack(self,fieldname,values): 510 def valuesVStack(self,fieldname,values):
498 """ 511 """
499 Return a value that corresponds to concatenating (vertically) several values of the 512 @param fieldname: the name of the field from which the values were taken
500 same field. This can be important to build a minibatch out of individual examples. This 513 @type fieldname: any type
501 is likely to involve a copy of the original values. When the values are numpy arrays, the 514
502 result should be numpy.vstack(values). 515 @param values: bits near the beginning or end of the dataset
503 The default is to use numpy.vstack for numpy.ndarray values, and a list 516 @type values: list of minibatches (returned by minibatch_nowrap)
504 pointing to the original values for other data types. 517
505 """ 518 @return: the concatenation (stacking) of the values
506 all_numpy=True 519 @rtype: something suitable as a minibatch field
507 for value in values: 520 """
508 if not type(value) is numpy.ndarray: 521 rval = []
509 all_numpy=False 522 for v in values:
510 if all_numpy: 523 rval.extend(v)
511 return numpy.vstack(values) 524 return rval
512 # the default implementation of vertical stacking is to put values in a list
513 return values
514 525
515 def __or__(self,other): 526 def __or__(self,other):
516 """ 527 """
517 dataset1 | dataset2 returns a dataset whose list of fields is the concatenation of the list of 528 dataset1 | dataset2 returns a dataset whose list of fields is the concatenation of the list of
518 fields of the argument datasets. This only works if they all have the same length. 529 fields of the argument datasets. This only works if they all have the same length.
584 return FieldsSubsetIterator(self) 595 return FieldsSubsetIterator(self)
585 596
586 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): 597 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset):
587 assert self.hasFields(*fieldnames) 598 assert self.hasFields(*fieldnames)
588 return self.src.minibatches_nowrap(fieldnames,minibatch_size,n_batches,offset) 599 return self.src.minibatches_nowrap(fieldnames,minibatch_size,n_batches,offset)
589 def __getitem__(self,i): 600 def dontuse__getitem__(self,i):
590 return FieldsSubsetDataSet(self.src[i],self.fieldnames) 601 return FieldsSubsetDataSet(self.src[i],self.fieldnames)
591 602
592 603
593 class DataSetFields(LookupList): 604 class DataSetFields(Example):
594 """ 605 """
595 Although a L{DataSet} iterates over examples (like rows of a matrix), an associated 606 Although a L{DataSet} iterates over examples (like rows of a matrix), an associated
596 DataSetFields iterates over fields (like columns of a matrix), and can be understood 607 DataSetFields iterates over fields (like columns of a matrix), and can be understood
597 as a transpose of the associated dataset. 608 as a transpose of the associated dataset.
598 609
626 dataset = FieldsSubsetDataSet(dataset,fieldnames) 637 dataset = FieldsSubsetDataSet(dataset,fieldnames)
627 assert dataset.hasFields(*fieldnames) 638 assert dataset.hasFields(*fieldnames)
628 self.dataset=dataset 639 self.dataset=dataset
629 640
630 if isinstance(dataset,MinibatchDataSet): 641 if isinstance(dataset,MinibatchDataSet):
631 LookupList.__init__(self,fieldnames,list(dataset._fields)) 642 Example.__init__(self,fieldnames,list(dataset._fields))
632 elif isinstance(original_dataset,MinibatchDataSet): 643 elif isinstance(original_dataset,MinibatchDataSet):
633 LookupList.__init__(self,fieldnames, 644 Example.__init__(self,fieldnames,
634 [original_dataset._fields[field] 645 [original_dataset._fields[field]
635 for field in fieldnames]) 646 for field in fieldnames])
636 else: 647 else:
637 minibatch_iterator = dataset.minibatches(fieldnames, 648 minibatch_iterator = dataset.minibatches(fieldnames,
638 minibatch_size=len(dataset), 649 minibatch_size=len(dataset),
639 n_batches=1) 650 n_batches=1)
640 minibatch=minibatch_iterator.next() 651 minibatch=minibatch_iterator.next()
641 LookupList.__init__(self,fieldnames,minibatch) 652 Example.__init__(self,fieldnames,minibatch)
642 653
643 def examples(self): 654 def examples(self):
644 return self.dataset 655 return self.dataset
645 656
646 def __or__(self,other): 657 def __or__(self,other):
658 return (self.examples() | other.examples()).fields() 669 return (self.examples() | other.examples()).fields()
659 670
660 671
661 class MinibatchDataSet(DataSet): 672 class MinibatchDataSet(DataSet):
662 """ 673 """
663 Turn a L{LookupList} of same-length (iterable) fields into an example-iterable dataset. 674 Turn a L{Example} of same-length (iterable) fields into an example-iterable dataset.
664 Each element of the lookup-list should be an iterable and sliceable, all of the same length. 675 Each element of the lookup-list should be an iterable and sliceable, all of the same length.
665 """ 676 """
666 def __init__(self,fields_lookuplist,values_vstack=DataSet().valuesVStack, 677 def __init__(self,fields_lookuplist,values_vstack=DataSet().valuesVStack,
667 values_hstack=DataSet().valuesHStack): 678 values_hstack=DataSet().valuesHStack):
668 """ 679 """
678 if self.length != len(field) : 689 if self.length != len(field) :
679 print 'self.length = ',self.length 690 print 'self.length = ',self.length
680 print 'len(field) = ', len(field) 691 print 'len(field) = ', len(field)
681 print 'self._fields.keys() = ', self._fields.keys() 692 print 'self._fields.keys() = ', self._fields.keys()
682 print 'field=',field 693 print 'field=',field
694 print 'fields_lookuplist=', fields_lookuplist
683 assert self.length==len(field) 695 assert self.length==len(field)
684 self.values_vstack=values_vstack 696 self.valuesVStack=values_vstack
685 self.values_hstack=values_hstack 697 self.valuesHStack=values_hstack
686 698
687 def __len__(self): 699 def __len__(self):
688 return self.length 700 return self.length
689 701
690 def __getitem__(self,i): 702 def dontuse__getitem__(self,i):
691 if type(i) in (slice,list): 703 if type(i) in (slice,list):
692 return DataSetFields(MinibatchDataSet( 704 return DataSetFields(MinibatchDataSet(
693 Example(self._fields.keys(),[field[i] for field in self._fields])),self.fieldNames()) 705 Example(self._fields.keys(),[field[i] for field in self._fields])),self.fieldNames())
694 if type(i) is int: 706 if type(i) is int:
695 return Example(self._fields.keys(),[field[i] for field in self._fields]) 707 return Example(self._fields.keys(),[field[i] for field in self._fields])
715 if fieldnames is None: fieldnames = ds._fields.keys() 727 if fieldnames is None: fieldnames = ds._fields.keys()
716 self.fieldnames = fieldnames 728 self.fieldnames = fieldnames
717 729
718 self.ds=ds 730 self.ds=ds
719 self.next_example=offset 731 self.next_example=offset
720 assert minibatch_size > 0 732 assert minibatch_size >= 0
721 if offset+minibatch_size > ds.length: 733 if offset+minibatch_size > ds.length:
722 raise NotImplementedError() 734 raise NotImplementedError()
723 def __iter__(self): 735 def __iter__(self):
724 return self 736 return self
725 def next(self): 737 def next(self):
739 return minibatch 751 return minibatch
740 752
741 # tbm: added fieldnames to handle subset of fieldnames 753 # tbm: added fieldnames to handle subset of fieldnames
742 return Iterator(self,fieldnames) 754 return Iterator(self,fieldnames)
743 755
744 def valuesVStack(self,fieldname,fieldvalues):
745 return self.values_vstack(fieldname,fieldvalues)
746
747 def valuesHStack(self,fieldnames,fieldvalues):
748 return self.values_hstack(fieldnames,fieldvalues)
749
750 class HStackedDataSet(DataSet): 756 class HStackedDataSet(DataSet):
751 """ 757 """
752 A L{DataSet} that wraps several datasets and shows a view that includes all their fields, 758 A L{DataSet} that wraps several datasets and shows a view that includes all their fields,
753 i.e. whose list of fields is the concatenation of their lists of fields. 759 i.e. whose list of fields is the concatenation of their lists of fields.
754 760
808 self.iterators=iterators 814 self.iterators=iterators
809 def __iter__(self): 815 def __iter__(self):
810 return self 816 return self
811 def next(self): 817 def next(self):
812 # concatenate all the fields of the minibatches 818 # concatenate all the fields of the minibatches
813 l=LookupList() 819 l=Example()
814 for iter in self.iterators: 820 for iter in self.iterators:
815 l.append_lookuplist(iter.next()) 821 l.append_lookuplist(iter.next())
816 return l 822 return l
817 823
818 assert self.hasFields(*fieldnames) 824 assert self.hasFields(*fieldnames)
832 datasets=self.datasets 838 datasets=self.datasets
833 iterators=[dataset.minibatches(None,minibatch_size,n_batches,offset) for dataset in datasets] 839 iterators=[dataset.minibatches(None,minibatch_size,n_batches,offset) for dataset in datasets]
834 return HStackedIterator(self,iterators) 840 return HStackedIterator(self,iterators)
835 841
836 842
837 def valuesVStack(self,fieldname,fieldvalues): 843 def untested_valuesVStack(self,fieldname,fieldvalues):
838 return self.datasets[self.fieldname2dataset[fieldname]].valuesVStack(fieldname,fieldvalues) 844 return self.datasets[self.fieldname2dataset[fieldname]].valuesVStack(fieldname,fieldvalues)
839 845
840 def valuesHStack(self,fieldnames,fieldvalues): 846 def untested_valuesHStack(self,fieldnames,fieldvalues):
841 """ 847 """
842 We will use the sub-dataset associated with the first fieldname in the fieldnames list 848 We will use the sub-dataset associated with the first fieldname in the fieldnames list
843 to do the work, hoping that it can cope with the other values (i.e. won't care 849 to do the work, hoping that it can cope with the other values (i.e. won't care
844 about the incompatible fieldnames). Hence this heuristic will always work if 850 about the incompatible fieldnames). Hence this heuristic will always work if
845 all the fieldnames are of the same sub-dataset. 851 all the fieldnames are of the same sub-dataset.
959 Virtual super-class of datasets whose field values are numpy array, 965 Virtual super-class of datasets whose field values are numpy array,
960 thus defining valuesHStack and valuesVStack for sub-classes. 966 thus defining valuesHStack and valuesVStack for sub-classes.
961 """ 967 """
962 def __init__(self,description=None,field_types=None): 968 def __init__(self,description=None,field_types=None):
963 DataSet.__init__(self,description,field_types) 969 DataSet.__init__(self,description,field_types)
964 def valuesHStack(self,fieldnames,fieldvalues): 970 def untested_valuesHStack(self,fieldnames,fieldvalues):
965 """Concatenate field values horizontally, e.g. two vectors 971 """Concatenate field values horizontally, e.g. two vectors
966 become a longer vector, two matrices become a wider matrix, etc.""" 972 become a longer vector, two matrices become a wider matrix, etc."""
967 return numpy.hstack(fieldvalues) 973 return numpy.hstack(fieldvalues)
968 def valuesVStack(self,fieldname,values): 974 def untested_valuesVStack(self,fieldname,values):
969 """Concatenate field values vertically, e.g. two vectors 975 """Concatenate field values vertically, e.g. two vectors
970 become a two-row matrix, two matrices become a longer matrix, etc.""" 976 become a two-row matrix, two matrices become a longer matrix, etc."""
971 return numpy.vstack(values) 977 return numpy.vstack(values)
972 978
973 class ArrayDataSet(ArrayFieldsDataSet): 979 class ArrayDataSet(ArrayFieldsDataSet):
1017 return self.fields_columns.keys() 1023 return self.fields_columns.keys()
1018 1024
1019 def __len__(self): 1025 def __len__(self):
1020 return len(self.data) 1026 return len(self.data)
1021 1027
1022 def __getitem__(self,key): 1028 def dontuse__getitem__(self,key):
1023 """More efficient implementation than the default __getitem__""" 1029 """More efficient implementation than the default __getitem__"""
1024 fieldnames=self.fields_columns.keys() 1030 fieldnames=self.fields_columns.keys()
1025 values=self.fields_columns.values() 1031 values=self.fields_columns.values()
1026 if type(key) is int: 1032 if type(key) is int:
1027 return Example(fieldnames, 1033 return Example(fieldnames,
1049 return self.data[:,self.fields_columns[key]] 1055 return self.data[:,self.fields_columns[key]]
1050 # else we are trying to access a property of the dataset 1056 # else we are trying to access a property of the dataset
1051 assert key in self.__dict__ # else it means we are trying to access a non-existing property 1057 assert key in self.__dict__ # else it means we are trying to access a non-existing property
1052 return self.__dict__[key] 1058 return self.__dict__[key]
1053 1059
1054 def __iter__(self): 1060 def dontuse__iter__(self):
1055 class ArrayDataSetIteratorIter(object): 1061 class ArrayDataSetIteratorIter(object):
1056 def __init__(self,dataset,fieldnames): 1062 def __init__(self,dataset,fieldnames):
1057 if fieldnames is None: fieldnames = dataset.fieldNames() 1063 if fieldnames is None: fieldnames = dataset.fieldNames()
1058 # store the resulting minibatch in a lookup-list of values 1064 # store the resulting minibatch in a lookup-list of values
1059 self.minibatch = LookupList(fieldnames,[0]*len(fieldnames)) 1065 self.minibatch = Example(fieldnames,[0]*len(fieldnames))
1060 self.dataset=dataset 1066 self.dataset=dataset
1061 self.current=0 1067 self.current=0
1062 self.columns = [self.dataset.fields_columns[f] 1068 self.columns = [self.dataset.fields_columns[f]
1063 for f in self.minibatch._names] 1069 for f in self.minibatch._names]
1064 self.l = self.dataset.data.shape[0] 1070 self.l = self.dataset.data.shape[0]
1076 return self.minibatch 1082 return self.minibatch
1077 1083
1078 return ArrayDataSetIteratorIter(self,self.fieldNames()) 1084 return ArrayDataSetIteratorIter(self,self.fieldNames())
1079 1085
1080 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): 1086 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset):
1081 class ArrayDataSetIterator(object): 1087 cursor = Example(fieldnames,[0]*len(fieldnames))
1082 def __init__(self,dataset,fieldnames,minibatch_size,n_batches,offset): 1088 fieldnames = self.fieldNames() if fieldnames is None else fieldnames
1083 if fieldnames is None: fieldnames = dataset.fieldNames() 1089 for n in xrange(n_batches):
1084 # store the resulting minibatch in a lookup-list of values 1090 if offset == len(self):
1085 self.minibatch = LookupList(fieldnames,[0]*len(fieldnames)) 1091 break
1086 self.dataset=dataset 1092 sub_data = self.data[offset : offset+minibatch_size]
1087 self.minibatch_size=minibatch_size 1093 offset += len(sub_data) #can be less than minibatch_size at end
1088 assert offset>=0 and offset<len(dataset.data) 1094 cursor._values = [sub_data[:,self.fields_columns[f]] for f in cursor._names]
1089 assert offset+minibatch_size<=len(dataset.data) 1095 yield cursor
1090 self.current=offset 1096
1091 def __iter__(self): 1097 #return ArrayDataSetIterator(self,fieldnames,minibatch_size,n_batches,offset)
1092 return self
1093 def next(self):
1094 #@todo: we suppose that MinibatchWrapAroundIterator stop the iterator
1095 sub_data = self.dataset.data[self.current:self.current+self.minibatch_size]
1096 self.minibatch._values = [sub_data[:,self.dataset.fields_columns[f]] for f in self.minibatch._names]
1097 self.current+=self.minibatch_size
1098 return self.minibatch
1099
1100 return ArrayDataSetIterator(self,fieldnames,minibatch_size,n_batches,offset)
1101 1098
1102 1099
1103 class CachedDataSet(DataSet): 1100 class CachedDataSet(DataSet):
1104 """ 1101 """
1105 Wrap a L{DataSet} whose values are computationally expensive to obtain 1102 Wrap a L{DataSet} whose values are computationally expensive to obtain
1160 if self.all_fields: 1157 if self.all_fields:
1161 return all_fields_minibatch 1158 return all_fields_minibatch
1162 return Example(fieldnames,[all_fields_minibatch[name] for name in fieldnames]) 1159 return Example(fieldnames,[all_fields_minibatch[name] for name in fieldnames])
1163 return CacheIterator(self) 1160 return CacheIterator(self)
1164 1161
1165 def __getitem__(self,i): 1162 def dontuse__getitem__(self,i):
1166 if type(i)==int and len(self.cached_examples)>i: 1163 if type(i)==int and len(self.cached_examples)>i:
1167 return self.cached_examples[i] 1164 return self.cached_examples[i]
1168 else: 1165 else:
1169 return self.source_dataset[i] 1166 return self.source_dataset[i]
1170 1167
1173 def __init__(self,dataset): 1170 def __init__(self,dataset):
1174 self.dataset=dataset 1171 self.dataset=dataset
1175 self.l = len(dataset) 1172 self.l = len(dataset)
1176 self.current = 0 1173 self.current = 0
1177 self.fieldnames = self.dataset.fieldNames() 1174 self.fieldnames = self.dataset.fieldNames()
1178 self.example = LookupList(self.fieldnames,[0]*len(self.fieldnames)) 1175 self.example = Example(self.fieldnames,[0]*len(self.fieldnames))
1179 def __iter__(self): return self 1176 def __iter__(self): return self
1180 def next(self): 1177 def next(self):
1181 if self.current>=self.l: 1178 if self.current>=self.l:
1182 raise StopIteration 1179 raise StopIteration
1183 cache_len = len(self.dataset.cached_examples) 1180 cache_len = len(self.dataset.cached_examples)
1190 return self.example 1187 return self.example
1191 1188
1192 return CacheIteratorIter(self) 1189 return CacheIteratorIter(self)
1193 1190
1194 class ApplyFunctionDataSet(DataSet): 1191 class ApplyFunctionDataSet(DataSet):
1195 """ 1192 """
1196 A L{DataSet} that contains as fields the results of applying a 1193 A L{DataSet} that contains as fields the results of applying a
1197 given function example-wise or minibatch-wise to all the fields of 1194 given function example-wise or minibatch-wise to all the fields of
1198 an input dataset. The output of the function should be an iterable 1195 an input dataset. The output of the function should be an iterable
1199 (e.g. a list or a LookupList) over the resulting values. 1196 (e.g. a list or a Example) over the resulting values.
1200 1197
1201 The function take as input the fields of the dataset, not the examples. 1198 The function take as input the fields of the dataset, not the examples.
1202 1199
1203 In minibatch mode, the function is expected to work on minibatches 1200 In minibatch mode, the function is expected to work on minibatches
1204 (takes a minibatch in input and returns a minibatch in output). More 1201 (takes a minibatch in input and returns a minibatch in output). More
1205 precisely, it means that each element of the input or output list 1202 precisely, it means that each element of the input or output list
1206 should be iterable and indexable over the individual example values 1203 should be iterable and indexable over the individual example values
1207 (typically these elements will be numpy arrays). All of the elements 1204 (typically these elements will be numpy arrays). All of the elements
1208 in the input and output lists should have the same length, which is 1205 in the input and output lists should have the same length, which is
1209 the length of the minibatch. 1206 the length of the minibatch.
1210 1207
1211 The function is applied each time an example or a minibatch is accessed. 1208 The function is applied each time an example or a minibatch is accessed.
1212 To avoid re-doing computation, wrap this dataset inside a CachedDataSet. 1209 To avoid re-doing computation, wrap this dataset inside a CachedDataSet.
1213 1210
1214 If the values_{h,v}stack functions are not provided, then 1211 If the values_{h,v}stack functions are not provided, then
1215 the input_dataset.values{H,V}Stack functions are used by default. 1212 the input_dataset.values{H,V}Stack functions are used by default.
1216 """ 1213 """
1217 def __init__(self,input_dataset,function,output_names,minibatch_mode=True, 1214 def __init__(self,input_dataset,function,output_names,minibatch_mode=True,
1218 values_hstack=None,values_vstack=None, 1215 values_hstack=None,values_vstack=None,
1219 description=None,fieldtypes=None): 1216 description=None,fieldtypes=None):
1220 """ 1217 """
1221 Constructor takes an input dataset that has as many fields as the function 1218 Constructor takes an input dataset that has as many fields as the function
1222 expects as inputs. The resulting dataset has as many fields as the function 1219 expects as inputs. The resulting dataset has as many fields as the function
1223 produces as outputs, and that should correspond to the number of output names 1220 produces as outputs, and that should correspond to the number of output names
1224 (provided in a list). 1221 (provided in a list).
1225 1222
1226 Note that the expected semantics of the function differs in minibatch mode 1223 Note that the expected semantics of the function differs in minibatch mode
1227 (it takes minibatches of inputs and produces minibatches of outputs, as 1224 (it takes minibatches of inputs and produces minibatches of outputs, as
1228 documented in the class comment). 1225 documented in the class comment).
1229 1226
1230 TBM: are filedtypes the old field types (from input_dataset) or the new ones 1227 TBM: are filedtypes the old field types (from input_dataset) or the new ones
1231 (for the new dataset created)? 1228 (for the new dataset created)?
1232 """ 1229 """
1233 self.input_dataset=input_dataset 1230 self.input_dataset=input_dataset
1234 self.function=function 1231 self.function=function
1235 self.output_names=output_names 1232 self.output_names=output_names
1236 self.minibatch_mode=minibatch_mode 1233 self.minibatch_mode=minibatch_mode
1237 DataSet.__init__(self,description,fieldtypes) 1234 DataSet.__init__(self,description,fieldtypes)
1238 self.valuesHStack = values_hstack if values_hstack else input_dataset.valuesHStack 1235 self.valuesHStack = values_hstack if values_hstack else input_dataset.valuesHStack
1239 self.valuesVStack = values_vstack if values_vstack else input_dataset.valuesVStack 1236 self.valuesVStack = values_vstack if values_vstack else input_dataset.valuesVStack
1240 1237
1241 def __len__(self): 1238 def __len__(self):
1242 return len(self.input_dataset) 1239 return len(self.input_dataset)
1243 1240
1244 def fieldNames(self): 1241 def fieldNames(self):
1245 return self.output_names 1242 return self.output_names
1246 1243
1247 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): 1244 def minibatches_nowrap(self, fieldnames, *args, **kwargs):
1248 class ApplyFunctionIterator(object): 1245 for fields in self.input_dataset.minibatches_nowrap(fieldnames, *args, **kwargs):
1249 def __init__(self,output_dataset): 1246
1250 self.input_dataset=output_dataset.input_dataset 1247 #function_inputs = self.input_iterator.next()
1251 self.output_dataset=output_dataset 1248 if self.minibatch_mode:
1252 self.input_iterator=self.input_dataset.minibatches(minibatch_size=minibatch_size, 1249 function_outputs = self.function(*fields)
1253 n_batches=n_batches,offset=offset).__iter__() 1250 else:
1254 1251 input_examples = zip(*fields)
1255 def __iter__(self): return self 1252 output_examples = [self.function(*input_example)
1256 1253 for input_example in input_examples]
1257 def next(self): 1254 function_outputs = [self.valuesVStack(name,values)
1258 function_inputs = self.input_iterator.next() 1255 for name,values in zip(self.output_names,
1259 all_output_names = self.output_dataset.output_names 1256 zip(*output_examples))]
1260 if self.output_dataset.minibatch_mode: 1257 all_outputs = Example(self.output_names, function_outputs)
1261 function_outputs = self.output_dataset.function(*function_inputs) 1258 print fields
1262 else: 1259 print all_outputs
1263 input_examples = zip(*function_inputs) 1260 print '--------'
1264 output_examples = [self.output_dataset.function(*input_example) 1261 if fieldnames==self.output_names:
1265 for input_example in input_examples] 1262 yield all_outputs
1266 function_outputs = [self.output_dataset.valuesVStack(name,values) 1263 else:
1267 for name,values in zip(all_output_names, 1264 yield Example(fieldnames,[all_outputs[name] for name in fieldnames])
1268 zip(*output_examples))] 1265
1269 all_outputs = Example(all_output_names,function_outputs) 1266 def untested__iter__(self): # only implemented for increased efficiency
1270 if fieldnames==all_output_names: 1267 class ApplyFunctionSingleExampleIterator(object):
1271 return all_outputs 1268 def __init__(self,output_dataset):
1272 return Example(fieldnames,[all_outputs[name] for name in fieldnames]) 1269 self.current=0
1273 1270 self.output_dataset=output_dataset
1274 1271 self.input_iterator=output_dataset.input_dataset.__iter__()
1275 return ApplyFunctionIterator(self) 1272 def __iter__(self): return self
1276 1273 def next(self):
1277 def __iter__(self): # only implemented for increased efficiency 1274 if self.output_dataset.minibatch_mode:
1278 class ApplyFunctionSingleExampleIterator(object): 1275 function_inputs = [[input] for input in self.input_iterator.next()]
1279 def __init__(self,output_dataset): 1276 outputs = self.output_dataset.function(*function_inputs)
1280 self.current=0 1277 assert all([hasattr(output,'__iter__') for output in outputs])
1281 self.output_dataset=output_dataset 1278 function_outputs = [output[0] for output in outputs]
1282 self.input_iterator=output_dataset.input_dataset.__iter__() 1279 else:
1283 def __iter__(self): return self 1280 function_inputs = self.input_iterator.next()
1284 def next(self): 1281 function_outputs = self.output_dataset.function(*function_inputs)
1285 if self.output_dataset.minibatch_mode: 1282 return Example(self.output_dataset.output_names,function_outputs)
1286 function_inputs = [[input] for input in self.input_iterator.next()] 1283 return ApplyFunctionSingleExampleIterator(self)
1287 outputs = self.output_dataset.function(*function_inputs) 1284
1288 assert all([hasattr(output,'__iter__') for output in outputs])
1289 function_outputs = [output[0] for output in outputs]
1290 else:
1291 function_inputs = self.input_iterator.next()
1292 function_outputs = self.output_dataset.function(*function_inputs)
1293 return Example(self.output_dataset.output_names,function_outputs)
1294 return ApplyFunctionSingleExampleIterator(self)
1295
1296 1285
1297 def supervised_learning_dataset(src_dataset,input_fields,target_fields,weight_field=None): 1286 def supervised_learning_dataset(src_dataset,input_fields,target_fields,weight_field=None):
1298 """ 1287 """
1299 Wraps an arbitrary L{DataSet} into one for supervised learning tasks 1288 Wraps an arbitrary L{DataSet} into one for supervised learning tasks
1300 by forcing the user to define a set of fields as the 'input' field 1289 by forcing the user to define a set of fields as the 'input' field