comparison _test_dataset.py @ 298:5987415496df

better testing of the MultiLengthDataSet, now called exotic1
author Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
date Fri, 06 Jun 2008 17:55:14 -0400
parents d08b71d186c8
children eded3cb54930
comparison
equal deleted inserted replaced
297:d08b71d186c8 298:5987415496df
1 #!/bin/env python 1 #!/bin/env python
2 from dataset import * 2 from dataset import *
3 from math import * 3 from math import *
4 import numpy, unittest, sys 4 import numpy, unittest, sys
5 from misc import * 5 from misc import *
6 from lookup_list import LookupList
6 7
7 def have_raised(to_eval, **var): 8 def have_raised(to_eval, **var):
8 have_thrown = False 9 have_thrown = False
9 try: 10 try:
10 eval(to_eval) 11 eval(to_eval)
436 437
437 test_all(a,ds) 438 test_all(a,ds)
438 439
439 del a, ds 440 del a, ds
440 441
441 def test_MultiLengthDataSet(self): 442 def test_MinibatchDataSet(self):
442 class MultiLengthDataSet(DataSet): 443 raise NotImplementedError()
444 def test_HStackedDataSet(self):
445 raise NotImplementedError()
446 def test_VStackedDataSet(self):
447 raise NotImplementedError()
448 def test_ArrayFieldsDataSet(self):
449 raise NotImplementedError()
450
451
452 class T_Exotic1(unittest.TestCase):
453 class DataSet(DataSet):
443 """ Dummy dataset, where one field is a ndarray of variables size. """ 454 """ Dummy dataset, where one field is a ndarray of variables size. """
444 def __len__(self) : 455 def __len__(self) :
445 return 100 456 return 100
446 def fieldNames(self) : 457 def fieldNames(self) :
447 return 'input','target','name' 458 return 'input','target','name'
454 def __iter__(self): 465 def __iter__(self):
455 return self 466 return self
456 def next(self): 467 def next(self):
457 for k in self.minibatch._names : 468 for k in self.minibatch._names :
458 self.minibatch[k] = [] 469 self.minibatch[k] = []
459 for ex in range(self.minibatch_size) : 470 for ex in range(self.minibatch_size) :
460 if 'input' in self.minibatch._names: 471 if 'input' in self.minibatch._names:
461 self.minibatch['input'].append( numpy.array( range(self.current + 1) ) ) 472 self.minibatch['input'].append( numpy.array( range(self.current + 1) ) )
462 if 'target' in self.minibatch._names: 473 if 'target' in self.minibatch._names:
463 self.minibatch['target'].append( self.current % 2 ) 474 self.minibatch['target'].append( self.current % 2 )
464 if 'name' in self.minibatch._names: 475 if 'name' in self.minibatch._names:
465 self.minibatch['name'].append( str(self.current) ) 476 self.minibatch['name'].append( str(self.current) )
466 self.current += 1 477 self.current += 1
467 return self.minibatch 478 return self.minibatch
468 return MultiLengthDataSetIterator(self,fieldnames,minibatch_size,n_batches,offset) 479 return MultiLengthDataSetIterator(self,fieldnames,minibatch_size,n_batches,offset)
469 ds = MultiLengthDataSet() 480
470 for k in range(len(ds)): 481 def test_ApplyFunctionDataSet(self):
471 x = ds[k] 482 ds = T_Exotic1.DataSet()
472 dsa = ApplyFunctionDataset(ds,lambda x,y,z: (x[-1],y*10,int(z)),['input','target','name'],minibatch_mode=True) 483 dsa = ApplyFunctionDataSet(ds,lambda x,y,z: ([x[-1]],[y*10],[int(z)]),['input','target','name'],minibatch_mode=False) #broken!!!!!!
473 # needs more testing using ds, dsa, dscache, ... 484 for k in range(len(dsa)):
474 raise NotImplementedError() 485 res = dsa[k]
475 486 self.failUnless(ds[k]('input')[0][-1] == res('input')[0] , 'problem in first applied function')
476 def test_MinibatchDataSet(self): 487 res = dsa[33:96:3]
477 raise NotImplementedError() 488
478 def test_HStackedDataSet(self): 489 def test_CachedDataSet(self):
479 raise NotImplementedError() 490 ds = T_Exotic1.DataSet()
480 def test_VStackedDataSet(self): 491 dsc = CachedDataSet(ds)
481 raise NotImplementedError() 492 for k in range(len(dsc)) :
482 def test_ArrayFieldsDataSet(self): 493 self.failUnless(numpy.all( dsc[k]('input')[0] == ds[k]('input')[0] ) , (dsc[k],ds[k]) )
483 raise NotImplementedError() 494 res = dsc[:]
484
485 495
486 if __name__=='__main__': 496 if __name__=='__main__':
487 if len(sys.argv)==2: 497 if len(sys.argv)==2:
488 if sys.argv[1]=="--debug": 498 if sys.argv[1]=="--debug":
489 module = __import__("_test_dataset") 499 module = __import__("_test_dataset")