view _test_dataset.py @ 16:813723310d75

commenting
author bergstrj@iro.umontreal.ca
date Wed, 26 Mar 2008 18:23:44 -0400
parents be128b9127c8
children 759d17112b23
line wrap: on
line source

from dataset import *
from math import *
import unittest

def _sum_all(a):
    s=a
    while isinstance(s,numpy.ndarray):
        s=sum(s)
    return s
    
class T_arraydataset(unittest.TestCase):
    def setUp(self):
        numpy.random.seed(123456)

    def test0(self):
        a=ArrayDataSet(data=numpy.random.rand(8,3),fields={"x":slice(2),"y":slice(1,3)})
        s=0
        for example in a:
            s+=_sum_all(example.x)
        #print s
        self.failUnless(abs(s-7.25967597)<1e-6)

    def test1(self):
        a=ArrayDataSet(data=numpy.random.rand(10,4),fields={"x":slice(2),"y":slice(1,4)})
        s=0
        for mb in a.minibatches(2):
            s+=_sum_all(numpy.array(mb))
        s+=a[3:6].x[1,1]
        for mb in ArrayDataSet(data=a.y).minibatches(2):
            for e in mb:
                s+=sum(e)
        #print numpy.array(a)
        #print a.y[4:9:2]
        s+= _sum_all(a.y[4:9:2])
        #print s
        self.failUnless(abs(s-39.0334797)<1e-6)
        
if __name__ == '__main__':
    unittest.main()