Mercurial > pylearn
annotate _test_dataset.py @ 17:759d17112b23
more comments, looping ArrayDataSet iterator, bugfixes to lookup_list, more tests
author | bergstrj@iro.umontreal.ca |
---|---|
date | Wed, 26 Mar 2008 21:05:14 -0400 |
parents | be128b9127c8 |
children | 57f4015e2e09 |
rev | line source |
---|---|
4 | 1 from dataset import * |
2 from math import * | |
3 import unittest | |
4 | |
5 def _sum_all(a): | |
6 s=a | |
7 while isinstance(s,numpy.ndarray): | |
8 s=sum(s) | |
9 return s | |
10 | |
11 class T_arraydataset(unittest.TestCase): | |
12 def setUp(self): | |
13 numpy.random.seed(123456) | |
14 | |
17
759d17112b23
more comments, looping ArrayDataSet iterator, bugfixes to lookup_list, more tests
bergstrj@iro.umontreal.ca
parents:
11
diff
changeset
|
15 |
759d17112b23
more comments, looping ArrayDataSet iterator, bugfixes to lookup_list, more tests
bergstrj@iro.umontreal.ca
parents:
11
diff
changeset
|
16 def test_ctor_len(self): |
759d17112b23
more comments, looping ArrayDataSet iterator, bugfixes to lookup_list, more tests
bergstrj@iro.umontreal.ca
parents:
11
diff
changeset
|
17 n = numpy.random.rand(8,3) |
759d17112b23
more comments, looping ArrayDataSet iterator, bugfixes to lookup_list, more tests
bergstrj@iro.umontreal.ca
parents:
11
diff
changeset
|
18 a=ArrayDataSet(n) |
759d17112b23
more comments, looping ArrayDataSet iterator, bugfixes to lookup_list, more tests
bergstrj@iro.umontreal.ca
parents:
11
diff
changeset
|
19 self.failUnless(a.data is n) |
759d17112b23
more comments, looping ArrayDataSet iterator, bugfixes to lookup_list, more tests
bergstrj@iro.umontreal.ca
parents:
11
diff
changeset
|
20 self.failUnless(a.fields is None) |
759d17112b23
more comments, looping ArrayDataSet iterator, bugfixes to lookup_list, more tests
bergstrj@iro.umontreal.ca
parents:
11
diff
changeset
|
21 |
759d17112b23
more comments, looping ArrayDataSet iterator, bugfixes to lookup_list, more tests
bergstrj@iro.umontreal.ca
parents:
11
diff
changeset
|
22 self.failUnless(len(a) == n.shape[0]) |
759d17112b23
more comments, looping ArrayDataSet iterator, bugfixes to lookup_list, more tests
bergstrj@iro.umontreal.ca
parents:
11
diff
changeset
|
23 self.failUnless(a[0].shape == (n.shape[1],)) |
759d17112b23
more comments, looping ArrayDataSet iterator, bugfixes to lookup_list, more tests
bergstrj@iro.umontreal.ca
parents:
11
diff
changeset
|
24 |
759d17112b23
more comments, looping ArrayDataSet iterator, bugfixes to lookup_list, more tests
bergstrj@iro.umontreal.ca
parents:
11
diff
changeset
|
25 def test_iter(self): |
759d17112b23
more comments, looping ArrayDataSet iterator, bugfixes to lookup_list, more tests
bergstrj@iro.umontreal.ca
parents:
11
diff
changeset
|
26 arr = numpy.random.rand(8,3) |
759d17112b23
more comments, looping ArrayDataSet iterator, bugfixes to lookup_list, more tests
bergstrj@iro.umontreal.ca
parents:
11
diff
changeset
|
27 a=ArrayDataSet(data=arr,fields={"x":slice(2),"y":slice(1,3)}) |
759d17112b23
more comments, looping ArrayDataSet iterator, bugfixes to lookup_list, more tests
bergstrj@iro.umontreal.ca
parents:
11
diff
changeset
|
28 for i, example in enumerate(a): |
759d17112b23
more comments, looping ArrayDataSet iterator, bugfixes to lookup_list, more tests
bergstrj@iro.umontreal.ca
parents:
11
diff
changeset
|
29 self.failUnless(numpy.all( example.x == arr[i,:2])) |
759d17112b23
more comments, looping ArrayDataSet iterator, bugfixes to lookup_list, more tests
bergstrj@iro.umontreal.ca
parents:
11
diff
changeset
|
30 self.failUnless(numpy.all( example.y == arr[i,1:3])) |
759d17112b23
more comments, looping ArrayDataSet iterator, bugfixes to lookup_list, more tests
bergstrj@iro.umontreal.ca
parents:
11
diff
changeset
|
31 |
759d17112b23
more comments, looping ArrayDataSet iterator, bugfixes to lookup_list, more tests
bergstrj@iro.umontreal.ca
parents:
11
diff
changeset
|
32 def test_zip(self): |
759d17112b23
more comments, looping ArrayDataSet iterator, bugfixes to lookup_list, more tests
bergstrj@iro.umontreal.ca
parents:
11
diff
changeset
|
33 arr = numpy.random.rand(8,3) |
759d17112b23
more comments, looping ArrayDataSet iterator, bugfixes to lookup_list, more tests
bergstrj@iro.umontreal.ca
parents:
11
diff
changeset
|
34 a=ArrayDataSet(data=arr,fields={"x":slice(2),"y":slice(1,3)}) |
759d17112b23
more comments, looping ArrayDataSet iterator, bugfixes to lookup_list, more tests
bergstrj@iro.umontreal.ca
parents:
11
diff
changeset
|
35 for i, x in enumerate(a.zip("x")): |
759d17112b23
more comments, looping ArrayDataSet iterator, bugfixes to lookup_list, more tests
bergstrj@iro.umontreal.ca
parents:
11
diff
changeset
|
36 self.failUnless(numpy.all( x == arr[i,:2])) |
759d17112b23
more comments, looping ArrayDataSet iterator, bugfixes to lookup_list, more tests
bergstrj@iro.umontreal.ca
parents:
11
diff
changeset
|
37 |
759d17112b23
more comments, looping ArrayDataSet iterator, bugfixes to lookup_list, more tests
bergstrj@iro.umontreal.ca
parents:
11
diff
changeset
|
38 def test_minibatch_basic(self): |
759d17112b23
more comments, looping ArrayDataSet iterator, bugfixes to lookup_list, more tests
bergstrj@iro.umontreal.ca
parents:
11
diff
changeset
|
39 arr = numpy.random.rand(10,4) |
759d17112b23
more comments, looping ArrayDataSet iterator, bugfixes to lookup_list, more tests
bergstrj@iro.umontreal.ca
parents:
11
diff
changeset
|
40 a=ArrayDataSet(data=arr,fields={"x":slice(2),"y":slice(1,4)}) |
759d17112b23
more comments, looping ArrayDataSet iterator, bugfixes to lookup_list, more tests
bergstrj@iro.umontreal.ca
parents:
11
diff
changeset
|
41 for i, mb in enumerate(a.minibatches(minibatch_size=2)): #all fields |
759d17112b23
more comments, looping ArrayDataSet iterator, bugfixes to lookup_list, more tests
bergstrj@iro.umontreal.ca
parents:
11
diff
changeset
|
42 self.failUnless(numpy.all( mb.x == arr[i*2:i*2+2,0:2])) |
759d17112b23
more comments, looping ArrayDataSet iterator, bugfixes to lookup_list, more tests
bergstrj@iro.umontreal.ca
parents:
11
diff
changeset
|
43 self.failUnless(numpy.all( mb.y == arr[i*2:i*2+2,1:4])) |
7
6f8f338686db
Moved iterating counter into a FiniteDataSetIterator to allow embedded iterations and multiple threads iterating at the same time on a dataset.
bengioy@bengiomac.local
parents:
6
diff
changeset
|
44 |
17
759d17112b23
more comments, looping ArrayDataSet iterator, bugfixes to lookup_list, more tests
bergstrj@iro.umontreal.ca
parents:
11
diff
changeset
|
45 def test_getattr(self): |
759d17112b23
more comments, looping ArrayDataSet iterator, bugfixes to lookup_list, more tests
bergstrj@iro.umontreal.ca
parents:
11
diff
changeset
|
46 arr = numpy.random.rand(10,4) |
759d17112b23
more comments, looping ArrayDataSet iterator, bugfixes to lookup_list, more tests
bergstrj@iro.umontreal.ca
parents:
11
diff
changeset
|
47 a=ArrayDataSet(data=arr,fields={"x":slice(2),"y":slice(1,4)}) |
759d17112b23
more comments, looping ArrayDataSet iterator, bugfixes to lookup_list, more tests
bergstrj@iro.umontreal.ca
parents:
11
diff
changeset
|
48 a_y = a.y |
759d17112b23
more comments, looping ArrayDataSet iterator, bugfixes to lookup_list, more tests
bergstrj@iro.umontreal.ca
parents:
11
diff
changeset
|
49 self.failUnless(numpy.all( a_y == arr[:,1:4])) |
759d17112b23
more comments, looping ArrayDataSet iterator, bugfixes to lookup_list, more tests
bergstrj@iro.umontreal.ca
parents:
11
diff
changeset
|
50 |
759d17112b23
more comments, looping ArrayDataSet iterator, bugfixes to lookup_list, more tests
bergstrj@iro.umontreal.ca
parents:
11
diff
changeset
|
51 def test_asarray(self): |
759d17112b23
more comments, looping ArrayDataSet iterator, bugfixes to lookup_list, more tests
bergstrj@iro.umontreal.ca
parents:
11
diff
changeset
|
52 arr = numpy.random.rand(3,4) |
759d17112b23
more comments, looping ArrayDataSet iterator, bugfixes to lookup_list, more tests
bergstrj@iro.umontreal.ca
parents:
11
diff
changeset
|
53 a=ArrayDataSet(data=arr,fields={"x":slice(2),"y":slice(1,4)}) |
759d17112b23
more comments, looping ArrayDataSet iterator, bugfixes to lookup_list, more tests
bergstrj@iro.umontreal.ca
parents:
11
diff
changeset
|
54 a_arr = numpy.asarray(a) |
759d17112b23
more comments, looping ArrayDataSet iterator, bugfixes to lookup_list, more tests
bergstrj@iro.umontreal.ca
parents:
11
diff
changeset
|
55 self.failUnless(a_arr.shape[1] == 2 + 3) |
759d17112b23
more comments, looping ArrayDataSet iterator, bugfixes to lookup_list, more tests
bergstrj@iro.umontreal.ca
parents:
11
diff
changeset
|
56 |
759d17112b23
more comments, looping ArrayDataSet iterator, bugfixes to lookup_list, more tests
bergstrj@iro.umontreal.ca
parents:
11
diff
changeset
|
57 def test_minibatch_wraparound_even(self): |
759d17112b23
more comments, looping ArrayDataSet iterator, bugfixes to lookup_list, more tests
bergstrj@iro.umontreal.ca
parents:
11
diff
changeset
|
58 arr = numpy.random.rand(10,4) |
759d17112b23
more comments, looping ArrayDataSet iterator, bugfixes to lookup_list, more tests
bergstrj@iro.umontreal.ca
parents:
11
diff
changeset
|
59 arr2 = ArrayDataSet.Iterator.matcat(arr,arr) |
759d17112b23
more comments, looping ArrayDataSet iterator, bugfixes to lookup_list, more tests
bergstrj@iro.umontreal.ca
parents:
11
diff
changeset
|
60 |
759d17112b23
more comments, looping ArrayDataSet iterator, bugfixes to lookup_list, more tests
bergstrj@iro.umontreal.ca
parents:
11
diff
changeset
|
61 a=ArrayDataSet(data=arr,fields={"x":slice(2),"y":slice(1,4)}) |
759d17112b23
more comments, looping ArrayDataSet iterator, bugfixes to lookup_list, more tests
bergstrj@iro.umontreal.ca
parents:
11
diff
changeset
|
62 |
759d17112b23
more comments, looping ArrayDataSet iterator, bugfixes to lookup_list, more tests
bergstrj@iro.umontreal.ca
parents:
11
diff
changeset
|
63 #print arr |
759d17112b23
more comments, looping ArrayDataSet iterator, bugfixes to lookup_list, more tests
bergstrj@iro.umontreal.ca
parents:
11
diff
changeset
|
64 for i, x in enumerate(a.minibatches(["x"], minibatch_size=2, n_batches=8)): |
759d17112b23
more comments, looping ArrayDataSet iterator, bugfixes to lookup_list, more tests
bergstrj@iro.umontreal.ca
parents:
11
diff
changeset
|
65 #print 'x' , x |
759d17112b23
more comments, looping ArrayDataSet iterator, bugfixes to lookup_list, more tests
bergstrj@iro.umontreal.ca
parents:
11
diff
changeset
|
66 self.failUnless(numpy.all( x == arr2[i*2:i*2+2,0:2])) |
759d17112b23
more comments, looping ArrayDataSet iterator, bugfixes to lookup_list, more tests
bergstrj@iro.umontreal.ca
parents:
11
diff
changeset
|
67 |
759d17112b23
more comments, looping ArrayDataSet iterator, bugfixes to lookup_list, more tests
bergstrj@iro.umontreal.ca
parents:
11
diff
changeset
|
68 def test_minibatch_wraparound_odd(self): |
759d17112b23
more comments, looping ArrayDataSet iterator, bugfixes to lookup_list, more tests
bergstrj@iro.umontreal.ca
parents:
11
diff
changeset
|
69 arr = numpy.random.rand(10,4) |
759d17112b23
more comments, looping ArrayDataSet iterator, bugfixes to lookup_list, more tests
bergstrj@iro.umontreal.ca
parents:
11
diff
changeset
|
70 arr2 = ArrayDataSet.Iterator.matcat(arr,arr) |
759d17112b23
more comments, looping ArrayDataSet iterator, bugfixes to lookup_list, more tests
bergstrj@iro.umontreal.ca
parents:
11
diff
changeset
|
71 |
759d17112b23
more comments, looping ArrayDataSet iterator, bugfixes to lookup_list, more tests
bergstrj@iro.umontreal.ca
parents:
11
diff
changeset
|
72 a=ArrayDataSet(data=arr,fields={"x":slice(2),"y":slice(1,4)}) |
759d17112b23
more comments, looping ArrayDataSet iterator, bugfixes to lookup_list, more tests
bergstrj@iro.umontreal.ca
parents:
11
diff
changeset
|
73 |
759d17112b23
more comments, looping ArrayDataSet iterator, bugfixes to lookup_list, more tests
bergstrj@iro.umontreal.ca
parents:
11
diff
changeset
|
74 for i, x in enumerate(a.minibatches(["x"], minibatch_size=3, n_batches=6)): |
759d17112b23
more comments, looping ArrayDataSet iterator, bugfixes to lookup_list, more tests
bergstrj@iro.umontreal.ca
parents:
11
diff
changeset
|
75 self.failUnless(numpy.all( x == arr2[i*3:i*3+3,0:2])) |
7
6f8f338686db
Moved iterating counter into a FiniteDataSetIterator to allow embedded iterations and multiple threads iterating at the same time on a dataset.
bengioy@bengiomac.local
parents:
6
diff
changeset
|
76 |
4 | 77 if __name__ == '__main__': |
78 unittest.main() | |
79 |