Mercurial > pylearn
comparison _test_dataset.py @ 376:c9a89be5cb0a
Redesigning linear_regression
author | Yoshua Bengio <bengioy@iro.umontreal.ca> |
---|---|
date | Mon, 07 Jul 2008 10:08:35 -0400 |
parents | 18702ceb2096 |
children | 82da179d95b2 |
comparison
equal
deleted
inserted
replaced
375:12ce29abf27d | 376:c9a89be5cb0a |
---|---|
1 #!/bin/env python | 1 #!/bin/env python |
2 from dataset import * | 2 from dataset import * |
3 from math import * | 3 from math import * |
4 import numpy, unittest, sys | 4 import numpy, unittest, sys |
5 from misc import * | 5 #from misc import * |
6 from lookup_list import LookupList | 6 from lookup_list import LookupList |
7 | 7 |
8 def have_raised(to_eval, **var): | 8 def have_raised(to_eval, **var): |
9 have_thrown = False | 9 have_thrown = False |
10 try: | 10 try: |
132 assert len(minibatch_field)==minibatch_size or ((nb_iter_finished*minibatch_size+len(minibatch_field))==len_ds and len(minibatch_field)<minibatch_size) | 132 assert len(minibatch_field)==minibatch_size or ((nb_iter_finished*minibatch_size+len(minibatch_field))==len_ds and len(minibatch_field)<minibatch_size) |
133 | 133 |
134 # - for minibatch in dataset.minibatches([field1, field2, ...],minibatch_size=N): | 134 # - for minibatch in dataset.minibatches([field1, field2, ...],minibatch_size=N): |
135 i=0 | 135 i=0 |
136 mi=0 | 136 mi=0 |
137 m=ds.minibatches(['x','z'], minibatch_size=3) | 137 size=3 |
138 assert isinstance(m,DataSet.MinibatchWrapAroundIterator) | 138 m=ds.minibatches(['x','z'], minibatch_size=size) |
139 assert hasattr(m,'__iter__') | |
139 for minibatch in m: | 140 for minibatch in m: |
140 assert isinstance(minibatch,DataSetFields) | 141 assert isinstance(minibatch,LookupList) |
141 assert len(minibatch)==2 | 142 assert len(minibatch)==2 |
142 test_minibatch_size(minibatch,m.minibatch_size,len(ds),2,mi) | 143 test_minibatch_size(minibatch,size,len(ds),2,mi) |
143 if type(ds)==ArrayDataSet: | 144 if type(ds)==ArrayDataSet: |
144 assert (minibatch[0][:,::2]==minibatch[1]).all() | 145 assert (minibatch[0][:,::2]==minibatch[1]).all() |
145 else: | 146 else: |
146 for j in xrange(len(minibatch[0])): | 147 for j in xrange(len(minibatch[0])): |
147 (minibatch[0][j][::2]==minibatch[1][j]).all() | 148 (minibatch[0][j][::2]==minibatch[1][j]).all() |
148 mi+=1 | 149 mi+=1 |
149 i+=len(minibatch[0]) | 150 i+=len(minibatch[0]) |
150 assert i==len(ds) | 151 assert i==(len(ds)/size)*size |
151 assert mi==4 | 152 assert mi==(len(ds)/size) |
152 del minibatch,i,m,mi | 153 del minibatch,i,m,mi,size |
153 | 154 |
154 i=0 | 155 i=0 |
155 mi=0 | 156 mi=0 |
156 m=ds.minibatches(['x','y'], minibatch_size=3) | 157 size=3 |
157 assert isinstance(m,DataSet.MinibatchWrapAroundIterator) | 158 m=ds.minibatches(['x','y'], minibatch_size=size) |
159 assert hasattr(m,'__iter__') | |
158 for minibatch in m: | 160 for minibatch in m: |
161 assert isinstance(minibatch,LookupList) | |
159 assert len(minibatch)==2 | 162 assert len(minibatch)==2 |
160 test_minibatch_size(minibatch,m.minibatch_size,len(ds),2,mi) | 163 test_minibatch_size(minibatch,size,len(ds),2,mi) |
161 mi+=1 | 164 mi+=1 |
162 for id in range(len(minibatch[0])): | 165 for id in range(len(minibatch[0])): |
163 assert (numpy.append(minibatch[0][id],minibatch[1][id])==array[i]).all() | 166 assert (numpy.append(minibatch[0][id],minibatch[1][id])==array[i]).all() |
164 i+=1 | 167 i+=1 |
165 assert i==len(ds) | 168 assert i==(len(ds)/size)*size |
166 assert mi==4 | 169 assert mi==(len(ds)/size) |
167 del minibatch,i,id,m,mi | 170 del minibatch,i,id,m,mi,size |
168 | 171 |
169 # - for mini1,mini2,mini3 in dataset.minibatches([field1, field2, field3], minibatch_size=N): | 172 # - for mini1,mini2,mini3 in dataset.minibatches([field1, field2, field3], minibatch_size=N): |
170 i=0 | 173 i=0 |
171 mi=0 | 174 mi=0 |
172 m=ds.minibatches(['x','z'], minibatch_size=3) | 175 size=3 |
173 assert isinstance(m,DataSet.MinibatchWrapAroundIterator) | 176 m=ds.minibatches(['x','z'], minibatch_size=size) |
177 assert hasattr(m,'__iter__') | |
174 for x,z in m: | 178 for x,z in m: |
175 test_minibatch_field_size(x,m.minibatch_size,len(ds),mi) | 179 test_minibatch_field_size(x,size,len(ds),mi) |
176 test_minibatch_field_size(z,m.minibatch_size,len(ds),mi) | 180 test_minibatch_field_size(z,size,len(ds),mi) |
177 for id in range(len(x)): | 181 for id in range(len(x)): |
178 assert (x[id][::2]==z[id]).all() | 182 assert (x[id][::2]==z[id]).all() |
179 i+=1 | 183 i+=1 |
180 mi+=1 | 184 mi+=1 |
181 assert i==len(ds) | 185 assert i==(len(ds)/size)*size |
182 assert mi==4 | 186 assert mi==(len(ds)/size) |
183 del x,z,i,m,mi | 187 del x,z,i,m,mi,size |
188 | |
184 i=0 | 189 i=0 |
185 mi=0 | 190 mi=0 |
191 size=3 | |
186 m=ds.minibatches(['x','y'], minibatch_size=3) | 192 m=ds.minibatches(['x','y'], minibatch_size=3) |
193 assert hasattr(m,'__iter__') | |
187 for x,y in m: | 194 for x,y in m: |
188 test_minibatch_field_size(x,m.minibatch_size,len(ds),mi) | 195 assert len(x)==size |
189 test_minibatch_field_size(y,m.minibatch_size,len(ds),mi) | 196 assert len(y)==size |
197 test_minibatch_field_size(x,size,len(ds),mi) | |
198 test_minibatch_field_size(y,size,len(ds),mi) | |
190 mi+=1 | 199 mi+=1 |
191 for id in range(len(x)): | 200 for id in range(len(x)): |
192 assert (numpy.append(x[id],y[id])==array[i]).all() | 201 assert (numpy.append(x[id],y[id])==array[i]).all() |
193 i+=1 | 202 i+=1 |
194 assert i==len(ds) | 203 assert i==(len(ds)/size)*size |
195 assert mi==4 | 204 assert mi==(len(ds)/size) |
196 del x,y,i,id,m,mi | 205 del x,y,i,id,m,mi,size |
197 | 206 |
198 #not in doc | 207 #not in doc |
199 i=0 | 208 i=0 |
200 m=ds.minibatches(['x','y'],n_batches=1,minibatch_size=3,offset=4) | 209 size=3 |
201 assert isinstance(m,DataSet.MinibatchWrapAroundIterator) | 210 m=ds.minibatches(['x','y'],n_batches=1,minibatch_size=size,offset=4) |
211 assert hasattr(m,'__iter__') | |
202 for x,y in m: | 212 for x,y in m: |
203 assert len(x)==m.minibatch_size | 213 assert len(x)==size |
204 assert len(y)==m.minibatch_size | 214 assert len(y)==size |
205 for id in range(m.minibatch_size): | 215 for id in range(size): |
206 assert (numpy.append(x[id],y[id])==array[i+4]).all() | 216 assert (numpy.append(x[id],y[id])==array[i+4]).all() |
207 i+=1 | 217 i+=1 |
208 assert i==m.n_batches*m.minibatch_size | 218 assert i==size |
209 del x,y,i,id,m | 219 del x,y,i,id,m,size |
210 | 220 |
211 i=0 | 221 i=0 |
212 m=ds.minibatches(['x','y'],n_batches=2,minibatch_size=3,offset=4) | 222 size=3 |
213 assert isinstance(m,DataSet.MinibatchWrapAroundIterator) | 223 m=ds.minibatches(['x','y'],n_batches=2,minibatch_size=size,offset=4) |
224 assert hasattr(m,'__iter__') | |
214 for x,y in m: | 225 for x,y in m: |
215 assert len(x)==m.minibatch_size | 226 assert len(x)==size |
216 assert len(y)==m.minibatch_size | 227 assert len(y)==size |
217 for id in range(m.minibatch_size): | 228 for id in range(size): |
218 assert (numpy.append(x[id],y[id])==array[i+4]).all() | 229 assert (numpy.append(x[id],y[id])==array[i+4]).all() |
219 i+=1 | 230 i+=1 |
220 assert i==m.n_batches*m.minibatch_size | 231 assert i==2*size |
221 del x,y,i,id,m | 232 del x,y,i,id,m,size |
222 | 233 |
223 i=0 | 234 i=0 |
224 m=ds.minibatches(['x','y'],n_batches=20,minibatch_size=3,offset=4) | 235 size=3 |
225 assert isinstance(m,DataSet.MinibatchWrapAroundIterator) | 236 m=ds.minibatches(['x','y'],n_batches=20,minibatch_size=size,offset=4) |
237 assert hasattr(m,'__iter__') | |
226 for x,y in m: | 238 for x,y in m: |
227 assert len(x)==m.minibatch_size | 239 assert len(x)==size |
228 assert len(y)==m.minibatch_size | 240 assert len(y)==size |
229 for id in range(m.minibatch_size): | 241 for id in range(size): |
230 assert (numpy.append(x[id],y[id])==array[(i+4)%array.shape[0]]).all() | 242 assert (numpy.append(x[id],y[id])==array[(i+4)%array.shape[0]]).all() |
231 i+=1 | 243 i+=1 |
232 assert i==m.n_batches*m.minibatch_size | 244 assert i==2*size # should not wrap |
233 del x,y,i,id | 245 del x,y,i,id,size |
234 | 246 |
235 assert not have_raised2(ds.minibatches,['x','y'],n_batches=1,minibatch_size=len(array)+1,offset=0) | 247 assert have_raised2(ds.minibatches,['x','y'],n_batches=1,minibatch_size=len(array)+1,offset=0) |
236 assert not have_raised2(ds.minibatches,['x','y'],n_batches=1,minibatch_size=len(array),offset=0) | 248 assert not have_raised2(ds.minibatches,['x','y'],n_batches=1,minibatch_size=len(array),offset=0) |
237 | 249 |
238 def test_ds_iterator(array,iterator1,iterator2,iterator3): | 250 def test_ds_iterator(array,iterator1,iterator2,iterator3): |
239 l=len(iterator1) | 251 l=len(iterator1) |
240 i=0 | 252 i=0 |
260 assert i==l | 272 assert i==l |
261 | 273 |
262 def test_getitem(array,ds): | 274 def test_getitem(array,ds): |
263 def test_ds(orig,ds,index): | 275 def test_ds(orig,ds,index): |
264 i=0 | 276 i=0 |
265 assert len(ds)==len(index) | 277 assert isinstance(ds,LookupList) |
266 for x,z,y in ds('x','z','y'): | 278 assert len(ds)==3 |
267 assert (orig[index[i]]['x']==array[index[i]][:3]).all() | 279 assert len(ds[0])==len(index) |
268 assert (orig[index[i]]['x']==x).all() | 280 # for x,z,y in ds('x','z','y'): |
269 assert orig[index[i]]['y']==array[index[i]][3] | 281 for idx in index: |
270 assert (orig[index[i]]['y']==y).all() # why does it crash sometimes? | 282 assert (orig[idx]['x']==array[idx][:3]).all() |
271 assert (orig[index[i]]['z']==array[index[i]][0:3:2]).all() | 283 assert (orig[idx]['x']==ds['x'][i]).all() |
272 assert (orig[index[i]]['z']==z).all() | 284 assert orig[idx]['y']==array[idx][3] |
285 assert (orig[idx]['y']==ds['y'][i]).all() # why does it crash sometimes? | |
286 assert (orig[idx]['z']==array[idx][0:3:2]).all() | |
287 assert (orig[idx]['z']==ds['z'][i]).all() | |
273 i+=1 | 288 i+=1 |
274 del i | 289 del i |
275 ds[0] | 290 ds[0] |
276 if len(ds)>2: | 291 if len(ds)>2: |
277 ds[:1] | 292 ds[:1] |
280 if len(ds)>5: | 295 if len(ds)>5: |
281 ds[[1,2,3]] | 296 ds[[1,2,3]] |
282 for x in ds: | 297 for x in ds: |
283 pass | 298 pass |
284 | 299 |
285 #ds[:n] returns a dataset with the n first examples. | 300 #ds[:n] returns a LookupList with the n first examples. |
286 ds2=ds[:3] | 301 ds2=ds[:3] |
287 assert isinstance(ds2,LookupList) | |
288 test_ds(ds,ds2,index=[0,1,2]) | 302 test_ds(ds,ds2,index=[0,1,2]) |
289 del ds2 | 303 del ds2 |
290 | 304 |
291 #ds[i1:i2:s]# returns a ds with the examples i1,i1+s,...i2-s. | 305 #ds[i:j] returns a LookupList with examples i,i+1,...,j-1. |
292 ds2=ds.subset[1:7:2] | 306 ds2=ds[1:3] |
293 assert isinstance(ds2,DataSet) | 307 test_ds(ds,ds2,index=[1,2]) |
308 del ds2 | |
309 | |
310 #ds[i1:i2:s] returns a LookupList with the examples i1,i1+s,...i2-s. | |
311 ds2=ds[1:7:2] | |
294 test_ds(ds,ds2,[1,3,5]) | 312 test_ds(ds,ds2,[1,3,5]) |
295 del ds2 | 313 del ds2 |
296 | 314 |
297 #ds[i] | 315 #ds[i] returns the (i+1)-th example of the dataset. |
298 ds2=ds[5] | 316 ds2=ds[5] |
299 assert isinstance(ds2,Example) | 317 assert isinstance(ds2,Example) |
300 assert have_raised("var['ds']["+str(len(ds))+"]",ds=ds) # index not defined | 318 assert have_raised("var['ds']["+str(len(ds))+"]",ds=ds) # index not defined |
301 assert not have_raised("var['ds']["+str(len(ds)-1)+"]",ds=ds) | 319 assert not have_raised("var['ds']["+str(len(ds)-1)+"]",ds=ds) |
302 del ds2 | 320 del ds2 |
303 | 321 |
304 #ds[[i1,i2,...in]]# returns a ds with examples i1,i2,...in. | 322 #ds[[i1,i2,...in]]# returns a ds with examples i1,i2,...in. |
305 ds2=ds.subset[[4,7,2,8]] | 323 ds2=ds[[4,7,2,8]] |
306 assert isinstance(ds2,DataSet) | 324 # assert isinstance(ds2,DataSet) |
307 test_ds(ds,ds2,[4,7,2,8]) | 325 test_ds(ds,ds2,[4,7,2,8]) |
308 del ds2 | 326 del ds2 |
309 | 327 |
310 #ds.<property># returns the value of a property associated with | 328 #ds.<property># returns the value of a property associated with |
311 #the name <property>. The following properties should be supported: | 329 #the name <property>. The following properties should be supported: |
323 # for example in hstack([ds('x'),ds('y'),ds('z')]): | 341 # for example in hstack([ds('x'),ds('y'),ds('z')]): |
324 # example==ds[i] | 342 # example==ds[i] |
325 # i+=1 | 343 # i+=1 |
326 # del i,example | 344 # del i,example |
327 #* ds1 & ds2 & ds3 == ds.vstack([ds1,ds2,ds3])#???? | 345 #* ds1 & ds2 & ds3 == ds.vstack([ds1,ds2,ds3])#???? |
346 | |
347 def test_subset(array,ds): | |
348 def test_ds(orig,ds,index): | |
349 i=0 | |
350 assert isinstance(ds2,DataSet) | |
351 assert len(ds)==len(index) | |
352 for x,z,y in ds('x','z','y'): | |
353 assert (orig[index[i]]['x']==array[index[i]][:3]).all() | |
354 assert (orig[index[i]]['x']==x).all() | |
355 assert orig[index[i]]['y']==array[index[i]][3] | |
356 assert orig[index[i]]['y']==y | |
357 assert (orig[index[i]]['z']==array[index[i]][0:3:2]).all() | |
358 assert (orig[index[i]]['z']==z).all() | |
359 i+=1 | |
360 del i | |
361 ds[0] | |
362 if len(ds)>2: | |
363 ds[:1] | |
364 ds[1:1] | |
365 ds[1:1:1] | |
366 if len(ds)>5: | |
367 ds[[1,2,3]] | |
368 for x in ds: | |
369 pass | |
370 | |
371 #ds[:n] returns a dataset with the n first examples. | |
372 ds2=ds.subset[:3] | |
373 test_ds(ds,ds2,index=[0,1,2]) | |
374 # del ds2 | |
375 | |
376 #ds[i1:i2:s]# returns a ds with the examples i1,i1+s,...i2-s. | |
377 ds2=ds.subset[1:7:2] | |
378 test_ds(ds,ds2,[1,3,5]) | |
379 # del ds2 | |
380 | |
381 # #ds[i] | |
382 # ds2=ds.subset[5] | |
383 # assert isinstance(ds2,Example) | |
384 # assert have_raised("var['ds']["+str(len(ds))+"]",ds=ds) # index not defined | |
385 # assert not have_raised("var['ds']["+str(len(ds)-1)+"]",ds=ds) | |
386 # del ds2 | |
387 | |
388 #ds[[i1,i2,...in]]# returns a ds with examples i1,i2,...in. | |
389 ds2=ds.subset[[4,7,2,8]] | |
390 test_ds(ds,ds2,[4,7,2,8]) | |
391 # del ds2 | |
392 | |
393 #ds.<property># returns the value of a property associated with | |
394 #the name <property>. The following properties should be supported: | |
395 # - 'description': a textual description or name for the ds | |
396 # - 'fieldtypes': a list of types (one per field) | |
397 | |
398 #* ds1 | ds2 | ds3 == ds.hstack([ds1,ds2,ds3])#???? | |
399 #assert hstack([ds('x','y'),ds('z')])==ds | |
400 #hstack([ds('z','y'),ds('x')])==ds | |
401 assert have_raised2(hstack,[ds('x'),ds('x')]) | |
402 assert have_raised2(hstack,[ds('y','x'),ds('x')]) | |
403 assert not have_raised2(hstack,[ds('x'),ds('y')]) | |
404 | |
405 # i=0 | |
406 # for example in hstack([ds('x'),ds('y'),ds('z')]): | |
407 # example==ds[i] | |
408 # i+=1 | |
409 # del i,example | |
410 #* ds1 & ds2 & ds3 == ds.vstack([ds1,ds2,ds3])#???? | |
328 | 411 |
329 def test_fields_fct(ds): | 412 def test_fields_fct(ds): |
330 #@todo, fill correctly | 413 #@todo, fill correctly |
331 assert len(ds.fields())==3 | 414 assert len(ds.fields())==3 |
332 i=0 | 415 i=0 |
453 def test_all(array,ds): | 536 def test_all(array,ds): |
454 assert len(ds)==10 | 537 assert len(ds)==10 |
455 test_iterate_over_examples(array, ds) | 538 test_iterate_over_examples(array, ds) |
456 test_overrides(ds) | 539 test_overrides(ds) |
457 test_getitem(array, ds) | 540 test_getitem(array, ds) |
541 test_subset(array, ds) | |
458 test_ds_iterator(array,ds('x','y'),ds('y','z'),ds('x','y','z')) | 542 test_ds_iterator(array,ds('x','y'),ds('y','z'),ds('x','y','z')) |
459 test_fields_fct(ds) | 543 test_fields_fct(ds) |
460 | 544 |
461 | 545 |
462 class T_DataSet(unittest.TestCase): | 546 class T_DataSet(unittest.TestCase): |
508 | 592 |
509 def test_FieldsSubsetDataSet(self): | 593 def test_FieldsSubsetDataSet(self): |
510 a = numpy.random.rand(10,4) | 594 a = numpy.random.rand(10,4) |
511 ds = ArrayDataSet(a,Example(['x','y','z','w'],[slice(3),3,[0,2],0])) | 595 ds = ArrayDataSet(a,Example(['x','y','z','w'],[slice(3),3,[0,2],0])) |
512 ds = FieldsSubsetDataSet(ds,['x','y','z']) | 596 ds = FieldsSubsetDataSet(ds,['x','y','z']) |
597 | |
598 test_all(a,ds) | |
599 | |
600 del a, ds | |
601 | |
602 def test_RenamedFieldsDataSet(self): | |
603 a = numpy.random.rand(10,4) | |
604 ds = ArrayDataSet(a,Example(['x1','y1','z1','w1'],[slice(3),3,[0,2],0])) | |
605 ds = RenamedFieldsDataSet(ds,['x1','y1','z1'],['x','y','z']) | |
513 | 606 |
514 test_all(a,ds) | 607 test_all(a,ds) |
515 | 608 |
516 del a, ds | 609 del a, ds |
517 | 610 |
568 for k in range(len(dsc)) : | 661 for k in range(len(dsc)) : |
569 self.failUnless(numpy.all( dsc[k]('input')[0] == ds[k]('input')[0] ) , (dsc[k],ds[k]) ) | 662 self.failUnless(numpy.all( dsc[k]('input')[0] == ds[k]('input')[0] ) , (dsc[k],ds[k]) ) |
570 res = dsc[:] | 663 res = dsc[:] |
571 | 664 |
572 if __name__=='__main__': | 665 if __name__=='__main__': |
573 if len(sys.argv)==2: | 666 tests = [] |
574 if sys.argv[1]=="--debug": | 667 debug=False |
668 if len(sys.argv)==1: | |
669 unittest.main() | |
670 else: | |
671 assert sys.argv[1]=="--debug" | |
672 for arg in sys.argv[2:]: | |
673 tests.append(arg) | |
674 if tests: | |
675 unittest.TestSuite(map(T_DataSet, tests)).debug() | |
676 else: | |
575 module = __import__("_test_dataset") | 677 module = __import__("_test_dataset") |
576 tests = unittest.TestLoader().loadTestsFromModule(module) | 678 tests = unittest.TestLoader().loadTestsFromModule(module) |
577 tests.debug() | 679 tests.debug() |
578 print "bad argument: only --debug is accepted" | |
579 elif len(sys.argv)==1: | |
580 unittest.main() | |
581 else: | |
582 print "bad argument: only --debug is accepted" | |
583 |