diff dataset.py @ 316:5fe6d0c93109

getitem in ArrayDataSet is set up again, supposed to be faster than default one, has been tested agains the default behaviour. In particular, now always return a LookupList
author Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
date Wed, 11 Jun 2008 16:28:09 -0400
parents 105b54ac8260
children 14081904d8f3
line wrap: on
line diff
--- a/dataset.py	Wed Jun 11 16:26:41 2008 -0400
+++ b/dataset.py	Wed Jun 11 16:28:09 2008 -0400
@@ -1028,6 +1028,50 @@
         become a two-row matrix, two matrices become a longer matrix, etc."""
         return numpy.vstack(values)
 
+
+
+class NArraysDataSet(ArrayFieldsDataSet) :
+    """
+    An NArraysDataSet stores fields that are numpy tensor, whose first axis
+    iterates over examples. It's a generalization of ArrayDataSet.
+    """
+    #@TODO not completely implemented yet
+    def __init__(self, data_arrays, fieldnames, **kwargs) :
+        """
+        Construct an NArraysDataSet from a list of numpy tensor (data_arrays) and a list
+        of fieldnames. The number of arrays must be the same as the number of
+        fieldnames. Each set of numpy tensor must have the same first dimension (first
+        axis) corresponding to the number of examples.
+
+        Every tensor is treated as a numpy matrix (using numpy.asmatrix)
+        """
+        ArrayFieldsDataSet.__init(self,**kwargs)
+        assert len(data_arrays) == len(fieldnames)
+        assert len(fieldnames) > 0
+        num_examples = numpy.asmatrix(data_arrays[0]).shape[0]
+        for k in range(data_arrays) :
+            assert numpy.asmatrix(data_arrays[k]).shape[0] == num_examples
+        self._fieldnames = fieldnames
+        self._datas = []
+        for k in range(data_arrays) :
+            self._datas.append( numpy.asmatrix(data_arrays[k]) )
+        raise NotImplemented
+
+
+    def __len__(self) :
+        """
+        Length of the dataset is based on the first array = data_arrays[0], using its shape
+        """
+        return self.datas[0].shape[0]
+
+    def fieldNames(self) :
+        """
+        Returns the fieldnames as set in self.__init__
+        """
+        return self._fieldnames
+
+
+
 class ArrayDataSet(ArrayFieldsDataSet):
     """
     An ArrayDataSet stores the fields as groups of columns in a numpy tensor,
@@ -1077,7 +1121,7 @@
     def __len__(self):
         return len(self.data)
 
-    def dontuse__getitem__(self,key):
+    def __getitem__(self,key):
         """More efficient implementation than the default __getitem__"""
         fieldnames=self.fields_columns.keys()
         values=self.fields_columns.values()
@@ -1085,22 +1129,18 @@
             return Example(fieldnames,
                            [self.data[key,col] for col in values])
         if type(key) is slice:
-            return MinibatchDataSet(Example(fieldnames,
-                                            [self.data[key,col] for col in values]))
+            return Example(fieldnames,[self.data[key,col] for col in values])
         if type(key) is list:
             for i in range(len(key)):
                 if self.hasFields(key[i]):
                     key[i]=self.fields_columns[key[i]]
-            return MinibatchDataSet(Example(fieldnames,
-                                            #we must separate differently for list as numpy
-                                            # doesn't support self.data[[i1,...],[i2,...]]
-                                            # when their is more then two i1 and i2
-                                            [self.data[key,:][:,col]
-                                             if isinstance(col,list) else
-                                             self.data[key,col] for col in values]),
-
-
-                                    self.valuesVStack,self.valuesHStack)
+            return Example(fieldnames,
+                               #we must separate differently for list as numpy
+                               # doesn't support self.data[[i1,...],[i2,...]]
+                               # when their is more then two i1 and i2
+                               [self.data[key,:][:,col]
+                               if isinstance(col,list) else
+                               self.data[key,col] for col in values])
 
         # else check for a fieldname
         if self.hasFields(key):