Mercurial > pylearn
view _test_dataset.py @ 22:b6b36f65664f
Created virtual sub-classes of DataSet: {Finite{Length,Width},Sliceable}DataSet,
removed .field ability from LookupList (because of setattr problems), removed
fieldNames() from DataSet (but is in FiniteWidthDataSet, where it makes sense),
and added hasFields() instead. Fixed problems in asarray, and tested
previous functionality in _test_dataset.py, but not yet new functionality.
author | bengioy@esprit.iro.umontreal.ca |
---|---|
date | Mon, 07 Apr 2008 20:44:37 -0400 |
parents | fdf0abc490f7 |
children | 672fe4b23032 |
line wrap: on
line source
from dataset import * from math import * import unittest def _sum_all(a): s=a while isinstance(s,numpy.ndarray): s=sum(s) return s class T_arraydataset(unittest.TestCase): def setUp(self): numpy.random.seed(123456) def test_ctor_len(self): n = numpy.random.rand(8,3) a=ArrayDataSet(n) self.failUnless(a.data is n) self.failUnless(a.fields is None) self.failUnless(len(a) == n.shape[0]) self.failUnless(a[0].shape == (n.shape[1],)) def test_iter(self): arr = numpy.random.rand(8,3) a=ArrayDataSet(data=arr,fields={"x":slice(2),"y":slice(1,3)}) for i, example in enumerate(a): self.failUnless(numpy.all( example['x'] == arr[i,:2])) self.failUnless(numpy.all( example['y'] == arr[i,1:3])) def test_zip(self): arr = numpy.random.rand(8,3) a=ArrayDataSet(data=arr,fields={"x":slice(2),"y":slice(1,3)}) for i, x in enumerate(a.zip("x")): self.failUnless(numpy.all( x == arr[i,:2])) def test_minibatch_basic(self): arr = numpy.random.rand(10,4) a=ArrayDataSet(data=arr,fields={"x":slice(2),"y":slice(1,4)}) for i, mb in enumerate(a.minibatches(minibatch_size=2)): #all fields self.failUnless(numpy.all( mb['x'] == arr[i*2:i*2+2,0:2])) self.failUnless(numpy.all( mb['y'] == arr[i*2:i*2+2,1:4])) def test_getattr(self): arr = numpy.random.rand(10,4) a=ArrayDataSet(data=arr,fields={"x":slice(2),"y":slice(1,4)}) a_y = a.y self.failUnless(numpy.all( a_y == arr[:,1:4])) def test_asarray(self): arr = numpy.random.rand(3,4) a=ArrayDataSet(data=arr,fields={"x":slice(2),"y":slice(2,4)}) a_arr = numpy.asarray(a) self.failUnless(a_arr.shape[1] == 2 + 2) self.failUnless(numpy.sum(numpy.square(a_arr-a.data))==0) a=ArrayDataSet(data=arr,fields={"x":slice(2),"y":slice(1,4)}) a_arr = numpy.asarray(a) self.failUnless(a_arr.shape[1] == 2 + 3) def test_minibatch_wraparound_even(self): arr = numpy.random.rand(10,4) arr2 = ArrayDataSet.Iterator.matcat(arr,arr) a=ArrayDataSet(data=arr,fields={"x":slice(2),"y":slice(1,4)}) #print arr for i, x in enumerate(a.minibatches(["x"], minibatch_size=2, n_batches=8)): #print 'x' , x self.failUnless(numpy.all( x == arr2[i*2:i*2+2,0:2])) def test_minibatch_wraparound_odd(self): arr = numpy.random.rand(10,4) arr2 = ArrayDataSet.Iterator.matcat(arr,arr) a=ArrayDataSet(data=arr,fields={"x":slice(2),"y":slice(1,4)}) for i, x in enumerate(a.minibatches(["x"], minibatch_size=3, n_batches=6)): self.failUnless(numpy.all( x == arr2[i*3:i*3+3,0:2])) if __name__ == '__main__': unittest.main()