Mercurial > pylearn
comparison _test_dataset.py @ 17:759d17112b23
more comments, looping ArrayDataSet iterator, bugfixes to lookup_list, more tests
author | bergstrj@iro.umontreal.ca |
---|---|
date | Wed, 26 Mar 2008 21:05:14 -0400 |
parents | be128b9127c8 |
children | 57f4015e2e09 |
comparison
equal
deleted
inserted
replaced
16:813723310d75 | 17:759d17112b23 |
---|---|
10 | 10 |
11 class T_arraydataset(unittest.TestCase): | 11 class T_arraydataset(unittest.TestCase): |
12 def setUp(self): | 12 def setUp(self): |
13 numpy.random.seed(123456) | 13 numpy.random.seed(123456) |
14 | 14 |
15 def test0(self): | |
16 a=ArrayDataSet(data=numpy.random.rand(8,3),fields={"x":slice(2),"y":slice(1,3)}) | |
17 s=0 | |
18 for example in a: | |
19 s+=_sum_all(example.x) | |
20 #print s | |
21 self.failUnless(abs(s-7.25967597)<1e-6) | |
22 | 15 |
23 def test1(self): | 16 def test_ctor_len(self): |
24 a=ArrayDataSet(data=numpy.random.rand(10,4),fields={"x":slice(2),"y":slice(1,4)}) | 17 n = numpy.random.rand(8,3) |
25 s=0 | 18 a=ArrayDataSet(n) |
26 for mb in a.minibatches(2): | 19 self.failUnless(a.data is n) |
27 s+=_sum_all(numpy.array(mb)) | 20 self.failUnless(a.fields is None) |
28 s+=a[3:6].x[1,1] | 21 |
29 for mb in ArrayDataSet(data=a.y).minibatches(2): | 22 self.failUnless(len(a) == n.shape[0]) |
30 for e in mb: | 23 self.failUnless(a[0].shape == (n.shape[1],)) |
31 s+=sum(e) | 24 |
32 #print numpy.array(a) | 25 def test_iter(self): |
33 #print a.y[4:9:2] | 26 arr = numpy.random.rand(8,3) |
34 s+= _sum_all(a.y[4:9:2]) | 27 a=ArrayDataSet(data=arr,fields={"x":slice(2),"y":slice(1,3)}) |
35 #print s | 28 for i, example in enumerate(a): |
36 self.failUnless(abs(s-39.0334797)<1e-6) | 29 self.failUnless(numpy.all( example.x == arr[i,:2])) |
30 self.failUnless(numpy.all( example.y == arr[i,1:3])) | |
31 | |
32 def test_zip(self): | |
33 arr = numpy.random.rand(8,3) | |
34 a=ArrayDataSet(data=arr,fields={"x":slice(2),"y":slice(1,3)}) | |
35 for i, x in enumerate(a.zip("x")): | |
36 self.failUnless(numpy.all( x == arr[i,:2])) | |
37 | |
38 def test_minibatch_basic(self): | |
39 arr = numpy.random.rand(10,4) | |
40 a=ArrayDataSet(data=arr,fields={"x":slice(2),"y":slice(1,4)}) | |
41 for i, mb in enumerate(a.minibatches(minibatch_size=2)): #all fields | |
42 self.failUnless(numpy.all( mb.x == arr[i*2:i*2+2,0:2])) | |
43 self.failUnless(numpy.all( mb.y == arr[i*2:i*2+2,1:4])) | |
44 | |
45 def test_getattr(self): | |
46 arr = numpy.random.rand(10,4) | |
47 a=ArrayDataSet(data=arr,fields={"x":slice(2),"y":slice(1,4)}) | |
48 a_y = a.y | |
49 self.failUnless(numpy.all( a_y == arr[:,1:4])) | |
50 | |
51 def test_asarray(self): | |
52 arr = numpy.random.rand(3,4) | |
53 a=ArrayDataSet(data=arr,fields={"x":slice(2),"y":slice(1,4)}) | |
54 a_arr = numpy.asarray(a) | |
55 self.failUnless(a_arr.shape[1] == 2 + 3) | |
56 | |
57 def test_minibatch_wraparound_even(self): | |
58 arr = numpy.random.rand(10,4) | |
59 arr2 = ArrayDataSet.Iterator.matcat(arr,arr) | |
60 | |
61 a=ArrayDataSet(data=arr,fields={"x":slice(2),"y":slice(1,4)}) | |
62 | |
63 #print arr | |
64 for i, x in enumerate(a.minibatches(["x"], minibatch_size=2, n_batches=8)): | |
65 #print 'x' , x | |
66 self.failUnless(numpy.all( x == arr2[i*2:i*2+2,0:2])) | |
67 | |
68 def test_minibatch_wraparound_odd(self): | |
69 arr = numpy.random.rand(10,4) | |
70 arr2 = ArrayDataSet.Iterator.matcat(arr,arr) | |
71 | |
72 a=ArrayDataSet(data=arr,fields={"x":slice(2),"y":slice(1,4)}) | |
73 | |
74 for i, x in enumerate(a.minibatches(["x"], minibatch_size=3, n_batches=6)): | |
75 self.failUnless(numpy.all( x == arr2[i*3:i*3+3,0:2])) | |
37 | 76 |
38 if __name__ == '__main__': | 77 if __name__ == '__main__': |
39 unittest.main() | 78 unittest.main() |
40 | 79 |