Mercurial > pylearn
comparison test_dataset.py @ 145:933db7ece663
make some function global to reuse them to test other dataset
author | Frederic Bastien <bastienf@iro.umontreal.ca> |
---|---|
date | Mon, 12 May 2008 15:35:18 -0400 |
parents | 0c6fec172ae1 |
children | a5329e719229 |
comparison
equal
deleted
inserted
replaced
144:ceae4de18981 | 145:933db7ece663 |
---|---|
1 #!/bin/env python | 1 #!/bin/env python |
2 from dataset import * | 2 from dataset import * |
3 from math import * | 3 from math import * |
4 import numpy | 4 import numpy |
5 | 5 |
6 def have_raised(to_eval): | 6 def have_raised(to_eval, **var): |
7 | |
7 have_thrown = False | 8 have_thrown = False |
8 try: | 9 try: |
9 eval(to_eval) | 10 eval(to_eval) |
10 except : | 11 except : |
11 have_thrown = True | 12 have_thrown = True |
30 print "minibatch=",minibatch | 31 print "minibatch=",minibatch |
31 for var in minibatch: | 32 for var in minibatch: |
32 print "var=",var | 33 print "var=",var |
33 print "take a slice and look at field y",ds[1:6:2]["y"] | 34 print "take a slice and look at field y",ds[1:6:2]["y"] |
34 | 35 |
36 def test_iterate_over_examples(array,ds): | |
37 #not in doc!!! | |
38 i=0 | |
39 for example in range(len(ds)): | |
40 assert (ds[example]['x']==array[example][:3]).all() | |
41 assert ds[example]['y']==array[example][3] | |
42 assert (ds[example]['z']==array[example][[0,2]]).all() | |
43 i+=1 | |
44 assert i==len(ds) | |
45 del example,i | |
46 | |
47 # - for example in dataset: | |
48 i=0 | |
49 for example in ds: | |
50 assert len(example)==3 | |
51 assert (example['x']==array[i][:3]).all() | |
52 assert example['y']==array[i][3] | |
53 assert (example['z']==array[i][0:3:2]).all() | |
54 assert (numpy.append(example['x'],example['y'])==array[i]).all() | |
55 i+=1 | |
56 assert i==len(ds) | |
57 del example,i | |
58 | |
59 # - for val1,val2,... in dataset: | |
60 i=0 | |
61 for x,y,z in ds: | |
62 assert (x==array[i][:3]).all() | |
63 assert y==array[i][3] | |
64 assert (z==array[i][0:3:2]).all() | |
65 assert (numpy.append(x,y)==array[i]).all() | |
66 i+=1 | |
67 assert i==len(ds) | |
68 del x,y,z,i | |
69 | |
70 # - for example in dataset(field1, field2,field3, ...): | |
71 i=0 | |
72 for example in ds('x','y','z'): | |
73 assert len(example)==3 | |
74 assert (example['x']==array[i][:3]).all() | |
75 assert example['y']==array[i][3] | |
76 assert (example['z']==array[i][0:3:2]).all() | |
77 assert (numpy.append(example['x'],example['y'])==array[i]).all() | |
78 i+=1 | |
79 assert i==len(ds) | |
80 del example,i | |
81 i=0 | |
82 for example in ds('y','x'): | |
83 assert len(example)==2 | |
84 assert (example['x']==array[i][:3]).all() | |
85 assert example['y']==array[i][3] | |
86 assert (numpy.append(example['x'],example['y'])==array[i]).all() | |
87 i+=1 | |
88 assert i==len(ds) | |
89 del example,i | |
90 | |
91 # - for val1,val2,val3 in dataset(field1, field2,field3): | |
92 i=0 | |
93 for x,y,z in ds('x','y','z'): | |
94 assert (x==array[i][:3]).all() | |
95 assert y==array[i][3] | |
96 assert (z==array[i][0:3:2]).all() | |
97 assert (numpy.append(x,y)==array[i]).all() | |
98 i+=1 | |
99 assert i==len(ds) | |
100 del x,y,z,i | |
101 i=0 | |
102 for y,x in ds('y','x',): | |
103 assert (x==array[i][:3]).all() | |
104 assert y==array[i][3] | |
105 assert (numpy.append(x,y)==array[i]).all() | |
106 i+=1 | |
107 assert i==len(ds) | |
108 del x,y,i | |
109 | |
110 def test_minibatch_size(minibatch,minibatch_size,len_ds,nb_field,nb_iter_finished): | |
111 ##full minibatch or the last minibatch | |
112 for idx in range(nb_field): | |
113 test_minibatch_field_size(minibatch[idx],minibatch_size,len_ds,nb_iter_finished) | |
114 del idx | |
115 def test_minibatch_field_size(minibatch_field,minibatch_size,len_ds,nb_iter_finished): | |
116 assert len(minibatch_field)==minibatch_size or ((nb_iter_finished*minibatch_size+len(minibatch_field))==len_ds and len(minibatch_field)<minibatch_size) | |
117 | |
118 # - for minibatch in dataset.minibatches([field1, field2, ...],minibatch_size=N): | |
119 i=0 | |
120 mi=0 | |
121 m=ds.minibatches(['x','z'], minibatch_size=3) | |
122 assert isinstance(m,DataSet.MinibatchWrapAroundIterator) | |
123 for minibatch in m: | |
124 assert len(minibatch)==2 | |
125 test_minibatch_size(minibatch,m.minibatch_size,len(ds),2,mi) | |
126 assert (minibatch[0][:,0:3:2]==minibatch[1]).all() | |
127 mi+=1 | |
128 i+=len(minibatch[0]) | |
129 assert i==len(ds) | |
130 assert mi==4 | |
131 del minibatch,i,m,mi | |
132 | |
133 i=0 | |
134 mi=0 | |
135 m=ds.minibatches(['x','y'], minibatch_size=3) | |
136 assert isinstance(m,DataSet.MinibatchWrapAroundIterator) | |
137 for minibatch in m: | |
138 assert len(minibatch)==2 | |
139 test_minibatch_size(minibatch,m.minibatch_size,len(ds),2,mi) | |
140 mi+=1 | |
141 for id in range(len(minibatch[0])): | |
142 assert (numpy.append(minibatch[0][id],minibatch[1][id])==array[i]).all() | |
143 i+=1 | |
144 assert i==len(ds) | |
145 assert mi==4 | |
146 del minibatch,i,id,m,mi | |
147 | |
148 # - for mini1,mini2,mini3 in dataset.minibatches([field1, field2, field3], minibatch_size=N): | |
149 i=0 | |
150 mi=0 | |
151 m=ds.minibatches(['x','z'], minibatch_size=3) | |
152 assert isinstance(m,DataSet.MinibatchWrapAroundIterator) | |
153 for x,z in m: | |
154 test_minibatch_field_size(x,m.minibatch_size,len(ds),mi) | |
155 test_minibatch_field_size(z,m.minibatch_size,len(ds),mi) | |
156 assert (x[:,0:3:2]==z).all() | |
157 i+=len(x) | |
158 mi+=1 | |
159 assert i==len(ds) | |
160 assert mi==4 | |
161 del x,z,i,m,mi | |
162 i=0 | |
163 mi=0 | |
164 m=ds.minibatches(['x','y'], minibatch_size=3) | |
165 for x,y in m: | |
166 test_minibatch_field_size(x,m.minibatch_size,len(ds),mi) | |
167 test_minibatch_field_size(y,m.minibatch_size,len(ds),mi) | |
168 mi+=1 | |
169 for id in range(len(x)): | |
170 assert (numpy.append(x[id],y[id])==array[i]).all() | |
171 i+=1 | |
172 assert i==len(ds) | |
173 assert mi==4 | |
174 del x,y,i,id,m,mi | |
175 | |
176 #not in doc | |
177 i=0 | |
178 m=ds.minibatches(['x','y'],n_batches=1,minibatch_size=3,offset=4) | |
179 assert isinstance(m,DataSet.MinibatchWrapAroundIterator) | |
180 for x,y in m: | |
181 assert len(x)==3 | |
182 assert len(y)==3 | |
183 for id in range(3): | |
184 assert (numpy.append(x[id],y[id])==array[i+4]).all() | |
185 i+=1 | |
186 assert i==3 | |
187 del x,y,i,id,m | |
188 | |
189 i=0 | |
190 m=ds.minibatches(['x','y'],n_batches=2,minibatch_size=3,offset=4) | |
191 assert isinstance(m,DataSet.MinibatchWrapAroundIterator) | |
192 for x,y in m: | |
193 assert len(x)==3 | |
194 assert len(y)==3 | |
195 for id in range(3): | |
196 assert (numpy.append(x[id],y[id])==array[i+4]).all() | |
197 i+=1 | |
198 assert i==6 | |
199 del x,y,i,id,m | |
200 | |
201 i=0 | |
202 m=ds.minibatches(['x','y'],n_batches=20,minibatch_size=3,offset=4) | |
203 assert isinstance(m,DataSet.MinibatchWrapAroundIterator) | |
204 for x,y in m: | |
205 assert len(x)==3 | |
206 assert len(y)==3 | |
207 for id in range(3): | |
208 assert (numpy.append(x[id],y[id])==array[(i+4)%array.shape[0]]).all() | |
209 i+=1 | |
210 assert i==m.n_batches*m.minibatch_size | |
211 del x,y,i,id | |
212 | |
213 | |
214 def test_ds_iterator(array,iterator1,iterator2,iterator3): | |
215 l=len(iterator1) | |
216 i=0 | |
217 for x,y in iterator1: | |
218 assert (x==array[i][:3]).all() | |
219 assert y==array[i][3] | |
220 assert (numpy.append(x,y)==array[i]).all() | |
221 i+=1 | |
222 assert i==l | |
223 i=0 | |
224 for y,z in iterator2: | |
225 assert y==array[i][3] | |
226 assert (z==array[i][0:3:2]).all() | |
227 i+=1 | |
228 assert i==l | |
229 i=0 | |
230 for x,y,z in iterator3: | |
231 assert (x==array[i][:3]).all() | |
232 assert y==array[i][3] | |
233 assert (z==array[i][0:3:2]).all() | |
234 assert (numpy.append(x,y)==array[i]).all() | |
235 i+=1 | |
236 assert i==l | |
237 | |
238 def test_getitem(array,ds): | |
239 def test_ds(orig,ds,index): | |
240 i=0 | |
241 assert len(ds)==len(index) | |
242 for x,z,y in ds('x','z','y'): | |
243 assert (orig[index[i]]['x']==array[index[i]][:3]).all() | |
244 assert (orig[index[i]]['x']==x).all() | |
245 assert orig[index[i]]['y']==array[index[i]][3] | |
246 assert orig[index[i]]['y']==y | |
247 assert (orig[index[i]]['z']==array[index[i]][0:3:2]).all() | |
248 assert (orig[index[i]]['z']==z).all() | |
249 i+=1 | |
250 del i | |
251 ds[0] | |
252 if len(ds)>2: | |
253 ds[:1] | |
254 ds[1:1] | |
255 ds[1:1:1] | |
256 if len(ds)>5: | |
257 ds[[1,2,3]] | |
258 for x in ds: | |
259 pass | |
260 | |
261 #ds[:n] returns a dataset with the n first examples. | |
262 ds2=ds[:3] | |
263 assert isinstance(ds2,DataSet) | |
264 test_ds(ds,ds2,index=[0,1,2]) | |
265 del ds2 | |
266 | |
267 #ds[i1:i2:s]# returns a ds with the examples i1,i1+s,...i2-s. | |
268 ds2=ds[1:7:2] | |
269 assert isinstance(ds2,DataSet) | |
270 test_ds(ds,ds2,[1,3,5]) | |
271 del ds2 | |
272 | |
273 #ds[i] | |
274 ds2=ds[5] | |
275 assert isinstance(ds2,Example) | |
276 assert have_raised("var['ds']["+str(len(ds))+"]",ds=ds) # index not defined | |
277 assert not have_raised("var['ds']["+str(len(ds)-1)+"]",ds=ds) | |
278 del ds2 | |
279 | |
280 #ds[[i1,i2,...in]]# returns a ds with examples i1,i2,...in. | |
281 ds2=ds[[4,7,2,8]] | |
282 assert isinstance(ds2,DataSet) | |
283 test_ds(ds,ds2,[4,7,2,8]) | |
284 del ds2 | |
285 | |
286 #ds[fieldname]# an iterable over the values of the field fieldname across | |
287 #the ds (the iterable is obtained by default by calling valuesVStack | |
288 #over the values for individual examples). | |
289 assert have_raised("ds['h']") # h is not defined... | |
290 assert have_raised("ds[['x']]") # bad syntax | |
291 assert not have_raised("var['ds']['x']",ds=ds) | |
292 isinstance(ds['x'],DataSetFields) | |
293 ds2=ds['x'] | |
294 assert len(ds['x'])==10 | |
295 assert len(ds['y'])==10 | |
296 assert len(ds['z'])==10 | |
297 i=0 | |
298 for example in ds['x']: | |
299 assert (example==array[i][:3]).all() | |
300 i+=1 | |
301 i=0 | |
302 for example in ds['y']: | |
303 assert (example==array[i][3]).all() | |
304 i+=1 | |
305 i=0 | |
306 for example in ds['z']: | |
307 assert (example==array[i,0:3:2]).all() | |
308 i+=1 | |
309 del ds2,i | |
310 | |
311 #ds.<property># returns the value of a property associated with | |
312 #the name <property>. The following properties should be supported: | |
313 # - 'description': a textual description or name for the ds | |
314 # - 'fieldtypes': a list of types (one per field) | |
315 | |
316 #* ds1 | ds2 | ds3 == ds.hstack([ds1,ds2,ds3])#???? | |
317 #hstack([ds('x','y'),ds('z')] | |
318 #hstack([ds('z','y'),ds('x')] | |
319 #assert have_thrown("hstack([ds('x'),ds('x')]") | |
320 #assert not have_thrown("hstack([ds('x'),ds('x')]") | |
321 #accept_nonunique_names | |
322 #assert have_thrown("hstack([ds('y','x'),ds('x')]") | |
323 # i=0 | |
324 # for example in hstack([ds('x'),ds('y'),ds('z')]): | |
325 # example==ds[i] | |
326 # i+=1 | |
327 # del i,example | |
328 #* ds1 & ds2 & ds3 == ds.vstack([ds1,ds2,ds3])#???? | |
329 | |
330 | |
35 def test_ArrayDataSet(): | 331 def test_ArrayDataSet(): |
36 #don't test stream | 332 #don't test stream |
37 #tested only with float value | 333 #tested only with float value |
38 #don't always test with y | 334 #don't always test with y |
39 #don't test missing value | 335 #don't test missing value |
40 #don't test with tuple | 336 #don't test with tuple |
41 #don't test proterties | 337 #don't test proterties |
42 def test_iterate_over_examples(array,ds): | |
43 #not in doc!!! | |
44 i=0 | |
45 for example in range(len(ds)): | |
46 assert (ds[example]['x']==a[example][:3]).all() | |
47 assert ds[example]['y']==a[example][3] | |
48 assert (ds[example]['z']==a[example][[0,2]]).all() | |
49 i+=1 | |
50 assert i==len(ds) | |
51 del example,i | |
52 | |
53 # - for example in dataset: | |
54 i=0 | |
55 for example in ds: | |
56 assert len(example)==3 | |
57 assert (example['x']==array[i][:3]).all() | |
58 assert example['y']==array[i][3] | |
59 assert (example['z']==array[i][0:3:2]).all() | |
60 assert (numpy.append(example['x'],example['y'])==array[i]).all() | |
61 i+=1 | |
62 assert i==len(ds) | |
63 del example,i | |
64 | |
65 # - for val1,val2,... in dataset: | |
66 i=0 | |
67 for x,y,z in ds: | |
68 assert (x==array[i][:3]).all() | |
69 assert y==array[i][3] | |
70 assert (z==array[i][0:3:2]).all() | |
71 assert (numpy.append(x,y)==array[i]).all() | |
72 i+=1 | |
73 assert i==len(ds) | |
74 del x,y,z,i | |
75 | |
76 # - for example in dataset(field1, field2,field3, ...): | |
77 i=0 | |
78 for example in ds('x','y','z'): | |
79 assert len(example)==3 | |
80 assert (example['x']==array[i][:3]).all() | |
81 assert example['y']==array[i][3] | |
82 assert (example['z']==array[i][0:3:2]).all() | |
83 assert (numpy.append(example['x'],example['y'])==array[i]).all() | |
84 i+=1 | |
85 assert i==len(ds) | |
86 del example,i | |
87 i=0 | |
88 for example in ds('y','x'): | |
89 assert len(example)==2 | |
90 assert (example['x']==array[i][:3]).all() | |
91 assert example['y']==array[i][3] | |
92 assert (numpy.append(example['x'],example['y'])==array[i]).all() | |
93 i+=1 | |
94 assert i==len(ds) | |
95 del example,i | |
96 | |
97 # - for val1,val2,val3 in dataset(field1, field2,field3): | |
98 i=0 | |
99 for x,y,z in ds('x','y','z'): | |
100 assert (x==array[i][:3]).all() | |
101 assert y==array[i][3] | |
102 assert (z==array[i][0:3:2]).all() | |
103 assert (numpy.append(x,y)==array[i]).all() | |
104 i+=1 | |
105 assert i==len(ds) | |
106 del x,y,z,i | |
107 i=0 | |
108 for y,x in ds('y','x',): | |
109 assert (x==array[i][:3]).all() | |
110 assert y==array[i][3] | |
111 assert (numpy.append(x,y)==array[i]).all() | |
112 i+=1 | |
113 assert i==len(ds) | |
114 del x,y,i | |
115 | |
116 def test_minibatch_size(minibatch,minibatch_size,len_ds,nb_field,nb_iter_finished): | |
117 ##full minibatch or the last minibatch | |
118 for idx in range(nb_field): | |
119 test_minibatch_field_size(minibatch[idx],minibatch_size,len_ds,nb_iter_finished) | |
120 del idx | |
121 def test_minibatch_field_size(minibatch_field,minibatch_size,len_ds,nb_iter_finished): | |
122 assert len(minibatch_field)==minibatch_size or ((nb_iter_finished*minibatch_size+len(minibatch_field))==len_ds and len(minibatch_field)<minibatch_size) | |
123 | |
124 # - for minibatch in dataset.minibatches([field1, field2, ...],minibatch_size=N): | |
125 i=0 | |
126 mi=0 | |
127 m=ds.minibatches(['x','z'], minibatch_size=3) | |
128 assert isinstance(m,DataSet.MinibatchWrapAroundIterator) | |
129 for minibatch in m: | |
130 assert len(minibatch)==2 | |
131 test_minibatch_size(minibatch,m.minibatch_size,len(ds),2,mi) | |
132 assert (minibatch[0][:,0:3:2]==minibatch[1]).all() | |
133 mi+=1 | |
134 i+=len(minibatch[0]) | |
135 assert i==len(ds) | |
136 assert mi==4 | |
137 del minibatch,i,m,mi | |
138 | |
139 i=0 | |
140 mi=0 | |
141 m=ds.minibatches(['x','y'], minibatch_size=3) | |
142 assert isinstance(m,DataSet.MinibatchWrapAroundIterator) | |
143 for minibatch in m: | |
144 assert len(minibatch)==2 | |
145 test_minibatch_size(minibatch,m.minibatch_size,len(ds),2,mi) | |
146 mi+=1 | |
147 for id in range(len(minibatch[0])): | |
148 assert (numpy.append(minibatch[0][id],minibatch[1][id])==a[i]).all() | |
149 i+=1 | |
150 assert i==len(ds) | |
151 assert mi==4 | |
152 del minibatch,i,id,m,mi | |
153 | |
154 # - for mini1,mini2,mini3 in dataset.minibatches([field1, field2, field3], minibatch_size=N): | |
155 i=0 | |
156 mi=0 | |
157 m=ds.minibatches(['x','z'], minibatch_size=3) | |
158 assert isinstance(m,DataSet.MinibatchWrapAroundIterator) | |
159 for x,z in m: | |
160 test_minibatch_field_size(x,m.minibatch_size,len(ds),mi) | |
161 test_minibatch_field_size(z,m.minibatch_size,len(ds),mi) | |
162 assert (x[:,0:3:2]==z).all() | |
163 i+=len(x) | |
164 mi+=1 | |
165 assert i==len(ds) | |
166 assert mi==4 | |
167 del x,z,i,m,mi | |
168 i=0 | |
169 mi=0 | |
170 m=ds.minibatches(['x','y'], minibatch_size=3) | |
171 for x,y in m: | |
172 test_minibatch_field_size(x,m.minibatch_size,len(ds),mi) | |
173 test_minibatch_field_size(y,m.minibatch_size,len(ds),mi) | |
174 mi+=1 | |
175 for id in range(len(x)): | |
176 assert (numpy.append(x[id],y[id])==a[i]).all() | |
177 i+=1 | |
178 assert i==len(ds) | |
179 assert mi==4 | |
180 del x,y,i,id,m,mi | |
181 | |
182 #not in doc | |
183 i=0 | |
184 m=ds.minibatches(['x','y'],n_batches=1,minibatch_size=3,offset=4) | |
185 assert isinstance(m,DataSet.MinibatchWrapAroundIterator) | |
186 for x,y in m: | |
187 assert len(x)==3 | |
188 assert len(y)==3 | |
189 for id in range(3): | |
190 assert (numpy.append(x[id],y[id])==a[i+4]).all() | |
191 i+=1 | |
192 assert i==3 | |
193 del x,y,i,id,m | |
194 | |
195 i=0 | |
196 m=ds.minibatches(['x','y'],n_batches=2,minibatch_size=3,offset=4) | |
197 assert isinstance(m,DataSet.MinibatchWrapAroundIterator) | |
198 for x,y in m: | |
199 assert len(x)==3 | |
200 assert len(y)==3 | |
201 for id in range(3): | |
202 assert (numpy.append(x[id],y[id])==a[i+4]).all() | |
203 i+=1 | |
204 assert i==6 | |
205 del x,y,i,id,m | |
206 | |
207 i=0 | |
208 m=ds.minibatches(['x','y'],n_batches=20,minibatch_size=3,offset=4) | |
209 assert isinstance(m,DataSet.MinibatchWrapAroundIterator) | |
210 for x,y in m: | |
211 assert len(x)==3 | |
212 assert len(y)==3 | |
213 for id in range(3): | |
214 assert (numpy.append(x[id],y[id])==a[(i+4)%a.shape[0]]).all() | |
215 i+=1 | |
216 assert i==m.n_batches*m.minibatch_size | |
217 del x,y,i,id | |
218 | |
219 | |
220 def test_ds_iterator(array,iterator1,iterator2,iterator3): | |
221 i=0 | |
222 for x,y in iterator1: | |
223 assert (x==array[i][:3]).all() | |
224 assert y==array[i][3] | |
225 assert (numpy.append(x,y)==array[i]).all() | |
226 i+=1 | |
227 assert i==len(ds) | |
228 i=0 | |
229 for y,z in iterator2: | |
230 assert y==array[i][3] | |
231 assert (z==array[i][0:3:2]).all() | |
232 i+=1 | |
233 assert i==len(ds) | |
234 i=0 | |
235 for x,y,z in iterator3: | |
236 assert (x==array[i][:3]).all() | |
237 assert y==array[i][3] | |
238 assert (z==array[i][0:3:2]).all() | |
239 assert (numpy.append(x,y)==array[i]).all() | |
240 i+=1 | |
241 assert i==len(ds) | |
242 | |
243 def test_getitem(array,ds): | |
244 | |
245 def test_ds(orig,ds,index): | |
246 i=0 | |
247 assert len(ds)==len(index) | |
248 for x,z,y in ds('x','z','y'): | |
249 assert (orig[index[i]]['x']==array[index[i]][:3]).all() | |
250 assert (orig[index[i]]['x']==x).all() | |
251 assert orig[index[i]]['y']==array[index[i]][3] | |
252 assert orig[index[i]]['y']==y | |
253 assert (orig[index[i]]['z']==array[index[i]][0:3:2]).all() | |
254 assert (orig[index[i]]['z']==z).all() | |
255 i+=1 | |
256 del i | |
257 ds[0] | |
258 if len(ds)>2: | |
259 ds[:1] | |
260 ds[1:1] | |
261 ds[1:1:1] | |
262 if len(ds)>5: | |
263 ds[[1,2,3]] | |
264 for x in ds: | |
265 pass | |
266 | |
267 #ds[:n] returns a dataset with the n first examples. | |
268 ds2=ds[:3] | |
269 assert isinstance(ds2,DataSet) | |
270 test_ds(ds,ds2,index=[0,1,2]) | |
271 del ds2 | |
272 | |
273 #ds[i1:i2:s]# returns a ds with the examples i1,i1+s,...i2-s. | |
274 ds2=ds[1:7:2] | |
275 assert isinstance(ds2,DataSet) | |
276 test_ds(ds,ds2,[1,3,5]) | |
277 del ds2 | |
278 | |
279 #ds[i] | |
280 ds2=ds[5] | |
281 assert isinstance(ds2,Example) | |
282 assert have_raised("ds["+str(len(ds))+"]") # index not defined | |
283 assert not have_raised("ds["+str(len(ds)-1)+"]") | |
284 del ds2 | |
285 | |
286 #ds[[i1,i2,...in]]# returns a ds with examples i1,i2,...in. | |
287 ds2=ds[[4,7,2,8]] | |
288 assert isinstance(ds2,DataSet) | |
289 test_ds(ds,ds2,[4,7,2,8]) | |
290 del ds2 | |
291 | |
292 #ds[fieldname]# an iterable over the values of the field fieldname across | |
293 #the ds (the iterable is obtained by default by calling valuesVStack | |
294 #over the values for individual examples). | |
295 assert have_raised("ds['h']") # h is not defined... | |
296 assert have_raised("ds[['x']]") # bad syntax | |
297 assert not have_raised("ds['x']") | |
298 isinstance(ds['x'],DataSetFields) | |
299 ds2=ds['x'] | |
300 assert len(ds['x'])==10 | |
301 assert len(ds['y'])==10 | |
302 assert len(ds['z'])==10 | |
303 i=0 | |
304 for example in ds['x']: | |
305 assert (example==a[i][:3]).all() | |
306 i+=1 | |
307 i=0 | |
308 for example in ds['y']: | |
309 assert (example==a[i][3]).all() | |
310 i+=1 | |
311 i=0 | |
312 for example in ds['z']: | |
313 assert (example==a[i,0:3:2]).all() | |
314 i+=1 | |
315 del ds2,i | |
316 | |
317 #ds.<property># returns the value of a property associated with | |
318 #the name <property>. The following properties should be supported: | |
319 # - 'description': a textual description or name for the ds | |
320 # - 'fieldtypes': a list of types (one per field) | |
321 | |
322 #* ds1 | ds2 | ds3 == ds.hstack([ds1,ds2,ds3])#???? | |
323 #hstack([ds('x','y'),ds('z')] | |
324 #hstack([ds('z','y'),ds('x')] | |
325 #assert have_thrown("hstack([ds('x'),ds('x')]") | |
326 #assert not have_thrown("hstack([ds('x'),ds('x')]") | |
327 #accept_nonunique_names | |
328 #assert have_thrown("hstack([ds('y','x'),ds('x')]") | |
329 # i=0 | |
330 # for example in hstack([ds('x'),ds('y'),ds('z')]): | |
331 # example==ds[i] | |
332 # i+=1 | |
333 # del i,example | |
334 #* ds1 & ds2 & ds3 == ds.vstack([ds1,ds2,ds3])#???? | |
335 | |
336 | |
337 print "test_ArrayDataSet" | 338 print "test_ArrayDataSet" |
338 a = numpy.random.rand(10,4) | 339 a2 = numpy.random.rand(10,4) |
339 ds = ArrayDataSet(a,{'x':slice(3),'y':3,'z':[0,2]})###???tuple not tested | 340 ds = ArrayDataSet(a2,{'x':slice(3),'y':3,'z':[0,2]})###???tuple not tested |
340 ds = ArrayDataSet(a,LookupList(['x','y','z'],[slice(3),3,[0,2]]))###???tuple not tested | 341 ds = ArrayDataSet(a2,LookupList(['x','y','z'],[slice(3),3,[0,2]]))###???tuple not tested |
341 assert len(ds)==10 | 342 assert len(ds)==10 |
342 #assert ds==a? should this work? | 343 #assert ds==a? should this work? |
343 | 344 |
344 test_iterate_over_examples(a, ds) | 345 test_iterate_over_examples(a2, ds) |
345 test_getitem(a, ds) | 346 test_getitem(a2, ds) |
346 | 347 |
347 # - for val1,val2,val3 in dataset(field1, field2,field3): | 348 # - for val1,val2,val3 in dataset(field1, field2,field3): |
348 test_ds_iterator(a,ds('x','y'),ds('y','z'),ds('x','y','z')) | 349 test_ds_iterator(a2,ds('x','y'),ds('y','z'),ds('x','y','z')) |
349 | 350 |
350 | 351 |
351 assert len(ds.fields())==3 | 352 assert len(ds.fields())==3 |
352 for field in ds.fields(): | 353 for field in ds.fields(): |
353 for field_value in field: # iterate over the values associated to that field for all the ds examples | 354 for field_value in field: # iterate over the values associated to that field for all the ds examples |
378 example.append_keyval('u',0) # adds item with name 'u' and value 0 | 379 example.append_keyval('u',0) # adds item with name 'u' and value 0 |
379 assert len(example)==4 # number of items = 4 here | 380 assert len(example)==4 # number of items = 4 here |
380 example2 = LookupList(['v','w'], ['a','b']) | 381 example2 = LookupList(['v','w'], ['a','b']) |
381 example3 = LookupList(['x','y','z','u','v','w'], [[1, 2, 3],2,3,0,'a','b']) | 382 example3 = LookupList(['x','y','z','u','v','w'], [[1, 2, 3],2,3,0,'a','b']) |
382 assert example+example2==example3 | 383 assert example+example2==example3 |
383 assert have_raised("example+example") | 384 assert have_raised("var['x']+var['x']",x=example) |
384 | 385 |
385 def test_ApplyFunctionDataSet(): | 386 def test_ApplyFunctionDataSet(): |
386 print "test_ApplyFunctionDataSet" | 387 print "test_ApplyFunctionDataSet" |
387 raise NotImplementedError() | 388 raise NotImplementedError() |
388 def test_CacheDataSet(): | 389 def test_CacheDataSet(): |