Mercurial > pylearn
annotate examples/linear_classifier.py @ 440:18dbc1c11647
Work on softmax operators
author | Pascal Lamblin <lamblinp@iro.umontreal.ca> |
---|---|
date | Thu, 21 Aug 2008 13:55:16 -0400 |
parents | 52b4908d8971 |
children | 4060812caa22 |
rev | line source |
---|---|
428
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
1 #! /usr/bin/env python |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
2 """ |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
3 T. Bertin-Mahieux (2008) University of Montreal |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
4 bertinmt@iro.umontreal.ca |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
5 |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
6 linear_classifier.py |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
7 Simple script that creates a linear_classifier, and |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
8 learns the paramters using backpropagation. |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
9 |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
10 This is to illustrate how to use theano/pylearn. |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
11 Anyone that knows how to make this script simpler/clearer is welcomed to |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
12 make the modifications. |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
13 """ |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
14 |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
15 |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
16 import os |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
17 import sys |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
18 import time |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
19 import copy |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
20 import pickle |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
21 import numpy |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
22 import numpy as N |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
23 import numpy.random as NR |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
24 from pylearn import cost |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
25 import theano |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
26 from theano import tensor as T |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
27 |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
28 |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
29 def cost_function(*args,**kwargs) : |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
30 """ default cost function, quadratic """ |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
31 return cost.quadratic(*args,**kwargs) |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
32 |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
33 |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
34 class modelgraph() : |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
35 """ class that contains the graph of the model """ |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
36 lr = T.scalar() # learning rate |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
37 inputs = T.matrix() # inputs (one example per line) |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
38 true_outputs = T.matrix() # outputs (one example per line) |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
39 W = T.matrix() # weights input * W + b= output |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
40 b = T.vector() # bias |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
41 outputs = T.dot(inputs,W) + b # output, one per line |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
42 costs = cost_function(true_outputs,outputs) # costs |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
43 g_W = T.grad(costs,W) # gradient of W |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
44 g_b = T.grad(costs,b) # gradient of b |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
45 new_W = T.sub_inplace(W, lr * g_W) # update inplace of W |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
46 new_b = T.sub_inplace(b, lr * g_b) # update inplace of b |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
47 |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
48 |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
49 class model() : |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
50 """ |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
51 The model! |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
52 Contains needed matrices, needed functions, and a link to the model graph. |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
53 """ |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
54 |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
55 def __init__(self,input_size,output_size) : |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
56 """ init matrix and bias, creates the graph, create a dict of compiled functions """ |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
57 # graph |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
58 self.graph = modelgraph() |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
59 # weights and bias, saved in self.params |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
60 seed = 666 |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
61 r = NR.RandomState(seed) |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
62 W = r.uniform(size = [input_size, output_size], low = -1/N.sqrt(input_size), high = 1/N.sqrt(input_size)) |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
63 b = numpy.zeros((output_size, )) |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
64 self.params = [W,b] |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
65 # dictionary of compiled functions |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
66 self.func_dict = dict() |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
67 # keep some init_infos (may not be necessary) |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
68 self.init_params = [input_size,output_size] |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
69 |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
70 |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
71 def update(self,lr,true_inputs,true_outputs) : |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
72 """ does an update of the model, one gradient descent """ |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
73 # do we already have the proper theano function? |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
74 if self.func_dict.has_key('update_func') : |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
75 self.func_dict['update_func'](lr,true_inputs,true_outputs,self.params[0],self.params[1]) |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
76 return |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
77 else : |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
78 # create the theano function, tell him what are the inputs and outputs) |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
79 func = theano.function([self.graph.lr,self.graph.inputs,self.graph.true_outputs, |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
80 self.graph.W, self.graph.b], |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
81 [self.graph.new_W,self.graph.new_b]) |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
82 # add function to dictionary, so we don't compile it again |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
83 self.func_dict['update_func'] = func |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
84 # use this function |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
85 func(lr,true_inputs,true_outputs,self.params[0],self.params[1]) |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
86 return |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
87 |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
88 def costs(self,true_inputs,true_outputs) : |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
89 """ get the costs for given examples, don't update """ |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
90 # do we already have the proper theano function? |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
91 if self.func_dict.has_key('costs_func') : |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
92 return self.func_dict['costs_func'](true_inputs,true_outputs,self.params[0],self.params[1]) |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
93 else : |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
94 # create the theano function, tell him what are the inputs and outputs) |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
95 func = theano.function([self.graph.inputs,self.graph.true_outputs,self.graph.W,self.graph.b], |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
96 [self.graph.costs]) |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
97 # add function to dictionary, se we don't compile it again |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
98 self.func_dict['costs_func'] = func |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
99 # use this function |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
100 return func(true_inputs,true_outputs,self.params[0],self.params[1]) |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
101 |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
102 def outputs(self,true_inputs) : |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
103 """ get the output for a set of examples (could be called 'predict') """ |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
104 # do we already have the proper theano function? |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
105 if self.func_dict.has_key('outputs_func') : |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
106 return self.func_dict['outputs_func'](true_inputs,self.params[0],self.params[1]) |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
107 else : |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
108 # create the theano function, tell him what are the inputs and outputs) |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
109 func = theano.function([self.graph.inputs, self.graph.W, self.graph.b], |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
110 [self.graph.outputs]) |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
111 # add function to dictionary, se we don't compile it again |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
112 self.func_dict['outputs_func'] = func |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
113 # use this function |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
114 return func(true_inputs,self.params[0],self.params[1]) |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
115 |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
116 def __getitem__(self,inputs) : |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
117 """ for simplicity, we can use the model this way: predictions = model[inputs] """ |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
118 return self.outputs(inputs) |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
119 |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
120 def __getstate__(self) : |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
121 """ |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
122 To save/copy the model, used by pickle.dump() and by copy.deepcopy(). |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
123 @return a dictionnary with the params (matrix + bias) |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
124 """ |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
125 d = dict() |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
126 d['params'] = self.params |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
127 d['init_params'] = self.init_params |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
128 return d |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
129 |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
130 def __setstate__(self,d) : |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
131 """ |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
132 Get the dictionary created by __getstate__(), use it to recreate the model. |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
133 """ |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
134 self.params = d['params'] |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
135 self.init_params = d['init_params'] |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
136 self.graph = modelgraph() # we did not save the model graph |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
137 |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
138 def __str__(self) : |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
139 """ returns a string representing the model """ |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
140 res = "Linear regressor, input size =",str(self.init_params[0]) |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
141 res += ", output size =", str(self.init_params[1]) |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
142 return res |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
143 |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
144 def __equal__(self,other) : |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
145 """ |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
146 Compares the model based on the params. |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
147 @return True if the params are the same, False otherwise |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
148 """ |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
149 # class |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
150 if not isinstance(other,model) : |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
151 return False |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
152 # input size |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
153 if self.params[0].shape[0] != other.params[0].shape[0] : |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
154 return False |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
155 # output size |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
156 if self.params[0].shape[1] != other.params[0].shape[1] : |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
157 return False |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
158 # actual values |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
159 if not (self.params[0] == other.params[0]).all(): |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
160 return False |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
161 if not (self.params[1] == other.params[1]).all(): |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
162 return False |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
163 # all good |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
164 return True |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
165 |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
166 |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
167 def die_with_usage() : |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
168 """ help menu """ |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
169 print 'simple script to illustrate how to use theano/pylearn' |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
170 print 'to launch:' |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
171 print ' python linear_classifier.py -launch' |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
172 sys.exit(0) |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
173 |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
174 |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
175 |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
176 #************************************************************ |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
177 # main |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
178 |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
179 if __name__ == '__main__' : |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
180 |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
181 if len(sys.argv) < 2 : |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
182 die_with_usage() |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
183 |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
184 # print create data |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
185 inputs = numpy.array([[.1,.2], |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
186 [.2,.8], |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
187 [.9,.3], |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
188 [.6,.5]]) |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
189 outputs = numpy.array([[0], |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
190 [0], |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
191 [1], |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
192 [1]]) |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
193 assert inputs.shape[0] == outputs.shape[0] |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
194 |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
195 # create model |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
196 m = model(2,1) |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
197 |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
198 # predict |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
199 print 'prediction before training:' |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
200 print m[inputs] |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
201 |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
202 # update it for 100 iterations |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
203 for k in range(50) : |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
204 m.update(.1,inputs,outputs) |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
205 |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
206 # predict |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
207 print 'prediction after training:' |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
208 print m[inputs] |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
209 |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
210 # show points |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
211 import pylab as P |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
212 colors = outputs.flatten().tolist() |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
213 x = inputs[:,0] |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
214 y = inputs[:,1] |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
215 P.plot(x[numpy.where(outputs==0)[0]],y[numpy.where(outputs==0)[0]],'r+') |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
216 P.plot(x[numpy.where(outputs==1)[0]],y[numpy.where(outputs==1)[0]],'b+') |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
217 # decision line |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
218 p1 = (.5 - m.params[1] * 1.) / m.params[0][1,0] # abs = 0 |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
219 p2 = (.5 - m.params[1] * 1.) / m.params[0][0,0] # ord = 0 |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
220 P.plot((0,p2[0],2*p2[0]),(p1[0],0,-p1[0]),'g-') |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
221 # show |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
222 P.axis([-1,2,-1,2]) |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
223 P.show() |
52b4908d8971
simple example of theano
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
diff
changeset
|
224 |