Mercurial > pylearn
comparison _test_dataset.py @ 315:b48cf8dce2bf
test to compare overriden __getitem__ implemented, tested on ArrayDataSet.__getitem__
author | Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca> |
---|---|
date | Wed, 11 Jun 2008 16:26:41 -0400 |
parents | 96cca78de3ed |
children | 9c08e3af975e 4efb503fd0da |
comparison
equal
deleted
inserted
replaced
314:105b54ac8260 | 315:b48cf8dce2bf |
---|---|
265 assert len(ds)==len(index) | 265 assert len(ds)==len(index) |
266 for x,z,y in ds('x','z','y'): | 266 for x,z,y in ds('x','z','y'): |
267 assert (orig[index[i]]['x']==array[index[i]][:3]).all() | 267 assert (orig[index[i]]['x']==array[index[i]][:3]).all() |
268 assert (orig[index[i]]['x']==x).all() | 268 assert (orig[index[i]]['x']==x).all() |
269 assert orig[index[i]]['y']==array[index[i]][3] | 269 assert orig[index[i]]['y']==array[index[i]][3] |
270 assert (orig[index[i]]['y']==y).all() | 270 assert (orig[index[i]]['y']==y).all() # why does it crash sometimes? |
271 assert (orig[index[i]]['z']==array[index[i]][0:3:2]).all() | 271 assert (orig[index[i]]['z']==array[index[i]][0:3:2]).all() |
272 assert (orig[index[i]]['z']==z).all() | 272 assert (orig[index[i]]['z']==z).all() |
273 i+=1 | 273 i+=1 |
274 del i | 274 del i |
275 ds[0] | 275 ds[0] |
373 assert len(ds('x','y').fields()) == 2 | 373 assert len(ds('x','y').fields()) == 2 |
374 assert len(ds('x','z').fields()) == 2 | 374 assert len(ds('x','z').fields()) == 2 |
375 assert len(ds('y').fields()) == 1 | 375 assert len(ds('y').fields()) == 1 |
376 | 376 |
377 del field | 377 del field |
378 | |
379 def test_overrides(ds) : | |
380 """ Test for examples that an override __getitem__ acts as the one in DataSet """ | |
381 def ndarray_list_equal(nda,l) : | |
382 """ | |
383 Compares if a ndarray is the same as the list. Do it by converting the list into | |
384 an numpy.ndarray, if possible | |
385 """ | |
386 try : | |
387 l = numpy.asmatrix(l) | |
388 except : | |
389 return False | |
390 return smart_equal(nda,l) | |
391 | |
392 def smart_equal(a1,a2) : | |
393 """ | |
394 Handles numpy.ndarray, LookupList, and basic containers | |
395 """ | |
396 if not isinstance(a1,type(a2)) and not isinstance(a2,type(a1)): | |
397 #special case: matrix vs list of arrays | |
398 if isinstance(a1,numpy.ndarray) : | |
399 return ndarray_list_equal(a1,a2) | |
400 elif isinstance(a2,numpy.ndarray) : | |
401 return ndarray_list_equal(a2,a1) | |
402 return False | |
403 # compares 2 numpy.ndarray | |
404 if isinstance(a1,numpy.ndarray): | |
405 if len(a1.shape) != len(a2.shape): | |
406 return False | |
407 for k in range(len(a1.shape)) : | |
408 if a1.shape[k] != a2.shape[k]: | |
409 return False | |
410 return (a1==a2).all() | |
411 # compares 2 lookuplists | |
412 if isinstance(a1,LookupList) : | |
413 if len(a1._names) != len(a2._names) : | |
414 return False | |
415 for k in a1._names : | |
416 if k not in a2._names : | |
417 return False | |
418 if not smart_equal(a1[k],a2[k]) : | |
419 return False | |
420 return True | |
421 # compares 2 basic containers | |
422 if hasattr(a1,'__len__'): | |
423 if len(a1) != len(a2) : | |
424 return False | |
425 for k in range(len(a1)) : | |
426 if not smart_equal(a1[k],a2[k]): | |
427 return False | |
428 return True | |
429 # try basic equals | |
430 return a1 is a2 | |
431 | |
432 def mask(ds) : | |
433 class TestOverride(type(ds)): | |
434 def __init__(self,ds) : | |
435 self.ds = ds | |
436 def __getitem__(self,key) : | |
437 res1 = self.ds[key] | |
438 res2 = DataSet.__getitem__(ds,key) | |
439 assert smart_equal(res1,res2) | |
440 return res1 | |
441 return TestOverride(ds) | |
442 # test getitem | |
443 ds2 = mask(ds) | |
444 for k in range(10): | |
445 res = ds2[k] | |
446 res = ds2[1:len(ds):3] | |
447 | |
448 | |
449 | |
450 | |
451 | |
452 | |
378 def test_all(array,ds): | 453 def test_all(array,ds): |
379 assert len(ds)==10 | 454 assert len(ds)==10 |
380 | |
381 test_iterate_over_examples(array, ds) | 455 test_iterate_over_examples(array, ds) |
456 test_overrides(ds) | |
382 test_getitem(array, ds) | 457 test_getitem(array, ds) |
383 test_ds_iterator(array,ds('x','y'),ds('y','z'),ds('x','y','z')) | 458 test_ds_iterator(array,ds('x','y'),ds('y','z'),ds('x','y','z')) |
384 test_fields_fct(ds) | 459 test_fields_fct(ds) |
460 | |
385 | 461 |
386 class T_DataSet(unittest.TestCase): | 462 class T_DataSet(unittest.TestCase): |
387 def test_ArrayDataSet(self): | 463 def test_ArrayDataSet(self): |
388 #don't test stream | 464 #don't test stream |
389 #tested only with float value | 465 #tested only with float value |