comparison _test_dataset.py @ 8:d1c394486037

Replaced asarray() method by __array__ method which gets called automatically when trying to cast into an array or numpy array or upon numpy.asarray(dataset).
author bengioy@bengiomac.local
date Mon, 24 Mar 2008 15:56:53 -0400
parents 6f8f338686db
children be128b9127c8
comparison
equal deleted inserted replaced
7:6f8f338686db 8:d1c394486037
14 14
15 def test0(self): 15 def test0(self):
16 a=ArrayDataSet(data=numpy.random.rand(8,3),fields={"x":slice(2),"y":slice(1,3)},minibatch_size=1) 16 a=ArrayDataSet(data=numpy.random.rand(8,3),fields={"x":slice(2),"y":slice(1,3)},minibatch_size=1)
17 s=0 17 s=0
18 for example in a: 18 for example in a:
19 print len(example), example.x
20 s+=_sum_all(example.x) 19 s+=_sum_all(example.x)
21 print s 20 #print s
22 self.failUnless(abs(s-7.25967597)<1e-6) 21 self.failUnless(abs(s-7.25967597)<1e-6)
23 22
24 def test1(self): 23 def test1(self):
25 a=ArrayDataSet(data=numpy.random.rand(10,4),fields={"x":slice(2),"y":slice(1,4)},minibatch_size=1) 24 a=ArrayDataSet(data=numpy.random.rand(10,4),fields={"x":slice(2),"y":slice(1,4)},minibatch_size=1)
26 a.minibatch_size=2 25 a.minibatch_size=2
27 print a.asarray() 26 s=0
28 for mb in a: 27 for mb in a:
29 print mb,mb.asarray() 28 s+=_sum_all(numpy.array(mb))
30 print "a.y=",a.y 29 s+=a[3:6].x[1,1]
31 for mb in ArrayDataSet(data=a.y,minibatch_size=2): 30 for mb in ArrayDataSet(data=a.y,minibatch_size=2):
32 print mb
33 for e in mb: 31 for e in mb:
34 print e 32 s+=sum(e)
35 self.failUnless(True) 33 #print numpy.array(a)
34 #print a.y[4:9:2]
35 s+= _sum_all(a.y[4:9:2])
36 #print s
37 self.failUnless(abs(s-39.0334797)<1e-6)
36 38
37 if __name__ == '__main__': 39 if __name__ == '__main__':
38 unittest.main() 40 unittest.main()
39 41