comparison mlp_factory_approach.py @ 218:df3fae88ab46

small debugging
author Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
date Fri, 23 May 2008 12:22:54 -0400
parents 6fa8fbb0c3f6
children 3595ba2610f7
comparison
equal deleted inserted replaced
217:44dd9b6448c5 218:df3fae88ab46
2 import numpy 2 import numpy
3 3
4 import theano 4 import theano
5 from theano import tensor as t 5 from theano import tensor as t
6 6
7 import dataset, nnet_ops, stopper 7 from pylearn import dataset, nnet_ops, stopper
8 8
9 9
10 def _randshape(*shape): 10 def _randshape(*shape):
11 return (numpy.random.rand(*shape) -0.5) * 0.001 11 return (numpy.random.rand(*shape) -0.5) * 0.001
12 12
29 29
30 def update(self, input_target): 30 def update(self, input_target):
31 """Update this model from more training data.""" 31 """Update this model from more training data."""
32 params = self.params 32 params = self.params
33 #TODO: why should we have to unpack target like this? 33 #TODO: why should we have to unpack target like this?
34 # tbm : creates problem...
34 for input, target in input_target: 35 for input, target in input_target:
35 rval= self.update_fn(input, target[:,0], *params) 36 rval= self.update_fn(input, target[:,0], *params)
36 #print rval[0] 37 #print rval[0]
37 38
38 def __call__(self, testset, fieldnames=['output_class']): 39 def __call__(self, testset, fieldnames=['output_class'],input='input',target='target'):
39 """Apply this model (as a function) to new data""" 40 """Apply this model (as a function) to new data"""
40 #TODO: cache fn between calls 41 #TODO: cache fn between calls
41 assert 'input' == testset.fieldNames()[0] 42 assert input == testset.fieldNames()[0] # why first one???
42 assert len(testset.fieldNames()) <= 2 43 assert len(testset.fieldNames()) <= 2
43 v = self.algo.v 44 v = self.algo.v
44 outputs = [getattr(v, name) for name in fieldnames] 45 outputs = [getattr(v, name) for name in fieldnames]
45 inputs = [v.input] + ([v.target] if 'target' in testset else []) 46 inputs = [v.input] + ([v.target] if target in testset else [])
46 inputs.extend(v.params) 47 inputs.extend(v.params)
47 theano_fn = _cache(self._fn_cache, (tuple(inputs), tuple(outputs)), 48 theano_fn = _cache(self._fn_cache, (tuple(inputs), tuple(outputs)),
48 lambda: self.algo._fn(inputs, outputs)) 49 lambda: self.algo._fn(inputs, outputs))
49 lambda_fn = lambda *args: theano_fn(*(list(args) + self.params)) 50 lambda_fn = lambda *args: theano_fn(*(list(args) + self.params))
50 return dataset.ApplyFunctionDataSet(testset, lambda_fn, fieldnames) 51 return dataset.ApplyFunctionDataSet(testset, lambda_fn, fieldnames)