Mercurial > pylearn
comparison dataset.py @ 73:69f97aad3faf
Coded untested ApplyFunctionDataSet and CacheDataSet
author | bengioy@bengiomac.local |
---|---|
date | Sat, 03 May 2008 14:29:56 -0400 |
parents | 2b6656b2ef52 |
children | b4159cbdc06b |
comparison
equal
deleted
inserted
replaced
72:2b6656b2ef52 | 73:69f97aad3faf |
---|---|
225 for name in self.fieldnames]) | 225 for name in self.fieldnames]) |
226 self.next_row=upper | 226 self.next_row=upper |
227 self.n_batches_done+=1 | 227 self.n_batches_done+=1 |
228 if upper >= self.L and self.n_batches: | 228 if upper >= self.L and self.n_batches: |
229 self.next_row -= self.L | 229 self.next_row -= self.L |
230 return minibatch | 230 return DataSetFields(MinibatchDataSet(minibatch,self.dataset.valuesVStack, |
231 self.dataset.valuesHStack), | |
232 minibatch.keys())) | |
231 | 233 |
232 | 234 |
233 minibatches_fieldnames = None | 235 minibatches_fieldnames = None |
234 minibatches_minibatch_size = 1 | 236 minibatches_minibatch_size = 1 |
235 minibatches_n_batches = None | 237 minibatches_n_batches = None |
390 examples = [self[row] for row in rows] | 392 examples = [self[row] for row in rows] |
391 fields_values = zip(*examples) | 393 fields_values = zip(*examples) |
392 return MinibatchDataSet( | 394 return MinibatchDataSet( |
393 Example(self.fieldNames(),[ self.valuesVStack(fieldname,field_values) | 395 Example(self.fieldNames(),[ self.valuesVStack(fieldname,field_values) |
394 for fieldname,field_values | 396 for fieldname,field_values |
395 in zip(self.fieldNames(),fields_values)])) | 397 in zip(self.fieldNames(),fields_values)]), |
398 self.valuesVStack,self.valuesHStack) | |
396 # else check for a fieldname | 399 # else check for a fieldname |
397 if self.hasFields(i): | 400 if self.hasFields(i): |
398 return self.minibatches(fieldnames=[i],minibatch_size=len(self),n_batches=1,offset=0).next()[0] | 401 return self.minibatches(fieldnames=[i],minibatch_size=len(self),n_batches=1,offset=0).next()[0] |
399 # else we are trying to access a property of the dataset | 402 # else we are trying to access a property of the dataset |
400 assert i in self.__dict__ # else it means we are trying to access a non-existing property | 403 assert i in self.__dict__ # else it means we are trying to access a non-existing property |
607 | 610 |
608 def __len__(self): | 611 def __len__(self): |
609 return self.length | 612 return self.length |
610 | 613 |
611 def __getitem__(self,i): | 614 def __getitem__(self,i): |
612 if type(i) in (int,slice,list): | 615 if type(i) is int: |
613 return DataSetFields(MinibatchDataSet( | 616 return Example(self._fields.keys(),[field[i] for field in self._fields]) |
614 Example(self._fields.keys(),[field[i] for field in self._fields])),self._fields) | 617 if type(i) in (slice,list): |
618 return MinibatchDataSet(Example(self._fields.keys(), | |
619 [field[i] for field in self._fields]), | |
620 self.valuesVStack,self.valuesHStack) | |
615 if self.hasFields(i): | 621 if self.hasFields(i): |
616 return self._fields[i] | 622 return self._fields[i] |
617 assert i in self.__dict__ # else it means we are trying to access a non-existing property | 623 assert i in self.__dict__ # else it means we are trying to access a non-existing property |
618 return self.__dict__[i] | 624 return self.__dict__[i] |
619 | 625 |
641 assert upper<=self.ds.length | 647 assert upper<=self.ds.length |
642 minibatch = Example(self.ds._fields.keys(), | 648 minibatch = Example(self.ds._fields.keys(), |
643 [field[self.next_example:upper] | 649 [field[self.next_example:upper] |
644 for field in self.ds._fields]) | 650 for field in self.ds._fields]) |
645 self.next_example+=minibatch_size | 651 self.next_example+=minibatch_size |
646 return DataSetFields(MinibatchDataSet(minibatch),*fieldnames) | 652 return minibatch |
647 | 653 |
648 return Iterator(self) | 654 return Iterator(self) |
649 | 655 |
650 def valuesVStack(self,fieldname,fieldvalues): | 656 def valuesVStack(self,fieldname,fieldvalues): |
651 return self.values_vstack(fieldname,fieldvalues) | 657 return self.values_vstack(fieldname,fieldvalues) |
714 self.iterators=iterators | 720 self.iterators=iterators |
715 def __iter__(self): | 721 def __iter__(self): |
716 return self | 722 return self |
717 def next(self): | 723 def next(self): |
718 # concatenate all the fields of the minibatches | 724 # concatenate all the fields of the minibatches |
719 minibatch = reduce(LookupList.__add__,[iterator.next() for iterator in self.iterators]) | 725 return reduce(LookupList.__add__,[iterator.next() for iterator in self.iterators]) |
720 # and return a DataSetFields whose dataset is the transpose (=examples()) of this minibatch | |
721 return DataSetFields(MinibatchDataSet(minibatch,self.hsds.valuesVStack, | |
722 self.hsds.valuesHStack), | |
723 fieldnames if fieldnames else hsds.fieldNames()) | |
724 | 726 |
725 assert self.hasfields(fieldnames) | 727 assert self.hasfields(fieldnames) |
726 # find out which underlying datasets are necessary to service the required fields | 728 # find out which underlying datasets are necessary to service the required fields |
727 # and construct corresponding minibatch iterators | 729 # and construct corresponding minibatch iterators |
728 if fieldnames: | 730 if fieldnames: |
847 if self.n_left_in_mb: | 849 if self.n_left_in_mb: |
848 extra_mb = [] | 850 extra_mb = [] |
849 while self.n_left_in_mb>0: | 851 while self.n_left_in_mb>0: |
850 self.move_to_next_dataset() | 852 self.move_to_next_dataset() |
851 extra_mb.append(self.next_iterator.next()) | 853 extra_mb.append(self.next_iterator.next()) |
852 examples = Example(names, | 854 mb = Example(fieldnames, |
853 [dataset.valuesVStack(name, | 855 [dataset.valuesVStack(name, |
854 [mb[name]]+[b[name] for b in extra_mb]) | 856 [mb[name]]+[b[name] for b in extra_mb]) |
855 for name in fieldnames]) | 857 for name in fieldnames]) |
856 mb = DataSetFields(MinibatchDataSet(examples),fieldnames) | |
857 | 858 |
858 self.next_row+=minibatch_size | 859 self.next_row+=minibatch_size |
859 self.next_dataset_row+=minibatch_size | 860 self.next_dataset_row+=minibatch_size |
860 if self.next_row+minibatch_size>len(dataset): | 861 if self.next_row+minibatch_size>len(dataset): |
861 self.move_to_next_dataset() | 862 self.move_to_next_dataset() |
924 if type(i) is int: | 925 if type(i) is int: |
925 return Example(fieldnames, | 926 return Example(fieldnames, |
926 [self.data[i,self.fields_columns[f]] for f in fieldnames]) | 927 [self.data[i,self.fields_columns[f]] for f in fieldnames]) |
927 if type(i) in (slice,list): | 928 if type(i) in (slice,list): |
928 return MinibatchDataSet(Example(fieldnames, | 929 return MinibatchDataSet(Example(fieldnames, |
929 [self.data[i,self.fields_columns[f]] for f in fieldnames])) | 930 [self.data[i,self.fields_columns[f]] for f in fieldnames]), |
931 self.valuesVStack,self.valuesHStack) | |
930 # else check for a fieldname | 932 # else check for a fieldname |
931 if self.hasFields(i): | 933 if self.hasFields(i): |
932 return Example([i],[self.data[self.fields_columns[i],:]]) | 934 return Example([i],[self.data[self.fields_columns[i],:]]) |
933 # else we are trying to access a property of the dataset | 935 # else we are trying to access a property of the dataset |
934 assert i in self.__dict__ # else it means we are trying to access a non-existing property | 936 assert i in self.__dict__ # else it means we are trying to access a non-existing property |
965 by caching every example value that has been accessed at least once. | 967 by caching every example value that has been accessed at least once. |
966 | 968 |
967 Optionally, for finite-length dataset, all the values can be computed | 969 Optionally, for finite-length dataset, all the values can be computed |
968 (and cached) upon construction of the CachedDataSet, rather at the | 970 (and cached) upon construction of the CachedDataSet, rather at the |
969 first access. | 971 first access. |
972 | |
973 TODO: add disk-buffering capability, so that when the cache becomes too | |
974 big for memory, we cache things on disk, trying to keep in memory only | |
975 the record most likely to be accessed next. | |
970 """ | 976 """ |
971 | 977 def __init__(self,source_dataset,cache_all_upon_construction=False): |
978 self.source_dataset=source_dataset | |
979 self.cache_all_upon_construction=cache_all_upon_construction | |
980 if cache_all_upon_construction: | |
981 self.cached_examples = zip(*source_dataset.minibatches(minibatch_size=len(source_dataset)).__iter__().next()) | |
982 else: | |
983 self.cached_examples = [] | |
984 | |
985 self.fieldNames = source_dataset.fieldNames | |
986 self.hasFields = source_dataset.hasFields | |
987 self.valuesHStack = source_dataset.valuesHStack | |
988 self.valuesVStack = source_dataset.valuesVStack | |
989 | |
990 def __len__(self): | |
991 return len(self.source_dataset) | |
992 | |
993 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): | |
994 class CacheIterator(object): | |
995 def __init__(self,dataset): | |
996 self.dataset=dataset | |
997 self.current=offset | |
998 def __iter__(self): return self | |
999 def next(self): | |
1000 upper = self.current+minibatch_size | |
1001 cache_len = len(self.dataset.cached_examples) | |
1002 if upper>=cache_len: # whole minibatch is not already in cache | |
1003 # cache everything from current length to upper | |
1004 for example in self.dataset.source_dataset[cache_len:upper]: | |
1005 self.dataset.cached_examples.append(example) | |
1006 all_fields_minibatch = Example(self.dataset.fieldNames(), | |
1007 self.dataset.cached_examples[self.current:self.current+minibatch_size]) | |
1008 if self.dataset.fieldNames()==fieldnames: | |
1009 return all_fields_minibatch | |
1010 return Example(fieldnames,[all_fields_minibatch[name] for name in fieldnames]) | |
1011 return CacheIterator(self) | |
1012 | |
1013 | |
972 class ApplyFunctionDataSet(DataSet): | 1014 class ApplyFunctionDataSet(DataSet): |
973 """ | 1015 """ |
974 A dataset that contains as fields the results of applying a given function | 1016 A dataset that contains as fields the results of applying a given function |
975 example-wise or minibatch-wise to all the fields of an input dataset. | 1017 example-wise or minibatch-wise to all the fields of an input dataset. |
976 The output of the function should be an iterable (e.g. a list or a LookupList) | 1018 The output of the function should be an iterable (e.g. a list or a LookupList) |
977 over the resulting values. In minibatch mode, the function is expected | 1019 over the resulting values. |
978 to work on minibatches (takes a minibatch in input and returns a minibatch | 1020 |
979 in output). | 1021 In minibatch mode, the function is expected to work on minibatches (takes |
1022 a minibatch in input and returns a minibatch in output). More precisely, | |
1023 it means that each element of the input or output list should be iterable | |
1024 and indexable over the individual example values (typically these | |
1025 elements will be numpy arrays). All of the elements in the input and | |
1026 output lists should have the same length, which is the length of the | |
1027 minibatch. | |
980 | 1028 |
981 The function is applied each time an example or a minibatch is accessed. | 1029 The function is applied each time an example or a minibatch is accessed. |
982 To avoid re-doing computation, wrap this dataset inside a CachedDataSet. | 1030 To avoid re-doing computation, wrap this dataset inside a CachedDataSet. |
1031 | |
1032 If the values_{h,v}stack functions are not provided, then | |
1033 the input_dataset.values{H,V}Stack functions are used by default. | |
983 """ | 1034 """ |
1035 def __init__(self,input_dataset,function,output_names,minibatch_mode=True, | |
1036 values_hstack=None,values_vstack, | |
1037 description=None,fieldtypes=None): | |
1038 """ | |
1039 Constructor takes an input dataset that has as many fields as the function | |
1040 expects as inputs. The resulting dataset has as many fields as the function | |
1041 produces as outputs, and that should correspond to the number of output names | |
1042 (provided in a list). | |
1043 | |
1044 Note that the expected semantics of the function differs in minibatch mode | |
1045 (it takes minibatches of inputs and produces minibatches of outputs, as | |
1046 documented in the class comment). | |
1047 """ | |
1048 self.input_dataset=input_dataset | |
1049 self.function=function | |
1050 self.output_names=output_names | |
1051 self.minibatch_mode=minibatch_mode | |
1052 DataSet.__init__(description,fieldtypes) | |
1053 self.valuesHStack = values_hstack if values_hstack else input_dataset.valuesHStack | |
1054 self.valuesVStack = values_vstack if values_vstack else input_dataset.valuesVStack | |
1055 | |
1056 def __len__(self): | |
1057 return len(self.input_dataset) | |
1058 | |
1059 def fieldnames(self): | |
1060 return self.output_names | |
1061 | |
1062 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): | |
1063 class ApplyFunctionIterator(object): | |
1064 def __init__(self,output_dataset): | |
1065 self.input_dataset=output_dataset.input_dataset | |
1066 self.output_dataset=output_dataset | |
1067 self.input_iterator=input_dataset.minibatches(minibatch_size=minibatch_size, | |
1068 n_batches=n_batches,offset=offset).__iter__() | |
1069 | |
1070 def __iter__(self): return self | |
1071 | |
1072 def next(self): | |
1073 function_inputs = self.input_iterator.next() | |
1074 all_output_names = self.output_dataset.output_names | |
1075 if self.output_dataset.minibatch_mode: | |
1076 function_outputs = self.output_dataset.function(function_inputs) | |
1077 else: | |
1078 input_examples = zip(*function_inputs) | |
1079 output_examples = [self.output_dataset.function(input_example) | |
1080 for input_example in input_examples] | |
1081 function_outputs = [self.output_dataset.valuesVStack(name,values) | |
1082 for name,values in zip(all_output_names, | |
1083 zip(*output_examples))] | |
1084 all_outputs = Example(all_output_names,function_outputs) | |
1085 if fieldnames==all_output_names: | |
1086 return all_outputs | |
1087 return Example(fieldnames,[all_outputs[name] for name in fieldnames]) | |
1088 | |
1089 return ApplyFunctionIterator(self.input_dataset,self) | |
1090 | |
1091 def __iter__(self): # only implemented for increased efficiency | |
1092 class ApplyFunctionSingleExampleIterator(object): | |
1093 def __init__(self,output_dataset): | |
1094 self.current=0 | |
1095 self.output_dataset=output_dataset | |
1096 self.input_iterator=output_dataset.input_dataset.__iter__() | |
1097 def __iter__(self): return self | |
1098 def next(self): | |
1099 function_inputs = self.input_iterator.next() | |
1100 if self.output_dataset.minibatch_mode: | |
1101 function_outputs = [output[0] for output in self.output_dataset.function(function_inputs)] | |
1102 else: | |
1103 function_outputs = self.output_dataset.function(function_inputs) | |
1104 return Example(self.output_dataset.output_names,function_outputs) | |
1105 return ApplyFunctionSingleExampleIterator(self) | |
984 | 1106 |
985 | 1107 |
986 def supervised_learning_dataset(src_dataset,input_fields,target_fields,weight_field=None): | 1108 def supervised_learning_dataset(src_dataset,input_fields,target_fields,weight_field=None): |
987 """ | 1109 """ |
988 Wraps an arbitrary DataSet into one for supervised learning tasks by forcing the | 1110 Wraps an arbitrary DataSet into one for supervised learning tasks by forcing the |