Mercurial > pylearn
comparison _test_dataset.py @ 295:7380376816e5
started a test for datasets where one field has a variable length. Not obvious, all tests requires a matrix as a reference
author | Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca> |
---|---|
date | Fri, 06 Jun 2008 17:11:25 -0400 |
parents | f7924e13e426 |
children | d08b71d186c8 |
comparison
equal
deleted
inserted
replaced
294:f7924e13e426 | 295:7380376816e5 |
---|---|
435 ds = FieldsSubsetDataSet(ds,['x','y','z']) | 435 ds = FieldsSubsetDataSet(ds,['x','y','z']) |
436 | 436 |
437 test_all(a,ds) | 437 test_all(a,ds) |
438 | 438 |
439 del a, ds | 439 del a, ds |
440 | |
441 def test_MultiLengthDataSet(self): | |
442 class MultiLengthDataSet(DataSet): | |
443 """ Dummy dataset, where one field is a ndarray of variables size. """ | |
444 def __len__(self) : | |
445 return 100 | |
446 def fieldNames(self) : | |
447 return 'input','target','name' | |
448 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): | |
449 class MultiLengthDataSetIterator(object): | |
450 def __init__(self,dataset,fieldnames,minibatch_size,n_batches,offset): | |
451 if fieldnames is None: fieldnames = dataset.fieldNames() | |
452 self.minibatch = LookupList(fieldnames,range(len(fieldnames))) | |
453 self.dataset, self.minibatch_size, self.current = dataset, minibatch_size, offset | |
454 def __iter__(self): | |
455 return self | |
456 def next(self): | |
457 for k in self.minibatch._names : | |
458 self.minibatch[k] = [] | |
459 for ex in range(self.minibatch_size) : | |
460 if 'input' in self.minibatch._names: | |
461 self.minibatch['input'].append( numpy.array( range(self.current + 1) ) ) | |
462 if 'target' in self.minibatch._names: | |
463 self.minibatch['target'].append( self.current % 2 ) | |
464 if 'name' in self.minibatch._names: | |
465 self.minibatch['name'].append( str(self.current) ) | |
466 self.current += 1 | |
467 return self.minibatch | |
468 return MultiLengthDataSetIterator(self,fieldnames,minibatch_size,n_batches,offset) | |
469 ds = MultiLengthDataSet() | |
470 for k in range(len(ds)): | |
471 x = ds[k] | |
472 dsa = ApplyFunctionDataset(ds,lambda x,y,z: (x[-1],y*10,int(z)),['input','target','name'],minibatch_mode=True) | |
473 # needs more testing using ds, dsa, dscache, ... | |
474 raise NotImplementedError() | |
475 | |
440 def test_MinibatchDataSet(self): | 476 def test_MinibatchDataSet(self): |
441 raise NotImplementedError() | 477 raise NotImplementedError() |
442 def test_HStackedDataSet(self): | 478 def test_HStackedDataSet(self): |
443 raise NotImplementedError() | 479 raise NotImplementedError() |
444 def test_VStackedDataSet(self): | 480 def test_VStackedDataSet(self): |