Mercurial > pylearn
view _test_dataset.py @ 15:88168361a5ab
comment re: ArrayDataSet.__array__
author | bergstrj@iro.umontreal.ca |
---|---|
date | Tue, 25 Mar 2008 13:38:51 -0400 |
parents | d1c394486037 |
children | be128b9127c8 |
line wrap: on
line source
from dataset import * from math import * import unittest def _sum_all(a): s=a while isinstance(s,numpy.ndarray): s=sum(s) return s class T_arraydataset(unittest.TestCase): def setUp(self): numpy.random.seed(123456) def test0(self): a=ArrayDataSet(data=numpy.random.rand(8,3),fields={"x":slice(2),"y":slice(1,3)},minibatch_size=1) s=0 for example in a: s+=_sum_all(example.x) #print s self.failUnless(abs(s-7.25967597)<1e-6) def test1(self): a=ArrayDataSet(data=numpy.random.rand(10,4),fields={"x":slice(2),"y":slice(1,4)},minibatch_size=1) a.minibatch_size=2 s=0 for mb in a: s+=_sum_all(numpy.array(mb)) s+=a[3:6].x[1,1] for mb in ArrayDataSet(data=a.y,minibatch_size=2): for e in mb: s+=sum(e) #print numpy.array(a) #print a.y[4:9:2] s+= _sum_all(a.y[4:9:2]) #print s self.failUnless(abs(s-39.0334797)<1e-6) if __name__ == '__main__': unittest.main()