changeset 151:39bb21348fdf

Automated merge with ssh://p-omega1@lgcm.iro.umontreal.ca/tlearn
author Frederic Bastien <bastienf@iro.umontreal.ca>
date Mon, 12 May 2008 15:51:43 -0400
parents 625d2b21ee48 (current diff) 9abd19af822e (diff)
children 3f627e844cba 71107b0ac860 ae5651a3696b
files dataset.py
diffstat 2 files changed, 34 insertions(+), 4 deletions(-) [+]
line wrap: on
line diff
--- a/dataset.py	Mon May 12 15:50:34 2008 -0400
+++ b/dataset.py	Mon May 12 15:51:43 2008 -0400
@@ -1077,7 +1077,7 @@
                   for example in self.dataset.source_dataset[cache_len:upper]:
                       self.dataset.cached_examples.append(example)
               all_fields_minibatch = Example(self.dataset.fieldNames(),
-                                             self.dataset.cached_examples[self.current:self.current+minibatch_size])
+                                             *self.dataset.cached_examples[self.current:self.current+minibatch_size])
               if self.dataset.fieldNames()==fieldnames:
                   return all_fields_minibatch
               return Example(fieldnames,[all_fields_minibatch[name] for name in fieldnames])
--- a/test_dataset.py	Mon May 12 15:50:34 2008 -0400
+++ b/test_dataset.py	Mon May 12 15:51:43 2008 -0400
@@ -383,12 +383,41 @@
     assert example+example2==example3
     assert have_raised("var['x']+var['x']",x=example)
 
+def test_CacheDataSet():
+    print "test_CacheDataSet"
+    a2 = numpy.random.rand(10,4)
+    ds1 = ArrayDataSet(a2,LookupList(['x','y','z'],[slice(3),3,[0,2]]))###???tuple not tested
+    ds2 = CachedDataSet(ds1)
+    ds3 = CachedDataSet(ds1,cache_all_upon_construction=True)
+    assert len(ds2)==10
+
+    test_iterate_over_examples(a2, ds2)
+    test_getitem(a2, ds2)
+
+#     - for val1,val2,val3 in dataset(field1, field2,field3):
+    test_ds_iterator(a2,ds2('x','y'),ds2('y','z'),ds2('x','y','z'))
+
+
+    assert len(ds2.fields())==3
+    for field in ds2.fields():
+        for field_value in field: # iterate over the values associated to that field for all the ds examples
+            pass
+    for field in ds2('x','z').fields():
+        pass
+    for field in ds2.fields('x','y'):
+        pass
+    for field_examples in ds2.fields():
+        for example_value in field_examples:
+            pass
+
+    assert ds2 == ds2.fields().examples()
+#    for ((x,y),a_v) in (ds('x','y'),a): #???don't work # haven't found a variant that work.# will not work
+#        assert numpy.append(x,y)==z
+
+
 def test_ApplyFunctionDataSet():
     print "test_ApplyFunctionDataSet"
     raise NotImplementedError()
-def test_CacheDataSet():
-    print "test_CacheDataSet"
-    raise NotImplementedError()
 def test_FieldsSubsetDataSet():
     print "test_FieldsSubsetDataSet"
     raise NotImplementedError()
@@ -411,4 +440,5 @@
 test1()
 test_LookupList()
 test_ArrayDataSet()
+test_CacheDataSet()
 #test pmat.py