Mercurial > pylearn
diff _test_dataset.py @ 295:7380376816e5
started a test for datasets where one field has a variable length. Not obvious, all tests requires a matrix as a reference
author | Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca> |
---|---|
date | Fri, 06 Jun 2008 17:11:25 -0400 |
parents | f7924e13e426 |
children | d08b71d186c8 |
line wrap: on
line diff
--- a/_test_dataset.py Fri Jun 06 16:15:47 2008 -0400 +++ b/_test_dataset.py Fri Jun 06 17:11:25 2008 -0400 @@ -437,6 +437,42 @@ test_all(a,ds) del a, ds + + def test_MultiLengthDataSet(self): + class MultiLengthDataSet(DataSet): + """ Dummy dataset, where one field is a ndarray of variables size. """ + def __len__(self) : + return 100 + def fieldNames(self) : + return 'input','target','name' + def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): + class MultiLengthDataSetIterator(object): + def __init__(self,dataset,fieldnames,minibatch_size,n_batches,offset): + if fieldnames is None: fieldnames = dataset.fieldNames() + self.minibatch = LookupList(fieldnames,range(len(fieldnames))) + self.dataset, self.minibatch_size, self.current = dataset, minibatch_size, offset + def __iter__(self): + return self + def next(self): + for k in self.minibatch._names : + self.minibatch[k] = [] + for ex in range(self.minibatch_size) : + if 'input' in self.minibatch._names: + self.minibatch['input'].append( numpy.array( range(self.current + 1) ) ) + if 'target' in self.minibatch._names: + self.minibatch['target'].append( self.current % 2 ) + if 'name' in self.minibatch._names: + self.minibatch['name'].append( str(self.current) ) + self.current += 1 + return self.minibatch + return MultiLengthDataSetIterator(self,fieldnames,minibatch_size,n_batches,offset) + ds = MultiLengthDataSet() + for k in range(len(ds)): + x = ds[k] + dsa = ApplyFunctionDataset(ds,lambda x,y,z: (x[-1],y*10,int(z)),['input','target','name'],minibatch_mode=True) + # needs more testing using ds, dsa, dscache, ... + raise NotImplementedError() + def test_MinibatchDataSet(self): raise NotImplementedError() def test_HStackedDataSet(self):