comparison dataset.py @ 268:3f1cd8897fda

reverting dataset
author James Bergstra <bergstrj@iro.umontreal.ca>
date Wed, 04 Jun 2008 18:48:50 -0400
parents 6e69fb91f3c0
children fdce496c3b56
comparison
equal deleted inserted replaced
267:4dad41215967 268:3f1cd8897fda
107 107
108 - dataset[i] returns an Example. 108 - dataset[i] returns an Example.
109 109
110 - dataset[[i1,i2,...in]] returns a dataset with examples i1,i2,...in. 110 - dataset[[i1,i2,...in]] returns a dataset with examples i1,i2,...in.
111 111
112 - dataset[fieldname] an iterable over the values of the field fieldname across
113 the dataset (the iterable is obtained by default by calling valuesVStack
114 over the values for individual examples).
115
112 - dataset.<property> returns the value of a property associated with 116 - dataset.<property> returns the value of a property associated with
113 the name <property>. The following properties should be supported: 117 the name <property>. The following properties should be supported:
114 - 'description': a textual description or name for the dataset 118 - 'description': a textual description or name for the dataset
115 - 'fieldtypes': a list of types (one per field) 119 - 'fieldtypes': a list of types (one per field)
116 A DataSet may have other attributes that it makes visible to other objects. These are 120 A DataSet may have other attributes that it makes visible to other objects. These are
156 A sub-class should also append attributes to self._attribute_names 160 A sub-class should also append attributes to self._attribute_names
157 (the default value returned by attributeNames()). 161 (the default value returned by attributeNames()).
158 By convention, attributes not in attributeNames() should have a name 162 By convention, attributes not in attributeNames() should have a name
159 starting with an underscore. 163 starting with an underscore.
160 @todo enforce/test that convention! 164 @todo enforce/test that convention!
161 165 """
162 """ 166
163 167 numpy_vstack = lambda fieldname,values: numpy.vstack(values)
164 if 0: 168 numpy_hstack = lambda fieldnames,values: numpy.hstack(values)
165 # removed by James June 4... these aren't used anywhere according to
166 # grep
167 numpy_vstack = lambda fieldname,values: numpy.vstack(values)
168 numpy_hstack = lambda fieldnames,values: numpy.hstack(values)
169 169
170 def __init__(self,description=None,fieldtypes=None): 170 def __init__(self,description=None,fieldtypes=None):
171 if description is None: 171 if description is None:
172 # by default return "<DataSetType>(<SuperClass1>,<SuperClass2>,...)" 172 # by default return "<DataSetType>(<SuperClass1>,<SuperClass2>,...)"
173 description = type(self).__name__ + " ( " + join([x.__name__ for x in type(self).__bases__]) + " )" 173 description = type(self).__name__ + " ( " + join([x.__name__ for x in type(self).__bases__]) + " )"
275 else: 275 else:
276 # we must concatenate (vstack) the bottom and top parts of our minibatch 276 # we must concatenate (vstack) the bottom and top parts of our minibatch
277 # first get the beginning of our minibatch (top of dataset) 277 # first get the beginning of our minibatch (top of dataset)
278 first_part = self.dataset.minibatches_nowrap(self.fieldnames,self.L-self.next_row,1,self.next_row).next() 278 first_part = self.dataset.minibatches_nowrap(self.fieldnames,self.L-self.next_row,1,self.next_row).next()
279 second_part = self.dataset.minibatches_nowrap(self.fieldnames,upper-self.L,1,0).next() 279 second_part = self.dataset.minibatches_nowrap(self.fieldnames,upper-self.L,1,0).next()
280 280 minibatch = Example(self.fieldnames,
281 blah = [self.dataset.valuesVStack(name,[first_part[name],second_part[name]]) 281 [self.dataset.valuesVStack(name,[first_part[name],second_part[name]])
282 for name in self.fieldnames] 282 for name in self.fieldnames])
283 print type(self.dataset), blah
284 minibatch = Example(self.fieldnames,blah)
285 self.next_row=upper 283 self.next_row=upper
286 self.n_batches_done+=1 284 self.n_batches_done+=1
287 if upper >= self.L and self.n_batches: 285 if upper >= self.L and self.n_batches:
288 self.next_row -= self.L 286 self.next_row -= self.L
289 ds_nbatches = (self.L-self.next_row)/self.minibatch_size 287 ds_nbatches = (self.L-self.next_row)/self.minibatch_size
460 return MinibatchDataSet( 458 return MinibatchDataSet(
461 Example(self.fieldNames(),[ self.valuesVStack(fieldname,field_values) 459 Example(self.fieldNames(),[ self.valuesVStack(fieldname,field_values)
462 for fieldname,field_values 460 for fieldname,field_values
463 in zip(self.fieldNames(),fields_values)]), 461 in zip(self.fieldNames(),fields_values)]),
464 self.valuesVStack,self.valuesHStack) 462 self.valuesVStack,self.valuesHStack)
465 463 # else check for a fieldname
466 raise TypeError(i) 464 if self.hasFields(i):
467 if 0: 465 return self.minibatches(fieldnames=[i],minibatch_size=len(self),n_batches=1,offset=0).next()[0]
468 # else check for a fieldname 466 # else we are trying to access a property of the dataset
469 #after talk with Yoshua June 4, this is disabled. 467 assert i in self.__dict__ # else it means we are trying to access a non-existing property
470 if self.hasFields(i): 468 return self.__dict__[i]
471 return self.minibatches(fieldnames=[i],minibatch_size=len(self),n_batches=1,offset=0).next()[0]
472 # else we are trying to access a property of the dataset
473 assert i in self.__dict__ # else it means we are trying to access a non-existing property
474 return self.__dict__[i]
475 469
476 def valuesHStack(self,fieldnames,fieldvalues): 470 def valuesHStack(self,fieldnames,fieldvalues):
477 """ 471 """
478 Return a value that corresponds to concatenating (horizontally) several field values. 472 Return a value that corresponds to concatenating (horizontally) several field values.
479 This can be useful to merge some fields. The implementation of this operation is likely 473 This can be useful to merge some fields. The implementation of this operation is likely
495 return fieldvalues 489 return fieldvalues
496 490
497 491
498 def valuesVStack(self,fieldname,values): 492 def valuesVStack(self,fieldname,values):
499 """ 493 """
500 @param fieldname: the name of the field from which the values were taken 494 Return a value that corresponds to concatenating (vertically) several values of the
501 @type fieldname: any type 495 same field. This can be important to build a minibatch out of individual examples. This
502 496 is likely to involve a copy of the original values. When the values are numpy arrays, the
503 @param values: bits near the beginning or end of the dataset 497 result should be numpy.vstack(values).
504 @type values: list of minibatches (returned by minibatch_nowrap) 498 The default is to use numpy.vstack for numpy.ndarray values, and a list
505 499 pointing to the original values for other data types.
506 @return: the concatenation (stacking) of the values 500 """
507 @rtype: something suitable as a minibatch field 501 all_numpy=True
508 502 for value in values:
509 """ 503 if not type(value) is numpy.ndarray:
510 rval = [] 504 all_numpy=False
511 for sub_batch in values: 505 if all_numpy:
512 rval.extend(sub_batch) 506 return numpy.vstack(values)
513 return rval 507 # the default implementation of vertical stacking is to put values in a list
508 return values
514 509
515 def __or__(self,other): 510 def __or__(self,other):
516 """ 511 """
517 dataset1 | dataset2 returns a dataset whose list of fields is the concatenation of the list of 512 dataset1 | dataset2 returns a dataset whose list of fields is the concatenation of the list of
518 fields of the argument datasets. This only works if they all have the same length. 513 fields of the argument datasets. This only works if they all have the same length.
956 class ArrayFieldsDataSet(DataSet): 951 class ArrayFieldsDataSet(DataSet):
957 """ 952 """
958 Virtual super-class of datasets whose field values are numpy array, 953 Virtual super-class of datasets whose field values are numpy array,
959 thus defining valuesHStack and valuesVStack for sub-classes. 954 thus defining valuesHStack and valuesVStack for sub-classes.
960 """ 955 """
961 def __init__(self, description=None, field_types=None): 956 def __init__(self,description=None,field_types=None):
962 DataSet.__init__(self, description, field_types) 957 DataSet.__init__(self,description,field_types)
963 def valuesHStack(self, fieldnames, fieldvalues): 958 def valuesHStack(self,fieldnames,fieldvalues):
964 """Concatenate field values horizontally, e.g. two vectors 959 """Concatenate field values horizontally, e.g. two vectors
965 become a longer vector, two matrices become a wider matrix, etc.""" 960 become a longer vector, two matrices become a wider matrix, etc."""
966 return numpy.hstack(fieldvalues) 961 return numpy.hstack(fieldvalues)
967 def valuesVStack(self, fieldname, values): 962 def valuesVStack(self,fieldname,values):
968 """Concatenate field values vertically, e.g. two vectors 963 """Concatenate field values vertically, e.g. two vectors
969 become a two-row matrix, two matrices become a longer matrix, etc.""" 964 become a two-row matrix, two matrices become a longer matrix, etc."""
970 #print len(values) 965 return numpy.vstack(values)
971 for v in values:
972 if not isinstance(v, numpy.ndarray):
973 raise TypeError(v, type(v))
974
975 s0 = sum([v.shape[0] for v in values])
976 #TODO: there's gotta be a better way to do this!
977 dtype = values[0].dtype
978 rval = numpy.ndarray([s0] + list(values[0].shape[1:]), dtype=dtype)
979 cur_row = 0
980 for v in values:
981 rval[cur_row:cur_row+v.shape[0]] = v
982 cur_row += v.shape[0]
983 return rval
984 966
985 class ArrayDataSet(ArrayFieldsDataSet): 967 class ArrayDataSet(ArrayFieldsDataSet):
986 """ 968 """
987 An ArrayDataSet stores the fields as groups of columns in a numpy tensor, 969 An ArrayDataSet stores the fields as groups of columns in a numpy tensor,
988 whose first axis iterates over examples, second axis determines fields. 970 whose first axis iterates over examples, second axis determines fields.
1003 985
1004 # check consistency and complete slices definitions 986 # check consistency and complete slices definitions
1005 for fieldname, fieldcolumns in self.fields_columns.items(): 987 for fieldname, fieldcolumns in self.fields_columns.items():
1006 if type(fieldcolumns) is int: 988 if type(fieldcolumns) is int:
1007 assert fieldcolumns>=0 and fieldcolumns<data_array.shape[1] 989 assert fieldcolumns>=0 and fieldcolumns<data_array.shape[1]
1008 if 0: 990 if 1:
1009 #I changed this because it didn't make sense to me, 991 #I changed this because it didn't make sense to me,
1010 # and it made it more difficult to write my learner. 992 # and it made it more difficult to write my learner.
1011 # If it breaks stuff, let's talk about it. 993 # If it breaks stuff, let's talk about it.
1012 # - James 22/05/2008 994 # - James 22/05/2008
1013 self.fields_columns[fieldname]=[fieldcolumns] 995 self.fields_columns[fieldname]=[fieldcolumns]
1035 """More efficient implementation than the default __getitem__""" 1017 """More efficient implementation than the default __getitem__"""
1036 fieldnames=self.fields_columns.keys() 1018 fieldnames=self.fields_columns.keys()
1037 values=self.fields_columns.values() 1019 values=self.fields_columns.values()
1038 if type(key) is int: 1020 if type(key) is int:
1039 return Example(fieldnames, 1021 return Example(fieldnames,
1040 [numpy.asarray(self.data[key,col]) for col in values]) 1022 [self.data[key,col] for col in values])
1041 if type(key) is slice: 1023 if type(key) is slice:
1042 return MinibatchDataSet(Example(fieldnames, 1024 return MinibatchDataSet(Example(fieldnames,
1043 [self.data[key,col] for col in values])) 1025 [self.data[key,col] for col in values]))
1044 if type(key) is list: 1026 if type(key) is list:
1045 for i in range(len(key)): 1027 for i in range(len(key)):
1113 1095
1114 return ArrayDataSetIterator(self,fieldnames,minibatch_size,n_batches,offset) 1096 return ArrayDataSetIterator(self,fieldnames,minibatch_size,n_batches,offset)
1115 1097
1116 1098
1117 class CachedDataSet(DataSet): 1099 class CachedDataSet(DataSet):
1118 """ 1100 """
1119 Wrap a L{DataSet} whose values are computationally expensive to obtain 1101 Wrap a L{DataSet} whose values are computationally expensive to obtain
1120 (e.g. because they involve some computation, or disk access), 1102 (e.g. because they involve some computation, or disk access),
1121 so that repeated accesses to the same example are done cheaply, 1103 so that repeated accesses to the same example are done cheaply,
1122 by caching every example value that has been accessed at least once. 1104 by caching every example value that has been accessed at least once.
1123 1105
1124 Optionally, for finite-length dataset, all the values can be computed 1106 Optionally, for finite-length dataset, all the values can be computed
1125 (and cached) upon construction of the CachedDataSet, rather at the 1107 (and cached) upon construction of the CachedDataSet, rather at the
1126 first access. 1108 first access.
1127 1109
1128 @todo: when cache_all_upon_construction create mini-batches that are as 1110 @todo: when cache_all_upon_construction create mini-batches that are as
1129 large as possible but not so large as to fill up memory. 1111 large as possible but not so large as to fill up memory.
1130 1112
1131 @todo: add disk-buffering capability, so that when the cache becomes too 1113 @todo: add disk-buffering capability, so that when the cache becomes too
1132 big for memory, we cache things on disk, trying to keep in memory only 1114 big for memory, we cache things on disk, trying to keep in memory only
1133 the record most likely to be accessed next. 1115 the record most likely to be accessed next.
1134 """ 1116 """
1135 def __init__(self,source_dataset,cache_all_upon_construction=False): 1117 def __init__(self,source_dataset,cache_all_upon_construction=False):
1136 self.source_dataset=source_dataset 1118 self.source_dataset=source_dataset
1137 self.cache_all_upon_construction=cache_all_upon_construction 1119 self.cache_all_upon_construction=cache_all_upon_construction
1138 self.cached_examples = [] #a list of LookupList (copies) 1120 self.cached_examples = []
1139 if cache_all_upon_construction: 1121 if cache_all_upon_construction:
1140 # this potentially brings all the source examples 1122 # this potentially brings all the source examples
1141 # into memory at once, which may be too much 1123 # into memory at once, which may be too much
1142 # the work could possibly be done by minibatches 1124 # the work could possibly be done by minibatches
1143 # that are as large as possible but no more than what memory allows. 1125 # that are as large as possible but no more than what memory allows.
1144 fields_values = source_dataset.minibatches(minibatch_size=len(source_dataset)).__iter__().next() 1126 fields_values = source_dataset.minibatches(minibatch_size=len(source_dataset)).__iter__().next()
1145 assert all([len(self)==len(fval) for fval in fields_values]) 1127 assert all([len(self)==len(field_values) for field_values in fields_values])
1146 for example in fields_values.examples(): 1128 for example in fields_values.examples():
1147 dup = copy.copy(example) 1129 self.cached_examples.append(copy.copy(example))
1148 self.cached_examples.append(dup) 1130
1149 1131 self.fieldNames = source_dataset.fieldNames
1150 self.fieldNames = source_dataset.fieldNames 1132 self.hasFields = source_dataset.hasFields
1151 self.hasFields = source_dataset.hasFields 1133 self.valuesHStack = source_dataset.valuesHStack
1152 self.valuesHStack = source_dataset.valuesHStack 1134 self.valuesVStack = source_dataset.valuesVStack
1153 self.valuesVStack = source_dataset.valuesVStack
1154 1135
1155 def __len__(self): 1136 def __len__(self):
1156 return len(self.source_dataset) 1137 return len(self.source_dataset)
1157 1138
1158 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): 1139 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset):
1159 class CacheIterator(object): 1140 class CacheIterator(object):
1160 def __init__(self,dataset): 1141 def __init__(self,dataset):
1161 self.dataset=dataset 1142 self.dataset=dataset
1162 self.current=offset 1143 self.current=offset
1163 self.all_fields = self.dataset.fieldNames()==fieldnames 1144 self.all_fields = self.dataset.fieldNames()==fieldnames
1164 def __iter__(self): return self 1145 def __iter__(self): return self
1165 def next(self): 1146 def next(self):
1166 upper = self.current+minibatch_size 1147 upper = self.current+minibatch_size
1167 cache_len = len(self.dataset.cached_examples) 1148 cache_len = len(self.dataset.cached_examples)
1168 if upper>cache_len: 1149 if upper>cache_len: # whole minibatch is not already in cache
1169 # whole minibatch is not already in cache 1150 # cache everything from current length to upper
1170 # cache everything from current length to upper 1151 for example in self.dataset.source_dataset[cache_len:upper]:
1171 for example in self.dataset.source_dataset[cache_len:upper]: 1152 self.dataset.cached_examples.append(example)
1172 self.dataset.cached_examples.append(example) 1153 all_fields_minibatch = Example(self.dataset.fieldNames(),
1173 1154 zip(*self.dataset.cached_examples[self.current:self.current+minibatch_size]))
1174 next_range = slice(self.current, self.current+minibatch_size) 1155 self.current+=minibatch_size
1175 blah = self.dataset.cached_examples[next_range] 1156 if self.all_fields:
1176 all_fields_minibatch = Example(self.dataset.fieldNames(), zip(*blah)) 1157 return all_fields_minibatch
1177 self.current+=minibatch_size 1158 return Example(fieldnames,[all_fields_minibatch[name] for name in fieldnames])
1178 1159 return CacheIterator(self)
1179 #little optimization to avoid second Example computation if 1160
1180 #possible. 1161 def __getitem__(self,i):
1181 if self.all_fields: 1162 if type(i)==int and len(self.cached_examples)>i:
1182 return all_fields_minibatch 1163 return self.cached_examples[i]
1183 1164 else:
1184 rval = Example(fieldnames,[all_fields_minibatch[name] for name in fieldnames]) 1165 return self.source_dataset[i]
1185 return rval 1166
1186 return CacheIterator(self) 1167 def __iter__(self):
1187 1168 class CacheIteratorIter(object):
1188 def __getitem__(self,i): 1169 def __init__(self,dataset):
1189 if type(i)==int and len(self.cached_examples)>i: 1170 self.dataset=dataset
1190 return self.cached_examples[i] 1171 self.l = len(dataset)
1191 else: 1172 self.current = 0
1192 return self.source_dataset[i] 1173 self.fieldnames = self.dataset.fieldNames()
1193 1174 self.example = LookupList(self.fieldnames,[0]*len(self.fieldnames))
1194 def __iter__(self): 1175 def __iter__(self): return self
1195 class CacheIteratorIter(object): 1176 def next(self):
1196 def __init__(self,dataset): 1177 if self.current>=self.l:
1197 self.dataset=dataset 1178 raise StopIteration
1198 self.l = len(dataset) 1179 cache_len = len(self.dataset.cached_examples)
1199 self.current = 0 1180 if self.current>=cache_len: # whole minibatch is not already in cache
1200 self.fieldnames = self.dataset.fieldNames() 1181 # cache everything from current length to upper
1201 self.example = LookupList(self.fieldnames,[0]*len(self.fieldnames)) 1182 self.dataset.cached_examples.append(
1202 def __iter__(self): return self 1183 self.dataset.source_dataset[self.current])
1203 def next(self): 1184 self.example._values = self.dataset.cached_examples[self.current]
1204 if self.current>=self.l: 1185 self.current+=1
1205 raise StopIteration 1186 return self.example
1206 cache_len = len(self.dataset.cached_examples) 1187
1207 if self.current>=cache_len: # whole minibatch is not already in cache 1188 return CacheIteratorIter(self)
1208 # cache everything from current length to upper
1209 self.dataset.cached_examples.append(
1210 self.dataset.source_dataset[self.current])
1211 self.example._values = self.dataset.cached_examples[self.current]
1212 self.current+=1
1213 return self.example
1214
1215 return CacheIteratorIter(self)
1216 1189
1217 class ApplyFunctionDataSet(DataSet): 1190 class ApplyFunctionDataSet(DataSet):
1218 """ 1191 """
1219 A L{DataSet} that contains as fields the results of applying a 1192 A L{DataSet} that contains as fields the results of applying a
1220 given function example-wise or minibatch-wise to all the fields of 1193 given function example-wise or minibatch-wise to all the fields of
1221 an input dataset. The output of the function should be an iterable 1194 an input dataset. The output of the function should be an iterable
1222 (e.g. a list or a LookupList) over the resulting values. 1195 (e.g. a list or a LookupList) over the resulting values.
1223 1196
1224 The function take as input the fields of the dataset, not the examples. 1197 The function take as input the fields of the dataset, not the examples.
1225 1198
1226 In minibatch mode, the function is expected to work on minibatches 1199 In minibatch mode, the function is expected to work on minibatches
1227 (takes a minibatch in input and returns a minibatch in output). More 1200 (takes a minibatch in input and returns a minibatch in output). More
1228 precisely, it means that each element of the input or output list 1201 precisely, it means that each element of the input or output list
1229 should be iterable and indexable over the individual example values 1202 should be iterable and indexable over the individual example values
1230 (typically these elements will be numpy arrays). All of the elements 1203 (typically these elements will be numpy arrays). All of the elements
1231 in the input and output lists should have the same length, which is 1204 in the input and output lists should have the same length, which is
1232 the length of the minibatch. 1205 the length of the minibatch.
1233 1206
1234 The function is applied each time an example or a minibatch is accessed. 1207 The function is applied each time an example or a minibatch is accessed.
1235 To avoid re-doing computation, wrap this dataset inside a CachedDataSet. 1208 To avoid re-doing computation, wrap this dataset inside a CachedDataSet.
1236 1209
1237 If the values_{h,v}stack functions are not provided, then 1210 If the values_{h,v}stack functions are not provided, then
1238 the input_dataset.values{H,V}Stack functions are used by default. 1211 the input_dataset.values{H,V}Stack functions are used by default.
1239 """ 1212 """
1240 def __init__(self,input_dataset,function,output_names,minibatch_mode=True, 1213 def __init__(self,input_dataset,function,output_names,minibatch_mode=True,
1241 values_hstack=None,values_vstack=None, 1214 values_hstack=None,values_vstack=None,
1242 description=None,fieldtypes=None): 1215 description=None,fieldtypes=None):
1243 """ 1216 """
1244 Constructor takes an input dataset that has as many fields as the function 1217 Constructor takes an input dataset that has as many fields as the function
1245 expects as inputs. The resulting dataset has as many fields as the function 1218 expects as inputs. The resulting dataset has as many fields as the function
1246 produces as outputs, and that should correspond to the number of output names 1219 produces as outputs, and that should correspond to the number of output names
1247 (provided in a list). 1220 (provided in a list).
1248 1221
1249 Note that the expected semantics of the function differs in minibatch mode 1222 Note that the expected semantics of the function differs in minibatch mode
1250 (it takes minibatches of inputs and produces minibatches of outputs, as 1223 (it takes minibatches of inputs and produces minibatches of outputs, as
1251 documented in the class comment). 1224 documented in the class comment).
1252 1225
1253 TBM: are filedtypes the old field types (from input_dataset) or the new ones 1226 TBM: are filedtypes the old field types (from input_dataset) or the new ones
1254 (for the new dataset created)? 1227 (for the new dataset created)?
1255 """ 1228 """
1256 self.input_dataset=input_dataset 1229 self.input_dataset=input_dataset
1257 self.function=function 1230 self.function=function
1258 self.output_names=output_names 1231 self.output_names=output_names
1259 self.minibatch_mode=minibatch_mode 1232 self.minibatch_mode=minibatch_mode
1260 DataSet.__init__(self,description,fieldtypes) 1233 DataSet.__init__(self,description,fieldtypes)
1261 self.valuesHStack = values_hstack if values_hstack else input_dataset.valuesHStack 1234 self.valuesHStack = values_hstack if values_hstack else input_dataset.valuesHStack
1262 self.valuesVStack = values_vstack if values_vstack else input_dataset.valuesVStack 1235 self.valuesVStack = values_vstack if values_vstack else input_dataset.valuesVStack
1263 1236
1264 def __len__(self): 1237 def __len__(self):
1265 return len(self.input_dataset) 1238 return len(self.input_dataset)
1266 1239
1267 def fieldNames(self): 1240 def fieldNames(self):
1268 return self.output_names 1241 return self.output_names
1269 1242
1270 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): 1243 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset):
1271 class ApplyFunctionIterator(object): 1244 class ApplyFunctionIterator(object):
1272 def __init__(self,output_dataset): 1245 def __init__(self,output_dataset):
1273 self.input_dataset=output_dataset.input_dataset 1246 self.input_dataset=output_dataset.input_dataset
1274 self.output_dataset=output_dataset 1247 self.output_dataset=output_dataset
1275 self.input_iterator=self.input_dataset.minibatches(minibatch_size=minibatch_size, 1248 self.input_iterator=self.input_dataset.minibatches(minibatch_size=minibatch_size,
1276 n_batches=n_batches,offset=offset).__iter__() 1249 n_batches=n_batches,offset=offset).__iter__()
1277 1250
1278 def __iter__(self): return self 1251 def __iter__(self): return self
1279 1252
1280 def next(self): 1253 def next(self):
1281 function_inputs = self.input_iterator.next() 1254 function_inputs = self.input_iterator.next()
1282 all_output_names = self.output_dataset.output_names 1255 all_output_names = self.output_dataset.output_names
1283 if self.output_dataset.minibatch_mode: 1256 if self.output_dataset.minibatch_mode:
1284 function_outputs = self.output_dataset.function(*function_inputs) 1257 function_outputs = self.output_dataset.function(*function_inputs)
1285 else: 1258 else:
1286 input_examples = zip(*function_inputs) 1259 input_examples = zip(*function_inputs)
1287 output_examples = [self.output_dataset.function(*input_example) 1260 output_examples = [self.output_dataset.function(*input_example)
1288 for input_example in input_examples] 1261 for input_example in input_examples]
1289 function_outputs = [self.output_dataset.valuesVStack(name,values) 1262 function_outputs = [self.output_dataset.valuesVStack(name,values)
1290 for name,values in zip(all_output_names, 1263 for name,values in zip(all_output_names,
1291 zip(*output_examples))] 1264 zip(*output_examples))]
1292 all_outputs = Example(all_output_names,function_outputs) 1265 all_outputs = Example(all_output_names,function_outputs)
1293 if fieldnames==all_output_names: 1266 if fieldnames==all_output_names:
1294 return all_outputs 1267 return all_outputs
1295 return Example(fieldnames,[all_outputs[name] for name in fieldnames]) 1268 return Example(fieldnames,[all_outputs[name] for name in fieldnames])
1296 1269
1297 1270
1298 return ApplyFunctionIterator(self) 1271 return ApplyFunctionIterator(self)
1299 1272
1300 def __iter__(self): # only implemented for increased efficiency 1273 def __iter__(self): # only implemented for increased efficiency
1301 class ApplyFunctionSingleExampleIterator(object): 1274 class ApplyFunctionSingleExampleIterator(object):
1302 def __init__(self,output_dataset): 1275 def __init__(self,output_dataset):
1303 self.current=0 1276 self.current=0
1304 self.output_dataset=output_dataset 1277 self.output_dataset=output_dataset
1305 self.input_iterator=output_dataset.input_dataset.__iter__() 1278 self.input_iterator=output_dataset.input_dataset.__iter__()
1306 def __iter__(self): return self 1279 def __iter__(self): return self
1307 def next(self): 1280 def next(self):
1308 if self.output_dataset.minibatch_mode: 1281 if self.output_dataset.minibatch_mode:
1309 function_inputs = [[input] for input in self.input_iterator.next()] 1282 function_inputs = [[input] for input in self.input_iterator.next()]
1310 outputs = self.output_dataset.function(*function_inputs) 1283 outputs = self.output_dataset.function(*function_inputs)
1311 assert all([hasattr(output,'__iter__') for output in outputs]) 1284 assert all([hasattr(output,'__iter__') for output in outputs])
1312 function_outputs = [output[0] for output in outputs] 1285 function_outputs = [output[0] for output in outputs]
1313 else: 1286 else:
1314 function_inputs = self.input_iterator.next() 1287 function_inputs = self.input_iterator.next()
1315 function_outputs = self.output_dataset.function(*function_inputs) 1288 function_outputs = self.output_dataset.function(*function_inputs)
1316 return Example(self.output_dataset.output_names,function_outputs) 1289 return Example(self.output_dataset.output_names,function_outputs)
1317 return ApplyFunctionSingleExampleIterator(self) 1290 return ApplyFunctionSingleExampleIterator(self)
1318 1291
1319 1292
1320 def supervised_learning_dataset(src_dataset,input_fields,target_fields,weight_field=None): 1293 def supervised_learning_dataset(src_dataset,input_fields,target_fields,weight_field=None):
1321 """ 1294 """
1322 Wraps an arbitrary L{DataSet} into one for supervised learning tasks 1295 Wraps an arbitrary L{DataSet} into one for supervised learning tasks
1323 by forcing the user to define a set of fields as the 'input' field 1296 by forcing the user to define a set of fields as the 'input' field