comparison _test_dataset.py @ 17:759d17112b23

more comments, looping ArrayDataSet iterator, bugfixes to lookup_list, more tests
author bergstrj@iro.umontreal.ca
date Wed, 26 Mar 2008 21:05:14 -0400
parents be128b9127c8
children 57f4015e2e09
comparison
equal deleted inserted replaced
16:813723310d75 17:759d17112b23
10 10
11 class T_arraydataset(unittest.TestCase): 11 class T_arraydataset(unittest.TestCase):
12 def setUp(self): 12 def setUp(self):
13 numpy.random.seed(123456) 13 numpy.random.seed(123456)
14 14
15 def test0(self):
16 a=ArrayDataSet(data=numpy.random.rand(8,3),fields={"x":slice(2),"y":slice(1,3)})
17 s=0
18 for example in a:
19 s+=_sum_all(example.x)
20 #print s
21 self.failUnless(abs(s-7.25967597)<1e-6)
22 15
23 def test1(self): 16 def test_ctor_len(self):
24 a=ArrayDataSet(data=numpy.random.rand(10,4),fields={"x":slice(2),"y":slice(1,4)}) 17 n = numpy.random.rand(8,3)
25 s=0 18 a=ArrayDataSet(n)
26 for mb in a.minibatches(2): 19 self.failUnless(a.data is n)
27 s+=_sum_all(numpy.array(mb)) 20 self.failUnless(a.fields is None)
28 s+=a[3:6].x[1,1] 21
29 for mb in ArrayDataSet(data=a.y).minibatches(2): 22 self.failUnless(len(a) == n.shape[0])
30 for e in mb: 23 self.failUnless(a[0].shape == (n.shape[1],))
31 s+=sum(e) 24
32 #print numpy.array(a) 25 def test_iter(self):
33 #print a.y[4:9:2] 26 arr = numpy.random.rand(8,3)
34 s+= _sum_all(a.y[4:9:2]) 27 a=ArrayDataSet(data=arr,fields={"x":slice(2),"y":slice(1,3)})
35 #print s 28 for i, example in enumerate(a):
36 self.failUnless(abs(s-39.0334797)<1e-6) 29 self.failUnless(numpy.all( example.x == arr[i,:2]))
30 self.failUnless(numpy.all( example.y == arr[i,1:3]))
31
32 def test_zip(self):
33 arr = numpy.random.rand(8,3)
34 a=ArrayDataSet(data=arr,fields={"x":slice(2),"y":slice(1,3)})
35 for i, x in enumerate(a.zip("x")):
36 self.failUnless(numpy.all( x == arr[i,:2]))
37
38 def test_minibatch_basic(self):
39 arr = numpy.random.rand(10,4)
40 a=ArrayDataSet(data=arr,fields={"x":slice(2),"y":slice(1,4)})
41 for i, mb in enumerate(a.minibatches(minibatch_size=2)): #all fields
42 self.failUnless(numpy.all( mb.x == arr[i*2:i*2+2,0:2]))
43 self.failUnless(numpy.all( mb.y == arr[i*2:i*2+2,1:4]))
44
45 def test_getattr(self):
46 arr = numpy.random.rand(10,4)
47 a=ArrayDataSet(data=arr,fields={"x":slice(2),"y":slice(1,4)})
48 a_y = a.y
49 self.failUnless(numpy.all( a_y == arr[:,1:4]))
50
51 def test_asarray(self):
52 arr = numpy.random.rand(3,4)
53 a=ArrayDataSet(data=arr,fields={"x":slice(2),"y":slice(1,4)})
54 a_arr = numpy.asarray(a)
55 self.failUnless(a_arr.shape[1] == 2 + 3)
56
57 def test_minibatch_wraparound_even(self):
58 arr = numpy.random.rand(10,4)
59 arr2 = ArrayDataSet.Iterator.matcat(arr,arr)
60
61 a=ArrayDataSet(data=arr,fields={"x":slice(2),"y":slice(1,4)})
62
63 #print arr
64 for i, x in enumerate(a.minibatches(["x"], minibatch_size=2, n_batches=8)):
65 #print 'x' , x
66 self.failUnless(numpy.all( x == arr2[i*2:i*2+2,0:2]))
67
68 def test_minibatch_wraparound_odd(self):
69 arr = numpy.random.rand(10,4)
70 arr2 = ArrayDataSet.Iterator.matcat(arr,arr)
71
72 a=ArrayDataSet(data=arr,fields={"x":slice(2),"y":slice(1,4)})
73
74 for i, x in enumerate(a.minibatches(["x"], minibatch_size=3, n_batches=6)):
75 self.failUnless(numpy.all( x == arr2[i*3:i*3+3,0:2]))
37 76
38 if __name__ == '__main__': 77 if __name__ == '__main__':
39 unittest.main() 78 unittest.main()
40 79