changeset 64:863da25a60f1

trying to fix infinite loop
author Yoshua Bengio <bengioy@iro.umontreal.ca>
date Fri, 02 May 2008 11:01:28 -0400
parents 14589f02a372
children d48eba49a2f4
files dataset.py test_dataset.py
diffstat 2 files changed, 18 insertions(+), 6 deletions(-) [+]
line wrap: on
line diff
--- a/dataset.py	Fri May 02 10:14:24 2008 -0400
+++ b/dataset.py	Fri May 02 11:01:28 2008 -0400
@@ -552,11 +552,14 @@
             dataset = FieldsSubsetDataSet(dataset,fieldnames)
         assert dataset.hasFields(*fieldnames)
         self.dataset=dataset
-        minibatch_iterator = dataset.minibatches(fieldnames,
-                                                 minibatch_size=len(dataset),
-                                                 n_batches=1)
-        minibatch=minibatch_iterator.next()
-        LookupList.__init__(self,fieldnames,minibatch)
+        if isinstance(dataset,MinibatchDataSet):
+            LookupList.__init__(self,fieldnames,list(dataset._fields))
+        else:
+            minibatch_iterator = dataset.minibatches(fieldnames,
+                                                     minibatch_size=len(dataset),
+                                                     n_batches=1)
+            minibatch=minibatch_iterator.next()
+            LookupList.__init__(self,fieldnames,minibatch)
         
     def examples(self):
         return self.dataset
--- a/test_dataset.py	Fri May 02 10:14:24 2008 -0400
+++ b/test_dataset.py	Fri May 02 11:01:28 2008 -0400
@@ -23,6 +23,13 @@
         print "var=",var
     print "take a slice and look at field y",ds[1:6:2]["y"]
 
+def test2():
+    a = numpy.random.rand(10,4)
+    print a
+    ds = ArrayDataSet(a,{'x':slice(3),'y':3,'z':[0,2]})
+    for x,z in ds[:3]('x','z'):
+        assert ds[i]['z'].all()==a[i][0:3:2].all()
+    
 def test_ArrayDataSet():
     #don't test stream
     #tested only with float value
@@ -109,5 +116,7 @@
     #* ds1 & ds2 & ds3 == ds.vstack([ds1,ds2,ds3])
 
 
-test_ArrayDataSet()
+test2()
 
+#test_ArrayDataSet()
+