diff mlp_factory_approach.py @ 218:df3fae88ab46

small debugging
author Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
date Fri, 23 May 2008 12:22:54 -0400
parents 6fa8fbb0c3f6
children 3595ba2610f7
line wrap: on
line diff
--- a/mlp_factory_approach.py	Thu May 22 19:08:46 2008 -0400
+++ b/mlp_factory_approach.py	Fri May 23 12:22:54 2008 -0400
@@ -4,7 +4,7 @@
 import theano
 from theano import tensor as t
 
-import dataset, nnet_ops, stopper
+from pylearn import dataset, nnet_ops, stopper
 
 
 def _randshape(*shape): 
@@ -31,18 +31,19 @@
         """Update this model from more training data."""
         params = self.params
         #TODO: why should we have to unpack target like this?
+        # tbm : creates problem...
         for input, target in input_target:
             rval= self.update_fn(input, target[:,0], *params)
             #print rval[0]
 
-    def __call__(self, testset, fieldnames=['output_class']):
+    def __call__(self, testset, fieldnames=['output_class'],input='input',target='target'):
         """Apply this model (as a function) to new data"""
         #TODO: cache fn between calls
-        assert 'input' == testset.fieldNames()[0]
+        assert input == testset.fieldNames()[0] # why first one???
         assert len(testset.fieldNames()) <= 2
         v = self.algo.v
         outputs = [getattr(v, name) for name in fieldnames]
-        inputs = [v.input] + ([v.target] if 'target' in testset else [])
+        inputs = [v.input] + ([v.target] if target in testset else [])
         inputs.extend(v.params)
         theano_fn = _cache(self._fn_cache, (tuple(inputs), tuple(outputs)),
                 lambda: self.algo._fn(inputs, outputs))