Mercurial > pylearn
comparison _test_dataset.py @ 298:5987415496df
better testing of the MultiLengthDataSet, now called exotic1
author | Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca> |
---|---|
date | Fri, 06 Jun 2008 17:55:14 -0400 |
parents | d08b71d186c8 |
children | eded3cb54930 |
comparison
equal
deleted
inserted
replaced
297:d08b71d186c8 | 298:5987415496df |
---|---|
1 #!/bin/env python | 1 #!/bin/env python |
2 from dataset import * | 2 from dataset import * |
3 from math import * | 3 from math import * |
4 import numpy, unittest, sys | 4 import numpy, unittest, sys |
5 from misc import * | 5 from misc import * |
6 from lookup_list import LookupList | |
6 | 7 |
7 def have_raised(to_eval, **var): | 8 def have_raised(to_eval, **var): |
8 have_thrown = False | 9 have_thrown = False |
9 try: | 10 try: |
10 eval(to_eval) | 11 eval(to_eval) |
436 | 437 |
437 test_all(a,ds) | 438 test_all(a,ds) |
438 | 439 |
439 del a, ds | 440 del a, ds |
440 | 441 |
441 def test_MultiLengthDataSet(self): | 442 def test_MinibatchDataSet(self): |
442 class MultiLengthDataSet(DataSet): | 443 raise NotImplementedError() |
444 def test_HStackedDataSet(self): | |
445 raise NotImplementedError() | |
446 def test_VStackedDataSet(self): | |
447 raise NotImplementedError() | |
448 def test_ArrayFieldsDataSet(self): | |
449 raise NotImplementedError() | |
450 | |
451 | |
452 class T_Exotic1(unittest.TestCase): | |
453 class DataSet(DataSet): | |
443 """ Dummy dataset, where one field is a ndarray of variables size. """ | 454 """ Dummy dataset, where one field is a ndarray of variables size. """ |
444 def __len__(self) : | 455 def __len__(self) : |
445 return 100 | 456 return 100 |
446 def fieldNames(self) : | 457 def fieldNames(self) : |
447 return 'input','target','name' | 458 return 'input','target','name' |
454 def __iter__(self): | 465 def __iter__(self): |
455 return self | 466 return self |
456 def next(self): | 467 def next(self): |
457 for k in self.minibatch._names : | 468 for k in self.minibatch._names : |
458 self.minibatch[k] = [] | 469 self.minibatch[k] = [] |
459 for ex in range(self.minibatch_size) : | 470 for ex in range(self.minibatch_size) : |
460 if 'input' in self.minibatch._names: | 471 if 'input' in self.minibatch._names: |
461 self.minibatch['input'].append( numpy.array( range(self.current + 1) ) ) | 472 self.minibatch['input'].append( numpy.array( range(self.current + 1) ) ) |
462 if 'target' in self.minibatch._names: | 473 if 'target' in self.minibatch._names: |
463 self.minibatch['target'].append( self.current % 2 ) | 474 self.minibatch['target'].append( self.current % 2 ) |
464 if 'name' in self.minibatch._names: | 475 if 'name' in self.minibatch._names: |
465 self.minibatch['name'].append( str(self.current) ) | 476 self.minibatch['name'].append( str(self.current) ) |
466 self.current += 1 | 477 self.current += 1 |
467 return self.minibatch | 478 return self.minibatch |
468 return MultiLengthDataSetIterator(self,fieldnames,minibatch_size,n_batches,offset) | 479 return MultiLengthDataSetIterator(self,fieldnames,minibatch_size,n_batches,offset) |
469 ds = MultiLengthDataSet() | 480 |
470 for k in range(len(ds)): | 481 def test_ApplyFunctionDataSet(self): |
471 x = ds[k] | 482 ds = T_Exotic1.DataSet() |
472 dsa = ApplyFunctionDataset(ds,lambda x,y,z: (x[-1],y*10,int(z)),['input','target','name'],minibatch_mode=True) | 483 dsa = ApplyFunctionDataSet(ds,lambda x,y,z: ([x[-1]],[y*10],[int(z)]),['input','target','name'],minibatch_mode=False) #broken!!!!!! |
473 # needs more testing using ds, dsa, dscache, ... | 484 for k in range(len(dsa)): |
474 raise NotImplementedError() | 485 res = dsa[k] |
475 | 486 self.failUnless(ds[k]('input')[0][-1] == res('input')[0] , 'problem in first applied function') |
476 def test_MinibatchDataSet(self): | 487 res = dsa[33:96:3] |
477 raise NotImplementedError() | 488 |
478 def test_HStackedDataSet(self): | 489 def test_CachedDataSet(self): |
479 raise NotImplementedError() | 490 ds = T_Exotic1.DataSet() |
480 def test_VStackedDataSet(self): | 491 dsc = CachedDataSet(ds) |
481 raise NotImplementedError() | 492 for k in range(len(dsc)) : |
482 def test_ArrayFieldsDataSet(self): | 493 self.failUnless(numpy.all( dsc[k]('input')[0] == ds[k]('input')[0] ) , (dsc[k],ds[k]) ) |
483 raise NotImplementedError() | 494 res = dsc[:] |
484 | |
485 | 495 |
486 if __name__=='__main__': | 496 if __name__=='__main__': |
487 if len(sys.argv)==2: | 497 if len(sys.argv)==2: |
488 if sys.argv[1]=="--debug": | 498 if sys.argv[1]=="--debug": |
489 module = __import__("_test_dataset") | 499 module = __import__("_test_dataset") |