comparison mlp_factory_approach.py @ 226:3595ba2610f7

merged
author James Bergstra <bergstrj@iro.umontreal.ca>
date Fri, 23 May 2008 17:12:12 -0400
parents 8bc16220b29a df3fae88ab46
children c047238e5b3f
comparison
equal deleted inserted replaced
225:8bc16220b29a 226:3595ba2610f7
15 import numpy 15 import numpy
16 16
17 import theano 17 import theano
18 from theano import tensor as t 18 from theano import tensor as t
19 19
20 import dataset, nnet_ops, stopper 20 from pylearn import dataset, nnet_ops, stopper
21 21
22 22
23 def _randshape(*shape): 23 def _randshape(*shape):
24 return (numpy.random.rand(*shape) -0.5) * 0.001 24 return (numpy.random.rand(*shape) -0.5) * 0.001
25 25
42 42
43 def update(self, input_target): 43 def update(self, input_target):
44 """Update this model from more training data.""" 44 """Update this model from more training data."""
45 params = self.params 45 params = self.params
46 #TODO: why should we have to unpack target like this? 46 #TODO: why should we have to unpack target like this?
47 # tbm : creates problem...
47 for input, target in input_target: 48 for input, target in input_target:
48 rval= self.update_fn(input, target[:,0], *params) 49 rval= self.update_fn(input, target[:,0], *params)
49 #print rval[0] 50 #print rval[0]
50 51
51 def __call__(self, testset, fieldnames=['output_class']): 52 def __call__(self, testset, fieldnames=['output_class'],input='input',target='target'):
52 """Apply this model (as a function) to new data""" 53 """Apply this model (as a function) to new data"""
53 #TODO: cache fn between calls 54 #TODO: cache fn between calls
54 assert 'input' == testset.fieldNames()[0] 55 assert input == testset.fieldNames()[0] # why first one???
55 assert len(testset.fieldNames()) <= 2 56 assert len(testset.fieldNames()) <= 2
56 v = self.algo.v 57 v = self.algo.v
57 outputs = [getattr(v, name) for name in fieldnames] 58 outputs = [getattr(v, name) for name in fieldnames]
58 inputs = [v.input] + ([v.target] if 'target' in testset else []) 59 inputs = [v.input] + ([v.target] if target in testset else [])
59 inputs.extend(v.params) 60 inputs.extend(v.params)
60 theano_fn = _cache(self._fn_cache, (tuple(inputs), tuple(outputs)), 61 theano_fn = _cache(self._fn_cache, (tuple(inputs), tuple(outputs)),
61 lambda: self.algo._fn(inputs, outputs)) 62 lambda: self.algo._fn(inputs, outputs))
62 lambda_fn = lambda *args: theano_fn(*(list(args) + self.params)) 63 lambda_fn = lambda *args: theano_fn(*(list(args) + self.params))
63 return dataset.ApplyFunctionDataSet(testset, lambda_fn, fieldnames) 64 return dataset.ApplyFunctionDataSet(testset, lambda_fn, fieldnames)