comparison dataset.py @ 73:69f97aad3faf

Coded untested ApplyFunctionDataSet and CacheDataSet
author bengioy@bengiomac.local
date Sat, 03 May 2008 14:29:56 -0400
parents 2b6656b2ef52
children b4159cbdc06b
comparison
equal deleted inserted replaced
72:2b6656b2ef52 73:69f97aad3faf
225 for name in self.fieldnames]) 225 for name in self.fieldnames])
226 self.next_row=upper 226 self.next_row=upper
227 self.n_batches_done+=1 227 self.n_batches_done+=1
228 if upper >= self.L and self.n_batches: 228 if upper >= self.L and self.n_batches:
229 self.next_row -= self.L 229 self.next_row -= self.L
230 return minibatch 230 return DataSetFields(MinibatchDataSet(minibatch,self.dataset.valuesVStack,
231 self.dataset.valuesHStack),
232 minibatch.keys()))
231 233
232 234
233 minibatches_fieldnames = None 235 minibatches_fieldnames = None
234 minibatches_minibatch_size = 1 236 minibatches_minibatch_size = 1
235 minibatches_n_batches = None 237 minibatches_n_batches = None
390 examples = [self[row] for row in rows] 392 examples = [self[row] for row in rows]
391 fields_values = zip(*examples) 393 fields_values = zip(*examples)
392 return MinibatchDataSet( 394 return MinibatchDataSet(
393 Example(self.fieldNames(),[ self.valuesVStack(fieldname,field_values) 395 Example(self.fieldNames(),[ self.valuesVStack(fieldname,field_values)
394 for fieldname,field_values 396 for fieldname,field_values
395 in zip(self.fieldNames(),fields_values)])) 397 in zip(self.fieldNames(),fields_values)]),
398 self.valuesVStack,self.valuesHStack)
396 # else check for a fieldname 399 # else check for a fieldname
397 if self.hasFields(i): 400 if self.hasFields(i):
398 return self.minibatches(fieldnames=[i],minibatch_size=len(self),n_batches=1,offset=0).next()[0] 401 return self.minibatches(fieldnames=[i],minibatch_size=len(self),n_batches=1,offset=0).next()[0]
399 # else we are trying to access a property of the dataset 402 # else we are trying to access a property of the dataset
400 assert i in self.__dict__ # else it means we are trying to access a non-existing property 403 assert i in self.__dict__ # else it means we are trying to access a non-existing property
607 610
608 def __len__(self): 611 def __len__(self):
609 return self.length 612 return self.length
610 613
611 def __getitem__(self,i): 614 def __getitem__(self,i):
612 if type(i) in (int,slice,list): 615 if type(i) is int:
613 return DataSetFields(MinibatchDataSet( 616 return Example(self._fields.keys(),[field[i] for field in self._fields])
614 Example(self._fields.keys(),[field[i] for field in self._fields])),self._fields) 617 if type(i) in (slice,list):
618 return MinibatchDataSet(Example(self._fields.keys(),
619 [field[i] for field in self._fields]),
620 self.valuesVStack,self.valuesHStack)
615 if self.hasFields(i): 621 if self.hasFields(i):
616 return self._fields[i] 622 return self._fields[i]
617 assert i in self.__dict__ # else it means we are trying to access a non-existing property 623 assert i in self.__dict__ # else it means we are trying to access a non-existing property
618 return self.__dict__[i] 624 return self.__dict__[i]
619 625
641 assert upper<=self.ds.length 647 assert upper<=self.ds.length
642 minibatch = Example(self.ds._fields.keys(), 648 minibatch = Example(self.ds._fields.keys(),
643 [field[self.next_example:upper] 649 [field[self.next_example:upper]
644 for field in self.ds._fields]) 650 for field in self.ds._fields])
645 self.next_example+=minibatch_size 651 self.next_example+=minibatch_size
646 return DataSetFields(MinibatchDataSet(minibatch),*fieldnames) 652 return minibatch
647 653
648 return Iterator(self) 654 return Iterator(self)
649 655
650 def valuesVStack(self,fieldname,fieldvalues): 656 def valuesVStack(self,fieldname,fieldvalues):
651 return self.values_vstack(fieldname,fieldvalues) 657 return self.values_vstack(fieldname,fieldvalues)
714 self.iterators=iterators 720 self.iterators=iterators
715 def __iter__(self): 721 def __iter__(self):
716 return self 722 return self
717 def next(self): 723 def next(self):
718 # concatenate all the fields of the minibatches 724 # concatenate all the fields of the minibatches
719 minibatch = reduce(LookupList.__add__,[iterator.next() for iterator in self.iterators]) 725 return reduce(LookupList.__add__,[iterator.next() for iterator in self.iterators])
720 # and return a DataSetFields whose dataset is the transpose (=examples()) of this minibatch
721 return DataSetFields(MinibatchDataSet(minibatch,self.hsds.valuesVStack,
722 self.hsds.valuesHStack),
723 fieldnames if fieldnames else hsds.fieldNames())
724 726
725 assert self.hasfields(fieldnames) 727 assert self.hasfields(fieldnames)
726 # find out which underlying datasets are necessary to service the required fields 728 # find out which underlying datasets are necessary to service the required fields
727 # and construct corresponding minibatch iterators 729 # and construct corresponding minibatch iterators
728 if fieldnames: 730 if fieldnames:
847 if self.n_left_in_mb: 849 if self.n_left_in_mb:
848 extra_mb = [] 850 extra_mb = []
849 while self.n_left_in_mb>0: 851 while self.n_left_in_mb>0:
850 self.move_to_next_dataset() 852 self.move_to_next_dataset()
851 extra_mb.append(self.next_iterator.next()) 853 extra_mb.append(self.next_iterator.next())
852 examples = Example(names, 854 mb = Example(fieldnames,
853 [dataset.valuesVStack(name, 855 [dataset.valuesVStack(name,
854 [mb[name]]+[b[name] for b in extra_mb]) 856 [mb[name]]+[b[name] for b in extra_mb])
855 for name in fieldnames]) 857 for name in fieldnames])
856 mb = DataSetFields(MinibatchDataSet(examples),fieldnames)
857 858
858 self.next_row+=minibatch_size 859 self.next_row+=minibatch_size
859 self.next_dataset_row+=minibatch_size 860 self.next_dataset_row+=minibatch_size
860 if self.next_row+minibatch_size>len(dataset): 861 if self.next_row+minibatch_size>len(dataset):
861 self.move_to_next_dataset() 862 self.move_to_next_dataset()
924 if type(i) is int: 925 if type(i) is int:
925 return Example(fieldnames, 926 return Example(fieldnames,
926 [self.data[i,self.fields_columns[f]] for f in fieldnames]) 927 [self.data[i,self.fields_columns[f]] for f in fieldnames])
927 if type(i) in (slice,list): 928 if type(i) in (slice,list):
928 return MinibatchDataSet(Example(fieldnames, 929 return MinibatchDataSet(Example(fieldnames,
929 [self.data[i,self.fields_columns[f]] for f in fieldnames])) 930 [self.data[i,self.fields_columns[f]] for f in fieldnames]),
931 self.valuesVStack,self.valuesHStack)
930 # else check for a fieldname 932 # else check for a fieldname
931 if self.hasFields(i): 933 if self.hasFields(i):
932 return Example([i],[self.data[self.fields_columns[i],:]]) 934 return Example([i],[self.data[self.fields_columns[i],:]])
933 # else we are trying to access a property of the dataset 935 # else we are trying to access a property of the dataset
934 assert i in self.__dict__ # else it means we are trying to access a non-existing property 936 assert i in self.__dict__ # else it means we are trying to access a non-existing property
965 by caching every example value that has been accessed at least once. 967 by caching every example value that has been accessed at least once.
966 968
967 Optionally, for finite-length dataset, all the values can be computed 969 Optionally, for finite-length dataset, all the values can be computed
968 (and cached) upon construction of the CachedDataSet, rather at the 970 (and cached) upon construction of the CachedDataSet, rather at the
969 first access. 971 first access.
972
973 TODO: add disk-buffering capability, so that when the cache becomes too
974 big for memory, we cache things on disk, trying to keep in memory only
975 the record most likely to be accessed next.
970 """ 976 """
971 977 def __init__(self,source_dataset,cache_all_upon_construction=False):
978 self.source_dataset=source_dataset
979 self.cache_all_upon_construction=cache_all_upon_construction
980 if cache_all_upon_construction:
981 self.cached_examples = zip(*source_dataset.minibatches(minibatch_size=len(source_dataset)).__iter__().next())
982 else:
983 self.cached_examples = []
984
985 self.fieldNames = source_dataset.fieldNames
986 self.hasFields = source_dataset.hasFields
987 self.valuesHStack = source_dataset.valuesHStack
988 self.valuesVStack = source_dataset.valuesVStack
989
990 def __len__(self):
991 return len(self.source_dataset)
992
993 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset):
994 class CacheIterator(object):
995 def __init__(self,dataset):
996 self.dataset=dataset
997 self.current=offset
998 def __iter__(self): return self
999 def next(self):
1000 upper = self.current+minibatch_size
1001 cache_len = len(self.dataset.cached_examples)
1002 if upper>=cache_len: # whole minibatch is not already in cache
1003 # cache everything from current length to upper
1004 for example in self.dataset.source_dataset[cache_len:upper]:
1005 self.dataset.cached_examples.append(example)
1006 all_fields_minibatch = Example(self.dataset.fieldNames(),
1007 self.dataset.cached_examples[self.current:self.current+minibatch_size])
1008 if self.dataset.fieldNames()==fieldnames:
1009 return all_fields_minibatch
1010 return Example(fieldnames,[all_fields_minibatch[name] for name in fieldnames])
1011 return CacheIterator(self)
1012
1013
972 class ApplyFunctionDataSet(DataSet): 1014 class ApplyFunctionDataSet(DataSet):
973 """ 1015 """
974 A dataset that contains as fields the results of applying a given function 1016 A dataset that contains as fields the results of applying a given function
975 example-wise or minibatch-wise to all the fields of an input dataset. 1017 example-wise or minibatch-wise to all the fields of an input dataset.
976 The output of the function should be an iterable (e.g. a list or a LookupList) 1018 The output of the function should be an iterable (e.g. a list or a LookupList)
977 over the resulting values. In minibatch mode, the function is expected 1019 over the resulting values.
978 to work on minibatches (takes a minibatch in input and returns a minibatch 1020
979 in output). 1021 In minibatch mode, the function is expected to work on minibatches (takes
1022 a minibatch in input and returns a minibatch in output). More precisely,
1023 it means that each element of the input or output list should be iterable
1024 and indexable over the individual example values (typically these
1025 elements will be numpy arrays). All of the elements in the input and
1026 output lists should have the same length, which is the length of the
1027 minibatch.
980 1028
981 The function is applied each time an example or a minibatch is accessed. 1029 The function is applied each time an example or a minibatch is accessed.
982 To avoid re-doing computation, wrap this dataset inside a CachedDataSet. 1030 To avoid re-doing computation, wrap this dataset inside a CachedDataSet.
1031
1032 If the values_{h,v}stack functions are not provided, then
1033 the input_dataset.values{H,V}Stack functions are used by default.
983 """ 1034 """
1035 def __init__(self,input_dataset,function,output_names,minibatch_mode=True,
1036 values_hstack=None,values_vstack,
1037 description=None,fieldtypes=None):
1038 """
1039 Constructor takes an input dataset that has as many fields as the function
1040 expects as inputs. The resulting dataset has as many fields as the function
1041 produces as outputs, and that should correspond to the number of output names
1042 (provided in a list).
1043
1044 Note that the expected semantics of the function differs in minibatch mode
1045 (it takes minibatches of inputs and produces minibatches of outputs, as
1046 documented in the class comment).
1047 """
1048 self.input_dataset=input_dataset
1049 self.function=function
1050 self.output_names=output_names
1051 self.minibatch_mode=minibatch_mode
1052 DataSet.__init__(description,fieldtypes)
1053 self.valuesHStack = values_hstack if values_hstack else input_dataset.valuesHStack
1054 self.valuesVStack = values_vstack if values_vstack else input_dataset.valuesVStack
1055
1056 def __len__(self):
1057 return len(self.input_dataset)
1058
1059 def fieldnames(self):
1060 return self.output_names
1061
1062 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset):
1063 class ApplyFunctionIterator(object):
1064 def __init__(self,output_dataset):
1065 self.input_dataset=output_dataset.input_dataset
1066 self.output_dataset=output_dataset
1067 self.input_iterator=input_dataset.minibatches(minibatch_size=minibatch_size,
1068 n_batches=n_batches,offset=offset).__iter__()
1069
1070 def __iter__(self): return self
1071
1072 def next(self):
1073 function_inputs = self.input_iterator.next()
1074 all_output_names = self.output_dataset.output_names
1075 if self.output_dataset.minibatch_mode:
1076 function_outputs = self.output_dataset.function(function_inputs)
1077 else:
1078 input_examples = zip(*function_inputs)
1079 output_examples = [self.output_dataset.function(input_example)
1080 for input_example in input_examples]
1081 function_outputs = [self.output_dataset.valuesVStack(name,values)
1082 for name,values in zip(all_output_names,
1083 zip(*output_examples))]
1084 all_outputs = Example(all_output_names,function_outputs)
1085 if fieldnames==all_output_names:
1086 return all_outputs
1087 return Example(fieldnames,[all_outputs[name] for name in fieldnames])
1088
1089 return ApplyFunctionIterator(self.input_dataset,self)
1090
1091 def __iter__(self): # only implemented for increased efficiency
1092 class ApplyFunctionSingleExampleIterator(object):
1093 def __init__(self,output_dataset):
1094 self.current=0
1095 self.output_dataset=output_dataset
1096 self.input_iterator=output_dataset.input_dataset.__iter__()
1097 def __iter__(self): return self
1098 def next(self):
1099 function_inputs = self.input_iterator.next()
1100 if self.output_dataset.minibatch_mode:
1101 function_outputs = [output[0] for output in self.output_dataset.function(function_inputs)]
1102 else:
1103 function_outputs = self.output_dataset.function(function_inputs)
1104 return Example(self.output_dataset.output_names,function_outputs)
1105 return ApplyFunctionSingleExampleIterator(self)
984 1106
985 1107
986 def supervised_learning_dataset(src_dataset,input_fields,target_fields,weight_field=None): 1108 def supervised_learning_dataset(src_dataset,input_fields,target_fields,weight_field=None):
987 """ 1109 """
988 Wraps an arbitrary DataSet into one for supervised learning tasks by forcing the 1110 Wraps an arbitrary DataSet into one for supervised learning tasks by forcing the