changeset 4:f7dcfb5f9d5b

Added test for dataset.
author bengioy@bengiomac.local
date Sun, 23 Mar 2008 22:14:10 -0400
parents 378b68d5c4ad
children 8039918516fe
files _test_dataset.py dataset.py
diffstat 2 files changed, 54 insertions(+), 15 deletions(-) [+]
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/_test_dataset.py	Sun Mar 23 22:14:10 2008 -0400
@@ -0,0 +1,25 @@
+from dataset import *
+from math import *
+import unittest
+
+def _sum_all(a):
+    s=a
+    while isinstance(s,numpy.ndarray):
+        s=sum(s)
+    return s
+    
+class T_arraydataset(unittest.TestCase):
+    def setUp(self):
+        numpy.random.seed(123456)
+
+    def test0(self):
+        a=ArrayDataSet(data=numpy.random.rand(8,3),fields={"x":slice(2),"y":slice(1,3)})
+        s=0
+        for example in a:
+            s+=_sum_all(example.x)
+        print s
+        self.failUnless(abs(s-11.4674133)<1e-6)
+
+if __name__ == '__main__':
+    unittest.main()
+    
--- a/dataset.py	Sun Mar 23 14:41:22 2008 -0400
+++ b/dataset.py	Sun Mar 23 22:14:10 2008 -0400
@@ -58,7 +58,7 @@
 
 # we may want ArrayDataSet defined in another python file
 
-from numpy import *
+import numpy
 
 class ArrayDataSet(FiniteDataSet):
     """
@@ -70,41 +70,57 @@
     by the asarray(dataset) call.
     """
 
-    def __self__(self,dataset=None,data=None,fields={}):
+    def __init__(self,dataset=None,data=None,fields={}):
         """
         Construct an ArrayDataSet, either from a DataSet, or from
         a numpy.array plus an optional specification of fields (by
         a dictionary of column slices indexed by field names).
         """
         self.current_row=-1 # used for view of this dataset as an iterator
-        if dataset:
+        if dataset!=None:
             assert data==None and fields=={}
             # convert dataset to an ArrayDataSet
             raise NotImplementedError
-        if data:
+        if data!=None:
             assert dataset==None
             self.data=data
             self.fields=fields
             self.width = data.shape[1]
             for fieldname in fields:
                 fieldslice=fields[fieldname]
-                assert fieldslice.start>=0 and fieldslice.stop<=width)
+                # make sure fieldslice.start and fieldslice.step are defined
+                start=fieldslice.start
+                step=fieldslice.step
+                if not start:
+                    start=0
+                if not step:
+                    step=1
+                if not fieldslice.start or not fieldslice.step:
+                    fieldslice = slice(start,fieldslice.stop,step)
+                # and coherent with the data array
+                assert fieldslice.start>=0 and fieldslice.stop<=self.width
 
     def next(self):
-        """Return the next example in the dataset. If the dataset has fields,
-        the 'example' is just a one-row ArrayDataSet, otherwise it is a numpy.array."""
-        if fields:
+        """
+        Return the next example in the dataset. If the dataset has fields,
+        the 'example' is just a one-row ArrayDataSet, otherwise it is a numpy.array.
+        """
+        if self.fields:
             self.current_row+=1
             if self.current_row==len(self.data):
-                self.current_row=0
+                self.current_row=-1
+                raise StopIteration
             return self[self.current_row]
         else:
             return self.data[self.current_row]
 
     def __getattr__(self,fieldname):
         """Return a sub-dataset containing only the given fieldname as field."""
-        data = self.fields[fieldname]
-        return ArrayDataSet(data=data)
+        data=self.data[self.fields[fieldname]]                
+        if len(data)==1:
+            return data
+        else:
+            return ArrayDataSet(data=data)
 
     def __call__(self,*fieldnames):
         """Return a sub-dataset containing only the given fieldnames as fields."""
@@ -144,7 +160,7 @@
 
     def asarray(self):
         if self.fields:
-            columns_used = zeros((self.data.shape[1]),dtype=bool)
+            columns_used = numpy.zeros((self.data.shape[1]),dtype=bool)
             for field_slice in self.fields.values():
                 for c in xrange(field_slice.start,field_slice.stop,field_slice.step):
                     columns_used[c]=True
@@ -175,9 +191,7 @@
             result = zeros((len(self.data),n_columns)+self.data.shape[2:],self.data.dtype)
             c=0
             for field_slice in self.fields.values():
-               slice_width=field_slice.stop-field_slice.start
-               if field_slice.step:
-                   slice_width /= field_slice.step
+               slice_width=field_slice.stop-field_slice.start/field_slice.step
                # copy the field here
                result[:,slice(c,slice_width)]=self.data[field_slice]
                c+=slice_width