diff test_dataset.py @ 161:60e00cce3492

bugfix test in case it is not an ArrayDataSet
author Frederic Bastien <bastienf@iro.umontreal.ca>
date Mon, 12 May 2008 17:25:52 -0400
parents 90104343c665
children 45427d4d64b3
line wrap: on
line diff
--- a/test_dataset.py	Mon May 12 16:52:00 2008 -0400
+++ b/test_dataset.py	Mon May 12 17:25:52 2008 -0400
@@ -121,9 +121,14 @@
     m=ds.minibatches(['x','z'], minibatch_size=3)
     assert isinstance(m,DataSet.MinibatchWrapAroundIterator)
     for minibatch in m:
+        assert isinstance(minibatch,DataSetFields)
         assert len(minibatch)==2
         test_minibatch_size(minibatch,m.minibatch_size,len(ds),2,mi)
-        assert (minibatch[0][:,0:3:2]==minibatch[1]).all()
+        if type(ds)==ArrayDataSet:
+            assert (minibatch[0][:,::2]==minibatch[1]).all()
+        else:
+            for i in xrange(len(minibatch[0])):
+                (minibatch[0][i][::2]==minibatch[1][i]).all()
         mi+=1
         i+=len(minibatch[0])
     assert i==len(ds)
@@ -415,15 +420,16 @@
 #        assert numpy.append(x,y)==z
 
 
+def test_DataSetFields():
+    print "test_DataSetFields"
+    raise NotImplementedError()
+
 def test_ApplyFunctionDataSet():
     print "test_ApplyFunctionDataSet"
     raise NotImplementedError()
 def test_FieldsSubsetDataSet():
     print "test_FieldsSubsetDataSet"
     raise NotImplementedError()
-def test_DataSetFields():
-    print "test_DataSetFields"
-    raise NotImplementedError()
 def test_MinibatchDataSet():
     print "test_MinibatchDataSet"
     raise NotImplementedError()