diff mlp_factory_approach.py @ 226:3595ba2610f7

merged
author James Bergstra <bergstrj@iro.umontreal.ca>
date Fri, 23 May 2008 17:12:12 -0400
parents 8bc16220b29a df3fae88ab46
children c047238e5b3f
line wrap: on
line diff
--- a/mlp_factory_approach.py	Fri May 23 17:11:39 2008 -0400
+++ b/mlp_factory_approach.py	Fri May 23 17:12:12 2008 -0400
@@ -17,7 +17,7 @@
 import theano
 from theano import tensor as t
 
-import dataset, nnet_ops, stopper
+from pylearn import dataset, nnet_ops, stopper
 
 
 def _randshape(*shape): 
@@ -44,18 +44,19 @@
         """Update this model from more training data."""
         params = self.params
         #TODO: why should we have to unpack target like this?
+        # tbm : creates problem...
         for input, target in input_target:
             rval= self.update_fn(input, target[:,0], *params)
             #print rval[0]
 
-    def __call__(self, testset, fieldnames=['output_class']):
+    def __call__(self, testset, fieldnames=['output_class'],input='input',target='target'):
         """Apply this model (as a function) to new data"""
         #TODO: cache fn between calls
-        assert 'input' == testset.fieldNames()[0]
+        assert input == testset.fieldNames()[0] # why first one???
         assert len(testset.fieldNames()) <= 2
         v = self.algo.v
         outputs = [getattr(v, name) for name in fieldnames]
-        inputs = [v.input] + ([v.target] if 'target' in testset else [])
+        inputs = [v.input] + ([v.target] if target in testset else [])
         inputs.extend(v.params)
         theano_fn = _cache(self._fn_cache, (tuple(inputs), tuple(outputs)),
                 lambda: self.algo._fn(inputs, outputs))