Mercurial > pylearn
comparison mlp_factory_approach.py @ 218:df3fae88ab46
small debugging
author | Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca> |
---|---|
date | Fri, 23 May 2008 12:22:54 -0400 |
parents | 6fa8fbb0c3f6 |
children | 3595ba2610f7 |
comparison
equal
deleted
inserted
replaced
217:44dd9b6448c5 | 218:df3fae88ab46 |
---|---|
2 import numpy | 2 import numpy |
3 | 3 |
4 import theano | 4 import theano |
5 from theano import tensor as t | 5 from theano import tensor as t |
6 | 6 |
7 import dataset, nnet_ops, stopper | 7 from pylearn import dataset, nnet_ops, stopper |
8 | 8 |
9 | 9 |
10 def _randshape(*shape): | 10 def _randshape(*shape): |
11 return (numpy.random.rand(*shape) -0.5) * 0.001 | 11 return (numpy.random.rand(*shape) -0.5) * 0.001 |
12 | 12 |
29 | 29 |
30 def update(self, input_target): | 30 def update(self, input_target): |
31 """Update this model from more training data.""" | 31 """Update this model from more training data.""" |
32 params = self.params | 32 params = self.params |
33 #TODO: why should we have to unpack target like this? | 33 #TODO: why should we have to unpack target like this? |
34 # tbm : creates problem... | |
34 for input, target in input_target: | 35 for input, target in input_target: |
35 rval= self.update_fn(input, target[:,0], *params) | 36 rval= self.update_fn(input, target[:,0], *params) |
36 #print rval[0] | 37 #print rval[0] |
37 | 38 |
38 def __call__(self, testset, fieldnames=['output_class']): | 39 def __call__(self, testset, fieldnames=['output_class'],input='input',target='target'): |
39 """Apply this model (as a function) to new data""" | 40 """Apply this model (as a function) to new data""" |
40 #TODO: cache fn between calls | 41 #TODO: cache fn between calls |
41 assert 'input' == testset.fieldNames()[0] | 42 assert input == testset.fieldNames()[0] # why first one??? |
42 assert len(testset.fieldNames()) <= 2 | 43 assert len(testset.fieldNames()) <= 2 |
43 v = self.algo.v | 44 v = self.algo.v |
44 outputs = [getattr(v, name) for name in fieldnames] | 45 outputs = [getattr(v, name) for name in fieldnames] |
45 inputs = [v.input] + ([v.target] if 'target' in testset else []) | 46 inputs = [v.input] + ([v.target] if target in testset else []) |
46 inputs.extend(v.params) | 47 inputs.extend(v.params) |
47 theano_fn = _cache(self._fn_cache, (tuple(inputs), tuple(outputs)), | 48 theano_fn = _cache(self._fn_cache, (tuple(inputs), tuple(outputs)), |
48 lambda: self.algo._fn(inputs, outputs)) | 49 lambda: self.algo._fn(inputs, outputs)) |
49 lambda_fn = lambda *args: theano_fn(*(list(args) + self.params)) | 50 lambda_fn = lambda *args: theano_fn(*(list(args) + self.params)) |
50 return dataset.ApplyFunctionDataSet(testset, lambda_fn, fieldnames) | 51 return dataset.ApplyFunctionDataSet(testset, lambda_fn, fieldnames) |