Mercurial > pylearn
diff _test_dataset.py @ 298:5987415496df
better testing of the MultiLengthDataSet, now called exotic1
author | Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca> |
---|---|
date | Fri, 06 Jun 2008 17:55:14 -0400 |
parents | d08b71d186c8 |
children | eded3cb54930 |
line wrap: on
line diff
--- a/_test_dataset.py Fri Jun 06 17:52:00 2008 -0400 +++ b/_test_dataset.py Fri Jun 06 17:55:14 2008 -0400 @@ -3,6 +3,7 @@ from math import * import numpy, unittest, sys from misc import * +from lookup_list import LookupList def have_raised(to_eval, **var): have_thrown = False @@ -438,8 +439,18 @@ del a, ds - def test_MultiLengthDataSet(self): - class MultiLengthDataSet(DataSet): + def test_MinibatchDataSet(self): + raise NotImplementedError() + def test_HStackedDataSet(self): + raise NotImplementedError() + def test_VStackedDataSet(self): + raise NotImplementedError() + def test_ArrayFieldsDataSet(self): + raise NotImplementedError() + + +class T_Exotic1(unittest.TestCase): + class DataSet(DataSet): """ Dummy dataset, where one field is a ndarray of variables size. """ def __len__(self) : return 100 @@ -456,32 +467,31 @@ def next(self): for k in self.minibatch._names : self.minibatch[k] = [] - for ex in range(self.minibatch_size) : - if 'input' in self.minibatch._names: - self.minibatch['input'].append( numpy.array( range(self.current + 1) ) ) - if 'target' in self.minibatch._names: - self.minibatch['target'].append( self.current % 2 ) - if 'name' in self.minibatch._names: - self.minibatch['name'].append( str(self.current) ) - self.current += 1 + for ex in range(self.minibatch_size) : + if 'input' in self.minibatch._names: + self.minibatch['input'].append( numpy.array( range(self.current + 1) ) ) + if 'target' in self.minibatch._names: + self.minibatch['target'].append( self.current % 2 ) + if 'name' in self.minibatch._names: + self.minibatch['name'].append( str(self.current) ) + self.current += 1 return self.minibatch return MultiLengthDataSetIterator(self,fieldnames,minibatch_size,n_batches,offset) - ds = MultiLengthDataSet() - for k in range(len(ds)): - x = ds[k] - dsa = ApplyFunctionDataset(ds,lambda x,y,z: (x[-1],y*10,int(z)),['input','target','name'],minibatch_mode=True) - # needs more testing using ds, dsa, dscache, ... - raise NotImplementedError() - - def test_MinibatchDataSet(self): - raise NotImplementedError() - def test_HStackedDataSet(self): - raise NotImplementedError() - def test_VStackedDataSet(self): - raise NotImplementedError() - def test_ArrayFieldsDataSet(self): - raise NotImplementedError() - + + def test_ApplyFunctionDataSet(self): + ds = T_Exotic1.DataSet() + dsa = ApplyFunctionDataSet(ds,lambda x,y,z: ([x[-1]],[y*10],[int(z)]),['input','target','name'],minibatch_mode=False) #broken!!!!!! + for k in range(len(dsa)): + res = dsa[k] + self.failUnless(ds[k]('input')[0][-1] == res('input')[0] , 'problem in first applied function') + res = dsa[33:96:3] + + def test_CachedDataSet(self): + ds = T_Exotic1.DataSet() + dsc = CachedDataSet(ds) + for k in range(len(dsc)) : + self.failUnless(numpy.all( dsc[k]('input')[0] == ds[k]('input')[0] ) , (dsc[k],ds[k]) ) + res = dsc[:] if __name__=='__main__': if len(sys.argv)==2: