changeset 135:0d8e721cc63c

Fixed bugs in dataset to make test_mlp.py work
author Yoshua Bengio <bengioy@iro.umontreal.ca>
date Mon, 12 May 2008 14:30:21 -0400
parents 3f4e5c9bdc5e
children ceae4de18981
files dataset.py learner.py test_mlp.py
diffstat 3 files changed, 19 insertions(+), 12 deletions(-) [+]
line wrap: on
line diff
--- a/dataset.py	Fri May 09 17:38:57 2008 -0400
+++ b/dataset.py	Mon May 12 14:30:21 2008 -0400
@@ -429,8 +429,8 @@
         rows=None
         # or a slice
         if type(i) is slice:
-            if not i.start: i.start=0
-            if not i.step: i.step=1
+            if not i.start: i=slice(0,i.stop,i.step)
+            if not i.step: i=slice(i.start,i.stop,1)
             if i.step is 1:
                 return self.minibatches(minibatch_size=i.stop-i.start,n_batches=1,offset=i.start).next().examples()
             rows = range(i.start,i.stop,i.step)
@@ -497,7 +497,7 @@
         dataset1 | dataset2 returns a dataset whose list of fields is the concatenation of the list of
         fields of the argument datasets. This only works if they all have the same length.
         """
-        return HStackedDataSet(self,other)
+        return HStackedDataSet([self,other])
 
     def __and__(self,other):
         """
@@ -505,7 +505,7 @@
         (and whose length is the sum of the length of the argument datasets). This only
         works if they all have the same fields.
         """
-        return VStackedDataSet(self,other)
+        return VStackedDataSet([self,other])
 
 def hstack(datasets):
     """
@@ -1068,7 +1068,7 @@
           def next(self):
               upper = self.current+minibatch_size
               cache_len = len(self.dataset.cached_examples)
-              if upper>=cache_len: # whole minibatch is not already in cache
+              if upper>cache_len: # whole minibatch is not already in cache
                   # cache everything from current length to upper
                   for example in self.dataset.source_dataset[cache_len:upper]:
                       self.dataset.cached_examples.append(example)
--- a/learner.py	Fri May 09 17:38:57 2008 -0400
+++ b/learner.py	Mon May 12 14:30:21 2008 -0400
@@ -11,8 +11,9 @@
     algorithms.
 
     A L{Learner} can be seen as a learning algorithm, a function that when
-    applied to training data returns a learned function, an object that
-    can be applied to other data and return some output data.
+    applied to training data returns a learned function (which is an object that
+    can be applied to other data and return some output data).
+    
     """
     
     def __init__(self):
@@ -51,7 +52,7 @@
         return self.update(training_set,train_stats_collector)
 
     def use(self,input_dataset,output_fieldnames=None,
-            test_stats_collector=None,copy_inputs=True,
+            test_stats_collector=None,copy_inputs=False,
             put_stats_in_output_dataset=True,
             output_attributes=[]):
         """
@@ -85,11 +86,16 @@
         If a test_stats_collector is provided, then its attributes (test_stats_collector.AttributeNames())
         are also copied into the output dataset attributes.
         """
-        minibatchwise_use_function = self.minibatchwiseUseFunction(input_dataset.fieldNames(),
+        input_fieldnames = input_dataset.fieldNames()
+        if not output_fieldnames:
+            output_fieldnames = self.defaultOutputFields(input_fieldnames)
+
+        minibatchwise_use_function = self.minibatchwiseUseFunction(input_fieldnames,
                                                                    output_fieldnames,
                                                                    test_stats_collector)
         virtual_output_dataset = ApplyFunctionDataSet(input_dataset,
                                                       minibatchwise_use_function,
+                                                      output_fieldnames,
                                                       True,DataSet.numpy_vstack,
                                                       DataSet.numpy_hstack)
         # actually force the computation
@@ -212,8 +218,6 @@
         Implement minibatchwiseUseFunction by exploiting Theano compilation
         and the expression graph defined by a sub-class constructor.
         """
-        if not output_fields:
-            output_fields = self.defaultOutputFields(input_fields)
         if stats_collector:
             stats_collector_inputs = stats_collector.input2UpdateAttributes()
             for attribute in stats_collector_inputs:
--- a/test_mlp.py	Fri May 09 17:38:57 2008 -0400
+++ b/test_mlp.py	Mon May 12 14:30:21 2008 -0400
@@ -11,7 +11,10 @@
                                         {'input':slice(2),'target':2})
     fprop=nnet(training_set)
 
-    print fprop(training_set)
+    output_ds = fprop(training_set)
+
+    for fieldname in output_ds.fieldNames():
+        print fieldname+"=",output_ds[fieldname]
 
 test0()