comparison _test_dataset.py @ 295:7380376816e5

started a test for datasets where one field has a variable length. Not obvious, all tests requires a matrix as a reference
author Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
date Fri, 06 Jun 2008 17:11:25 -0400
parents f7924e13e426
children d08b71d186c8
comparison
equal deleted inserted replaced
294:f7924e13e426 295:7380376816e5
435 ds = FieldsSubsetDataSet(ds,['x','y','z']) 435 ds = FieldsSubsetDataSet(ds,['x','y','z'])
436 436
437 test_all(a,ds) 437 test_all(a,ds)
438 438
439 del a, ds 439 del a, ds
440
441 def test_MultiLengthDataSet(self):
442 class MultiLengthDataSet(DataSet):
443 """ Dummy dataset, where one field is a ndarray of variables size. """
444 def __len__(self) :
445 return 100
446 def fieldNames(self) :
447 return 'input','target','name'
448 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset):
449 class MultiLengthDataSetIterator(object):
450 def __init__(self,dataset,fieldnames,minibatch_size,n_batches,offset):
451 if fieldnames is None: fieldnames = dataset.fieldNames()
452 self.minibatch = LookupList(fieldnames,range(len(fieldnames)))
453 self.dataset, self.minibatch_size, self.current = dataset, minibatch_size, offset
454 def __iter__(self):
455 return self
456 def next(self):
457 for k in self.minibatch._names :
458 self.minibatch[k] = []
459 for ex in range(self.minibatch_size) :
460 if 'input' in self.minibatch._names:
461 self.minibatch['input'].append( numpy.array( range(self.current + 1) ) )
462 if 'target' in self.minibatch._names:
463 self.minibatch['target'].append( self.current % 2 )
464 if 'name' in self.minibatch._names:
465 self.minibatch['name'].append( str(self.current) )
466 self.current += 1
467 return self.minibatch
468 return MultiLengthDataSetIterator(self,fieldnames,minibatch_size,n_batches,offset)
469 ds = MultiLengthDataSet()
470 for k in range(len(ds)):
471 x = ds[k]
472 dsa = ApplyFunctionDataset(ds,lambda x,y,z: (x[-1],y*10,int(z)),['input','target','name'],minibatch_mode=True)
473 # needs more testing using ds, dsa, dscache, ...
474 raise NotImplementedError()
475
440 def test_MinibatchDataSet(self): 476 def test_MinibatchDataSet(self):
441 raise NotImplementedError() 477 raise NotImplementedError()
442 def test_HStackedDataSet(self): 478 def test_HStackedDataSet(self):
443 raise NotImplementedError() 479 raise NotImplementedError()
444 def test_VStackedDataSet(self): 480 def test_VStackedDataSet(self):