view _test_dataset.py @ 10:80bf5492e571

Rewrote learner.py according to the specs in the wiki for learners.
author bengioy@esprit.iro.umontreal.ca
date Tue, 25 Mar 2008 11:39:02 -0400
parents d1c394486037
children be128b9127c8
line wrap: on
line source

from dataset import *
from math import *
import unittest

def _sum_all(a):
    s=a
    while isinstance(s,numpy.ndarray):
        s=sum(s)
    return s
    
class T_arraydataset(unittest.TestCase):
    def setUp(self):
        numpy.random.seed(123456)

    def test0(self):
        a=ArrayDataSet(data=numpy.random.rand(8,3),fields={"x":slice(2),"y":slice(1,3)},minibatch_size=1)
        s=0
        for example in a:
            s+=_sum_all(example.x)
        #print s
        self.failUnless(abs(s-7.25967597)<1e-6)

    def test1(self):
        a=ArrayDataSet(data=numpy.random.rand(10,4),fields={"x":slice(2),"y":slice(1,4)},minibatch_size=1)
        a.minibatch_size=2
        s=0
        for mb in a:
            s+=_sum_all(numpy.array(mb))
        s+=a[3:6].x[1,1]
        for mb in ArrayDataSet(data=a.y,minibatch_size=2):
            for e in mb:
                s+=sum(e)
        #print numpy.array(a)
        #print a.y[4:9:2]
        s+= _sum_all(a.y[4:9:2])
        #print s
        self.failUnless(abs(s-39.0334797)<1e-6)
        
if __name__ == '__main__':
    unittest.main()