comparison _test_linear_regression.py @ 432:8e4d2ebd816a

added a test for LinearRegression
author Yoshua Bengio <bengioy@iro.umontreal.ca>
date Tue, 29 Jul 2008 11:16:05 -0400
parents
children 317a052f9b14
comparison
equal deleted inserted replaced
431:0f8c81b0776d 432:8e4d2ebd816a
1
2 import unittest
3 from linear_regression import *
4 from make_test_datasets import *
5 import numpy
6
7 class test_linear_regression(unittest.TestCase):
8
9 def test1(self):
10 trainset,testset,theta=make_artificial_datasets_from_function(n_inputs=3,
11 n_targets=2,
12 n_examples=100,
13 f=linear_predictor)
14
15 assert trainset.fields()['input'].shape==(50,3)
16 assert testset.fields()['target'].shape==(50,2)
17 regressor = LinearRegression(L2_regularizer=0.1)
18 predictor = regressor(trainset)
19 test_data = testset.fields()
20 mse = predictor.compute_mse(test_data['input'],test_data['target'])
21 print 'mse = ',mse
22
23 if __name__ == '__main__':
24 unittest.main()
25