comparison test_dataset.py @ 165:2a12e7437c56

small refactoring
author Frederic Bastien <bastienf@iro.umontreal.ca>
date Mon, 12 May 2008 17:51:28 -0400
parents 3518710e16ec
children c704a66706fe
comparison
equal deleted inserted replaced
164:3518710e16ec 165:2a12e7437c56
31 print "minibatch=",minibatch 31 print "minibatch=",minibatch
32 for var in minibatch: 32 for var in minibatch:
33 print "var=",var 33 print "var=",var
34 print "take a slice and look at field y",ds[1:6:2]["y"] 34 print "take a slice and look at field y",ds[1:6:2]["y"]
35 35
36 del a,ds,x,y,minibatch_iterator,minibatch,var
37
36 def test_iterate_over_examples(array,ds): 38 def test_iterate_over_examples(array,ds):
37 #not in doc!!! 39 #not in doc!!!
38 i=0 40 i=0
39 for example in range(len(ds)): 41 for example in range(len(ds)):
40 assert (ds[example]['x']==array[example][:3]).all() 42 assert (ds[example]['x']==array[example][:3]).all()
331 # example==ds[i] 333 # example==ds[i]
332 # i+=1 334 # i+=1
333 # del i,example 335 # del i,example
334 #* ds1 & ds2 & ds3 == ds.vstack([ds1,ds2,ds3])#???? 336 #* ds1 & ds2 & ds3 == ds.vstack([ds1,ds2,ds3])#????
335 337
338 def test_fields_fct(ds):
339 #@todo, fill correctly
340 assert len(ds.fields())==3
341 for field in ds.fields():
342 for field_value in field: # iterate over the values associated to that field for all the ds examples
343 pass
344 for field in ds('x','z').fields():
345 pass
346 for field in ds.fields('x','y'):
347 pass
348 for field_examples in ds.fields():
349 for example_value in field_examples:
350 pass
351
352 assert ds == ds.fields().examples()
353
354
336 355
337 def test_ArrayDataSet(): 356 def test_ArrayDataSet():
338 #don't test stream 357 #don't test stream
339 #tested only with float value 358 #tested only with float value
340 #don't always test with y 359 #don't always test with y
351 test_iterate_over_examples(a2, ds) 370 test_iterate_over_examples(a2, ds)
352 test_getitem(a2, ds) 371 test_getitem(a2, ds)
353 372
354 # - for val1,val2,val3 in dataset(field1, field2,field3): 373 # - for val1,val2,val3 in dataset(field1, field2,field3):
355 test_ds_iterator(a2,ds('x','y'),ds('y','z'),ds('x','y','z')) 374 test_ds_iterator(a2,ds('x','y'),ds('y','z'),ds('x','y','z'))
356 375 test_fields_fct(ds)
357 376 del a2, ds
358 assert len(ds.fields())==3
359 for field in ds.fields():
360 for field_value in field: # iterate over the values associated to that field for all the ds examples
361 pass
362 for field in ds('x','z').fields():
363 pass
364 for field in ds.fields('x','y'):
365 pass
366 for field_examples in ds.fields():
367 for example_value in field_examples:
368 pass
369
370 assert ds == ds.fields().examples()
371 # for ((x,y),a_v) in (ds('x','y'),a): #???don't work # haven't found a variant that work.# will not work
372 # assert numpy.append(x,y)==z
373 377
374 def test_LookupList(): 378 def test_LookupList():
375 #test only the example in the doc??? 379 #test only the example in the doc???
376 print "test_LookupList" 380 print "test_LookupList"
377 example = LookupList(['x','y','z'],[1,2,3]) 381 example = LookupList(['x','y','z'],[1,2,3])
387 example2 = LookupList(['v','w'], ['a','b']) 391 example2 = LookupList(['v','w'], ['a','b'])
388 example3 = LookupList(['x','y','z','u','v','w'], [[1, 2, 3],2,3,0,'a','b']) 392 example3 = LookupList(['x','y','z','u','v','w'], [[1, 2, 3],2,3,0,'a','b'])
389 assert example+example2==example3 393 assert example+example2==example3
390 assert have_raised("var['x']+var['x']",x=example) 394 assert have_raised("var['x']+var['x']",x=example)
391 395
396 del example, example2, example3, x, y ,z
397
392 def test_CachedDataSet(): 398 def test_CachedDataSet():
393 print "test_CacheDataSet" 399 print "test_CacheDataSet"
394 a2 = numpy.random.rand(10,4) 400 a2 = numpy.random.rand(10,4)
395 ds1 = ArrayDataSet(a2,LookupList(['x','y','z'],[slice(3),3,[0,2]]))###???tuple not tested 401 ds1 = ArrayDataSet(a2,LookupList(['x','y','z'],[slice(3),3,[0,2]]))###???tuple not tested
396 ds2 = CachedDataSet(ds1) 402 ds2 = CachedDataSet(ds1)
400 test_iterate_over_examples(a2, ds2) 406 test_iterate_over_examples(a2, ds2)
401 test_getitem(a2, ds2) 407 test_getitem(a2, ds2)
402 408
403 # - for val1,val2,val3 in dataset(field1, field2,field3): 409 # - for val1,val2,val3 in dataset(field1, field2,field3):
404 test_ds_iterator(a2,ds2('x','y'),ds2('y','z'),ds2('x','y','z')) 410 test_ds_iterator(a2,ds2('x','y'),ds2('y','z'),ds2('x','y','z'))
405 411 test_fields_fct(ds2)
406 412
407 assert len(ds2.fields())==3 413 del a2,ds1,ds2,ds3
408 for field in ds2.fields():
409 for field_value in field: # iterate over the values associated to that field for all the ds examples
410 pass
411 for field in ds2('x','z').fields():
412 pass
413 for field in ds2.fields('x','y'):
414 pass
415 for field_examples in ds2.fields():
416 for example_value in field_examples:
417 pass
418
419 assert ds2 == ds2.fields().examples()
420 # for ((x,y),a_v) in (ds('x','y'),a): #???don't work # haven't found a variant that work.# will not work
421 # assert numpy.append(x,y)==z
422 414
423 415
424 def test_DataSetFields(): 416 def test_DataSetFields():
425 print "test_DataSetFields" 417 print "test_DataSetFields"
426 raise NotImplementedError() 418 raise NotImplementedError()