Mercurial > pylearn
comparison test_dataset.py @ 84:aa9e786ee849
added function have_raised that evaluate the string in parameter and return true if the function have raised an exception
code cleanup and renaming
added a few test
author | Frederic Bastien <bastienf@iro.umontreal.ca> |
---|---|
date | Mon, 05 May 2008 11:49:40 -0400 |
parents | 158653a9bc7c |
children | fdf72ea4f2bc |
comparison
equal
deleted
inserted
replaced
83:c0f211213a58 | 84:aa9e786ee849 |
---|---|
1 #!/bin/env python | 1 #!/bin/env python |
2 from dataset import * | 2 from dataset import * |
3 from math import * | 3 from math import * |
4 import numpy | 4 import numpy |
5 | 5 |
6 def have_raised(to_eval): | |
7 have_thrown = False | |
8 try: | |
9 eval(to_eval) | |
10 except : | |
11 have_thrown = True | |
12 return have_thrown | |
13 | |
6 def test1(): | 14 def test1(): |
15 print "test1" | |
7 global a,ds | 16 global a,ds |
8 a = numpy.random.rand(10,4) | 17 a = numpy.random.rand(10,4) |
9 print a | 18 print a |
10 ds = ArrayDataSet(a,{'x':slice(3),'y':3,'z':[0,2]}) | 19 ds = ArrayDataSet(a,{'x':slice(3),'y':3,'z':[0,2]}) |
11 print "len(ds)=",len(ds) | 20 print "len(ds)=",len(ds) |
27 #don't test stream | 36 #don't test stream |
28 #tested only with float value | 37 #tested only with float value |
29 #test with y too | 38 #test with y too |
30 #test missing value | 39 #test missing value |
31 | 40 |
41 print "test_ArrayDataSet" | |
32 a = numpy.random.rand(10,4) | 42 a = numpy.random.rand(10,4) |
33 print a | |
34 ds = ArrayDataSet(a,{'x':slice(3),'y':3,'z':[0,2]}) | 43 ds = ArrayDataSet(a,{'x':slice(3),'y':3,'z':[0,2]}) |
35 assert len(ds)==10 | 44 assert len(ds)==10 |
36 #assert ds==a? should this work? | 45 #assert ds==a? should this work? |
37 for i in range(len(ds)): | 46 for i in range(len(ds)): |
38 assert ds[i]['x'].all()==a[i][:2].all() | 47 assert ds[i]['x'].all()==a[i][:2].all() |
39 assert ds[i]['y']==a[i][3] | 48 assert ds[i]['y']==a[i][3] |
40 assert ds[i]['z'].all()==a[i][0:3:2].all() | 49 assert ds[i]['z'].all()==a[i][0:3:2].all() |
41 print "x=",ds["x"] | |
42 i=0 | 50 i=0 |
43 for x in ds('x','y'): | 51 for x in ds('x','y'): |
44 assert numpy.append(x['x'],x['y']).all()==a[i].all() | 52 assert numpy.append(x['x'],x['y']).all()==a[i].all() |
45 i+=1 | 53 i+=1 |
46 | 54 |
59 # minibatch = minibatch_iterator.__iter__().next() | 67 # minibatch = minibatch_iterator.__iter__().next() |
60 # print "minibatch=",minibatch | 68 # print "minibatch=",minibatch |
61 # for var in minibatch: | 69 # for var in minibatch: |
62 # print "var=",var | 70 # print "var=",var |
63 # print "take a slice and look at field y",ds[1:6:2]["y"] | 71 # print "take a slice and look at field y",ds[1:6:2]["y"] |
64 have_thrown = False | 72 assert have_raised("ds['h']") # h is not defined... |
65 try: | 73 assert have_raised("ds[['h']]") # h is not defined... |
66 ds['h'] # h is not defined... | |
67 except : | |
68 have_thrown = True | |
69 assert have_thrown == True | |
70 | |
71 have_thrown = False | |
72 try: | |
73 ds[['h']] # h is not defined... | |
74 except : | |
75 have_thrown = True | |
76 assert have_thrown == True | |
77 | 74 |
78 assert len(ds.fields())==3 | 75 assert len(ds.fields())==3 |
79 for field in ds.fields(): | 76 for field in ds.fields(): |
80 for field_value in field: # iterate over the values associated to that field for all the ds examples | 77 for field_value in field: # iterate over the values associated to that field for all the ds examples |
81 pass | 78 pass |
91 | 88 |
92 | 89 |
93 #ds[:n] returns a dataset with the n first examples. | 90 #ds[:n] returns a dataset with the n first examples. |
94 assert len(ds[:3])==3 | 91 assert len(ds[:3])==3 |
95 i=0 | 92 i=0 |
93 for x,z,y in ds[:3]('x','z','y'): | |
94 assert ds[i]['x'].all()==a[i][:3].all() | |
95 assert ds[i]['x'].all()==x.all() | |
96 assert ds[i]['y']==a[i][3] | |
97 assert ds[i]['y']==y | |
98 assert ds[i]['z'].all()==a[i][0:3:2].all() | |
99 assert ds[i]['z'].all()==z.all() | |
100 i+=1 | |
101 i=0 | |
96 for x,z in ds[:3]('x','z'): | 102 for x,z in ds[:3]('x','z'): |
97 assert ds[i]['z'].all()==a[i][0:3:2].all() | 103 assert ds[i]['z'].all()==a[i][0:3:2].all() |
98 i+=1 | 104 i+=1 |
99 | 105 |
100 #ds[i1:i2:s]# returns a ds with the examples i1,i1+s,...i2-s. | 106 #ds[i1:i2:s]# returns a ds with the examples i1,i1+s,...i2-s. |
101 ds[1:7:2][1] #fail??? | 107 ds[1:7:2][1] |
102 assert len(ds[1:7:2])==3 # should be number example 1,3 and 5 | 108 assert len(ds[1:7:2])==3 # should be number example 1,3 and 5 |
103 i=0 | 109 i=0 |
104 for x,z in ds[1:7:2]('x','z'): | 110 index=[1,3,5] |
105 assert ds[i]['z'].all()==a[i][0:3:2].all() | 111 for z,y,x in ds[1:7:2]('z','y','x'): |
112 assert ds[index[i]]['x'].all()==a[index[i]][:3].all() | |
113 assert ds[index[i]]['x'].all()==x.all() | |
114 assert ds[index[i]]['y']==a[index[i]][3] | |
115 assert ds[index[i]]['y']==y | |
116 assert ds[index[i]]['z'].all()==a[index[i]][0:3:2].all() | |
117 assert ds[index[i]]['z'].all()==z.all() | |
106 i+=1 | 118 i+=1 |
107 ds2=ds[1:7:2] | 119 |
108 for i in range(len(ds2)): | |
109 print ds2[i] | |
110 #ds[[i1,i2,...in]]# returns a ds with examples i1,i2,...in. | 120 #ds[[i1,i2,...in]]# returns a ds with examples i1,i2,...in. |
111 i=0 | 121 i=0 |
112 for x in ds[[1,2]]: | 122 for x in ds[[1,2]]: |
113 assert numpy.append(x['x'],x['y']).all()==a[i].all() | 123 assert numpy.append(x['x'],x['y']).all()==a[i].all() |
114 i+=1 | 124 i+=1 |
127 # for (x,y) in (ds('x','y'),a): #don't work # haven't found a variant that work. | 137 # for (x,y) in (ds('x','y'),a): #don't work # haven't found a variant that work. |
128 # assert numpy.append(x,y)==z | 138 # assert numpy.append(x,y)==z |
129 | 139 |
130 def test_LookupList(): | 140 def test_LookupList(): |
131 #test only the example in the doc??? | 141 #test only the example in the doc??? |
132 example = LookupList(['x','y','z'],[1,2,3]) | 142 print "test_LookupList" |
133 example['x'] = [1, 2, 3] # set or change a field | 143 example = LookupList(['x','y','z'],[1,2,3]) |
134 x, y, z = example | 144 example['x'] = [1, 2, 3] # set or change a field |
135 x = example[0] | 145 x, y, z = example |
136 x = example["x"] | 146 x = example[0] |
137 assert example.keys()==['x','y','z'] | 147 x = example["x"] |
138 assert example.values()==[[1,2,3],2,3] | 148 assert example.keys()==['x','y','z'] |
139 assert example.items()==[('x',[1,2,3]),('y',2),('z',3)] | 149 assert example.values()==[[1,2,3],2,3] |
140 example.append_keyval('u',0) # adds item with name 'u' and value 0 | 150 assert example.items()==[('x',[1,2,3]),('y',2),('z',3)] |
141 assert len(example)==4 # number of items = 4 here | 151 example.append_keyval('u',0) # adds item with name 'u' and value 0 |
142 example2 = LookupList(['v','w'], ['a','b']) | 152 assert len(example)==4 # number of items = 4 here |
143 example3 = LookupList(['x','y','z','u','v','w'], [[1, 2, 3],2,3,0,'a','b']) | 153 example2 = LookupList(['v','w'], ['a','b']) |
144 print example3 | 154 example3 = LookupList(['x','y','z','u','v','w'], [[1, 2, 3],2,3,0,'a','b']) |
145 print example+example2 | 155 assert example+example2==example3 |
146 print example+example2 | 156 assert have_raised("example+example") |
147 assert example+example2==example3 | |
148 have_throw=False | |
149 try: | |
150 example+example | |
151 except: | |
152 have_throw=True | |
153 assert have_throw | |
154 | 157 |
155 def ApplyFunctionDataSet(): | 158 def test_ApplyFunctionDataSet(): |
159 print "test_ApplyFunctionDataSet" | |
156 raise NotImplementedError() | 160 raise NotImplementedError() |
157 def CacheDataSet(): | 161 def test_CacheDataSet(): |
162 print "test_CacheDataSet" | |
158 raise NotImplementedError() | 163 raise NotImplementedError() |
159 def FieldsSubsetDataSet(): | 164 def test_FieldsSubsetDataSet(): |
165 print "test_FieldsSubsetDataSet" | |
160 raise NotImplementedError() | 166 raise NotImplementedError() |
161 def DataSetFields(): | 167 def test_DataSetFields(): |
168 print "test_DataSetFields" | |
162 raise NotImplementedError() | 169 raise NotImplementedError() |
163 def MinibatchDataSet(): | 170 def test_MinibatchDataSet(): |
171 print "test_MinibatchDataSet" | |
164 raise NotImplementedError() | 172 raise NotImplementedError() |
165 def HStackedDataSet(): | 173 def test_HStackedDataSet(): |
174 print "test_HStackedDataSet" | |
166 raise NotImplementedError() | 175 raise NotImplementedError() |
167 def VStackedDataSet(): | 176 def test_VStackedDataSet(): |
177 print "test_VStackedDataSet" | |
168 raise NotImplementedError() | 178 raise NotImplementedError() |
169 def ArrayFieldsDataSet(): | 179 def test_ArrayFieldsDataSet(): |
180 print "test_ArrayFieldsDataSet" | |
170 raise NotImplementedError() | 181 raise NotImplementedError() |
171 | 182 |
172 test1() | 183 test1() |
173 test_LookupList() | 184 test_LookupList() |
174 test_ArrayDataSet() | 185 test_ArrayDataSet() |