view _test_dataset.py @ 26:672fe4b23032

Fixed dataset errors so that _test_dataset.py works again.
author bengioy@grenat.iro.umontreal.ca
date Fri, 11 Apr 2008 11:14:54 -0400
parents b6b36f65664f
children 541a273bc89f
line wrap: on
line source

from dataset import *
from math import *
import unittest

def _sum_all(a):
    s=a
    while isinstance(s,numpy.ndarray):
        s=sum(s)
    return s
    
class T_arraydataset(unittest.TestCase):
    def setUp(self):
        numpy.random.seed(123456)


    def test_ctor_len(self):
        n = numpy.random.rand(8,3)
        a=ArrayDataSet(n)
        self.failUnless(a.data is n)
        self.failUnless(a.fields is None)

        self.failUnless(len(a) == n.shape[0])
        self.failUnless(a[0].shape == (n.shape[1],))

    def test_iter(self):
        arr = numpy.random.rand(8,3)
        a=ArrayDataSet(data=arr,fields={"x":slice(2),"y":slice(1,3)})
        for i, example in enumerate(a):
            self.failUnless(numpy.all( example['x'] == arr[i,:2]))
            self.failUnless(numpy.all( example['y'] == arr[i,1:3]))

    def test_zip(self):
        arr = numpy.random.rand(8,3)
        a=ArrayDataSet(data=arr,fields={"x":slice(2),"y":slice(1,3)})
        for i, x in enumerate(a.zip("x")):
            self.failUnless(numpy.all( x == arr[i,:2]))

    def test_minibatch_basic(self):
        arr = numpy.random.rand(10,4)
        a=ArrayDataSet(data=arr,fields={"x":slice(2),"y":slice(1,4)})
        for i, mb in enumerate(a.minibatches(minibatch_size=2)): #all fields
            self.failUnless(numpy.all( mb['x'] == arr[i*2:i*2+2,0:2]))
            self.failUnless(numpy.all( mb['y'] == arr[i*2:i*2+2,1:4]))

    def test_getattr(self):
        arr = numpy.random.rand(10,4)
        a=ArrayDataSet(data=arr,fields={"x":slice(2),"y":slice(1,4)})
        a_y = a.y
        self.failUnless(numpy.all( a_y == arr[:,1:4]))

    def test_asarray(self):
        arr = numpy.random.rand(3,4)
        a=ArrayDataSet(data=arr,fields={"x":slice(2),"y":slice(2,4)})
        a_arr = numpy.asarray(a)
        self.failUnless(a_arr.shape[1] == 2 + 2)
        self.failUnless(numpy.sum(numpy.square(a_arr-a.data))==0)
        a=ArrayDataSet(data=arr,fields={"x":slice(2),"y":slice(1,4)})
        a_arr = numpy.asarray(a)
        self.failUnless(a_arr.shape[1] == 2 + 3)

    def test_minibatch_wraparound_even(self):
        arr = numpy.random.rand(10,4)
        arr2 = ArrayDataSet.Iterator.matcat(arr,arr)

        a=ArrayDataSet(data=arr,fields={"x":slice(2),"y":slice(1,4)})

        #print arr
        for i, x in enumerate(a.minibatches(["x"], minibatch_size=2, n_batches=8)):
            #print 'x' , x
            self.failUnless(numpy.all( x == arr2[i*2:i*2+2,0:2]))

    def test_minibatch_wraparound_odd(self):
        arr = numpy.random.rand(10,4)
        arr2 = ArrayDataSet.Iterator.matcat(arr,arr)

        a=ArrayDataSet(data=arr,fields={"x":slice(2),"y":slice(1,4)})

        for i, x in enumerate(a.minibatches(["x"], minibatch_size=3, n_batches=6)):
            self.failUnless(numpy.all( x == arr2[i*3:i*3+3,0:2]))
    

class T_renamingdataset(unittest.TestCase):
    def setUp(self):
        numpy.random.seed(123456)


    def test_hasfield(self):
        n = numpy.random.rand(3,8)
        a=ArrayDataSet(data=n,fields={"x":slice(2),"y":slice(1,4),"z":slice(4,6)})
        b=a.rename({'xx':'x','zz':'z'})
        self.failUnless(b.hasFields('xx','zz') and not b.hasFields('x') and not b.hasFields('y'))


if __name__ == '__main__':
    unittest.main()