Mercurial > pylearn
comparison test_dataset.py @ 161:60e00cce3492
bugfix test in case it is not an ArrayDataSet
author | Frederic Bastien <bastienf@iro.umontreal.ca> |
---|---|
date | Mon, 12 May 2008 17:25:52 -0400 |
parents | 90104343c665 |
children | 45427d4d64b3 |
comparison
equal
deleted
inserted
replaced
160:a910141fbe5b | 161:60e00cce3492 |
---|---|
119 i=0 | 119 i=0 |
120 mi=0 | 120 mi=0 |
121 m=ds.minibatches(['x','z'], minibatch_size=3) | 121 m=ds.minibatches(['x','z'], minibatch_size=3) |
122 assert isinstance(m,DataSet.MinibatchWrapAroundIterator) | 122 assert isinstance(m,DataSet.MinibatchWrapAroundIterator) |
123 for minibatch in m: | 123 for minibatch in m: |
124 assert isinstance(minibatch,DataSetFields) | |
124 assert len(minibatch)==2 | 125 assert len(minibatch)==2 |
125 test_minibatch_size(minibatch,m.minibatch_size,len(ds),2,mi) | 126 test_minibatch_size(minibatch,m.minibatch_size,len(ds),2,mi) |
126 assert (minibatch[0][:,0:3:2]==minibatch[1]).all() | 127 if type(ds)==ArrayDataSet: |
128 assert (minibatch[0][:,::2]==minibatch[1]).all() | |
129 else: | |
130 for i in xrange(len(minibatch[0])): | |
131 (minibatch[0][i][::2]==minibatch[1][i]).all() | |
127 mi+=1 | 132 mi+=1 |
128 i+=len(minibatch[0]) | 133 i+=len(minibatch[0]) |
129 assert i==len(ds) | 134 assert i==len(ds) |
130 assert mi==4 | 135 assert mi==4 |
131 del minibatch,i,m,mi | 136 del minibatch,i,m,mi |
413 assert ds2 == ds2.fields().examples() | 418 assert ds2 == ds2.fields().examples() |
414 # for ((x,y),a_v) in (ds('x','y'),a): #???don't work # haven't found a variant that work.# will not work | 419 # for ((x,y),a_v) in (ds('x','y'),a): #???don't work # haven't found a variant that work.# will not work |
415 # assert numpy.append(x,y)==z | 420 # assert numpy.append(x,y)==z |
416 | 421 |
417 | 422 |
423 def test_DataSetFields(): | |
424 print "test_DataSetFields" | |
425 raise NotImplementedError() | |
426 | |
418 def test_ApplyFunctionDataSet(): | 427 def test_ApplyFunctionDataSet(): |
419 print "test_ApplyFunctionDataSet" | 428 print "test_ApplyFunctionDataSet" |
420 raise NotImplementedError() | 429 raise NotImplementedError() |
421 def test_FieldsSubsetDataSet(): | 430 def test_FieldsSubsetDataSet(): |
422 print "test_FieldsSubsetDataSet" | 431 print "test_FieldsSubsetDataSet" |
423 raise NotImplementedError() | |
424 def test_DataSetFields(): | |
425 print "test_DataSetFields" | |
426 raise NotImplementedError() | 432 raise NotImplementedError() |
427 def test_MinibatchDataSet(): | 433 def test_MinibatchDataSet(): |
428 print "test_MinibatchDataSet" | 434 print "test_MinibatchDataSet" |
429 raise NotImplementedError() | 435 raise NotImplementedError() |
430 def test_HStackedDataSet(): | 436 def test_HStackedDataSet(): |