Mercurial > pylearn
comparison dataset.py @ 268:3f1cd8897fda
reverting dataset
author | James Bergstra <bergstrj@iro.umontreal.ca> |
---|---|
date | Wed, 04 Jun 2008 18:48:50 -0400 |
parents | 6e69fb91f3c0 |
children | fdce496c3b56 |
comparison
equal
deleted
inserted
replaced
267:4dad41215967 | 268:3f1cd8897fda |
---|---|
107 | 107 |
108 - dataset[i] returns an Example. | 108 - dataset[i] returns an Example. |
109 | 109 |
110 - dataset[[i1,i2,...in]] returns a dataset with examples i1,i2,...in. | 110 - dataset[[i1,i2,...in]] returns a dataset with examples i1,i2,...in. |
111 | 111 |
112 - dataset[fieldname] an iterable over the values of the field fieldname across | |
113 the dataset (the iterable is obtained by default by calling valuesVStack | |
114 over the values for individual examples). | |
115 | |
112 - dataset.<property> returns the value of a property associated with | 116 - dataset.<property> returns the value of a property associated with |
113 the name <property>. The following properties should be supported: | 117 the name <property>. The following properties should be supported: |
114 - 'description': a textual description or name for the dataset | 118 - 'description': a textual description or name for the dataset |
115 - 'fieldtypes': a list of types (one per field) | 119 - 'fieldtypes': a list of types (one per field) |
116 A DataSet may have other attributes that it makes visible to other objects. These are | 120 A DataSet may have other attributes that it makes visible to other objects. These are |
156 A sub-class should also append attributes to self._attribute_names | 160 A sub-class should also append attributes to self._attribute_names |
157 (the default value returned by attributeNames()). | 161 (the default value returned by attributeNames()). |
158 By convention, attributes not in attributeNames() should have a name | 162 By convention, attributes not in attributeNames() should have a name |
159 starting with an underscore. | 163 starting with an underscore. |
160 @todo enforce/test that convention! | 164 @todo enforce/test that convention! |
161 | 165 """ |
162 """ | 166 |
163 | 167 numpy_vstack = lambda fieldname,values: numpy.vstack(values) |
164 if 0: | 168 numpy_hstack = lambda fieldnames,values: numpy.hstack(values) |
165 # removed by James June 4... these aren't used anywhere according to | |
166 # grep | |
167 numpy_vstack = lambda fieldname,values: numpy.vstack(values) | |
168 numpy_hstack = lambda fieldnames,values: numpy.hstack(values) | |
169 | 169 |
170 def __init__(self,description=None,fieldtypes=None): | 170 def __init__(self,description=None,fieldtypes=None): |
171 if description is None: | 171 if description is None: |
172 # by default return "<DataSetType>(<SuperClass1>,<SuperClass2>,...)" | 172 # by default return "<DataSetType>(<SuperClass1>,<SuperClass2>,...)" |
173 description = type(self).__name__ + " ( " + join([x.__name__ for x in type(self).__bases__]) + " )" | 173 description = type(self).__name__ + " ( " + join([x.__name__ for x in type(self).__bases__]) + " )" |
275 else: | 275 else: |
276 # we must concatenate (vstack) the bottom and top parts of our minibatch | 276 # we must concatenate (vstack) the bottom and top parts of our minibatch |
277 # first get the beginning of our minibatch (top of dataset) | 277 # first get the beginning of our minibatch (top of dataset) |
278 first_part = self.dataset.minibatches_nowrap(self.fieldnames,self.L-self.next_row,1,self.next_row).next() | 278 first_part = self.dataset.minibatches_nowrap(self.fieldnames,self.L-self.next_row,1,self.next_row).next() |
279 second_part = self.dataset.minibatches_nowrap(self.fieldnames,upper-self.L,1,0).next() | 279 second_part = self.dataset.minibatches_nowrap(self.fieldnames,upper-self.L,1,0).next() |
280 | 280 minibatch = Example(self.fieldnames, |
281 blah = [self.dataset.valuesVStack(name,[first_part[name],second_part[name]]) | 281 [self.dataset.valuesVStack(name,[first_part[name],second_part[name]]) |
282 for name in self.fieldnames] | 282 for name in self.fieldnames]) |
283 print type(self.dataset), blah | |
284 minibatch = Example(self.fieldnames,blah) | |
285 self.next_row=upper | 283 self.next_row=upper |
286 self.n_batches_done+=1 | 284 self.n_batches_done+=1 |
287 if upper >= self.L and self.n_batches: | 285 if upper >= self.L and self.n_batches: |
288 self.next_row -= self.L | 286 self.next_row -= self.L |
289 ds_nbatches = (self.L-self.next_row)/self.minibatch_size | 287 ds_nbatches = (self.L-self.next_row)/self.minibatch_size |
460 return MinibatchDataSet( | 458 return MinibatchDataSet( |
461 Example(self.fieldNames(),[ self.valuesVStack(fieldname,field_values) | 459 Example(self.fieldNames(),[ self.valuesVStack(fieldname,field_values) |
462 for fieldname,field_values | 460 for fieldname,field_values |
463 in zip(self.fieldNames(),fields_values)]), | 461 in zip(self.fieldNames(),fields_values)]), |
464 self.valuesVStack,self.valuesHStack) | 462 self.valuesVStack,self.valuesHStack) |
465 | 463 # else check for a fieldname |
466 raise TypeError(i) | 464 if self.hasFields(i): |
467 if 0: | 465 return self.minibatches(fieldnames=[i],minibatch_size=len(self),n_batches=1,offset=0).next()[0] |
468 # else check for a fieldname | 466 # else we are trying to access a property of the dataset |
469 #after talk with Yoshua June 4, this is disabled. | 467 assert i in self.__dict__ # else it means we are trying to access a non-existing property |
470 if self.hasFields(i): | 468 return self.__dict__[i] |
471 return self.minibatches(fieldnames=[i],minibatch_size=len(self),n_batches=1,offset=0).next()[0] | |
472 # else we are trying to access a property of the dataset | |
473 assert i in self.__dict__ # else it means we are trying to access a non-existing property | |
474 return self.__dict__[i] | |
475 | 469 |
476 def valuesHStack(self,fieldnames,fieldvalues): | 470 def valuesHStack(self,fieldnames,fieldvalues): |
477 """ | 471 """ |
478 Return a value that corresponds to concatenating (horizontally) several field values. | 472 Return a value that corresponds to concatenating (horizontally) several field values. |
479 This can be useful to merge some fields. The implementation of this operation is likely | 473 This can be useful to merge some fields. The implementation of this operation is likely |
495 return fieldvalues | 489 return fieldvalues |
496 | 490 |
497 | 491 |
498 def valuesVStack(self,fieldname,values): | 492 def valuesVStack(self,fieldname,values): |
499 """ | 493 """ |
500 @param fieldname: the name of the field from which the values were taken | 494 Return a value that corresponds to concatenating (vertically) several values of the |
501 @type fieldname: any type | 495 same field. This can be important to build a minibatch out of individual examples. This |
502 | 496 is likely to involve a copy of the original values. When the values are numpy arrays, the |
503 @param values: bits near the beginning or end of the dataset | 497 result should be numpy.vstack(values). |
504 @type values: list of minibatches (returned by minibatch_nowrap) | 498 The default is to use numpy.vstack for numpy.ndarray values, and a list |
505 | 499 pointing to the original values for other data types. |
506 @return: the concatenation (stacking) of the values | 500 """ |
507 @rtype: something suitable as a minibatch field | 501 all_numpy=True |
508 | 502 for value in values: |
509 """ | 503 if not type(value) is numpy.ndarray: |
510 rval = [] | 504 all_numpy=False |
511 for sub_batch in values: | 505 if all_numpy: |
512 rval.extend(sub_batch) | 506 return numpy.vstack(values) |
513 return rval | 507 # the default implementation of vertical stacking is to put values in a list |
508 return values | |
514 | 509 |
515 def __or__(self,other): | 510 def __or__(self,other): |
516 """ | 511 """ |
517 dataset1 | dataset2 returns a dataset whose list of fields is the concatenation of the list of | 512 dataset1 | dataset2 returns a dataset whose list of fields is the concatenation of the list of |
518 fields of the argument datasets. This only works if they all have the same length. | 513 fields of the argument datasets. This only works if they all have the same length. |
956 class ArrayFieldsDataSet(DataSet): | 951 class ArrayFieldsDataSet(DataSet): |
957 """ | 952 """ |
958 Virtual super-class of datasets whose field values are numpy array, | 953 Virtual super-class of datasets whose field values are numpy array, |
959 thus defining valuesHStack and valuesVStack for sub-classes. | 954 thus defining valuesHStack and valuesVStack for sub-classes. |
960 """ | 955 """ |
961 def __init__(self, description=None, field_types=None): | 956 def __init__(self,description=None,field_types=None): |
962 DataSet.__init__(self, description, field_types) | 957 DataSet.__init__(self,description,field_types) |
963 def valuesHStack(self, fieldnames, fieldvalues): | 958 def valuesHStack(self,fieldnames,fieldvalues): |
964 """Concatenate field values horizontally, e.g. two vectors | 959 """Concatenate field values horizontally, e.g. two vectors |
965 become a longer vector, two matrices become a wider matrix, etc.""" | 960 become a longer vector, two matrices become a wider matrix, etc.""" |
966 return numpy.hstack(fieldvalues) | 961 return numpy.hstack(fieldvalues) |
967 def valuesVStack(self, fieldname, values): | 962 def valuesVStack(self,fieldname,values): |
968 """Concatenate field values vertically, e.g. two vectors | 963 """Concatenate field values vertically, e.g. two vectors |
969 become a two-row matrix, two matrices become a longer matrix, etc.""" | 964 become a two-row matrix, two matrices become a longer matrix, etc.""" |
970 #print len(values) | 965 return numpy.vstack(values) |
971 for v in values: | |
972 if not isinstance(v, numpy.ndarray): | |
973 raise TypeError(v, type(v)) | |
974 | |
975 s0 = sum([v.shape[0] for v in values]) | |
976 #TODO: there's gotta be a better way to do this! | |
977 dtype = values[0].dtype | |
978 rval = numpy.ndarray([s0] + list(values[0].shape[1:]), dtype=dtype) | |
979 cur_row = 0 | |
980 for v in values: | |
981 rval[cur_row:cur_row+v.shape[0]] = v | |
982 cur_row += v.shape[0] | |
983 return rval | |
984 | 966 |
985 class ArrayDataSet(ArrayFieldsDataSet): | 967 class ArrayDataSet(ArrayFieldsDataSet): |
986 """ | 968 """ |
987 An ArrayDataSet stores the fields as groups of columns in a numpy tensor, | 969 An ArrayDataSet stores the fields as groups of columns in a numpy tensor, |
988 whose first axis iterates over examples, second axis determines fields. | 970 whose first axis iterates over examples, second axis determines fields. |
1003 | 985 |
1004 # check consistency and complete slices definitions | 986 # check consistency and complete slices definitions |
1005 for fieldname, fieldcolumns in self.fields_columns.items(): | 987 for fieldname, fieldcolumns in self.fields_columns.items(): |
1006 if type(fieldcolumns) is int: | 988 if type(fieldcolumns) is int: |
1007 assert fieldcolumns>=0 and fieldcolumns<data_array.shape[1] | 989 assert fieldcolumns>=0 and fieldcolumns<data_array.shape[1] |
1008 if 0: | 990 if 1: |
1009 #I changed this because it didn't make sense to me, | 991 #I changed this because it didn't make sense to me, |
1010 # and it made it more difficult to write my learner. | 992 # and it made it more difficult to write my learner. |
1011 # If it breaks stuff, let's talk about it. | 993 # If it breaks stuff, let's talk about it. |
1012 # - James 22/05/2008 | 994 # - James 22/05/2008 |
1013 self.fields_columns[fieldname]=[fieldcolumns] | 995 self.fields_columns[fieldname]=[fieldcolumns] |
1035 """More efficient implementation than the default __getitem__""" | 1017 """More efficient implementation than the default __getitem__""" |
1036 fieldnames=self.fields_columns.keys() | 1018 fieldnames=self.fields_columns.keys() |
1037 values=self.fields_columns.values() | 1019 values=self.fields_columns.values() |
1038 if type(key) is int: | 1020 if type(key) is int: |
1039 return Example(fieldnames, | 1021 return Example(fieldnames, |
1040 [numpy.asarray(self.data[key,col]) for col in values]) | 1022 [self.data[key,col] for col in values]) |
1041 if type(key) is slice: | 1023 if type(key) is slice: |
1042 return MinibatchDataSet(Example(fieldnames, | 1024 return MinibatchDataSet(Example(fieldnames, |
1043 [self.data[key,col] for col in values])) | 1025 [self.data[key,col] for col in values])) |
1044 if type(key) is list: | 1026 if type(key) is list: |
1045 for i in range(len(key)): | 1027 for i in range(len(key)): |
1113 | 1095 |
1114 return ArrayDataSetIterator(self,fieldnames,minibatch_size,n_batches,offset) | 1096 return ArrayDataSetIterator(self,fieldnames,minibatch_size,n_batches,offset) |
1115 | 1097 |
1116 | 1098 |
1117 class CachedDataSet(DataSet): | 1099 class CachedDataSet(DataSet): |
1118 """ | 1100 """ |
1119 Wrap a L{DataSet} whose values are computationally expensive to obtain | 1101 Wrap a L{DataSet} whose values are computationally expensive to obtain |
1120 (e.g. because they involve some computation, or disk access), | 1102 (e.g. because they involve some computation, or disk access), |
1121 so that repeated accesses to the same example are done cheaply, | 1103 so that repeated accesses to the same example are done cheaply, |
1122 by caching every example value that has been accessed at least once. | 1104 by caching every example value that has been accessed at least once. |
1123 | 1105 |
1124 Optionally, for finite-length dataset, all the values can be computed | 1106 Optionally, for finite-length dataset, all the values can be computed |
1125 (and cached) upon construction of the CachedDataSet, rather at the | 1107 (and cached) upon construction of the CachedDataSet, rather at the |
1126 first access. | 1108 first access. |
1127 | 1109 |
1128 @todo: when cache_all_upon_construction create mini-batches that are as | 1110 @todo: when cache_all_upon_construction create mini-batches that are as |
1129 large as possible but not so large as to fill up memory. | 1111 large as possible but not so large as to fill up memory. |
1130 | 1112 |
1131 @todo: add disk-buffering capability, so that when the cache becomes too | 1113 @todo: add disk-buffering capability, so that when the cache becomes too |
1132 big for memory, we cache things on disk, trying to keep in memory only | 1114 big for memory, we cache things on disk, trying to keep in memory only |
1133 the record most likely to be accessed next. | 1115 the record most likely to be accessed next. |
1134 """ | 1116 """ |
1135 def __init__(self,source_dataset,cache_all_upon_construction=False): | 1117 def __init__(self,source_dataset,cache_all_upon_construction=False): |
1136 self.source_dataset=source_dataset | 1118 self.source_dataset=source_dataset |
1137 self.cache_all_upon_construction=cache_all_upon_construction | 1119 self.cache_all_upon_construction=cache_all_upon_construction |
1138 self.cached_examples = [] #a list of LookupList (copies) | 1120 self.cached_examples = [] |
1139 if cache_all_upon_construction: | 1121 if cache_all_upon_construction: |
1140 # this potentially brings all the source examples | 1122 # this potentially brings all the source examples |
1141 # into memory at once, which may be too much | 1123 # into memory at once, which may be too much |
1142 # the work could possibly be done by minibatches | 1124 # the work could possibly be done by minibatches |
1143 # that are as large as possible but no more than what memory allows. | 1125 # that are as large as possible but no more than what memory allows. |
1144 fields_values = source_dataset.minibatches(minibatch_size=len(source_dataset)).__iter__().next() | 1126 fields_values = source_dataset.minibatches(minibatch_size=len(source_dataset)).__iter__().next() |
1145 assert all([len(self)==len(fval) for fval in fields_values]) | 1127 assert all([len(self)==len(field_values) for field_values in fields_values]) |
1146 for example in fields_values.examples(): | 1128 for example in fields_values.examples(): |
1147 dup = copy.copy(example) | 1129 self.cached_examples.append(copy.copy(example)) |
1148 self.cached_examples.append(dup) | 1130 |
1149 | 1131 self.fieldNames = source_dataset.fieldNames |
1150 self.fieldNames = source_dataset.fieldNames | 1132 self.hasFields = source_dataset.hasFields |
1151 self.hasFields = source_dataset.hasFields | 1133 self.valuesHStack = source_dataset.valuesHStack |
1152 self.valuesHStack = source_dataset.valuesHStack | 1134 self.valuesVStack = source_dataset.valuesVStack |
1153 self.valuesVStack = source_dataset.valuesVStack | |
1154 | 1135 |
1155 def __len__(self): | 1136 def __len__(self): |
1156 return len(self.source_dataset) | 1137 return len(self.source_dataset) |
1157 | 1138 |
1158 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): | 1139 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): |
1159 class CacheIterator(object): | 1140 class CacheIterator(object): |
1160 def __init__(self,dataset): | 1141 def __init__(self,dataset): |
1161 self.dataset=dataset | 1142 self.dataset=dataset |
1162 self.current=offset | 1143 self.current=offset |
1163 self.all_fields = self.dataset.fieldNames()==fieldnames | 1144 self.all_fields = self.dataset.fieldNames()==fieldnames |
1164 def __iter__(self): return self | 1145 def __iter__(self): return self |
1165 def next(self): | 1146 def next(self): |
1166 upper = self.current+minibatch_size | 1147 upper = self.current+minibatch_size |
1167 cache_len = len(self.dataset.cached_examples) | 1148 cache_len = len(self.dataset.cached_examples) |
1168 if upper>cache_len: | 1149 if upper>cache_len: # whole minibatch is not already in cache |
1169 # whole minibatch is not already in cache | 1150 # cache everything from current length to upper |
1170 # cache everything from current length to upper | 1151 for example in self.dataset.source_dataset[cache_len:upper]: |
1171 for example in self.dataset.source_dataset[cache_len:upper]: | 1152 self.dataset.cached_examples.append(example) |
1172 self.dataset.cached_examples.append(example) | 1153 all_fields_minibatch = Example(self.dataset.fieldNames(), |
1173 | 1154 zip(*self.dataset.cached_examples[self.current:self.current+minibatch_size])) |
1174 next_range = slice(self.current, self.current+minibatch_size) | 1155 self.current+=minibatch_size |
1175 blah = self.dataset.cached_examples[next_range] | 1156 if self.all_fields: |
1176 all_fields_minibatch = Example(self.dataset.fieldNames(), zip(*blah)) | 1157 return all_fields_minibatch |
1177 self.current+=minibatch_size | 1158 return Example(fieldnames,[all_fields_minibatch[name] for name in fieldnames]) |
1178 | 1159 return CacheIterator(self) |
1179 #little optimization to avoid second Example computation if | 1160 |
1180 #possible. | 1161 def __getitem__(self,i): |
1181 if self.all_fields: | 1162 if type(i)==int and len(self.cached_examples)>i: |
1182 return all_fields_minibatch | 1163 return self.cached_examples[i] |
1183 | 1164 else: |
1184 rval = Example(fieldnames,[all_fields_minibatch[name] for name in fieldnames]) | 1165 return self.source_dataset[i] |
1185 return rval | 1166 |
1186 return CacheIterator(self) | 1167 def __iter__(self): |
1187 | 1168 class CacheIteratorIter(object): |
1188 def __getitem__(self,i): | 1169 def __init__(self,dataset): |
1189 if type(i)==int and len(self.cached_examples)>i: | 1170 self.dataset=dataset |
1190 return self.cached_examples[i] | 1171 self.l = len(dataset) |
1191 else: | 1172 self.current = 0 |
1192 return self.source_dataset[i] | 1173 self.fieldnames = self.dataset.fieldNames() |
1193 | 1174 self.example = LookupList(self.fieldnames,[0]*len(self.fieldnames)) |
1194 def __iter__(self): | 1175 def __iter__(self): return self |
1195 class CacheIteratorIter(object): | 1176 def next(self): |
1196 def __init__(self,dataset): | 1177 if self.current>=self.l: |
1197 self.dataset=dataset | 1178 raise StopIteration |
1198 self.l = len(dataset) | 1179 cache_len = len(self.dataset.cached_examples) |
1199 self.current = 0 | 1180 if self.current>=cache_len: # whole minibatch is not already in cache |
1200 self.fieldnames = self.dataset.fieldNames() | 1181 # cache everything from current length to upper |
1201 self.example = LookupList(self.fieldnames,[0]*len(self.fieldnames)) | 1182 self.dataset.cached_examples.append( |
1202 def __iter__(self): return self | 1183 self.dataset.source_dataset[self.current]) |
1203 def next(self): | 1184 self.example._values = self.dataset.cached_examples[self.current] |
1204 if self.current>=self.l: | 1185 self.current+=1 |
1205 raise StopIteration | 1186 return self.example |
1206 cache_len = len(self.dataset.cached_examples) | 1187 |
1207 if self.current>=cache_len: # whole minibatch is not already in cache | 1188 return CacheIteratorIter(self) |
1208 # cache everything from current length to upper | |
1209 self.dataset.cached_examples.append( | |
1210 self.dataset.source_dataset[self.current]) | |
1211 self.example._values = self.dataset.cached_examples[self.current] | |
1212 self.current+=1 | |
1213 return self.example | |
1214 | |
1215 return CacheIteratorIter(self) | |
1216 | 1189 |
1217 class ApplyFunctionDataSet(DataSet): | 1190 class ApplyFunctionDataSet(DataSet): |
1218 """ | 1191 """ |
1219 A L{DataSet} that contains as fields the results of applying a | 1192 A L{DataSet} that contains as fields the results of applying a |
1220 given function example-wise or minibatch-wise to all the fields of | 1193 given function example-wise or minibatch-wise to all the fields of |
1221 an input dataset. The output of the function should be an iterable | 1194 an input dataset. The output of the function should be an iterable |
1222 (e.g. a list or a LookupList) over the resulting values. | 1195 (e.g. a list or a LookupList) over the resulting values. |
1223 | 1196 |
1224 The function take as input the fields of the dataset, not the examples. | 1197 The function take as input the fields of the dataset, not the examples. |
1225 | 1198 |
1226 In minibatch mode, the function is expected to work on minibatches | 1199 In minibatch mode, the function is expected to work on minibatches |
1227 (takes a minibatch in input and returns a minibatch in output). More | 1200 (takes a minibatch in input and returns a minibatch in output). More |
1228 precisely, it means that each element of the input or output list | 1201 precisely, it means that each element of the input or output list |
1229 should be iterable and indexable over the individual example values | 1202 should be iterable and indexable over the individual example values |
1230 (typically these elements will be numpy arrays). All of the elements | 1203 (typically these elements will be numpy arrays). All of the elements |
1231 in the input and output lists should have the same length, which is | 1204 in the input and output lists should have the same length, which is |
1232 the length of the minibatch. | 1205 the length of the minibatch. |
1233 | 1206 |
1234 The function is applied each time an example or a minibatch is accessed. | 1207 The function is applied each time an example or a minibatch is accessed. |
1235 To avoid re-doing computation, wrap this dataset inside a CachedDataSet. | 1208 To avoid re-doing computation, wrap this dataset inside a CachedDataSet. |
1236 | 1209 |
1237 If the values_{h,v}stack functions are not provided, then | 1210 If the values_{h,v}stack functions are not provided, then |
1238 the input_dataset.values{H,V}Stack functions are used by default. | 1211 the input_dataset.values{H,V}Stack functions are used by default. |
1239 """ | 1212 """ |
1240 def __init__(self,input_dataset,function,output_names,minibatch_mode=True, | 1213 def __init__(self,input_dataset,function,output_names,minibatch_mode=True, |
1241 values_hstack=None,values_vstack=None, | 1214 values_hstack=None,values_vstack=None, |
1242 description=None,fieldtypes=None): | 1215 description=None,fieldtypes=None): |
1243 """ | 1216 """ |
1244 Constructor takes an input dataset that has as many fields as the function | 1217 Constructor takes an input dataset that has as many fields as the function |
1245 expects as inputs. The resulting dataset has as many fields as the function | 1218 expects as inputs. The resulting dataset has as many fields as the function |
1246 produces as outputs, and that should correspond to the number of output names | 1219 produces as outputs, and that should correspond to the number of output names |
1247 (provided in a list). | 1220 (provided in a list). |
1248 | 1221 |
1249 Note that the expected semantics of the function differs in minibatch mode | 1222 Note that the expected semantics of the function differs in minibatch mode |
1250 (it takes minibatches of inputs and produces minibatches of outputs, as | 1223 (it takes minibatches of inputs and produces minibatches of outputs, as |
1251 documented in the class comment). | 1224 documented in the class comment). |
1252 | 1225 |
1253 TBM: are filedtypes the old field types (from input_dataset) or the new ones | 1226 TBM: are filedtypes the old field types (from input_dataset) or the new ones |
1254 (for the new dataset created)? | 1227 (for the new dataset created)? |
1255 """ | 1228 """ |
1256 self.input_dataset=input_dataset | 1229 self.input_dataset=input_dataset |
1257 self.function=function | 1230 self.function=function |
1258 self.output_names=output_names | 1231 self.output_names=output_names |
1259 self.minibatch_mode=minibatch_mode | 1232 self.minibatch_mode=minibatch_mode |
1260 DataSet.__init__(self,description,fieldtypes) | 1233 DataSet.__init__(self,description,fieldtypes) |
1261 self.valuesHStack = values_hstack if values_hstack else input_dataset.valuesHStack | 1234 self.valuesHStack = values_hstack if values_hstack else input_dataset.valuesHStack |
1262 self.valuesVStack = values_vstack if values_vstack else input_dataset.valuesVStack | 1235 self.valuesVStack = values_vstack if values_vstack else input_dataset.valuesVStack |
1263 | 1236 |
1264 def __len__(self): | 1237 def __len__(self): |
1265 return len(self.input_dataset) | 1238 return len(self.input_dataset) |
1266 | 1239 |
1267 def fieldNames(self): | 1240 def fieldNames(self): |
1268 return self.output_names | 1241 return self.output_names |
1269 | 1242 |
1270 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): | 1243 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): |
1271 class ApplyFunctionIterator(object): | 1244 class ApplyFunctionIterator(object): |
1272 def __init__(self,output_dataset): | 1245 def __init__(self,output_dataset): |
1273 self.input_dataset=output_dataset.input_dataset | 1246 self.input_dataset=output_dataset.input_dataset |
1274 self.output_dataset=output_dataset | 1247 self.output_dataset=output_dataset |
1275 self.input_iterator=self.input_dataset.minibatches(minibatch_size=minibatch_size, | 1248 self.input_iterator=self.input_dataset.minibatches(minibatch_size=minibatch_size, |
1276 n_batches=n_batches,offset=offset).__iter__() | 1249 n_batches=n_batches,offset=offset).__iter__() |
1277 | 1250 |
1278 def __iter__(self): return self | 1251 def __iter__(self): return self |
1279 | 1252 |
1280 def next(self): | 1253 def next(self): |
1281 function_inputs = self.input_iterator.next() | 1254 function_inputs = self.input_iterator.next() |
1282 all_output_names = self.output_dataset.output_names | 1255 all_output_names = self.output_dataset.output_names |
1283 if self.output_dataset.minibatch_mode: | 1256 if self.output_dataset.minibatch_mode: |
1284 function_outputs = self.output_dataset.function(*function_inputs) | 1257 function_outputs = self.output_dataset.function(*function_inputs) |
1285 else: | 1258 else: |
1286 input_examples = zip(*function_inputs) | 1259 input_examples = zip(*function_inputs) |
1287 output_examples = [self.output_dataset.function(*input_example) | 1260 output_examples = [self.output_dataset.function(*input_example) |
1288 for input_example in input_examples] | 1261 for input_example in input_examples] |
1289 function_outputs = [self.output_dataset.valuesVStack(name,values) | 1262 function_outputs = [self.output_dataset.valuesVStack(name,values) |
1290 for name,values in zip(all_output_names, | 1263 for name,values in zip(all_output_names, |
1291 zip(*output_examples))] | 1264 zip(*output_examples))] |
1292 all_outputs = Example(all_output_names,function_outputs) | 1265 all_outputs = Example(all_output_names,function_outputs) |
1293 if fieldnames==all_output_names: | 1266 if fieldnames==all_output_names: |
1294 return all_outputs | 1267 return all_outputs |
1295 return Example(fieldnames,[all_outputs[name] for name in fieldnames]) | 1268 return Example(fieldnames,[all_outputs[name] for name in fieldnames]) |
1296 | 1269 |
1297 | 1270 |
1298 return ApplyFunctionIterator(self) | 1271 return ApplyFunctionIterator(self) |
1299 | 1272 |
1300 def __iter__(self): # only implemented for increased efficiency | 1273 def __iter__(self): # only implemented for increased efficiency |
1301 class ApplyFunctionSingleExampleIterator(object): | 1274 class ApplyFunctionSingleExampleIterator(object): |
1302 def __init__(self,output_dataset): | 1275 def __init__(self,output_dataset): |
1303 self.current=0 | 1276 self.current=0 |
1304 self.output_dataset=output_dataset | 1277 self.output_dataset=output_dataset |
1305 self.input_iterator=output_dataset.input_dataset.__iter__() | 1278 self.input_iterator=output_dataset.input_dataset.__iter__() |
1306 def __iter__(self): return self | 1279 def __iter__(self): return self |
1307 def next(self): | 1280 def next(self): |
1308 if self.output_dataset.minibatch_mode: | 1281 if self.output_dataset.minibatch_mode: |
1309 function_inputs = [[input] for input in self.input_iterator.next()] | 1282 function_inputs = [[input] for input in self.input_iterator.next()] |
1310 outputs = self.output_dataset.function(*function_inputs) | 1283 outputs = self.output_dataset.function(*function_inputs) |
1311 assert all([hasattr(output,'__iter__') for output in outputs]) | 1284 assert all([hasattr(output,'__iter__') for output in outputs]) |
1312 function_outputs = [output[0] for output in outputs] | 1285 function_outputs = [output[0] for output in outputs] |
1313 else: | 1286 else: |
1314 function_inputs = self.input_iterator.next() | 1287 function_inputs = self.input_iterator.next() |
1315 function_outputs = self.output_dataset.function(*function_inputs) | 1288 function_outputs = self.output_dataset.function(*function_inputs) |
1316 return Example(self.output_dataset.output_names,function_outputs) | 1289 return Example(self.output_dataset.output_names,function_outputs) |
1317 return ApplyFunctionSingleExampleIterator(self) | 1290 return ApplyFunctionSingleExampleIterator(self) |
1318 | 1291 |
1319 | 1292 |
1320 def supervised_learning_dataset(src_dataset,input_fields,target_fields,weight_field=None): | 1293 def supervised_learning_dataset(src_dataset,input_fields,target_fields,weight_field=None): |
1321 """ | 1294 """ |
1322 Wraps an arbitrary L{DataSet} into one for supervised learning tasks | 1295 Wraps an arbitrary L{DataSet} into one for supervised learning tasks |
1323 by forcing the user to define a set of fields as the 'input' field | 1296 by forcing the user to define a set of fields as the 'input' field |