Mercurial > pylearn
diff test_dataset.py @ 161:60e00cce3492
bugfix test in case it is not an ArrayDataSet
author | Frederic Bastien <bastienf@iro.umontreal.ca> |
---|---|
date | Mon, 12 May 2008 17:25:52 -0400 |
parents | 90104343c665 |
children | 45427d4d64b3 |
line wrap: on
line diff
--- a/test_dataset.py Mon May 12 16:52:00 2008 -0400 +++ b/test_dataset.py Mon May 12 17:25:52 2008 -0400 @@ -121,9 +121,14 @@ m=ds.minibatches(['x','z'], minibatch_size=3) assert isinstance(m,DataSet.MinibatchWrapAroundIterator) for minibatch in m: + assert isinstance(minibatch,DataSetFields) assert len(minibatch)==2 test_minibatch_size(minibatch,m.minibatch_size,len(ds),2,mi) - assert (minibatch[0][:,0:3:2]==minibatch[1]).all() + if type(ds)==ArrayDataSet: + assert (minibatch[0][:,::2]==minibatch[1]).all() + else: + for i in xrange(len(minibatch[0])): + (minibatch[0][i][::2]==minibatch[1][i]).all() mi+=1 i+=len(minibatch[0]) assert i==len(ds) @@ -415,15 +420,16 @@ # assert numpy.append(x,y)==z +def test_DataSetFields(): + print "test_DataSetFields" + raise NotImplementedError() + def test_ApplyFunctionDataSet(): print "test_ApplyFunctionDataSet" raise NotImplementedError() def test_FieldsSubsetDataSet(): print "test_FieldsSubsetDataSet" raise NotImplementedError() -def test_DataSetFields(): - print "test_DataSetFields" - raise NotImplementedError() def test_MinibatchDataSet(): print "test_MinibatchDataSet" raise NotImplementedError()