Mercurial > pylearn
comparison mlp_factory_approach.py @ 226:3595ba2610f7
merged
author | James Bergstra <bergstrj@iro.umontreal.ca> |
---|---|
date | Fri, 23 May 2008 17:12:12 -0400 |
parents | 8bc16220b29a df3fae88ab46 |
children | c047238e5b3f |
comparison
equal
deleted
inserted
replaced
225:8bc16220b29a | 226:3595ba2610f7 |
---|---|
15 import numpy | 15 import numpy |
16 | 16 |
17 import theano | 17 import theano |
18 from theano import tensor as t | 18 from theano import tensor as t |
19 | 19 |
20 import dataset, nnet_ops, stopper | 20 from pylearn import dataset, nnet_ops, stopper |
21 | 21 |
22 | 22 |
23 def _randshape(*shape): | 23 def _randshape(*shape): |
24 return (numpy.random.rand(*shape) -0.5) * 0.001 | 24 return (numpy.random.rand(*shape) -0.5) * 0.001 |
25 | 25 |
42 | 42 |
43 def update(self, input_target): | 43 def update(self, input_target): |
44 """Update this model from more training data.""" | 44 """Update this model from more training data.""" |
45 params = self.params | 45 params = self.params |
46 #TODO: why should we have to unpack target like this? | 46 #TODO: why should we have to unpack target like this? |
47 # tbm : creates problem... | |
47 for input, target in input_target: | 48 for input, target in input_target: |
48 rval= self.update_fn(input, target[:,0], *params) | 49 rval= self.update_fn(input, target[:,0], *params) |
49 #print rval[0] | 50 #print rval[0] |
50 | 51 |
51 def __call__(self, testset, fieldnames=['output_class']): | 52 def __call__(self, testset, fieldnames=['output_class'],input='input',target='target'): |
52 """Apply this model (as a function) to new data""" | 53 """Apply this model (as a function) to new data""" |
53 #TODO: cache fn between calls | 54 #TODO: cache fn between calls |
54 assert 'input' == testset.fieldNames()[0] | 55 assert input == testset.fieldNames()[0] # why first one??? |
55 assert len(testset.fieldNames()) <= 2 | 56 assert len(testset.fieldNames()) <= 2 |
56 v = self.algo.v | 57 v = self.algo.v |
57 outputs = [getattr(v, name) for name in fieldnames] | 58 outputs = [getattr(v, name) for name in fieldnames] |
58 inputs = [v.input] + ([v.target] if 'target' in testset else []) | 59 inputs = [v.input] + ([v.target] if target in testset else []) |
59 inputs.extend(v.params) | 60 inputs.extend(v.params) |
60 theano_fn = _cache(self._fn_cache, (tuple(inputs), tuple(outputs)), | 61 theano_fn = _cache(self._fn_cache, (tuple(inputs), tuple(outputs)), |
61 lambda: self.algo._fn(inputs, outputs)) | 62 lambda: self.algo._fn(inputs, outputs)) |
62 lambda_fn = lambda *args: theano_fn(*(list(args) + self.params)) | 63 lambda_fn = lambda *args: theano_fn(*(list(args) + self.params)) |
63 return dataset.ApplyFunctionDataSet(testset, lambda_fn, fieldnames) | 64 return dataset.ApplyFunctionDataSet(testset, lambda_fn, fieldnames) |