diff dataset.py @ 134:3f4e5c9bdc5e

Fixes to ApplyFunctionDataSet and other things to make learner and mlp work
author Yoshua Bengio <bengioy@iro.umontreal.ca>
date Fri, 09 May 2008 17:38:57 -0400
parents f6505ec32dc3
children 0d8e721cc63c ad144fa72bf5
line wrap: on
line diff
--- a/dataset.py	Fri May 09 13:38:54 2008 -0400
+++ b/dataset.py	Fri May 09 17:38:57 2008 -0400
@@ -16,6 +16,11 @@
         raise AbstractFunction()
 
     def setAttributes(self,attribute_names,attribute_values,make_copies=False):
+        """
+        Allow the attribute_values to not be a list (but a single value) if the attribute_names is of length 1.
+        """
+        if len(attribute_names)==1 and not (isinstance(attribute_values,list) or isinstance(attribute_values,tuple) ):
+            attribute_values = [attribute_values]
         if make_copies:
             for name,value in zip(attribute_names,attribute_values):
                 self.__setattr__(name,copy.deepcopy(value))
@@ -1113,14 +1118,14 @@
       self.function=function
       self.output_names=output_names
       self.minibatch_mode=minibatch_mode
-      DataSet.__init__(description,fieldtypes)
+      DataSet.__init__(self,description,fieldtypes)
       self.valuesHStack = values_hstack if values_hstack else input_dataset.valuesHStack
       self.valuesVStack = values_vstack if values_vstack else input_dataset.valuesVStack
 
   def __len__(self):
       return len(self.input_dataset)
 
-  def fieldnames(self):
+  def fieldNames(self):
       return self.output_names
 
   def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset):
@@ -1128,8 +1133,8 @@
           def __init__(self,output_dataset):
               self.input_dataset=output_dataset.input_dataset
               self.output_dataset=output_dataset
-              self.input_iterator=input_dataset.minibatches(minibatch_size=minibatch_size,
-                                                            n_batches=n_batches,offset=offset).__iter__()
+              self.input_iterator=self.input_dataset.minibatches(minibatch_size=minibatch_size,
+                                                                 n_batches=n_batches,offset=offset).__iter__()
 
           def __iter__(self): return self
 
@@ -1137,7 +1142,7 @@
               function_inputs = self.input_iterator.next()
               all_output_names = self.output_dataset.output_names
               if self.output_dataset.minibatch_mode:
-                  function_outputs = self.output_dataset.function(function_inputs)
+                  function_outputs = self.output_dataset.function(*function_inputs)
               else:
                   input_examples = zip(*function_inputs)
                   output_examples = [self.output_dataset.function(input_example)
@@ -1150,7 +1155,7 @@
                   return all_outputs
               return Example(fieldnames,[all_outputs[name] for name in fieldnames])
 
-      return ApplyFunctionIterator(self.input_dataset,self)
+      return ApplyFunctionIterator(self)
 
   def __iter__(self): # only implemented for increased efficiency
       class ApplyFunctionSingleExampleIterator(object):