changeset 298:5987415496df

better testing of the MultiLengthDataSet, now called exotic1
author Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
date Fri, 06 Jun 2008 17:55:14 -0400
parents d08b71d186c8
children eded3cb54930
files _test_dataset.py
diffstat 1 files changed, 36 insertions(+), 26 deletions(-) [+]
line wrap: on
line diff
--- a/_test_dataset.py	Fri Jun 06 17:52:00 2008 -0400
+++ b/_test_dataset.py	Fri Jun 06 17:55:14 2008 -0400
@@ -3,6 +3,7 @@
 from math import *
 import numpy, unittest, sys
 from misc import *
+from lookup_list import LookupList
 
 def have_raised(to_eval, **var):
     have_thrown = False
@@ -438,8 +439,18 @@
 
         del a, ds
 
-    def test_MultiLengthDataSet(self):
-        class MultiLengthDataSet(DataSet):
+    def test_MinibatchDataSet(self):
+        raise NotImplementedError()
+    def test_HStackedDataSet(self):
+        raise NotImplementedError()
+    def test_VStackedDataSet(self):
+        raise NotImplementedError()
+    def test_ArrayFieldsDataSet(self):
+        raise NotImplementedError()
+
+
+class T_Exotic1(unittest.TestCase):
+    class DataSet(DataSet):
             """ Dummy dataset, where one field is a ndarray of variables size. """
             def __len__(self) :
                 return 100
@@ -456,32 +467,31 @@
                     def next(self):
                         for k in self.minibatch._names :
                             self.minibatch[k] = []
-                            for ex in range(self.minibatch_size) :
-                                if 'input' in self.minibatch._names:
-                                    self.minibatch['input'].append( numpy.array( range(self.current + 1) ) )
-                                if 'target' in self.minibatch._names:
-                                    self.minibatch['target'].append( self.current % 2 )
-                                if 'name' in self.minibatch._names:
-                                    self.minibatch['name'].append( str(self.current) )
-                                self.current += 1
+                        for ex in range(self.minibatch_size) :
+                            if 'input' in self.minibatch._names:
+                                self.minibatch['input'].append( numpy.array( range(self.current + 1) ) )
+                            if 'target' in self.minibatch._names:
+                                self.minibatch['target'].append( self.current % 2 )
+                            if 'name' in self.minibatch._names:
+                                self.minibatch['name'].append( str(self.current) )
+                            self.current += 1
                         return self.minibatch
                 return MultiLengthDataSetIterator(self,fieldnames,minibatch_size,n_batches,offset)
-        ds = MultiLengthDataSet()
-        for k in range(len(ds)):
-            x = ds[k]
-        dsa = ApplyFunctionDataset(ds,lambda x,y,z: (x[-1],y*10,int(z)),['input','target','name'],minibatch_mode=True)
-        # needs more testing using ds, dsa, dscache, ...
-        raise NotImplementedError()
-
-    def test_MinibatchDataSet(self):
-        raise NotImplementedError()
-    def test_HStackedDataSet(self):
-        raise NotImplementedError()
-    def test_VStackedDataSet(self):
-        raise NotImplementedError()
-    def test_ArrayFieldsDataSet(self):
-        raise NotImplementedError()
-
+    
+    def test_ApplyFunctionDataSet(self):
+        ds = T_Exotic1.DataSet()
+        dsa = ApplyFunctionDataSet(ds,lambda x,y,z: ([x[-1]],[y*10],[int(z)]),['input','target','name'],minibatch_mode=False) #broken!!!!!!
+        for k in range(len(dsa)):
+            res = dsa[k]
+            self.failUnless(ds[k]('input')[0][-1] == res('input')[0] , 'problem in first applied function')
+        res = dsa[33:96:3]
+          
+    def test_CachedDataSet(self):
+        ds = T_Exotic1.DataSet()
+        dsc = CachedDataSet(ds)
+        for k in range(len(dsc)) :
+            self.failUnless(numpy.all( dsc[k]('input')[0] == ds[k]('input')[0] ) , (dsc[k],ds[k]) )
+        res = dsc[:]
 
 if __name__=='__main__':
     if len(sys.argv)==2: