Mercurial > pylearn
changeset 297:d08b71d186c8
merged
author | James Bergstra <bergstrj@iro.umontreal.ca> |
---|---|
date | Fri, 06 Jun 2008 17:52:00 -0400 |
parents | f5d33f9c0b9c (current diff) 7380376816e5 (diff) |
children | 5987415496df |
files | _test_dataset.py |
diffstat | 1 files changed, 37 insertions(+), 1 deletions(-) [+] |
line wrap: on
line diff
--- a/_test_dataset.py Fri Jun 06 17:50:29 2008 -0400 +++ b/_test_dataset.py Fri Jun 06 17:52:00 2008 -0400 @@ -431,12 +431,48 @@ def test_FieldsSubsetDataSet(self): a = numpy.random.rand(10,4) - ds = ArrayDataSet(a,LookupList(['x','y','z','w'],[slice(3),3,[0,2],0])) + ds = ArrayDataSet(a,Example(['x','y','z','w'],[slice(3),3,[0,2],0])) ds = FieldsSubsetDataSet(ds,['x','y','z']) test_all(a,ds) del a, ds + + def test_MultiLengthDataSet(self): + class MultiLengthDataSet(DataSet): + """ Dummy dataset, where one field is a ndarray of variables size. """ + def __len__(self) : + return 100 + def fieldNames(self) : + return 'input','target','name' + def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): + class MultiLengthDataSetIterator(object): + def __init__(self,dataset,fieldnames,minibatch_size,n_batches,offset): + if fieldnames is None: fieldnames = dataset.fieldNames() + self.minibatch = Example(fieldnames,range(len(fieldnames))) + self.dataset, self.minibatch_size, self.current = dataset, minibatch_size, offset + def __iter__(self): + return self + def next(self): + for k in self.minibatch._names : + self.minibatch[k] = [] + for ex in range(self.minibatch_size) : + if 'input' in self.minibatch._names: + self.minibatch['input'].append( numpy.array( range(self.current + 1) ) ) + if 'target' in self.minibatch._names: + self.minibatch['target'].append( self.current % 2 ) + if 'name' in self.minibatch._names: + self.minibatch['name'].append( str(self.current) ) + self.current += 1 + return self.minibatch + return MultiLengthDataSetIterator(self,fieldnames,minibatch_size,n_batches,offset) + ds = MultiLengthDataSet() + for k in range(len(ds)): + x = ds[k] + dsa = ApplyFunctionDataset(ds,lambda x,y,z: (x[-1],y*10,int(z)),['input','target','name'],minibatch_mode=True) + # needs more testing using ds, dsa, dscache, ... + raise NotImplementedError() + def test_MinibatchDataSet(self): raise NotImplementedError() def test_HStackedDataSet(self):