comparison dataset.py @ 266:6e69fb91f3c0

initial commit of amat
author James Bergstra <bergstrj@iro.umontreal.ca>
date Wed, 04 Jun 2008 17:49:09 -0400
parents 19b14afe04b7
children 3f1cd8897fda
comparison
equal deleted inserted replaced
263:5614b186c5f4 266:6e69fb91f3c0
107 107
108 - dataset[i] returns an Example. 108 - dataset[i] returns an Example.
109 109
110 - dataset[[i1,i2,...in]] returns a dataset with examples i1,i2,...in. 110 - dataset[[i1,i2,...in]] returns a dataset with examples i1,i2,...in.
111 111
112 - dataset[fieldname] an iterable over the values of the field fieldname across
113 the dataset (the iterable is obtained by default by calling valuesVStack
114 over the values for individual examples).
115
116 - dataset.<property> returns the value of a property associated with 112 - dataset.<property> returns the value of a property associated with
117 the name <property>. The following properties should be supported: 113 the name <property>. The following properties should be supported:
118 - 'description': a textual description or name for the dataset 114 - 'description': a textual description or name for the dataset
119 - 'fieldtypes': a list of types (one per field) 115 - 'fieldtypes': a list of types (one per field)
120 A DataSet may have other attributes that it makes visible to other objects. These are 116 A DataSet may have other attributes that it makes visible to other objects. These are
160 A sub-class should also append attributes to self._attribute_names 156 A sub-class should also append attributes to self._attribute_names
161 (the default value returned by attributeNames()). 157 (the default value returned by attributeNames()).
162 By convention, attributes not in attributeNames() should have a name 158 By convention, attributes not in attributeNames() should have a name
163 starting with an underscore. 159 starting with an underscore.
164 @todo enforce/test that convention! 160 @todo enforce/test that convention!
165 """ 161
166 162 """
167 numpy_vstack = lambda fieldname,values: numpy.vstack(values) 163
168 numpy_hstack = lambda fieldnames,values: numpy.hstack(values) 164 if 0:
165 # removed by James June 4... these aren't used anywhere according to
166 # grep
167 numpy_vstack = lambda fieldname,values: numpy.vstack(values)
168 numpy_hstack = lambda fieldnames,values: numpy.hstack(values)
169 169
170 def __init__(self,description=None,fieldtypes=None): 170 def __init__(self,description=None,fieldtypes=None):
171 if description is None: 171 if description is None:
172 # by default return "<DataSetType>(<SuperClass1>,<SuperClass2>,...)" 172 # by default return "<DataSetType>(<SuperClass1>,<SuperClass2>,...)"
173 description = type(self).__name__ + " ( " + join([x.__name__ for x in type(self).__bases__]) + " )" 173 description = type(self).__name__ + " ( " + join([x.__name__ for x in type(self).__bases__]) + " )"
275 else: 275 else:
276 # we must concatenate (vstack) the bottom and top parts of our minibatch 276 # we must concatenate (vstack) the bottom and top parts of our minibatch
277 # first get the beginning of our minibatch (top of dataset) 277 # first get the beginning of our minibatch (top of dataset)
278 first_part = self.dataset.minibatches_nowrap(self.fieldnames,self.L-self.next_row,1,self.next_row).next() 278 first_part = self.dataset.minibatches_nowrap(self.fieldnames,self.L-self.next_row,1,self.next_row).next()
279 second_part = self.dataset.minibatches_nowrap(self.fieldnames,upper-self.L,1,0).next() 279 second_part = self.dataset.minibatches_nowrap(self.fieldnames,upper-self.L,1,0).next()
280 minibatch = Example(self.fieldnames, 280
281 [self.dataset.valuesAppend(name,[first_part[name],second_part[name]]) 281 blah = [self.dataset.valuesVStack(name,[first_part[name],second_part[name]])
282 for name in self.fieldnames]) 282 for name in self.fieldnames]
283 print type(self.dataset), blah
284 minibatch = Example(self.fieldnames,blah)
283 self.next_row=upper 285 self.next_row=upper
284 self.n_batches_done+=1 286 self.n_batches_done+=1
285 if upper >= self.L and self.n_batches: 287 if upper >= self.L and self.n_batches:
286 self.next_row -= self.L 288 self.next_row -= self.L
287 ds_nbatches = (self.L-self.next_row)/self.minibatch_size 289 ds_nbatches = (self.L-self.next_row)/self.minibatch_size
458 return MinibatchDataSet( 460 return MinibatchDataSet(
459 Example(self.fieldNames(),[ self.valuesVStack(fieldname,field_values) 461 Example(self.fieldNames(),[ self.valuesVStack(fieldname,field_values)
460 for fieldname,field_values 462 for fieldname,field_values
461 in zip(self.fieldNames(),fields_values)]), 463 in zip(self.fieldNames(),fields_values)]),
462 self.valuesVStack,self.valuesHStack) 464 self.valuesVStack,self.valuesHStack)
463 # else check for a fieldname 465
464 if self.hasFields(i): 466 raise TypeError(i)
465 return self.minibatches(fieldnames=[i],minibatch_size=len(self),n_batches=1,offset=0).next()[0] 467 if 0:
466 # else we are trying to access a property of the dataset 468 # else check for a fieldname
467 assert i in self.__dict__ # else it means we are trying to access a non-existing property 469 #after talk with Yoshua June 4, this is disabled.
468 return self.__dict__[i] 470 if self.hasFields(i):
471 return self.minibatches(fieldnames=[i],minibatch_size=len(self),n_batches=1,offset=0).next()[0]
472 # else we are trying to access a property of the dataset
473 assert i in self.__dict__ # else it means we are trying to access a non-existing property
474 return self.__dict__[i]
469 475
470 def valuesHStack(self,fieldnames,fieldvalues): 476 def valuesHStack(self,fieldnames,fieldvalues):
471 """ 477 """
472 Return a value that corresponds to concatenating (horizontally) several field values. 478 Return a value that corresponds to concatenating (horizontally) several field values.
473 This can be useful to merge some fields. The implementation of this operation is likely 479 This can be useful to merge some fields. The implementation of this operation is likely
489 return fieldvalues 495 return fieldvalues
490 496
491 497
492 def valuesVStack(self,fieldname,values): 498 def valuesVStack(self,fieldname,values):
493 """ 499 """
494 Return a value that corresponds to concatenating (vertically) several values of the 500 @param fieldname: the name of the field from which the values were taken
495 same field. This can be important to build a minibatch out of individual examples. This 501 @type fieldname: any type
496 is likely to involve a copy of the original values. When the values are numpy arrays, the 502
497 result should be numpy.vstack(values). 503 @param values: bits near the beginning or end of the dataset
498 The default is to use numpy.vstack for numpy.ndarray values, and a list 504 @type values: list of minibatches (returned by minibatch_nowrap)
499 pointing to the original values for other data types. 505
500 """ 506 @return: the concatenation (stacking) of the values
501 all_numpy=True 507 @rtype: something suitable as a minibatch field
502 for value in values: 508
503 if not type(value) is numpy.ndarray: 509 """
504 all_numpy=False 510 rval = []
505 if all_numpy: 511 for sub_batch in values:
506 return numpy.vstack(values) 512 rval.extend(sub_batch)
507 # the default implementation of vertical stacking is to put values in a list 513 return rval
508 return values
509 514
510 def __or__(self,other): 515 def __or__(self,other):
511 """ 516 """
512 dataset1 | dataset2 returns a dataset whose list of fields is the concatenation of the list of 517 dataset1 | dataset2 returns a dataset whose list of fields is the concatenation of the list of
513 fields of the argument datasets. This only works if they all have the same length. 518 fields of the argument datasets. This only works if they all have the same length.
960 become a longer vector, two matrices become a wider matrix, etc.""" 965 become a longer vector, two matrices become a wider matrix, etc."""
961 return numpy.hstack(fieldvalues) 966 return numpy.hstack(fieldvalues)
962 def valuesVStack(self, fieldname, values): 967 def valuesVStack(self, fieldname, values):
963 """Concatenate field values vertically, e.g. two vectors 968 """Concatenate field values vertically, e.g. two vectors
964 become a two-row matrix, two matrices become a longer matrix, etc.""" 969 become a two-row matrix, two matrices become a longer matrix, etc."""
965 return numpy.vstack(values) 970 #print len(values)
966 def valuesAppend(self, fieldname, values): 971 for v in values:
972 if not isinstance(v, numpy.ndarray):
973 raise TypeError(v, type(v))
974
967 s0 = sum([v.shape[0] for v in values]) 975 s0 = sum([v.shape[0] for v in values])
968 #TODO: there's gotta be a better way to do this! 976 #TODO: there's gotta be a better way to do this!
969 rval = numpy.ndarray([s0] + values[0].shape[1:],dtype=values[0].dtype) 977 dtype = values[0].dtype
978 rval = numpy.ndarray([s0] + list(values[0].shape[1:]), dtype=dtype)
970 cur_row = 0 979 cur_row = 0
971 for v in values: 980 for v in values:
972 rval[cur_row:cur_row+v.shape[0]] = v 981 rval[cur_row:cur_row+v.shape[0]] = v
973 cur_row += v.shape[0] 982 cur_row += v.shape[0]
974 return rval 983 return rval
1026 """More efficient implementation than the default __getitem__""" 1035 """More efficient implementation than the default __getitem__"""
1027 fieldnames=self.fields_columns.keys() 1036 fieldnames=self.fields_columns.keys()
1028 values=self.fields_columns.values() 1037 values=self.fields_columns.values()
1029 if type(key) is int: 1038 if type(key) is int:
1030 return Example(fieldnames, 1039 return Example(fieldnames,
1031 [self.data[key,col] for col in values]) 1040 [numpy.asarray(self.data[key,col]) for col in values])
1032 if type(key) is slice: 1041 if type(key) is slice:
1033 return MinibatchDataSet(Example(fieldnames, 1042 return MinibatchDataSet(Example(fieldnames,
1034 [self.data[key,col] for col in values])) 1043 [self.data[key,col] for col in values]))
1035 if type(key) is list: 1044 if type(key) is list:
1036 for i in range(len(key)): 1045 for i in range(len(key)):
1104 1113
1105 return ArrayDataSetIterator(self,fieldnames,minibatch_size,n_batches,offset) 1114 return ArrayDataSetIterator(self,fieldnames,minibatch_size,n_batches,offset)
1106 1115
1107 1116
1108 class CachedDataSet(DataSet): 1117 class CachedDataSet(DataSet):
1109 """ 1118 """
1110 Wrap a L{DataSet} whose values are computationally expensive to obtain 1119 Wrap a L{DataSet} whose values are computationally expensive to obtain
1111 (e.g. because they involve some computation, or disk access), 1120 (e.g. because they involve some computation, or disk access),
1112 so that repeated accesses to the same example are done cheaply, 1121 so that repeated accesses to the same example are done cheaply,
1113 by caching every example value that has been accessed at least once. 1122 by caching every example value that has been accessed at least once.
1114 1123
1115 Optionally, for finite-length dataset, all the values can be computed 1124 Optionally, for finite-length dataset, all the values can be computed
1116 (and cached) upon construction of the CachedDataSet, rather at the 1125 (and cached) upon construction of the CachedDataSet, rather at the
1117 first access. 1126 first access.
1118 1127
1119 @todo: when cache_all_upon_construction create mini-batches that are as 1128 @todo: when cache_all_upon_construction create mini-batches that are as
1120 large as possible but not so large as to fill up memory. 1129 large as possible but not so large as to fill up memory.
1121 1130
1122 @todo: add disk-buffering capability, so that when the cache becomes too 1131 @todo: add disk-buffering capability, so that when the cache becomes too
1123 big for memory, we cache things on disk, trying to keep in memory only 1132 big for memory, we cache things on disk, trying to keep in memory only
1124 the record most likely to be accessed next. 1133 the record most likely to be accessed next.
1125 """ 1134 """
1126 def __init__(self,source_dataset,cache_all_upon_construction=False): 1135 def __init__(self,source_dataset,cache_all_upon_construction=False):
1127 self.source_dataset=source_dataset 1136 self.source_dataset=source_dataset
1128 self.cache_all_upon_construction=cache_all_upon_construction 1137 self.cache_all_upon_construction=cache_all_upon_construction
1129 self.cached_examples = [] 1138 self.cached_examples = [] #a list of LookupList (copies)
1130 if cache_all_upon_construction: 1139 if cache_all_upon_construction:
1131 # this potentially brings all the source examples 1140 # this potentially brings all the source examples
1132 # into memory at once, which may be too much 1141 # into memory at once, which may be too much
1133 # the work could possibly be done by minibatches 1142 # the work could possibly be done by minibatches
1134 # that are as large as possible but no more than what memory allows. 1143 # that are as large as possible but no more than what memory allows.
1135 fields_values = source_dataset.minibatches(minibatch_size=len(source_dataset)).__iter__().next() 1144 fields_values = source_dataset.minibatches(minibatch_size=len(source_dataset)).__iter__().next()
1136 assert all([len(self)==len(field_values) for field_values in fields_values]) 1145 assert all([len(self)==len(fval) for fval in fields_values])
1137 for example in fields_values.examples(): 1146 for example in fields_values.examples():
1138 self.cached_examples.append(copy.copy(example)) 1147 dup = copy.copy(example)
1139 1148 self.cached_examples.append(dup)
1140 self.fieldNames = source_dataset.fieldNames 1149
1141 self.hasFields = source_dataset.hasFields 1150 self.fieldNames = source_dataset.fieldNames
1142 self.valuesHStack = source_dataset.valuesHStack 1151 self.hasFields = source_dataset.hasFields
1143 self.valuesVStack = source_dataset.valuesVStack 1152 self.valuesHStack = source_dataset.valuesHStack
1153 self.valuesVStack = source_dataset.valuesVStack
1144 1154
1145 def __len__(self): 1155 def __len__(self):
1146 return len(self.source_dataset) 1156 return len(self.source_dataset)
1147 1157
1148 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): 1158 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset):
1149 class CacheIterator(object): 1159 class CacheIterator(object):
1150 def __init__(self,dataset): 1160 def __init__(self,dataset):
1151 self.dataset=dataset 1161 self.dataset=dataset
1152 self.current=offset 1162 self.current=offset
1153 self.all_fields = self.dataset.fieldNames()==fieldnames 1163 self.all_fields = self.dataset.fieldNames()==fieldnames
1154 def __iter__(self): return self 1164 def __iter__(self): return self
1155 def next(self): 1165 def next(self):
1156 upper = self.current+minibatch_size 1166 upper = self.current+minibatch_size
1157 cache_len = len(self.dataset.cached_examples) 1167 cache_len = len(self.dataset.cached_examples)
1158 if upper>cache_len: # whole minibatch is not already in cache 1168 if upper>cache_len:
1159 # cache everything from current length to upper 1169 # whole minibatch is not already in cache
1160 for example in self.dataset.source_dataset[cache_len:upper]: 1170 # cache everything from current length to upper
1161 self.dataset.cached_examples.append(example) 1171 for example in self.dataset.source_dataset[cache_len:upper]:
1162 all_fields_minibatch = Example(self.dataset.fieldNames(), 1172 self.dataset.cached_examples.append(example)
1163 zip(*self.dataset.cached_examples[self.current:self.current+minibatch_size])) 1173
1164 self.current+=minibatch_size 1174 next_range = slice(self.current, self.current+minibatch_size)
1165 if self.all_fields: 1175 blah = self.dataset.cached_examples[next_range]
1166 return all_fields_minibatch 1176 all_fields_minibatch = Example(self.dataset.fieldNames(), zip(*blah))
1167 return Example(fieldnames,[all_fields_minibatch[name] for name in fieldnames]) 1177 self.current+=minibatch_size
1168 return CacheIterator(self) 1178
1169 1179 #little optimization to avoid second Example computation if
1170 def __getitem__(self,i): 1180 #possible.
1171 if type(i)==int and len(self.cached_examples)>i: 1181 if self.all_fields:
1172 return self.cached_examples[i] 1182 return all_fields_minibatch
1173 else: 1183
1174 return self.source_dataset[i] 1184 rval = Example(fieldnames,[all_fields_minibatch[name] for name in fieldnames])
1175 1185 return rval
1176 def __iter__(self): 1186 return CacheIterator(self)
1177 class CacheIteratorIter(object): 1187
1178 def __init__(self,dataset): 1188 def __getitem__(self,i):
1179 self.dataset=dataset 1189 if type(i)==int and len(self.cached_examples)>i:
1180 self.l = len(dataset) 1190 return self.cached_examples[i]
1181 self.current = 0 1191 else:
1182 self.fieldnames = self.dataset.fieldNames() 1192 return self.source_dataset[i]
1183 self.example = LookupList(self.fieldnames,[0]*len(self.fieldnames)) 1193
1184 def __iter__(self): return self 1194 def __iter__(self):
1185 def next(self): 1195 class CacheIteratorIter(object):
1186 if self.current>=self.l: 1196 def __init__(self,dataset):
1187 raise StopIteration 1197 self.dataset=dataset
1188 cache_len = len(self.dataset.cached_examples) 1198 self.l = len(dataset)
1189 if self.current>=cache_len: # whole minibatch is not already in cache 1199 self.current = 0
1190 # cache everything from current length to upper 1200 self.fieldnames = self.dataset.fieldNames()
1191 self.dataset.cached_examples.append( 1201 self.example = LookupList(self.fieldnames,[0]*len(self.fieldnames))
1192 self.dataset.source_dataset[self.current]) 1202 def __iter__(self): return self
1193 self.example._values = self.dataset.cached_examples[self.current] 1203 def next(self):
1194 self.current+=1 1204 if self.current>=self.l:
1195 return self.example 1205 raise StopIteration
1196 1206 cache_len = len(self.dataset.cached_examples)
1197 return CacheIteratorIter(self) 1207 if self.current>=cache_len: # whole minibatch is not already in cache
1208 # cache everything from current length to upper
1209 self.dataset.cached_examples.append(
1210 self.dataset.source_dataset[self.current])
1211 self.example._values = self.dataset.cached_examples[self.current]
1212 self.current+=1
1213 return self.example
1214
1215 return CacheIteratorIter(self)
1198 1216
1199 class ApplyFunctionDataSet(DataSet): 1217 class ApplyFunctionDataSet(DataSet):
1200 """ 1218 """
1201 A L{DataSet} that contains as fields the results of applying a 1219 A L{DataSet} that contains as fields the results of applying a
1202 given function example-wise or minibatch-wise to all the fields of 1220 given function example-wise or minibatch-wise to all the fields of
1203 an input dataset. The output of the function should be an iterable 1221 an input dataset. The output of the function should be an iterable
1204 (e.g. a list or a LookupList) over the resulting values. 1222 (e.g. a list or a LookupList) over the resulting values.
1205 1223
1206 The function take as input the fields of the dataset, not the examples. 1224 The function take as input the fields of the dataset, not the examples.
1207 1225
1208 In minibatch mode, the function is expected to work on minibatches 1226 In minibatch mode, the function is expected to work on minibatches
1209 (takes a minibatch in input and returns a minibatch in output). More 1227 (takes a minibatch in input and returns a minibatch in output). More
1210 precisely, it means that each element of the input or output list 1228 precisely, it means that each element of the input or output list
1211 should be iterable and indexable over the individual example values 1229 should be iterable and indexable over the individual example values
1212 (typically these elements will be numpy arrays). All of the elements 1230 (typically these elements will be numpy arrays). All of the elements
1213 in the input and output lists should have the same length, which is 1231 in the input and output lists should have the same length, which is
1214 the length of the minibatch. 1232 the length of the minibatch.
1215 1233
1216 The function is applied each time an example or a minibatch is accessed. 1234 The function is applied each time an example or a minibatch is accessed.
1217 To avoid re-doing computation, wrap this dataset inside a CachedDataSet. 1235 To avoid re-doing computation, wrap this dataset inside a CachedDataSet.
1218 1236
1219 If the values_{h,v}stack functions are not provided, then 1237 If the values_{h,v}stack functions are not provided, then
1220 the input_dataset.values{H,V}Stack functions are used by default. 1238 the input_dataset.values{H,V}Stack functions are used by default.
1221 """ 1239 """
1222 def __init__(self,input_dataset,function,output_names,minibatch_mode=True, 1240 def __init__(self,input_dataset,function,output_names,minibatch_mode=True,
1223 values_hstack=None,values_vstack=None, 1241 values_hstack=None,values_vstack=None,
1224 description=None,fieldtypes=None): 1242 description=None,fieldtypes=None):
1225 """ 1243 """
1226 Constructor takes an input dataset that has as many fields as the function 1244 Constructor takes an input dataset that has as many fields as the function
1227 expects as inputs. The resulting dataset has as many fields as the function 1245 expects as inputs. The resulting dataset has as many fields as the function
1228 produces as outputs, and that should correspond to the number of output names 1246 produces as outputs, and that should correspond to the number of output names
1229 (provided in a list). 1247 (provided in a list).
1230 1248
1231 Note that the expected semantics of the function differs in minibatch mode 1249 Note that the expected semantics of the function differs in minibatch mode
1232 (it takes minibatches of inputs and produces minibatches of outputs, as 1250 (it takes minibatches of inputs and produces minibatches of outputs, as
1233 documented in the class comment). 1251 documented in the class comment).
1234 1252
1235 TBM: are filedtypes the old field types (from input_dataset) or the new ones 1253 TBM: are filedtypes the old field types (from input_dataset) or the new ones
1236 (for the new dataset created)? 1254 (for the new dataset created)?
1237 """ 1255 """
1238 self.input_dataset=input_dataset 1256 self.input_dataset=input_dataset
1239 self.function=function 1257 self.function=function
1240 self.output_names=output_names 1258 self.output_names=output_names
1241 self.minibatch_mode=minibatch_mode 1259 self.minibatch_mode=minibatch_mode
1242 DataSet.__init__(self,description,fieldtypes) 1260 DataSet.__init__(self,description,fieldtypes)
1243 self.valuesHStack = values_hstack if values_hstack else input_dataset.valuesHStack 1261 self.valuesHStack = values_hstack if values_hstack else input_dataset.valuesHStack
1244 self.valuesVStack = values_vstack if values_vstack else input_dataset.valuesVStack 1262 self.valuesVStack = values_vstack if values_vstack else input_dataset.valuesVStack
1245 1263
1246 def __len__(self): 1264 def __len__(self):
1247 return len(self.input_dataset) 1265 return len(self.input_dataset)
1248 1266
1249 def fieldNames(self): 1267 def fieldNames(self):
1250 return self.output_names 1268 return self.output_names
1251 1269
1252 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): 1270 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset):
1253 class ApplyFunctionIterator(object): 1271 class ApplyFunctionIterator(object):
1254 def __init__(self,output_dataset): 1272 def __init__(self,output_dataset):
1255 self.input_dataset=output_dataset.input_dataset 1273 self.input_dataset=output_dataset.input_dataset
1256 self.output_dataset=output_dataset 1274 self.output_dataset=output_dataset
1257 self.input_iterator=self.input_dataset.minibatches(minibatch_size=minibatch_size, 1275 self.input_iterator=self.input_dataset.minibatches(minibatch_size=minibatch_size,
1258 n_batches=n_batches,offset=offset).__iter__() 1276 n_batches=n_batches,offset=offset).__iter__()
1259 1277
1260 def __iter__(self): return self 1278 def __iter__(self): return self
1261 1279
1262 def next(self): 1280 def next(self):
1263 function_inputs = self.input_iterator.next() 1281 function_inputs = self.input_iterator.next()
1264 all_output_names = self.output_dataset.output_names 1282 all_output_names = self.output_dataset.output_names
1265 if self.output_dataset.minibatch_mode: 1283 if self.output_dataset.minibatch_mode:
1266 function_outputs = self.output_dataset.function(*function_inputs) 1284 function_outputs = self.output_dataset.function(*function_inputs)
1267 else: 1285 else:
1268 input_examples = zip(*function_inputs) 1286 input_examples = zip(*function_inputs)
1269 output_examples = [self.output_dataset.function(*input_example) 1287 output_examples = [self.output_dataset.function(*input_example)
1270 for input_example in input_examples] 1288 for input_example in input_examples]
1271 function_outputs = [self.output_dataset.valuesVStack(name,values) 1289 function_outputs = [self.output_dataset.valuesVStack(name,values)
1272 for name,values in zip(all_output_names, 1290 for name,values in zip(all_output_names,
1273 zip(*output_examples))] 1291 zip(*output_examples))]
1274 all_outputs = Example(all_output_names,function_outputs) 1292 all_outputs = Example(all_output_names,function_outputs)
1275 if fieldnames==all_output_names: 1293 if fieldnames==all_output_names:
1276 return all_outputs 1294 return all_outputs
1277 return Example(fieldnames,[all_outputs[name] for name in fieldnames]) 1295 return Example(fieldnames,[all_outputs[name] for name in fieldnames])
1278 1296
1279 1297
1280 return ApplyFunctionIterator(self) 1298 return ApplyFunctionIterator(self)
1281 1299
1282 def __iter__(self): # only implemented for increased efficiency 1300 def __iter__(self): # only implemented for increased efficiency
1283 class ApplyFunctionSingleExampleIterator(object): 1301 class ApplyFunctionSingleExampleIterator(object):
1284 def __init__(self,output_dataset): 1302 def __init__(self,output_dataset):
1285 self.current=0 1303 self.current=0
1286 self.output_dataset=output_dataset 1304 self.output_dataset=output_dataset
1287 self.input_iterator=output_dataset.input_dataset.__iter__() 1305 self.input_iterator=output_dataset.input_dataset.__iter__()
1288 def __iter__(self): return self 1306 def __iter__(self): return self
1289 def next(self): 1307 def next(self):
1290 if self.output_dataset.minibatch_mode: 1308 if self.output_dataset.minibatch_mode:
1291 function_inputs = [[input] for input in self.input_iterator.next()] 1309 function_inputs = [[input] for input in self.input_iterator.next()]
1292 outputs = self.output_dataset.function(*function_inputs) 1310 outputs = self.output_dataset.function(*function_inputs)
1293 assert all([hasattr(output,'__iter__') for output in outputs]) 1311 assert all([hasattr(output,'__iter__') for output in outputs])
1294 function_outputs = [output[0] for output in outputs] 1312 function_outputs = [output[0] for output in outputs]
1295 else: 1313 else:
1296 function_inputs = self.input_iterator.next() 1314 function_inputs = self.input_iterator.next()
1297 function_outputs = self.output_dataset.function(*function_inputs) 1315 function_outputs = self.output_dataset.function(*function_inputs)
1298 return Example(self.output_dataset.output_names,function_outputs) 1316 return Example(self.output_dataset.output_names,function_outputs)
1299 return ApplyFunctionSingleExampleIterator(self) 1317 return ApplyFunctionSingleExampleIterator(self)
1300 1318
1301 1319
1302 def supervised_learning_dataset(src_dataset,input_fields,target_fields,weight_field=None): 1320 def supervised_learning_dataset(src_dataset,input_fields,target_fields,weight_field=None):
1303 """ 1321 """
1304 Wraps an arbitrary L{DataSet} into one for supervised learning tasks 1322 Wraps an arbitrary L{DataSet} into one for supervised learning tasks
1305 by forcing the user to define a set of fields as the 'input' field 1323 by forcing the user to define a set of fields as the 'input' field