comparison _test_dataset.py @ 16:813723310d75

commenting
author bergstrj@iro.umontreal.ca
date Wed, 26 Mar 2008 18:23:44 -0400
parents be128b9127c8
children 759d17112b23
comparison
equal deleted inserted replaced
15:88168361a5ab 16:813723310d75
11 class T_arraydataset(unittest.TestCase): 11 class T_arraydataset(unittest.TestCase):
12 def setUp(self): 12 def setUp(self):
13 numpy.random.seed(123456) 13 numpy.random.seed(123456)
14 14
15 def test0(self): 15 def test0(self):
16 a=ArrayDataSet(data=numpy.random.rand(8,3),fields={"x":slice(2),"y":slice(1,3)},minibatch_size=1) 16 a=ArrayDataSet(data=numpy.random.rand(8,3),fields={"x":slice(2),"y":slice(1,3)})
17 s=0 17 s=0
18 for example in a: 18 for example in a:
19 s+=_sum_all(example.x) 19 s+=_sum_all(example.x)
20 #print s 20 #print s
21 self.failUnless(abs(s-7.25967597)<1e-6) 21 self.failUnless(abs(s-7.25967597)<1e-6)
22 22
23 def test1(self): 23 def test1(self):
24 a=ArrayDataSet(data=numpy.random.rand(10,4),fields={"x":slice(2),"y":slice(1,4)},minibatch_size=1) 24 a=ArrayDataSet(data=numpy.random.rand(10,4),fields={"x":slice(2),"y":slice(1,4)})
25 a.minibatch_size=2
26 s=0 25 s=0
27 for mb in a: 26 for mb in a.minibatches(2):
28 s+=_sum_all(numpy.array(mb)) 27 s+=_sum_all(numpy.array(mb))
29 s+=a[3:6].x[1,1] 28 s+=a[3:6].x[1,1]
30 for mb in ArrayDataSet(data=a.y,minibatch_size=2): 29 for mb in ArrayDataSet(data=a.y).minibatches(2):
31 for e in mb: 30 for e in mb:
32 s+=sum(e) 31 s+=sum(e)
33 #print numpy.array(a) 32 #print numpy.array(a)
34 #print a.y[4:9:2] 33 #print a.y[4:9:2]
35 s+= _sum_all(a.y[4:9:2]) 34 s+= _sum_all(a.y[4:9:2])