comparison dataset.py @ 134:3f4e5c9bdc5e

Fixes to ApplyFunctionDataSet and other things to make learner and mlp work
author Yoshua Bengio <bengioy@iro.umontreal.ca>
date Fri, 09 May 2008 17:38:57 -0400
parents f6505ec32dc3
children 0d8e721cc63c ad144fa72bf5
comparison
equal deleted inserted replaced
133:b4657441dd65 134:3f4e5c9bdc5e
14 14
15 def attributeNames(self): 15 def attributeNames(self):
16 raise AbstractFunction() 16 raise AbstractFunction()
17 17
18 def setAttributes(self,attribute_names,attribute_values,make_copies=False): 18 def setAttributes(self,attribute_names,attribute_values,make_copies=False):
19 """
20 Allow the attribute_values to not be a list (but a single value) if the attribute_names is of length 1.
21 """
22 if len(attribute_names)==1 and not (isinstance(attribute_values,list) or isinstance(attribute_values,tuple) ):
23 attribute_values = [attribute_values]
19 if make_copies: 24 if make_copies:
20 for name,value in zip(attribute_names,attribute_values): 25 for name,value in zip(attribute_names,attribute_values):
21 self.__setattr__(name,copy.deepcopy(value)) 26 self.__setattr__(name,copy.deepcopy(value))
22 else: 27 else:
23 for name,value in zip(attribute_names,attribute_values): 28 for name,value in zip(attribute_names,attribute_values):
1111 """ 1116 """
1112 self.input_dataset=input_dataset 1117 self.input_dataset=input_dataset
1113 self.function=function 1118 self.function=function
1114 self.output_names=output_names 1119 self.output_names=output_names
1115 self.minibatch_mode=minibatch_mode 1120 self.minibatch_mode=minibatch_mode
1116 DataSet.__init__(description,fieldtypes) 1121 DataSet.__init__(self,description,fieldtypes)
1117 self.valuesHStack = values_hstack if values_hstack else input_dataset.valuesHStack 1122 self.valuesHStack = values_hstack if values_hstack else input_dataset.valuesHStack
1118 self.valuesVStack = values_vstack if values_vstack else input_dataset.valuesVStack 1123 self.valuesVStack = values_vstack if values_vstack else input_dataset.valuesVStack
1119 1124
1120 def __len__(self): 1125 def __len__(self):
1121 return len(self.input_dataset) 1126 return len(self.input_dataset)
1122 1127
1123 def fieldnames(self): 1128 def fieldNames(self):
1124 return self.output_names 1129 return self.output_names
1125 1130
1126 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): 1131 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset):
1127 class ApplyFunctionIterator(object): 1132 class ApplyFunctionIterator(object):
1128 def __init__(self,output_dataset): 1133 def __init__(self,output_dataset):
1129 self.input_dataset=output_dataset.input_dataset 1134 self.input_dataset=output_dataset.input_dataset
1130 self.output_dataset=output_dataset 1135 self.output_dataset=output_dataset
1131 self.input_iterator=input_dataset.minibatches(minibatch_size=minibatch_size, 1136 self.input_iterator=self.input_dataset.minibatches(minibatch_size=minibatch_size,
1132 n_batches=n_batches,offset=offset).__iter__() 1137 n_batches=n_batches,offset=offset).__iter__()
1133 1138
1134 def __iter__(self): return self 1139 def __iter__(self): return self
1135 1140
1136 def next(self): 1141 def next(self):
1137 function_inputs = self.input_iterator.next() 1142 function_inputs = self.input_iterator.next()
1138 all_output_names = self.output_dataset.output_names 1143 all_output_names = self.output_dataset.output_names
1139 if self.output_dataset.minibatch_mode: 1144 if self.output_dataset.minibatch_mode:
1140 function_outputs = self.output_dataset.function(function_inputs) 1145 function_outputs = self.output_dataset.function(*function_inputs)
1141 else: 1146 else:
1142 input_examples = zip(*function_inputs) 1147 input_examples = zip(*function_inputs)
1143 output_examples = [self.output_dataset.function(input_example) 1148 output_examples = [self.output_dataset.function(input_example)
1144 for input_example in input_examples] 1149 for input_example in input_examples]
1145 function_outputs = [self.output_dataset.valuesVStack(name,values) 1150 function_outputs = [self.output_dataset.valuesVStack(name,values)
1148 all_outputs = Example(all_output_names,function_outputs) 1153 all_outputs = Example(all_output_names,function_outputs)
1149 if fieldnames==all_output_names: 1154 if fieldnames==all_output_names:
1150 return all_outputs 1155 return all_outputs
1151 return Example(fieldnames,[all_outputs[name] for name in fieldnames]) 1156 return Example(fieldnames,[all_outputs[name] for name in fieldnames])
1152 1157
1153 return ApplyFunctionIterator(self.input_dataset,self) 1158 return ApplyFunctionIterator(self)
1154 1159
1155 def __iter__(self): # only implemented for increased efficiency 1160 def __iter__(self): # only implemented for increased efficiency
1156 class ApplyFunctionSingleExampleIterator(object): 1161 class ApplyFunctionSingleExampleIterator(object):
1157 def __init__(self,output_dataset): 1162 def __init__(self,output_dataset):
1158 self.current=0 1163 self.current=0