changeset 45:a5c70dc42972

Test functions for dataset.py
author bengioy@grenat.iro.umontreal.ca
date Tue, 29 Apr 2008 11:25:36 -0400
parents 5a85fda9b19b
children c5b07e87b0cb
files dataset.py lookup_list.py test_dataset.py
diffstat 3 files changed, 29 insertions(+), 2 deletions(-) [+]
line wrap: on
line diff
--- a/dataset.py	Mon Apr 28 13:52:54 2008 -0400
+++ b/dataset.py	Tue Apr 29 11:25:36 2008 -0400
@@ -4,6 +4,7 @@
 from misc import unique_elements_list_intersection
 from string import join
 from sys import maxint
+import numpy
 
 class AbstractFunction (Exception): """Derived class must override this function"""
 class NotImplementedYet (NotImplementedError): """Work in progress, this should eventually be implemented"""
@@ -376,7 +377,7 @@
             rows = i
         if rows is not None:
             fields_values = zip(*[self[row] for row in rows])
-            return DataSet.MinibatchDataSet(
+            return MinibatchDataSet(
                 Example(self.fieldNames(),[ self.valuesVStack(fieldname,field_values)
                                             for fieldname,field_values
                                             in zip(self.fieldNames(),fields_values)]))
--- a/lookup_list.py	Mon Apr 28 13:52:54 2008 -0400
+++ b/lookup_list.py	Tue Apr 29 11:25:36 2008 -0400
@@ -46,9 +46,11 @@
         The key in example[key] can either be an integer to index the fields
         or the name of the field.
         """
-        if isinstance(key,int):
+        if isinstance(key,int) or isinstance(key,slice) or isinstance(key,list):
             return self._values[key]
         else: # if not an int, key must be a name
+            # expecting key to be a valid field name
+            assert isinstance(key,str)
             return self._values[self._name2index[key]]
     
     def __setitem__(self,key,value):
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/test_dataset.py	Tue Apr 29 11:25:36 2008 -0400
@@ -0,0 +1,24 @@
+
+from dataset import *
+from math import *
+import numpy
+
+def test1():
+    global a,ds
+    a = numpy.random.rand(10,4)
+    print a
+    ds = ArrayDataSet(a,{'x':slice(3),'y':3,'z':[0,2]})
+    print "len(ds)=",len(ds)
+    print "example 0 = ",ds[0]
+    print "x=",ds["x"]
+    print "x|y"
+    for x,y in ds("x","y"):
+        print x,y
+    minibatch_iterator = ds.minibatches(fieldnames=['z','y'],n_batches=1,minibatch_size=3,offset=4)
+    minibatch = minibatch_iterator.__iter__().next()
+    print "minibatch=",minibatch
+    for var in minibatch:
+        print "var=",var
+    print "take a slice:",ds[1:6:2]
+
+test1()