Mercurial > pylearn
comparison test_dataset.py @ 165:2a12e7437c56
small refactoring
author | Frederic Bastien <bastienf@iro.umontreal.ca> |
---|---|
date | Mon, 12 May 2008 17:51:28 -0400 |
parents | 3518710e16ec |
children | c704a66706fe |
comparison
equal
deleted
inserted
replaced
164:3518710e16ec | 165:2a12e7437c56 |
---|---|
31 print "minibatch=",minibatch | 31 print "minibatch=",minibatch |
32 for var in minibatch: | 32 for var in minibatch: |
33 print "var=",var | 33 print "var=",var |
34 print "take a slice and look at field y",ds[1:6:2]["y"] | 34 print "take a slice and look at field y",ds[1:6:2]["y"] |
35 | 35 |
36 del a,ds,x,y,minibatch_iterator,minibatch,var | |
37 | |
36 def test_iterate_over_examples(array,ds): | 38 def test_iterate_over_examples(array,ds): |
37 #not in doc!!! | 39 #not in doc!!! |
38 i=0 | 40 i=0 |
39 for example in range(len(ds)): | 41 for example in range(len(ds)): |
40 assert (ds[example]['x']==array[example][:3]).all() | 42 assert (ds[example]['x']==array[example][:3]).all() |
331 # example==ds[i] | 333 # example==ds[i] |
332 # i+=1 | 334 # i+=1 |
333 # del i,example | 335 # del i,example |
334 #* ds1 & ds2 & ds3 == ds.vstack([ds1,ds2,ds3])#???? | 336 #* ds1 & ds2 & ds3 == ds.vstack([ds1,ds2,ds3])#???? |
335 | 337 |
338 def test_fields_fct(ds): | |
339 #@todo, fill correctly | |
340 assert len(ds.fields())==3 | |
341 for field in ds.fields(): | |
342 for field_value in field: # iterate over the values associated to that field for all the ds examples | |
343 pass | |
344 for field in ds('x','z').fields(): | |
345 pass | |
346 for field in ds.fields('x','y'): | |
347 pass | |
348 for field_examples in ds.fields(): | |
349 for example_value in field_examples: | |
350 pass | |
351 | |
352 assert ds == ds.fields().examples() | |
353 | |
354 | |
336 | 355 |
337 def test_ArrayDataSet(): | 356 def test_ArrayDataSet(): |
338 #don't test stream | 357 #don't test stream |
339 #tested only with float value | 358 #tested only with float value |
340 #don't always test with y | 359 #don't always test with y |
351 test_iterate_over_examples(a2, ds) | 370 test_iterate_over_examples(a2, ds) |
352 test_getitem(a2, ds) | 371 test_getitem(a2, ds) |
353 | 372 |
354 # - for val1,val2,val3 in dataset(field1, field2,field3): | 373 # - for val1,val2,val3 in dataset(field1, field2,field3): |
355 test_ds_iterator(a2,ds('x','y'),ds('y','z'),ds('x','y','z')) | 374 test_ds_iterator(a2,ds('x','y'),ds('y','z'),ds('x','y','z')) |
356 | 375 test_fields_fct(ds) |
357 | 376 del a2, ds |
358 assert len(ds.fields())==3 | |
359 for field in ds.fields(): | |
360 for field_value in field: # iterate over the values associated to that field for all the ds examples | |
361 pass | |
362 for field in ds('x','z').fields(): | |
363 pass | |
364 for field in ds.fields('x','y'): | |
365 pass | |
366 for field_examples in ds.fields(): | |
367 for example_value in field_examples: | |
368 pass | |
369 | |
370 assert ds == ds.fields().examples() | |
371 # for ((x,y),a_v) in (ds('x','y'),a): #???don't work # haven't found a variant that work.# will not work | |
372 # assert numpy.append(x,y)==z | |
373 | 377 |
374 def test_LookupList(): | 378 def test_LookupList(): |
375 #test only the example in the doc??? | 379 #test only the example in the doc??? |
376 print "test_LookupList" | 380 print "test_LookupList" |
377 example = LookupList(['x','y','z'],[1,2,3]) | 381 example = LookupList(['x','y','z'],[1,2,3]) |
387 example2 = LookupList(['v','w'], ['a','b']) | 391 example2 = LookupList(['v','w'], ['a','b']) |
388 example3 = LookupList(['x','y','z','u','v','w'], [[1, 2, 3],2,3,0,'a','b']) | 392 example3 = LookupList(['x','y','z','u','v','w'], [[1, 2, 3],2,3,0,'a','b']) |
389 assert example+example2==example3 | 393 assert example+example2==example3 |
390 assert have_raised("var['x']+var['x']",x=example) | 394 assert have_raised("var['x']+var['x']",x=example) |
391 | 395 |
396 del example, example2, example3, x, y ,z | |
397 | |
392 def test_CachedDataSet(): | 398 def test_CachedDataSet(): |
393 print "test_CacheDataSet" | 399 print "test_CacheDataSet" |
394 a2 = numpy.random.rand(10,4) | 400 a2 = numpy.random.rand(10,4) |
395 ds1 = ArrayDataSet(a2,LookupList(['x','y','z'],[slice(3),3,[0,2]]))###???tuple not tested | 401 ds1 = ArrayDataSet(a2,LookupList(['x','y','z'],[slice(3),3,[0,2]]))###???tuple not tested |
396 ds2 = CachedDataSet(ds1) | 402 ds2 = CachedDataSet(ds1) |
400 test_iterate_over_examples(a2, ds2) | 406 test_iterate_over_examples(a2, ds2) |
401 test_getitem(a2, ds2) | 407 test_getitem(a2, ds2) |
402 | 408 |
403 # - for val1,val2,val3 in dataset(field1, field2,field3): | 409 # - for val1,val2,val3 in dataset(field1, field2,field3): |
404 test_ds_iterator(a2,ds2('x','y'),ds2('y','z'),ds2('x','y','z')) | 410 test_ds_iterator(a2,ds2('x','y'),ds2('y','z'),ds2('x','y','z')) |
405 | 411 test_fields_fct(ds2) |
406 | 412 |
407 assert len(ds2.fields())==3 | 413 del a2,ds1,ds2,ds3 |
408 for field in ds2.fields(): | |
409 for field_value in field: # iterate over the values associated to that field for all the ds examples | |
410 pass | |
411 for field in ds2('x','z').fields(): | |
412 pass | |
413 for field in ds2.fields('x','y'): | |
414 pass | |
415 for field_examples in ds2.fields(): | |
416 for example_value in field_examples: | |
417 pass | |
418 | |
419 assert ds2 == ds2.fields().examples() | |
420 # for ((x,y),a_v) in (ds('x','y'),a): #???don't work # haven't found a variant that work.# will not work | |
421 # assert numpy.append(x,y)==z | |
422 | 414 |
423 | 415 |
424 def test_DataSetFields(): | 416 def test_DataSetFields(): |
425 print "test_DataSetFields" | 417 print "test_DataSetFields" |
426 raise NotImplementedError() | 418 raise NotImplementedError() |