Mercurial > pylearn
comparison dataset.py @ 134:3f4e5c9bdc5e
Fixes to ApplyFunctionDataSet and other things to make learner and mlp work
author | Yoshua Bengio <bengioy@iro.umontreal.ca> |
---|---|
date | Fri, 09 May 2008 17:38:57 -0400 |
parents | f6505ec32dc3 |
children | 0d8e721cc63c ad144fa72bf5 |
comparison
equal
deleted
inserted
replaced
133:b4657441dd65 | 134:3f4e5c9bdc5e |
---|---|
14 | 14 |
15 def attributeNames(self): | 15 def attributeNames(self): |
16 raise AbstractFunction() | 16 raise AbstractFunction() |
17 | 17 |
18 def setAttributes(self,attribute_names,attribute_values,make_copies=False): | 18 def setAttributes(self,attribute_names,attribute_values,make_copies=False): |
19 """ | |
20 Allow the attribute_values to not be a list (but a single value) if the attribute_names is of length 1. | |
21 """ | |
22 if len(attribute_names)==1 and not (isinstance(attribute_values,list) or isinstance(attribute_values,tuple) ): | |
23 attribute_values = [attribute_values] | |
19 if make_copies: | 24 if make_copies: |
20 for name,value in zip(attribute_names,attribute_values): | 25 for name,value in zip(attribute_names,attribute_values): |
21 self.__setattr__(name,copy.deepcopy(value)) | 26 self.__setattr__(name,copy.deepcopy(value)) |
22 else: | 27 else: |
23 for name,value in zip(attribute_names,attribute_values): | 28 for name,value in zip(attribute_names,attribute_values): |
1111 """ | 1116 """ |
1112 self.input_dataset=input_dataset | 1117 self.input_dataset=input_dataset |
1113 self.function=function | 1118 self.function=function |
1114 self.output_names=output_names | 1119 self.output_names=output_names |
1115 self.minibatch_mode=minibatch_mode | 1120 self.minibatch_mode=minibatch_mode |
1116 DataSet.__init__(description,fieldtypes) | 1121 DataSet.__init__(self,description,fieldtypes) |
1117 self.valuesHStack = values_hstack if values_hstack else input_dataset.valuesHStack | 1122 self.valuesHStack = values_hstack if values_hstack else input_dataset.valuesHStack |
1118 self.valuesVStack = values_vstack if values_vstack else input_dataset.valuesVStack | 1123 self.valuesVStack = values_vstack if values_vstack else input_dataset.valuesVStack |
1119 | 1124 |
1120 def __len__(self): | 1125 def __len__(self): |
1121 return len(self.input_dataset) | 1126 return len(self.input_dataset) |
1122 | 1127 |
1123 def fieldnames(self): | 1128 def fieldNames(self): |
1124 return self.output_names | 1129 return self.output_names |
1125 | 1130 |
1126 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): | 1131 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): |
1127 class ApplyFunctionIterator(object): | 1132 class ApplyFunctionIterator(object): |
1128 def __init__(self,output_dataset): | 1133 def __init__(self,output_dataset): |
1129 self.input_dataset=output_dataset.input_dataset | 1134 self.input_dataset=output_dataset.input_dataset |
1130 self.output_dataset=output_dataset | 1135 self.output_dataset=output_dataset |
1131 self.input_iterator=input_dataset.minibatches(minibatch_size=minibatch_size, | 1136 self.input_iterator=self.input_dataset.minibatches(minibatch_size=minibatch_size, |
1132 n_batches=n_batches,offset=offset).__iter__() | 1137 n_batches=n_batches,offset=offset).__iter__() |
1133 | 1138 |
1134 def __iter__(self): return self | 1139 def __iter__(self): return self |
1135 | 1140 |
1136 def next(self): | 1141 def next(self): |
1137 function_inputs = self.input_iterator.next() | 1142 function_inputs = self.input_iterator.next() |
1138 all_output_names = self.output_dataset.output_names | 1143 all_output_names = self.output_dataset.output_names |
1139 if self.output_dataset.minibatch_mode: | 1144 if self.output_dataset.minibatch_mode: |
1140 function_outputs = self.output_dataset.function(function_inputs) | 1145 function_outputs = self.output_dataset.function(*function_inputs) |
1141 else: | 1146 else: |
1142 input_examples = zip(*function_inputs) | 1147 input_examples = zip(*function_inputs) |
1143 output_examples = [self.output_dataset.function(input_example) | 1148 output_examples = [self.output_dataset.function(input_example) |
1144 for input_example in input_examples] | 1149 for input_example in input_examples] |
1145 function_outputs = [self.output_dataset.valuesVStack(name,values) | 1150 function_outputs = [self.output_dataset.valuesVStack(name,values) |
1148 all_outputs = Example(all_output_names,function_outputs) | 1153 all_outputs = Example(all_output_names,function_outputs) |
1149 if fieldnames==all_output_names: | 1154 if fieldnames==all_output_names: |
1150 return all_outputs | 1155 return all_outputs |
1151 return Example(fieldnames,[all_outputs[name] for name in fieldnames]) | 1156 return Example(fieldnames,[all_outputs[name] for name in fieldnames]) |
1152 | 1157 |
1153 return ApplyFunctionIterator(self.input_dataset,self) | 1158 return ApplyFunctionIterator(self) |
1154 | 1159 |
1155 def __iter__(self): # only implemented for increased efficiency | 1160 def __iter__(self): # only implemented for increased efficiency |
1156 class ApplyFunctionSingleExampleIterator(object): | 1161 class ApplyFunctionSingleExampleIterator(object): |
1157 def __init__(self,output_dataset): | 1162 def __init__(self,output_dataset): |
1158 self.current=0 | 1163 self.current=0 |