Mercurial > pylearn
diff dataset.py @ 134:3f4e5c9bdc5e
Fixes to ApplyFunctionDataSet and other things to make learner and mlp work
author | Yoshua Bengio <bengioy@iro.umontreal.ca> |
---|---|
date | Fri, 09 May 2008 17:38:57 -0400 |
parents | f6505ec32dc3 |
children | 0d8e721cc63c ad144fa72bf5 |
line wrap: on
line diff
--- a/dataset.py Fri May 09 13:38:54 2008 -0400 +++ b/dataset.py Fri May 09 17:38:57 2008 -0400 @@ -16,6 +16,11 @@ raise AbstractFunction() def setAttributes(self,attribute_names,attribute_values,make_copies=False): + """ + Allow the attribute_values to not be a list (but a single value) if the attribute_names is of length 1. + """ + if len(attribute_names)==1 and not (isinstance(attribute_values,list) or isinstance(attribute_values,tuple) ): + attribute_values = [attribute_values] if make_copies: for name,value in zip(attribute_names,attribute_values): self.__setattr__(name,copy.deepcopy(value)) @@ -1113,14 +1118,14 @@ self.function=function self.output_names=output_names self.minibatch_mode=minibatch_mode - DataSet.__init__(description,fieldtypes) + DataSet.__init__(self,description,fieldtypes) self.valuesHStack = values_hstack if values_hstack else input_dataset.valuesHStack self.valuesVStack = values_vstack if values_vstack else input_dataset.valuesVStack def __len__(self): return len(self.input_dataset) - def fieldnames(self): + def fieldNames(self): return self.output_names def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): @@ -1128,8 +1133,8 @@ def __init__(self,output_dataset): self.input_dataset=output_dataset.input_dataset self.output_dataset=output_dataset - self.input_iterator=input_dataset.minibatches(minibatch_size=minibatch_size, - n_batches=n_batches,offset=offset).__iter__() + self.input_iterator=self.input_dataset.minibatches(minibatch_size=minibatch_size, + n_batches=n_batches,offset=offset).__iter__() def __iter__(self): return self @@ -1137,7 +1142,7 @@ function_inputs = self.input_iterator.next() all_output_names = self.output_dataset.output_names if self.output_dataset.minibatch_mode: - function_outputs = self.output_dataset.function(function_inputs) + function_outputs = self.output_dataset.function(*function_inputs) else: input_examples = zip(*function_inputs) output_examples = [self.output_dataset.function(input_example) @@ -1150,7 +1155,7 @@ return all_outputs return Example(fieldnames,[all_outputs[name] for name in fieldnames]) - return ApplyFunctionIterator(self.input_dataset,self) + return ApplyFunctionIterator(self) def __iter__(self): # only implemented for increased efficiency class ApplyFunctionSingleExampleIterator(object):