annotate algorithms/logistic_regression.py @ 499:a419edf4e06c

removed unpicklable nested classes in logistic regression
author James Bergstra <bergstrj@iro.umontreal.ca>
date Tue, 28 Oct 2008 12:57:49 -0400
parents a272f4cbf004
children 4fb6f7320518
rev   line source
470
bd937e845bbb new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
1 import theano
bd937e845bbb new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
2 from theano import tensor as T
495
7560817a07e8 nnet_ops => nnet
Joseph Turian <turian@gmail.com>
parents: 491
diff changeset
3 from theano.tensor import nnet
470
bd937e845bbb new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
4 from theano.compile import module
bd937e845bbb new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
5 from theano import printing, pprint
bd937e845bbb new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
6 from theano import compile
bd937e845bbb new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
7
bd937e845bbb new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
8 import numpy as N
bd937e845bbb new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
9
499
a419edf4e06c removed unpicklable nested classes in logistic regression
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 497
diff changeset
10 class LogRegInstanceType(module.FancyModuleInstance):
a419edf4e06c removed unpicklable nested classes in logistic regression
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 497
diff changeset
11 def initialize(self, n_in, n_out=1, rng=N.random):
a419edf4e06c removed unpicklable nested classes in logistic regression
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 497
diff changeset
12 #self.component is the LogisticRegressionTemplate instance that built this guy.
a419edf4e06c removed unpicklable nested classes in logistic regression
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 497
diff changeset
13
a419edf4e06c removed unpicklable nested classes in logistic regression
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 497
diff changeset
14 self.w = N.zeros((n_in, n_out))
a419edf4e06c removed unpicklable nested classes in logistic regression
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 497
diff changeset
15 self.b = N.zeros(n_out)
a419edf4e06c removed unpicklable nested classes in logistic regression
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 497
diff changeset
16 self.lr = 0.01
a419edf4e06c removed unpicklable nested classes in logistic regression
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 497
diff changeset
17 self.__hide__ = ['params']
470
bd937e845bbb new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
18
bd937e845bbb new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
19 class Module_Nclass(module.FancyModule):
499
a419edf4e06c removed unpicklable nested classes in logistic regression
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 497
diff changeset
20 InstanceType = LogRegInstanceType
470
bd937e845bbb new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
21
499
a419edf4e06c removed unpicklable nested classes in logistic regression
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 497
diff changeset
22 def __init__(self, x=None, targ=None, w=None, b=None, lr=None, regularize=False):
470
bd937e845bbb new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
23 super(Module_Nclass, self).__init__() #boilerplate
bd937e845bbb new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
24
499
a419edf4e06c removed unpicklable nested classes in logistic regression
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 497
diff changeset
25 self.x = x if x is not None else T.matrix('input')
470
bd937e845bbb new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
26 self.targ = targ if targ is not None else T.lvector()
bd937e845bbb new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
27
bd937e845bbb new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
28 self.w = w if w is not None else module.Member(T.dmatrix())
bd937e845bbb new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
29 self.b = b if b is not None else module.Member(T.dvector())
bd937e845bbb new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
30 self.lr = lr if lr is not None else module.Member(T.dscalar())
bd937e845bbb new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
31
bd937e845bbb new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
32 self.params = [p for p in [self.w, self.b] if p.owner is None]
bd937e845bbb new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
33
497
a272f4cbf004 'x' => 'input'
Joseph Turian <turian@gmail.com>
parents: 495
diff changeset
34 xent, output = nnet.crossentropy_softmax_1hot(
499
a419edf4e06c removed unpicklable nested classes in logistic regression
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 497
diff changeset
35 T.dot(self.x, self.w) + self.b, self.targ)
470
bd937e845bbb new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
36 sum_xent = T.sum(xent)
bd937e845bbb new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
37
497
a272f4cbf004 'x' => 'input'
Joseph Turian <turian@gmail.com>
parents: 495
diff changeset
38 self.output = output
470
bd937e845bbb new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
39 self.sum_xent = sum_xent
499
a419edf4e06c removed unpicklable nested classes in logistic regression
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 497
diff changeset
40
a419edf4e06c removed unpicklable nested classes in logistic regression
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 497
diff changeset
41 #compatibility with current implementation of stacker/daa or something
a419edf4e06c removed unpicklable nested classes in logistic regression
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 497
diff changeset
42 #TODO: remove this, make a wrapper
491
180d125dc7e2 made logistic_regression classes compatible with stacker
Olivier Breuleux <breuleuo@iro.umontreal.ca>
parents: 476
diff changeset
43 self.cost = sum_xent
499
a419edf4e06c removed unpicklable nested classes in logistic regression
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 497
diff changeset
44 self.input = self.x
470
bd937e845bbb new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
45
bd937e845bbb new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
46 #define the apply method
497
a272f4cbf004 'x' => 'input'
Joseph Turian <turian@gmail.com>
parents: 495
diff changeset
47 self.pred = T.argmax(T.dot(self.input, self.w) + self.b, axis=1)
a272f4cbf004 'x' => 'input'
Joseph Turian <turian@gmail.com>
parents: 495
diff changeset
48 self.apply = module.Method([self.input], self.pred)
470
bd937e845bbb new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
49
bd937e845bbb new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
50 if self.params:
bd937e845bbb new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
51 gparams = T.grad(sum_xent, self.params)
bd937e845bbb new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
52
497
a272f4cbf004 'x' => 'input'
Joseph Turian <turian@gmail.com>
parents: 495
diff changeset
53 self.update = module.Method([self.input, self.targ], sum_xent,
470
bd937e845bbb new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
54 updates = dict((p, p - self.lr * g) for p, g in zip(self.params, gparams)))
bd937e845bbb new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
55
bd937e845bbb new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
56 class Module(module.FancyModule):
499
a419edf4e06c removed unpicklable nested classes in logistic regression
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 497
diff changeset
57 InstanceType = LogRegInstanceType
470
bd937e845bbb new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
58
497
a272f4cbf004 'x' => 'input'
Joseph Turian <turian@gmail.com>
parents: 495
diff changeset
59 def __init__(self, input=None, targ=None, w=None, b=None, lr=None, regularize=False):
470
bd937e845bbb new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
60 super(Module, self).__init__() #boilerplate
bd937e845bbb new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
61
497
a272f4cbf004 'x' => 'input'
Joseph Turian <turian@gmail.com>
parents: 495
diff changeset
62 self.input = input if input is not None else T.matrix('input')
470
bd937e845bbb new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
63 self.targ = targ if targ is not None else T.lcol()
bd937e845bbb new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
64
bd937e845bbb new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
65 self.w = w if w is not None else module.Member(T.dmatrix())
bd937e845bbb new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
66 self.b = b if b is not None else module.Member(T.dvector())
bd937e845bbb new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
67 self.lr = lr if lr is not None else module.Member(T.dscalar())
bd937e845bbb new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
68
bd937e845bbb new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
69 self.params = [p for p in [self.w, self.b] if p.owner is None]
bd937e845bbb new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
70
497
a272f4cbf004 'x' => 'input'
Joseph Turian <turian@gmail.com>
parents: 495
diff changeset
71 output = nnet.sigmoid(T.dot(self.x, self.w))
a272f4cbf004 'x' => 'input'
Joseph Turian <turian@gmail.com>
parents: 495
diff changeset
72 xent = -self.targ * T.log(output) - (1.0 - self.targ) * T.log(1.0 - output)
470
bd937e845bbb new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
73 sum_xent = T.sum(xent)
bd937e845bbb new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
74
497
a272f4cbf004 'x' => 'input'
Joseph Turian <turian@gmail.com>
parents: 495
diff changeset
75 self.output = output
470
bd937e845bbb new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
76 self.xent = xent
bd937e845bbb new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
77 self.sum_xent = sum_xent
491
180d125dc7e2 made logistic_regression classes compatible with stacker
Olivier Breuleux <breuleuo@iro.umontreal.ca>
parents: 476
diff changeset
78 self.cost = sum_xent
470
bd937e845bbb new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
79
bd937e845bbb new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
80 #define the apply method
497
a272f4cbf004 'x' => 'input'
Joseph Turian <turian@gmail.com>
parents: 495
diff changeset
81 self.pred = (T.dot(self.input, self.w) + self.b) > 0.0
a272f4cbf004 'x' => 'input'
Joseph Turian <turian@gmail.com>
parents: 495
diff changeset
82 self.apply = module.Method([self.input], self.pred)
470
bd937e845bbb new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
83
bd937e845bbb new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
84 #if this module has any internal parameters, define an update function for them
bd937e845bbb new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
85 if self.params:
bd937e845bbb new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
86 gparams = T.grad(sum_xent, self.params)
497
a272f4cbf004 'x' => 'input'
Joseph Turian <turian@gmail.com>
parents: 495
diff changeset
87 self.update = module.Method([self.input, self.targ], sum_xent,
470
bd937e845bbb new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
88 updates = dict((p, p - self.lr * g) for p, g in zip(self.params, gparams)))
bd937e845bbb new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
89
bd937e845bbb new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
90 class Learner(object):
bd937e845bbb new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
91 """TODO: Encapsulate the algorithm for finding an optimal regularization coefficient"""
bd937e845bbb new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
92 pass
bd937e845bbb new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
93