comparison _test_dataset.py @ 21:fdf0abc490f7

Adapted _test_dataset.py to changes in LookupList
author bengioy@bengiomac.local
date Mon, 07 Apr 2008 19:32:52 -0400
parents 57f4015e2e09
children b6b36f65664f
comparison
equal deleted inserted replaced
20:266c68cb6136 21:fdf0abc490f7
24 24
25 def test_iter(self): 25 def test_iter(self):
26 arr = numpy.random.rand(8,3) 26 arr = numpy.random.rand(8,3)
27 a=ArrayDataSet(data=arr,fields={"x":slice(2),"y":slice(1,3)}) 27 a=ArrayDataSet(data=arr,fields={"x":slice(2),"y":slice(1,3)})
28 for i, example in enumerate(a): 28 for i, example in enumerate(a):
29 self.failUnless(numpy.all( example.x == arr[i,:2])) 29 self.failUnless(numpy.all( example['x'] == arr[i,:2]))
30 self.failUnless(numpy.all( example.y == arr[i,1:3])) 30 self.failUnless(numpy.all( example['y'] == arr[i,1:3]))
31 31
32 def test_zip(self): 32 def test_zip(self):
33 arr = numpy.random.rand(8,3) 33 arr = numpy.random.rand(8,3)
34 a=ArrayDataSet(data=arr,fields={"x":slice(2),"y":slice(1,3)}) 34 a=ArrayDataSet(data=arr,fields={"x":slice(2),"y":slice(1,3)})
35 for i, x in enumerate(a.zip("x")): 35 for i, x in enumerate(a.zip("x")):
37 37
38 def test_minibatch_basic(self): 38 def test_minibatch_basic(self):
39 arr = numpy.random.rand(10,4) 39 arr = numpy.random.rand(10,4)
40 a=ArrayDataSet(data=arr,fields={"x":slice(2),"y":slice(1,4)}) 40 a=ArrayDataSet(data=arr,fields={"x":slice(2),"y":slice(1,4)})
41 for i, mb in enumerate(a.minibatches(minibatch_size=2)): #all fields 41 for i, mb in enumerate(a.minibatches(minibatch_size=2)): #all fields
42 self.failUnless(numpy.all( mb.x == arr[i*2:i*2+2,0:2])) 42 self.failUnless(numpy.all( mb['x'] == arr[i*2:i*2+2,0:2]))
43 self.failUnless(numpy.all( mb.y == arr[i*2:i*2+2,1:4])) 43 self.failUnless(numpy.all( mb['y'] == arr[i*2:i*2+2,1:4]))
44 44
45 def test_getattr(self): 45 def test_getattr(self):
46 arr = numpy.random.rand(10,4) 46 arr = numpy.random.rand(10,4)
47 a=ArrayDataSet(data=arr,fields={"x":slice(2),"y":slice(1,4)}) 47 a=ArrayDataSet(data=arr,fields={"x":slice(2),"y":slice(1,4)})
48 a_y = a.y 48 a_y = a.y
51 def test_asarray(self): 51 def test_asarray(self):
52 arr = numpy.random.rand(3,4) 52 arr = numpy.random.rand(3,4)
53 a=ArrayDataSet(data=arr,fields={"x":slice(2),"y":slice(1,4)}) 53 a=ArrayDataSet(data=arr,fields={"x":slice(2),"y":slice(1,4)})
54 a_arr = numpy.asarray(a) 54 a_arr = numpy.asarray(a)
55 self.failUnless(a_arr.shape[1] == 2 + 3) 55 self.failUnless(a_arr.shape[1] == 2 + 3)
56 self.failUnless(a_arr == arr)
56 57
57 def test_minibatch_wraparound_even(self): 58 def test_minibatch_wraparound_even(self):
58 arr = numpy.random.rand(10,4) 59 arr = numpy.random.rand(10,4)
59 arr2 = ArrayDataSet.Iterator.matcat(arr,arr) 60 arr2 = ArrayDataSet.Iterator.matcat(arr,arr)
60 61