Mercurial > pylearn
comparison dataset.py @ 266:6e69fb91f3c0
initial commit of amat
author | James Bergstra <bergstrj@iro.umontreal.ca> |
---|---|
date | Wed, 04 Jun 2008 17:49:09 -0400 |
parents | 19b14afe04b7 |
children | 3f1cd8897fda |
comparison
equal
deleted
inserted
replaced
263:5614b186c5f4 | 266:6e69fb91f3c0 |
---|---|
107 | 107 |
108 - dataset[i] returns an Example. | 108 - dataset[i] returns an Example. |
109 | 109 |
110 - dataset[[i1,i2,...in]] returns a dataset with examples i1,i2,...in. | 110 - dataset[[i1,i2,...in]] returns a dataset with examples i1,i2,...in. |
111 | 111 |
112 - dataset[fieldname] an iterable over the values of the field fieldname across | |
113 the dataset (the iterable is obtained by default by calling valuesVStack | |
114 over the values for individual examples). | |
115 | |
116 - dataset.<property> returns the value of a property associated with | 112 - dataset.<property> returns the value of a property associated with |
117 the name <property>. The following properties should be supported: | 113 the name <property>. The following properties should be supported: |
118 - 'description': a textual description or name for the dataset | 114 - 'description': a textual description or name for the dataset |
119 - 'fieldtypes': a list of types (one per field) | 115 - 'fieldtypes': a list of types (one per field) |
120 A DataSet may have other attributes that it makes visible to other objects. These are | 116 A DataSet may have other attributes that it makes visible to other objects. These are |
160 A sub-class should also append attributes to self._attribute_names | 156 A sub-class should also append attributes to self._attribute_names |
161 (the default value returned by attributeNames()). | 157 (the default value returned by attributeNames()). |
162 By convention, attributes not in attributeNames() should have a name | 158 By convention, attributes not in attributeNames() should have a name |
163 starting with an underscore. | 159 starting with an underscore. |
164 @todo enforce/test that convention! | 160 @todo enforce/test that convention! |
165 """ | 161 |
166 | 162 """ |
167 numpy_vstack = lambda fieldname,values: numpy.vstack(values) | 163 |
168 numpy_hstack = lambda fieldnames,values: numpy.hstack(values) | 164 if 0: |
165 # removed by James June 4... these aren't used anywhere according to | |
166 # grep | |
167 numpy_vstack = lambda fieldname,values: numpy.vstack(values) | |
168 numpy_hstack = lambda fieldnames,values: numpy.hstack(values) | |
169 | 169 |
170 def __init__(self,description=None,fieldtypes=None): | 170 def __init__(self,description=None,fieldtypes=None): |
171 if description is None: | 171 if description is None: |
172 # by default return "<DataSetType>(<SuperClass1>,<SuperClass2>,...)" | 172 # by default return "<DataSetType>(<SuperClass1>,<SuperClass2>,...)" |
173 description = type(self).__name__ + " ( " + join([x.__name__ for x in type(self).__bases__]) + " )" | 173 description = type(self).__name__ + " ( " + join([x.__name__ for x in type(self).__bases__]) + " )" |
275 else: | 275 else: |
276 # we must concatenate (vstack) the bottom and top parts of our minibatch | 276 # we must concatenate (vstack) the bottom and top parts of our minibatch |
277 # first get the beginning of our minibatch (top of dataset) | 277 # first get the beginning of our minibatch (top of dataset) |
278 first_part = self.dataset.minibatches_nowrap(self.fieldnames,self.L-self.next_row,1,self.next_row).next() | 278 first_part = self.dataset.minibatches_nowrap(self.fieldnames,self.L-self.next_row,1,self.next_row).next() |
279 second_part = self.dataset.minibatches_nowrap(self.fieldnames,upper-self.L,1,0).next() | 279 second_part = self.dataset.minibatches_nowrap(self.fieldnames,upper-self.L,1,0).next() |
280 minibatch = Example(self.fieldnames, | 280 |
281 [self.dataset.valuesAppend(name,[first_part[name],second_part[name]]) | 281 blah = [self.dataset.valuesVStack(name,[first_part[name],second_part[name]]) |
282 for name in self.fieldnames]) | 282 for name in self.fieldnames] |
283 print type(self.dataset), blah | |
284 minibatch = Example(self.fieldnames,blah) | |
283 self.next_row=upper | 285 self.next_row=upper |
284 self.n_batches_done+=1 | 286 self.n_batches_done+=1 |
285 if upper >= self.L and self.n_batches: | 287 if upper >= self.L and self.n_batches: |
286 self.next_row -= self.L | 288 self.next_row -= self.L |
287 ds_nbatches = (self.L-self.next_row)/self.minibatch_size | 289 ds_nbatches = (self.L-self.next_row)/self.minibatch_size |
458 return MinibatchDataSet( | 460 return MinibatchDataSet( |
459 Example(self.fieldNames(),[ self.valuesVStack(fieldname,field_values) | 461 Example(self.fieldNames(),[ self.valuesVStack(fieldname,field_values) |
460 for fieldname,field_values | 462 for fieldname,field_values |
461 in zip(self.fieldNames(),fields_values)]), | 463 in zip(self.fieldNames(),fields_values)]), |
462 self.valuesVStack,self.valuesHStack) | 464 self.valuesVStack,self.valuesHStack) |
463 # else check for a fieldname | 465 |
464 if self.hasFields(i): | 466 raise TypeError(i) |
465 return self.minibatches(fieldnames=[i],minibatch_size=len(self),n_batches=1,offset=0).next()[0] | 467 if 0: |
466 # else we are trying to access a property of the dataset | 468 # else check for a fieldname |
467 assert i in self.__dict__ # else it means we are trying to access a non-existing property | 469 #after talk with Yoshua June 4, this is disabled. |
468 return self.__dict__[i] | 470 if self.hasFields(i): |
471 return self.minibatches(fieldnames=[i],minibatch_size=len(self),n_batches=1,offset=0).next()[0] | |
472 # else we are trying to access a property of the dataset | |
473 assert i in self.__dict__ # else it means we are trying to access a non-existing property | |
474 return self.__dict__[i] | |
469 | 475 |
470 def valuesHStack(self,fieldnames,fieldvalues): | 476 def valuesHStack(self,fieldnames,fieldvalues): |
471 """ | 477 """ |
472 Return a value that corresponds to concatenating (horizontally) several field values. | 478 Return a value that corresponds to concatenating (horizontally) several field values. |
473 This can be useful to merge some fields. The implementation of this operation is likely | 479 This can be useful to merge some fields. The implementation of this operation is likely |
489 return fieldvalues | 495 return fieldvalues |
490 | 496 |
491 | 497 |
492 def valuesVStack(self,fieldname,values): | 498 def valuesVStack(self,fieldname,values): |
493 """ | 499 """ |
494 Return a value that corresponds to concatenating (vertically) several values of the | 500 @param fieldname: the name of the field from which the values were taken |
495 same field. This can be important to build a minibatch out of individual examples. This | 501 @type fieldname: any type |
496 is likely to involve a copy of the original values. When the values are numpy arrays, the | 502 |
497 result should be numpy.vstack(values). | 503 @param values: bits near the beginning or end of the dataset |
498 The default is to use numpy.vstack for numpy.ndarray values, and a list | 504 @type values: list of minibatches (returned by minibatch_nowrap) |
499 pointing to the original values for other data types. | 505 |
500 """ | 506 @return: the concatenation (stacking) of the values |
501 all_numpy=True | 507 @rtype: something suitable as a minibatch field |
502 for value in values: | 508 |
503 if not type(value) is numpy.ndarray: | 509 """ |
504 all_numpy=False | 510 rval = [] |
505 if all_numpy: | 511 for sub_batch in values: |
506 return numpy.vstack(values) | 512 rval.extend(sub_batch) |
507 # the default implementation of vertical stacking is to put values in a list | 513 return rval |
508 return values | |
509 | 514 |
510 def __or__(self,other): | 515 def __or__(self,other): |
511 """ | 516 """ |
512 dataset1 | dataset2 returns a dataset whose list of fields is the concatenation of the list of | 517 dataset1 | dataset2 returns a dataset whose list of fields is the concatenation of the list of |
513 fields of the argument datasets. This only works if they all have the same length. | 518 fields of the argument datasets. This only works if they all have the same length. |
960 become a longer vector, two matrices become a wider matrix, etc.""" | 965 become a longer vector, two matrices become a wider matrix, etc.""" |
961 return numpy.hstack(fieldvalues) | 966 return numpy.hstack(fieldvalues) |
962 def valuesVStack(self, fieldname, values): | 967 def valuesVStack(self, fieldname, values): |
963 """Concatenate field values vertically, e.g. two vectors | 968 """Concatenate field values vertically, e.g. two vectors |
964 become a two-row matrix, two matrices become a longer matrix, etc.""" | 969 become a two-row matrix, two matrices become a longer matrix, etc.""" |
965 return numpy.vstack(values) | 970 #print len(values) |
966 def valuesAppend(self, fieldname, values): | 971 for v in values: |
972 if not isinstance(v, numpy.ndarray): | |
973 raise TypeError(v, type(v)) | |
974 | |
967 s0 = sum([v.shape[0] for v in values]) | 975 s0 = sum([v.shape[0] for v in values]) |
968 #TODO: there's gotta be a better way to do this! | 976 #TODO: there's gotta be a better way to do this! |
969 rval = numpy.ndarray([s0] + values[0].shape[1:],dtype=values[0].dtype) | 977 dtype = values[0].dtype |
978 rval = numpy.ndarray([s0] + list(values[0].shape[1:]), dtype=dtype) | |
970 cur_row = 0 | 979 cur_row = 0 |
971 for v in values: | 980 for v in values: |
972 rval[cur_row:cur_row+v.shape[0]] = v | 981 rval[cur_row:cur_row+v.shape[0]] = v |
973 cur_row += v.shape[0] | 982 cur_row += v.shape[0] |
974 return rval | 983 return rval |
1026 """More efficient implementation than the default __getitem__""" | 1035 """More efficient implementation than the default __getitem__""" |
1027 fieldnames=self.fields_columns.keys() | 1036 fieldnames=self.fields_columns.keys() |
1028 values=self.fields_columns.values() | 1037 values=self.fields_columns.values() |
1029 if type(key) is int: | 1038 if type(key) is int: |
1030 return Example(fieldnames, | 1039 return Example(fieldnames, |
1031 [self.data[key,col] for col in values]) | 1040 [numpy.asarray(self.data[key,col]) for col in values]) |
1032 if type(key) is slice: | 1041 if type(key) is slice: |
1033 return MinibatchDataSet(Example(fieldnames, | 1042 return MinibatchDataSet(Example(fieldnames, |
1034 [self.data[key,col] for col in values])) | 1043 [self.data[key,col] for col in values])) |
1035 if type(key) is list: | 1044 if type(key) is list: |
1036 for i in range(len(key)): | 1045 for i in range(len(key)): |
1104 | 1113 |
1105 return ArrayDataSetIterator(self,fieldnames,minibatch_size,n_batches,offset) | 1114 return ArrayDataSetIterator(self,fieldnames,minibatch_size,n_batches,offset) |
1106 | 1115 |
1107 | 1116 |
1108 class CachedDataSet(DataSet): | 1117 class CachedDataSet(DataSet): |
1109 """ | 1118 """ |
1110 Wrap a L{DataSet} whose values are computationally expensive to obtain | 1119 Wrap a L{DataSet} whose values are computationally expensive to obtain |
1111 (e.g. because they involve some computation, or disk access), | 1120 (e.g. because they involve some computation, or disk access), |
1112 so that repeated accesses to the same example are done cheaply, | 1121 so that repeated accesses to the same example are done cheaply, |
1113 by caching every example value that has been accessed at least once. | 1122 by caching every example value that has been accessed at least once. |
1114 | 1123 |
1115 Optionally, for finite-length dataset, all the values can be computed | 1124 Optionally, for finite-length dataset, all the values can be computed |
1116 (and cached) upon construction of the CachedDataSet, rather at the | 1125 (and cached) upon construction of the CachedDataSet, rather at the |
1117 first access. | 1126 first access. |
1118 | 1127 |
1119 @todo: when cache_all_upon_construction create mini-batches that are as | 1128 @todo: when cache_all_upon_construction create mini-batches that are as |
1120 large as possible but not so large as to fill up memory. | 1129 large as possible but not so large as to fill up memory. |
1121 | 1130 |
1122 @todo: add disk-buffering capability, so that when the cache becomes too | 1131 @todo: add disk-buffering capability, so that when the cache becomes too |
1123 big for memory, we cache things on disk, trying to keep in memory only | 1132 big for memory, we cache things on disk, trying to keep in memory only |
1124 the record most likely to be accessed next. | 1133 the record most likely to be accessed next. |
1125 """ | 1134 """ |
1126 def __init__(self,source_dataset,cache_all_upon_construction=False): | 1135 def __init__(self,source_dataset,cache_all_upon_construction=False): |
1127 self.source_dataset=source_dataset | 1136 self.source_dataset=source_dataset |
1128 self.cache_all_upon_construction=cache_all_upon_construction | 1137 self.cache_all_upon_construction=cache_all_upon_construction |
1129 self.cached_examples = [] | 1138 self.cached_examples = [] #a list of LookupList (copies) |
1130 if cache_all_upon_construction: | 1139 if cache_all_upon_construction: |
1131 # this potentially brings all the source examples | 1140 # this potentially brings all the source examples |
1132 # into memory at once, which may be too much | 1141 # into memory at once, which may be too much |
1133 # the work could possibly be done by minibatches | 1142 # the work could possibly be done by minibatches |
1134 # that are as large as possible but no more than what memory allows. | 1143 # that are as large as possible but no more than what memory allows. |
1135 fields_values = source_dataset.minibatches(minibatch_size=len(source_dataset)).__iter__().next() | 1144 fields_values = source_dataset.minibatches(minibatch_size=len(source_dataset)).__iter__().next() |
1136 assert all([len(self)==len(field_values) for field_values in fields_values]) | 1145 assert all([len(self)==len(fval) for fval in fields_values]) |
1137 for example in fields_values.examples(): | 1146 for example in fields_values.examples(): |
1138 self.cached_examples.append(copy.copy(example)) | 1147 dup = copy.copy(example) |
1139 | 1148 self.cached_examples.append(dup) |
1140 self.fieldNames = source_dataset.fieldNames | 1149 |
1141 self.hasFields = source_dataset.hasFields | 1150 self.fieldNames = source_dataset.fieldNames |
1142 self.valuesHStack = source_dataset.valuesHStack | 1151 self.hasFields = source_dataset.hasFields |
1143 self.valuesVStack = source_dataset.valuesVStack | 1152 self.valuesHStack = source_dataset.valuesHStack |
1153 self.valuesVStack = source_dataset.valuesVStack | |
1144 | 1154 |
1145 def __len__(self): | 1155 def __len__(self): |
1146 return len(self.source_dataset) | 1156 return len(self.source_dataset) |
1147 | 1157 |
1148 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): | 1158 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): |
1149 class CacheIterator(object): | 1159 class CacheIterator(object): |
1150 def __init__(self,dataset): | 1160 def __init__(self,dataset): |
1151 self.dataset=dataset | 1161 self.dataset=dataset |
1152 self.current=offset | 1162 self.current=offset |
1153 self.all_fields = self.dataset.fieldNames()==fieldnames | 1163 self.all_fields = self.dataset.fieldNames()==fieldnames |
1154 def __iter__(self): return self | 1164 def __iter__(self): return self |
1155 def next(self): | 1165 def next(self): |
1156 upper = self.current+minibatch_size | 1166 upper = self.current+minibatch_size |
1157 cache_len = len(self.dataset.cached_examples) | 1167 cache_len = len(self.dataset.cached_examples) |
1158 if upper>cache_len: # whole minibatch is not already in cache | 1168 if upper>cache_len: |
1159 # cache everything from current length to upper | 1169 # whole minibatch is not already in cache |
1160 for example in self.dataset.source_dataset[cache_len:upper]: | 1170 # cache everything from current length to upper |
1161 self.dataset.cached_examples.append(example) | 1171 for example in self.dataset.source_dataset[cache_len:upper]: |
1162 all_fields_minibatch = Example(self.dataset.fieldNames(), | 1172 self.dataset.cached_examples.append(example) |
1163 zip(*self.dataset.cached_examples[self.current:self.current+minibatch_size])) | 1173 |
1164 self.current+=minibatch_size | 1174 next_range = slice(self.current, self.current+minibatch_size) |
1165 if self.all_fields: | 1175 blah = self.dataset.cached_examples[next_range] |
1166 return all_fields_minibatch | 1176 all_fields_minibatch = Example(self.dataset.fieldNames(), zip(*blah)) |
1167 return Example(fieldnames,[all_fields_minibatch[name] for name in fieldnames]) | 1177 self.current+=minibatch_size |
1168 return CacheIterator(self) | 1178 |
1169 | 1179 #little optimization to avoid second Example computation if |
1170 def __getitem__(self,i): | 1180 #possible. |
1171 if type(i)==int and len(self.cached_examples)>i: | 1181 if self.all_fields: |
1172 return self.cached_examples[i] | 1182 return all_fields_minibatch |
1173 else: | 1183 |
1174 return self.source_dataset[i] | 1184 rval = Example(fieldnames,[all_fields_minibatch[name] for name in fieldnames]) |
1175 | 1185 return rval |
1176 def __iter__(self): | 1186 return CacheIterator(self) |
1177 class CacheIteratorIter(object): | 1187 |
1178 def __init__(self,dataset): | 1188 def __getitem__(self,i): |
1179 self.dataset=dataset | 1189 if type(i)==int and len(self.cached_examples)>i: |
1180 self.l = len(dataset) | 1190 return self.cached_examples[i] |
1181 self.current = 0 | 1191 else: |
1182 self.fieldnames = self.dataset.fieldNames() | 1192 return self.source_dataset[i] |
1183 self.example = LookupList(self.fieldnames,[0]*len(self.fieldnames)) | 1193 |
1184 def __iter__(self): return self | 1194 def __iter__(self): |
1185 def next(self): | 1195 class CacheIteratorIter(object): |
1186 if self.current>=self.l: | 1196 def __init__(self,dataset): |
1187 raise StopIteration | 1197 self.dataset=dataset |
1188 cache_len = len(self.dataset.cached_examples) | 1198 self.l = len(dataset) |
1189 if self.current>=cache_len: # whole minibatch is not already in cache | 1199 self.current = 0 |
1190 # cache everything from current length to upper | 1200 self.fieldnames = self.dataset.fieldNames() |
1191 self.dataset.cached_examples.append( | 1201 self.example = LookupList(self.fieldnames,[0]*len(self.fieldnames)) |
1192 self.dataset.source_dataset[self.current]) | 1202 def __iter__(self): return self |
1193 self.example._values = self.dataset.cached_examples[self.current] | 1203 def next(self): |
1194 self.current+=1 | 1204 if self.current>=self.l: |
1195 return self.example | 1205 raise StopIteration |
1196 | 1206 cache_len = len(self.dataset.cached_examples) |
1197 return CacheIteratorIter(self) | 1207 if self.current>=cache_len: # whole minibatch is not already in cache |
1208 # cache everything from current length to upper | |
1209 self.dataset.cached_examples.append( | |
1210 self.dataset.source_dataset[self.current]) | |
1211 self.example._values = self.dataset.cached_examples[self.current] | |
1212 self.current+=1 | |
1213 return self.example | |
1214 | |
1215 return CacheIteratorIter(self) | |
1198 | 1216 |
1199 class ApplyFunctionDataSet(DataSet): | 1217 class ApplyFunctionDataSet(DataSet): |
1200 """ | 1218 """ |
1201 A L{DataSet} that contains as fields the results of applying a | 1219 A L{DataSet} that contains as fields the results of applying a |
1202 given function example-wise or minibatch-wise to all the fields of | 1220 given function example-wise or minibatch-wise to all the fields of |
1203 an input dataset. The output of the function should be an iterable | 1221 an input dataset. The output of the function should be an iterable |
1204 (e.g. a list or a LookupList) over the resulting values. | 1222 (e.g. a list or a LookupList) over the resulting values. |
1205 | 1223 |
1206 The function take as input the fields of the dataset, not the examples. | 1224 The function take as input the fields of the dataset, not the examples. |
1207 | 1225 |
1208 In minibatch mode, the function is expected to work on minibatches | 1226 In minibatch mode, the function is expected to work on minibatches |
1209 (takes a minibatch in input and returns a minibatch in output). More | 1227 (takes a minibatch in input and returns a minibatch in output). More |
1210 precisely, it means that each element of the input or output list | 1228 precisely, it means that each element of the input or output list |
1211 should be iterable and indexable over the individual example values | 1229 should be iterable and indexable over the individual example values |
1212 (typically these elements will be numpy arrays). All of the elements | 1230 (typically these elements will be numpy arrays). All of the elements |
1213 in the input and output lists should have the same length, which is | 1231 in the input and output lists should have the same length, which is |
1214 the length of the minibatch. | 1232 the length of the minibatch. |
1215 | 1233 |
1216 The function is applied each time an example or a minibatch is accessed. | 1234 The function is applied each time an example or a minibatch is accessed. |
1217 To avoid re-doing computation, wrap this dataset inside a CachedDataSet. | 1235 To avoid re-doing computation, wrap this dataset inside a CachedDataSet. |
1218 | 1236 |
1219 If the values_{h,v}stack functions are not provided, then | 1237 If the values_{h,v}stack functions are not provided, then |
1220 the input_dataset.values{H,V}Stack functions are used by default. | 1238 the input_dataset.values{H,V}Stack functions are used by default. |
1221 """ | 1239 """ |
1222 def __init__(self,input_dataset,function,output_names,minibatch_mode=True, | 1240 def __init__(self,input_dataset,function,output_names,minibatch_mode=True, |
1223 values_hstack=None,values_vstack=None, | 1241 values_hstack=None,values_vstack=None, |
1224 description=None,fieldtypes=None): | 1242 description=None,fieldtypes=None): |
1225 """ | 1243 """ |
1226 Constructor takes an input dataset that has as many fields as the function | 1244 Constructor takes an input dataset that has as many fields as the function |
1227 expects as inputs. The resulting dataset has as many fields as the function | 1245 expects as inputs. The resulting dataset has as many fields as the function |
1228 produces as outputs, and that should correspond to the number of output names | 1246 produces as outputs, and that should correspond to the number of output names |
1229 (provided in a list). | 1247 (provided in a list). |
1230 | 1248 |
1231 Note that the expected semantics of the function differs in minibatch mode | 1249 Note that the expected semantics of the function differs in minibatch mode |
1232 (it takes minibatches of inputs and produces minibatches of outputs, as | 1250 (it takes minibatches of inputs and produces minibatches of outputs, as |
1233 documented in the class comment). | 1251 documented in the class comment). |
1234 | 1252 |
1235 TBM: are filedtypes the old field types (from input_dataset) or the new ones | 1253 TBM: are filedtypes the old field types (from input_dataset) or the new ones |
1236 (for the new dataset created)? | 1254 (for the new dataset created)? |
1237 """ | 1255 """ |
1238 self.input_dataset=input_dataset | 1256 self.input_dataset=input_dataset |
1239 self.function=function | 1257 self.function=function |
1240 self.output_names=output_names | 1258 self.output_names=output_names |
1241 self.minibatch_mode=minibatch_mode | 1259 self.minibatch_mode=minibatch_mode |
1242 DataSet.__init__(self,description,fieldtypes) | 1260 DataSet.__init__(self,description,fieldtypes) |
1243 self.valuesHStack = values_hstack if values_hstack else input_dataset.valuesHStack | 1261 self.valuesHStack = values_hstack if values_hstack else input_dataset.valuesHStack |
1244 self.valuesVStack = values_vstack if values_vstack else input_dataset.valuesVStack | 1262 self.valuesVStack = values_vstack if values_vstack else input_dataset.valuesVStack |
1245 | 1263 |
1246 def __len__(self): | 1264 def __len__(self): |
1247 return len(self.input_dataset) | 1265 return len(self.input_dataset) |
1248 | 1266 |
1249 def fieldNames(self): | 1267 def fieldNames(self): |
1250 return self.output_names | 1268 return self.output_names |
1251 | 1269 |
1252 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): | 1270 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): |
1253 class ApplyFunctionIterator(object): | 1271 class ApplyFunctionIterator(object): |
1254 def __init__(self,output_dataset): | 1272 def __init__(self,output_dataset): |
1255 self.input_dataset=output_dataset.input_dataset | 1273 self.input_dataset=output_dataset.input_dataset |
1256 self.output_dataset=output_dataset | 1274 self.output_dataset=output_dataset |
1257 self.input_iterator=self.input_dataset.minibatches(minibatch_size=minibatch_size, | 1275 self.input_iterator=self.input_dataset.minibatches(minibatch_size=minibatch_size, |
1258 n_batches=n_batches,offset=offset).__iter__() | 1276 n_batches=n_batches,offset=offset).__iter__() |
1259 | 1277 |
1260 def __iter__(self): return self | 1278 def __iter__(self): return self |
1261 | 1279 |
1262 def next(self): | 1280 def next(self): |
1263 function_inputs = self.input_iterator.next() | 1281 function_inputs = self.input_iterator.next() |
1264 all_output_names = self.output_dataset.output_names | 1282 all_output_names = self.output_dataset.output_names |
1265 if self.output_dataset.minibatch_mode: | 1283 if self.output_dataset.minibatch_mode: |
1266 function_outputs = self.output_dataset.function(*function_inputs) | 1284 function_outputs = self.output_dataset.function(*function_inputs) |
1267 else: | 1285 else: |
1268 input_examples = zip(*function_inputs) | 1286 input_examples = zip(*function_inputs) |
1269 output_examples = [self.output_dataset.function(*input_example) | 1287 output_examples = [self.output_dataset.function(*input_example) |
1270 for input_example in input_examples] | 1288 for input_example in input_examples] |
1271 function_outputs = [self.output_dataset.valuesVStack(name,values) | 1289 function_outputs = [self.output_dataset.valuesVStack(name,values) |
1272 for name,values in zip(all_output_names, | 1290 for name,values in zip(all_output_names, |
1273 zip(*output_examples))] | 1291 zip(*output_examples))] |
1274 all_outputs = Example(all_output_names,function_outputs) | 1292 all_outputs = Example(all_output_names,function_outputs) |
1275 if fieldnames==all_output_names: | 1293 if fieldnames==all_output_names: |
1276 return all_outputs | 1294 return all_outputs |
1277 return Example(fieldnames,[all_outputs[name] for name in fieldnames]) | 1295 return Example(fieldnames,[all_outputs[name] for name in fieldnames]) |
1278 | 1296 |
1279 | 1297 |
1280 return ApplyFunctionIterator(self) | 1298 return ApplyFunctionIterator(self) |
1281 | 1299 |
1282 def __iter__(self): # only implemented for increased efficiency | 1300 def __iter__(self): # only implemented for increased efficiency |
1283 class ApplyFunctionSingleExampleIterator(object): | 1301 class ApplyFunctionSingleExampleIterator(object): |
1284 def __init__(self,output_dataset): | 1302 def __init__(self,output_dataset): |
1285 self.current=0 | 1303 self.current=0 |
1286 self.output_dataset=output_dataset | 1304 self.output_dataset=output_dataset |
1287 self.input_iterator=output_dataset.input_dataset.__iter__() | 1305 self.input_iterator=output_dataset.input_dataset.__iter__() |
1288 def __iter__(self): return self | 1306 def __iter__(self): return self |
1289 def next(self): | 1307 def next(self): |
1290 if self.output_dataset.minibatch_mode: | 1308 if self.output_dataset.minibatch_mode: |
1291 function_inputs = [[input] for input in self.input_iterator.next()] | 1309 function_inputs = [[input] for input in self.input_iterator.next()] |
1292 outputs = self.output_dataset.function(*function_inputs) | 1310 outputs = self.output_dataset.function(*function_inputs) |
1293 assert all([hasattr(output,'__iter__') for output in outputs]) | 1311 assert all([hasattr(output,'__iter__') for output in outputs]) |
1294 function_outputs = [output[0] for output in outputs] | 1312 function_outputs = [output[0] for output in outputs] |
1295 else: | 1313 else: |
1296 function_inputs = self.input_iterator.next() | 1314 function_inputs = self.input_iterator.next() |
1297 function_outputs = self.output_dataset.function(*function_inputs) | 1315 function_outputs = self.output_dataset.function(*function_inputs) |
1298 return Example(self.output_dataset.output_names,function_outputs) | 1316 return Example(self.output_dataset.output_names,function_outputs) |
1299 return ApplyFunctionSingleExampleIterator(self) | 1317 return ApplyFunctionSingleExampleIterator(self) |
1300 | 1318 |
1301 | 1319 |
1302 def supervised_learning_dataset(src_dataset,input_fields,target_fields,weight_field=None): | 1320 def supervised_learning_dataset(src_dataset,input_fields,target_fields,weight_field=None): |
1303 """ | 1321 """ |
1304 Wraps an arbitrary L{DataSet} into one for supervised learning tasks | 1322 Wraps an arbitrary L{DataSet} into one for supervised learning tasks |
1305 by forcing the user to define a set of fields as the 'input' field | 1323 by forcing the user to define a set of fields as the 'input' field |