changeset 297:d08b71d186c8

merged
author James Bergstra <bergstrj@iro.umontreal.ca>
date Fri, 06 Jun 2008 17:52:00 -0400
parents f5d33f9c0b9c (current diff) 7380376816e5 (diff)
children 5987415496df
files _test_dataset.py
diffstat 1 files changed, 37 insertions(+), 1 deletions(-) [+]
line wrap: on
line diff
--- a/_test_dataset.py	Fri Jun 06 17:50:29 2008 -0400
+++ b/_test_dataset.py	Fri Jun 06 17:52:00 2008 -0400
@@ -431,12 +431,48 @@
 
     def test_FieldsSubsetDataSet(self):
         a = numpy.random.rand(10,4)
-        ds = ArrayDataSet(a,LookupList(['x','y','z','w'],[slice(3),3,[0,2],0]))
+        ds = ArrayDataSet(a,Example(['x','y','z','w'],[slice(3),3,[0,2],0]))
         ds = FieldsSubsetDataSet(ds,['x','y','z'])
 
         test_all(a,ds)
 
         del a, ds
+
+    def test_MultiLengthDataSet(self):
+        class MultiLengthDataSet(DataSet):
+            """ Dummy dataset, where one field is a ndarray of variables size. """
+            def __len__(self) :
+                return 100
+            def fieldNames(self) :
+                return 'input','target','name'
+            def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset):
+                class MultiLengthDataSetIterator(object):
+                    def __init__(self,dataset,fieldnames,minibatch_size,n_batches,offset):
+                        if fieldnames is None: fieldnames = dataset.fieldNames()
+                        self.minibatch = Example(fieldnames,range(len(fieldnames)))
+                        self.dataset, self.minibatch_size, self.current = dataset, minibatch_size, offset
+                    def __iter__(self):
+                            return self
+                    def next(self):
+                        for k in self.minibatch._names :
+                            self.minibatch[k] = []
+                            for ex in range(self.minibatch_size) :
+                                if 'input' in self.minibatch._names:
+                                    self.minibatch['input'].append( numpy.array( range(self.current + 1) ) )
+                                if 'target' in self.minibatch._names:
+                                    self.minibatch['target'].append( self.current % 2 )
+                                if 'name' in self.minibatch._names:
+                                    self.minibatch['name'].append( str(self.current) )
+                                self.current += 1
+                        return self.minibatch
+                return MultiLengthDataSetIterator(self,fieldnames,minibatch_size,n_batches,offset)
+        ds = MultiLengthDataSet()
+        for k in range(len(ds)):
+            x = ds[k]
+        dsa = ApplyFunctionDataset(ds,lambda x,y,z: (x[-1],y*10,int(z)),['input','target','name'],minibatch_mode=True)
+        # needs more testing using ds, dsa, dscache, ...
+        raise NotImplementedError()
+
     def test_MinibatchDataSet(self):
         raise NotImplementedError()
     def test_HStackedDataSet(self):