changeset 295:7380376816e5

started a test for datasets where one field has a variable length. Not obvious, all tests requires a matrix as a reference
author Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
date Fri, 06 Jun 2008 17:11:25 -0400
parents f7924e13e426
children d08b71d186c8
files _test_dataset.py
diffstat 1 files changed, 36 insertions(+), 0 deletions(-) [+]
line wrap: on
line diff
--- a/_test_dataset.py	Fri Jun 06 16:15:47 2008 -0400
+++ b/_test_dataset.py	Fri Jun 06 17:11:25 2008 -0400
@@ -437,6 +437,42 @@
         test_all(a,ds)
 
         del a, ds
+
+    def test_MultiLengthDataSet(self):
+        class MultiLengthDataSet(DataSet):
+            """ Dummy dataset, where one field is a ndarray of variables size. """
+            def __len__(self) :
+                return 100
+            def fieldNames(self) :
+                return 'input','target','name'
+            def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset):
+                class MultiLengthDataSetIterator(object):
+                    def __init__(self,dataset,fieldnames,minibatch_size,n_batches,offset):
+                        if fieldnames is None: fieldnames = dataset.fieldNames()
+                        self.minibatch = LookupList(fieldnames,range(len(fieldnames)))
+                        self.dataset, self.minibatch_size, self.current = dataset, minibatch_size, offset
+                    def __iter__(self):
+                            return self
+                    def next(self):
+                        for k in self.minibatch._names :
+                            self.minibatch[k] = []
+                            for ex in range(self.minibatch_size) :
+                                if 'input' in self.minibatch._names:
+                                    self.minibatch['input'].append( numpy.array( range(self.current + 1) ) )
+                                if 'target' in self.minibatch._names:
+                                    self.minibatch['target'].append( self.current % 2 )
+                                if 'name' in self.minibatch._names:
+                                    self.minibatch['name'].append( str(self.current) )
+                                self.current += 1
+                        return self.minibatch
+                return MultiLengthDataSetIterator(self,fieldnames,minibatch_size,n_batches,offset)
+        ds = MultiLengthDataSet()
+        for k in range(len(ds)):
+            x = ds[k]
+        dsa = ApplyFunctionDataset(ds,lambda x,y,z: (x[-1],y*10,int(z)),['input','target','name'],minibatch_mode=True)
+        # needs more testing using ds, dsa, dscache, ...
+        raise NotImplementedError()
+
     def test_MinibatchDataSet(self):
         raise NotImplementedError()
     def test_HStackedDataSet(self):