diff dataset.py @ 12:ff4e551490f1

Added LookupList type in lookup_list.py and used it to keep order of field names in Example in ArrayDataSet. Example is now just = LookupList.
author bengioy@esprit.iro.umontreal.ca
date Wed, 26 Mar 2008 18:21:57 -0400
parents be128b9127c8
children 759d17112b23
line wrap: on
line diff
--- a/dataset.py	Wed Mar 26 15:01:30 2008 -0400
+++ b/dataset.py	Wed Mar 26 18:21:57 2008 -0400
@@ -1,42 +1,7 @@
 
-class Example(object):
-    """
-    An example is something that is like a tuple but whose elements can be named, to that
-    following syntactic constructions work as one would expect:
-       example.x = [1, 2, 3] # set a field
-       x, y, z = example
-       x = example[0]
-       x = example["x"]
-    """
-    def __init__(self,names,values):
-        assert len(values)==len(names)
-        self.__dict__['values']=values
-        self.__dict__['fields']={}
-        for i in xrange(len(values)):
-            self.fields[names[i]]=i
-            
-    def __getitem__(self,i):
-        if isinstance(i,int):
-            return self.values[i]
-        else:
-            return self.values[self.fields[i]]
-    
-    def __setitem__(self,i,value):
-        if isinstance(i,int):
-            self.values[i]=value
-        else:
-            self.values[self.fields[i]]=value
-
-    def __getattr__(self,name):
-        return self.values[self.fields[name]]
-
-    def __setattr__(self,name,value):
-        self.values[self.fields[name]]=value
-
-    def __len__(self):
-        return len(self.values)
-
-    
+from lookup_list import LookupList
+Example = LookupList
+        
 class DataSet(object):
     """
     This is a virtual base class or interface for datasets.
@@ -192,15 +157,15 @@
     by the numpy.array(dataset) call.
     """
 
-    def __init__(self,dataset=None,data=None,fields={}):
+    def __init__(self,dataset=None,data=None,fields=None):
         """
         There are two ways to construct an ArrayDataSet: (1) from an
         existing dataset (which may result in a copy of the data in a numpy array),
         or (2) from a numpy.array (the data argument), along with an optional description
-        of the fields (dictionary of column slices indexed by field names).
+        of the fields (a LookupList of column slices indexed by field names).
         """
         if dataset!=None:
-            assert data==None and fields=={}
+            assert data==None and fields==None
             # Make ONE big minibatch with all the examples, to separate the fields.
             n_examples=len(dataset)
             batch = dataset.minibatches(n_examples).next()
@@ -210,6 +175,7 @@
             fieldnames = batch.fields.keys()
             total_width = 0
             type = None
+            fields = LookupList()
             for i in xrange(n_fields):
                 field = array(batch[i])
                 assert field.shape[0]==n_examples
@@ -227,19 +193,19 @@
             self.data=data
             self.fields=fields
             self.width = data.shape[1]
-            for fieldname in fields:
-                fieldslice=fields[fieldname]
-                # make sure fieldslice.start and fieldslice.step are defined
-                start=fieldslice.start
-                step=fieldslice.step
-                if not start:
-                    start=0
-                if not step:
-                    step=1
-                if not fieldslice.start or not fieldslice.step:
-                    fields[fieldname] = fieldslice = slice(start,fieldslice.stop,step)
-                # and coherent with the data array
-                assert fieldslice.start>=0 and fieldslice.stop<=self.width
+            if fields:
+                for fieldname,fieldslice in fields.items():
+                    # make sure fieldslice.start and fieldslice.step are defined
+                    start=fieldslice.start
+                    step=fieldslice.step
+                    if not start:
+                        start=0
+                    if not step:
+                        step=1
+                    if not fieldslice.start or not fieldslice.step:
+                        fields[fieldname] = fieldslice = slice(start,fieldslice.stop,step)
+                    # and coherent with the data array
+                    assert fieldslice.start>=0 and fieldslice.stop<=self.width
 
     def __getattr__(self,fieldname):
         """
@@ -258,9 +224,9 @@
         for field_slice in self.fields.values():
             min_col=min(min_col,field_slice.start)
             max_col=max(max_col,field_slice.stop)
-        new_fields={}
-        for field in self.fields:
-            new_fields[field[0]]=slice(field[1].start-min_col,field[1].stop-min_col,field[1].step)
+        new_fields=LookupList()
+        for fieldname,fieldslice in self.fields.items():
+            new_fields[fieldname]=slice(fieldslice.start-min_col,fieldslice.stop-min_col,fieldslice.step)
         return ArrayDataSet(data=self.data[:,min_col:max_col],fields=new_fields)
 
     def fieldNames(self):
@@ -278,7 +244,7 @@
         """
         if self.fields:
             fieldnames,fieldslices=zip(*self.fields.items())
-            return Example(fieldnames,[self.data[i,fieldslice] for fieldslice in fieldslices])
+            return Example(self.fields.keys(),[self.data[i,fieldslice] for fieldslice in self.fields.values()])
         else:
             return self.data[i]