comparison test_dataset.py @ 161:60e00cce3492

bugfix test in case it is not an ArrayDataSet
author Frederic Bastien <bastienf@iro.umontreal.ca>
date Mon, 12 May 2008 17:25:52 -0400
parents 90104343c665
children 45427d4d64b3
comparison
equal deleted inserted replaced
160:a910141fbe5b 161:60e00cce3492
119 i=0 119 i=0
120 mi=0 120 mi=0
121 m=ds.minibatches(['x','z'], minibatch_size=3) 121 m=ds.minibatches(['x','z'], minibatch_size=3)
122 assert isinstance(m,DataSet.MinibatchWrapAroundIterator) 122 assert isinstance(m,DataSet.MinibatchWrapAroundIterator)
123 for minibatch in m: 123 for minibatch in m:
124 assert isinstance(minibatch,DataSetFields)
124 assert len(minibatch)==2 125 assert len(minibatch)==2
125 test_minibatch_size(minibatch,m.minibatch_size,len(ds),2,mi) 126 test_minibatch_size(minibatch,m.minibatch_size,len(ds),2,mi)
126 assert (minibatch[0][:,0:3:2]==minibatch[1]).all() 127 if type(ds)==ArrayDataSet:
128 assert (minibatch[0][:,::2]==minibatch[1]).all()
129 else:
130 for i in xrange(len(minibatch[0])):
131 (minibatch[0][i][::2]==minibatch[1][i]).all()
127 mi+=1 132 mi+=1
128 i+=len(minibatch[0]) 133 i+=len(minibatch[0])
129 assert i==len(ds) 134 assert i==len(ds)
130 assert mi==4 135 assert mi==4
131 del minibatch,i,m,mi 136 del minibatch,i,m,mi
413 assert ds2 == ds2.fields().examples() 418 assert ds2 == ds2.fields().examples()
414 # for ((x,y),a_v) in (ds('x','y'),a): #???don't work # haven't found a variant that work.# will not work 419 # for ((x,y),a_v) in (ds('x','y'),a): #???don't work # haven't found a variant that work.# will not work
415 # assert numpy.append(x,y)==z 420 # assert numpy.append(x,y)==z
416 421
417 422
423 def test_DataSetFields():
424 print "test_DataSetFields"
425 raise NotImplementedError()
426
418 def test_ApplyFunctionDataSet(): 427 def test_ApplyFunctionDataSet():
419 print "test_ApplyFunctionDataSet" 428 print "test_ApplyFunctionDataSet"
420 raise NotImplementedError() 429 raise NotImplementedError()
421 def test_FieldsSubsetDataSet(): 430 def test_FieldsSubsetDataSet():
422 print "test_FieldsSubsetDataSet" 431 print "test_FieldsSubsetDataSet"
423 raise NotImplementedError()
424 def test_DataSetFields():
425 print "test_DataSetFields"
426 raise NotImplementedError() 432 raise NotImplementedError()
427 def test_MinibatchDataSet(): 433 def test_MinibatchDataSet():
428 print "test_MinibatchDataSet" 434 print "test_MinibatchDataSet"
429 raise NotImplementedError() 435 raise NotImplementedError()
430 def test_HStackedDataSet(): 436 def test_HStackedDataSet():