view _test_dataset.py @ 22:b6b36f65664f

Created virtual sub-classes of DataSet: {Finite{Length,Width},Sliceable}DataSet, removed .field ability from LookupList (because of setattr problems), removed fieldNames() from DataSet (but is in FiniteWidthDataSet, where it makes sense), and added hasFields() instead. Fixed problems in asarray, and tested previous functionality in _test_dataset.py, but not yet new functionality.
author bengioy@esprit.iro.umontreal.ca
date Mon, 07 Apr 2008 20:44:37 -0400
parents fdf0abc490f7
children 672fe4b23032
line wrap: on
line source

from dataset import *
from math import *
import unittest

def _sum_all(a):
    s=a
    while isinstance(s,numpy.ndarray):
        s=sum(s)
    return s
    
class T_arraydataset(unittest.TestCase):
    def setUp(self):
        numpy.random.seed(123456)


    def test_ctor_len(self):
        n = numpy.random.rand(8,3)
        a=ArrayDataSet(n)
        self.failUnless(a.data is n)
        self.failUnless(a.fields is None)

        self.failUnless(len(a) == n.shape[0])
        self.failUnless(a[0].shape == (n.shape[1],))

    def test_iter(self):
        arr = numpy.random.rand(8,3)
        a=ArrayDataSet(data=arr,fields={"x":slice(2),"y":slice(1,3)})
        for i, example in enumerate(a):
            self.failUnless(numpy.all( example['x'] == arr[i,:2]))
            self.failUnless(numpy.all( example['y'] == arr[i,1:3]))

    def test_zip(self):
        arr = numpy.random.rand(8,3)
        a=ArrayDataSet(data=arr,fields={"x":slice(2),"y":slice(1,3)})
        for i, x in enumerate(a.zip("x")):
            self.failUnless(numpy.all( x == arr[i,:2]))

    def test_minibatch_basic(self):
        arr = numpy.random.rand(10,4)
        a=ArrayDataSet(data=arr,fields={"x":slice(2),"y":slice(1,4)})
        for i, mb in enumerate(a.minibatches(minibatch_size=2)): #all fields
            self.failUnless(numpy.all( mb['x'] == arr[i*2:i*2+2,0:2]))
            self.failUnless(numpy.all( mb['y'] == arr[i*2:i*2+2,1:4]))

    def test_getattr(self):
        arr = numpy.random.rand(10,4)
        a=ArrayDataSet(data=arr,fields={"x":slice(2),"y":slice(1,4)})
        a_y = a.y
        self.failUnless(numpy.all( a_y == arr[:,1:4]))

    def test_asarray(self):
        arr = numpy.random.rand(3,4)
        a=ArrayDataSet(data=arr,fields={"x":slice(2),"y":slice(2,4)})
        a_arr = numpy.asarray(a)
        self.failUnless(a_arr.shape[1] == 2 + 2)
        self.failUnless(numpy.sum(numpy.square(a_arr-a.data))==0)
        a=ArrayDataSet(data=arr,fields={"x":slice(2),"y":slice(1,4)})
        a_arr = numpy.asarray(a)
        self.failUnless(a_arr.shape[1] == 2 + 3)

    def test_minibatch_wraparound_even(self):
        arr = numpy.random.rand(10,4)
        arr2 = ArrayDataSet.Iterator.matcat(arr,arr)

        a=ArrayDataSet(data=arr,fields={"x":slice(2),"y":slice(1,4)})

        #print arr
        for i, x in enumerate(a.minibatches(["x"], minibatch_size=2, n_batches=8)):
            #print 'x' , x
            self.failUnless(numpy.all( x == arr2[i*2:i*2+2,0:2]))

    def test_minibatch_wraparound_odd(self):
        arr = numpy.random.rand(10,4)
        arr2 = ArrayDataSet.Iterator.matcat(arr,arr)

        a=ArrayDataSet(data=arr,fields={"x":slice(2),"y":slice(1,4)})

        for i, x in enumerate(a.minibatches(["x"], minibatch_size=3, n_batches=6)):
            self.failUnless(numpy.all( x == arr2[i*3:i*3+3,0:2]))
    
if __name__ == '__main__':
    unittest.main()