view _test_dataset.py @ 212:9b57ea8c767f

previous commit was supposed to concern only one file, dataset.py, try to undo my other changes with this commit (nothing was broken though, just useless debugging prints)
author Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
date Wed, 21 May 2008 17:42:20 -0400
parents 46c5c90019c2
children 5b3afda2f1ad
line wrap: on
line source

from dataset import *
from math import *
import unittest

def _sum_all(a):
    s=a
    while isinstance(s,numpy.ndarray):
        s=sum(s)
    return s
    
class T_arraydataset(unittest.TestCase):
    def setUp(self):
        numpy.random.seed(123456)


    def test_ctor_len(self):
        n = numpy.random.rand(8,3)
        a=ArrayDataSet(n)
        self.failUnless(a.data is n)
        self.failUnless(a.fields is None)

        self.failUnless(len(a) == n.shape[0])
        self.failUnless(a[0].shape == (n.shape[1],))

    def test_iter(self):
        arr = numpy.random.rand(8,3)
        a=ArrayDataSet(data=arr,fields={"x":slice(2),"y":slice(1,3)})
        for i, example in enumerate(a):
            self.failUnless(numpy.all( example['x'] == arr[i,:2]))
            self.failUnless(numpy.all( example['y'] == arr[i,1:3]))

    def test_zip(self):
        arr = numpy.random.rand(8,3)
        a=ArrayDataSet(data=arr,fields={"x":slice(2),"y":slice(1,3)})
        for i, x in enumerate(a.zip("x")):
            self.failUnless(numpy.all( x == arr[i,:2]))

    def test_minibatch_basic(self):
        arr = numpy.random.rand(10,4)
        a=ArrayDataSet(data=arr,fields={"x":slice(2),"y":slice(1,4)})
        for i, mb in enumerate(a.minibatches(minibatch_size=2)): #all fields
            self.failUnless(numpy.all( mb['x'] == arr[i*2:i*2+2,0:2]))
            self.failUnless(numpy.all( mb['y'] == arr[i*2:i*2+2,1:4]))

    def test_getattr(self):
        arr = numpy.random.rand(10,4)
        a=ArrayDataSet(data=arr,fields={"x":slice(2),"y":slice(1,4)})
        a_y = a.y
        self.failUnless(numpy.all( a_y == arr[:,1:4]))

    def test_minibatch_wraparound_even(self):
        arr = numpy.random.rand(10,4)
        arr2 = ArrayDataSet.Iterator.matcat(arr,arr)

        a=ArrayDataSet(data=arr,fields={"x":slice(2),"y":slice(1,4)})

        #print arr
        for i, x in enumerate(a.minibatches(["x"], minibatch_size=2, n_batches=8)):
            #print 'x' , x
            self.failUnless(numpy.all( x == arr2[i*2:i*2+2,0:2]))

    def test_minibatch_wraparound_odd(self):
        arr = numpy.random.rand(10,4)
        arr2 = ArrayDataSet.Iterator.matcat(arr,arr)

        a=ArrayDataSet(data=arr,fields={"x":slice(2),"y":slice(1,4)})

        for i, x in enumerate(a.minibatches(["x"], minibatch_size=3, n_batches=6)):
            self.failUnless(numpy.all( x == arr2[i*3:i*3+3,0:2]))
    

class T_renamingdataset(unittest.TestCase):
    def setUp(self):
        numpy.random.seed(123456)


    def test_hasfield(self):
        n = numpy.random.rand(3,8)
        a=ArrayDataSet(data=n,fields={"x":slice(2),"y":slice(1,4),"z":slice(4,6)})
        b=a.rename({'xx':'x','zz':'z'})
        self.failUnless(b.hasFields('xx','zz') and not b.hasFields('x') and not b.hasFields('y'))

class T_applyfunctiondataset(unittest.TestCase):
    def setUp(self):
        numpy.random.seed(123456)

    def test_function(self):
        n = numpy.random.rand(3,8)
        a=ArrayDataSet(data=n,fields={"x":slice(2),"y":slice(1,4),"z":slice(4,6)})
        b=a.apply_function(lambda x,y: x+y,x+1, ['x','y'], ['x+y','x+1'], False,False,False)
        print b.fieldNames()
        print b('x+y')
        

if __name__ == '__main__':
    unittest.main()