view _test_dataset.py @ 6:d5738b79089a

Removed MinibatchIterator and instead made minibatch_size a field of all DataSets, so that they can all iterate over minibatches, optionally.
author bengioy@bengiomac.local
date Mon, 24 Mar 2008 09:04:06 -0400
parents f7dcfb5f9d5b
children 6f8f338686db
line wrap: on
line source

from dataset import *
from math import *
import unittest

def _sum_all(a):
    s=a
    while isinstance(s,numpy.ndarray):
        s=sum(s)
    return s
    
class T_arraydataset(unittest.TestCase):
    def setUp(self):
        numpy.random.seed(123456)

    def test0(self):
        a=ArrayDataSet(data=numpy.random.rand(8,3),fields={"x":slice(2),"y":slice(1,3)},minibatch_size=1)
        s=0
        for example in a:
            s+=_sum_all(example.x)
        print s
        self.failUnless(abs(s-11.4674133)<1e-6)
        a.minibatch_size=2
        for mb in a:
            print mb
            

if __name__ == '__main__':
    unittest.main()