comparison _test_dataset.py @ 297:d08b71d186c8

merged
author James Bergstra <bergstrj@iro.umontreal.ca>
date Fri, 06 Jun 2008 17:52:00 -0400
parents 7380376816e5
children 5987415496df
comparison
equal deleted inserted replaced
296:f5d33f9c0b9c 297:d08b71d186c8
429 429
430 del a,ds1,ds2,ds3 430 del a,ds1,ds2,ds3
431 431
432 def test_FieldsSubsetDataSet(self): 432 def test_FieldsSubsetDataSet(self):
433 a = numpy.random.rand(10,4) 433 a = numpy.random.rand(10,4)
434 ds = ArrayDataSet(a,LookupList(['x','y','z','w'],[slice(3),3,[0,2],0])) 434 ds = ArrayDataSet(a,Example(['x','y','z','w'],[slice(3),3,[0,2],0]))
435 ds = FieldsSubsetDataSet(ds,['x','y','z']) 435 ds = FieldsSubsetDataSet(ds,['x','y','z'])
436 436
437 test_all(a,ds) 437 test_all(a,ds)
438 438
439 del a, ds 439 del a, ds
440
441 def test_MultiLengthDataSet(self):
442 class MultiLengthDataSet(DataSet):
443 """ Dummy dataset, where one field is a ndarray of variables size. """
444 def __len__(self) :
445 return 100
446 def fieldNames(self) :
447 return 'input','target','name'
448 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset):
449 class MultiLengthDataSetIterator(object):
450 def __init__(self,dataset,fieldnames,minibatch_size,n_batches,offset):
451 if fieldnames is None: fieldnames = dataset.fieldNames()
452 self.minibatch = Example(fieldnames,range(len(fieldnames)))
453 self.dataset, self.minibatch_size, self.current = dataset, minibatch_size, offset
454 def __iter__(self):
455 return self
456 def next(self):
457 for k in self.minibatch._names :
458 self.minibatch[k] = []
459 for ex in range(self.minibatch_size) :
460 if 'input' in self.minibatch._names:
461 self.minibatch['input'].append( numpy.array( range(self.current + 1) ) )
462 if 'target' in self.minibatch._names:
463 self.minibatch['target'].append( self.current % 2 )
464 if 'name' in self.minibatch._names:
465 self.minibatch['name'].append( str(self.current) )
466 self.current += 1
467 return self.minibatch
468 return MultiLengthDataSetIterator(self,fieldnames,minibatch_size,n_batches,offset)
469 ds = MultiLengthDataSet()
470 for k in range(len(ds)):
471 x = ds[k]
472 dsa = ApplyFunctionDataset(ds,lambda x,y,z: (x[-1],y*10,int(z)),['input','target','name'],minibatch_mode=True)
473 # needs more testing using ds, dsa, dscache, ...
474 raise NotImplementedError()
475
440 def test_MinibatchDataSet(self): 476 def test_MinibatchDataSet(self):
441 raise NotImplementedError() 477 raise NotImplementedError()
442 def test_HStackedDataSet(self): 478 def test_HStackedDataSet(self):
443 raise NotImplementedError() 479 raise NotImplementedError()
444 def test_VStackedDataSet(self): 480 def test_VStackedDataSet(self):