changeset 321:f03ae06fadc8

NArraysDataSet improved, use arrays instead of matrix, also a dictionnary of field indexes
author Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
date Thu, 12 Jun 2008 12:35:47 -0400
parents b1da46b9b901
children ad8be93b3c55
files dataset.py
diffstat 1 files changed, 25 insertions(+), 15 deletions(-) [+]
line wrap: on
line diff
--- a/dataset.py	Wed Jun 11 17:11:37 2008 -0400
+++ b/dataset.py	Thu Jun 12 12:35:47 2008 -0400
@@ -192,6 +192,12 @@
                 else description
         self._attribute_names = ["description"]
 
+        # create dictionnary of fieldnames index
+        self.map_field_idx = dict()
+        for k in len(range(self.fieldNames())):
+            map_field_idx[ self.fieldNames[k] ] = k
+
+
     attributeNames = property(lambda self: copy.copy(self._attribute_names))
 
     def __contains__(self, fieldname):
@@ -1054,24 +1060,23 @@
         fieldnames. Each set of numpy tensor must have the same first dimension (first
         axis) corresponding to the number of examples.
 
-        Every tensor is treated as a numpy matrix (using numpy.asmatrix)
+        Every tensor is treated as a numpy array (using numpy.asarray)
         """
         ArrayFieldsDataSet.__init__(self,**kwargs)
         assert len(data_arrays) == len(fieldnames)
         assert len(fieldnames) > 0
-        all_matrix_sizes = map(lambda x : numpy.asmatrix(x).shape[0] , data_arrays)
-        num_examples = max(all_matrix_sizes)
-        if num_examples == 1 :
-            # problem, do we transpose all arrays? is there only one example?
-            raise Exception("wrong initialization, unknow behaviour with 1-d arrays")
+        ndarrays = [numpy.ndarray(a) for a in data_arrays]
+        lens = [a.shape[0] for a in ndarrays]
+        num_examples = lens[0] #they must all be equal anyway
         self._fieldnames = fieldnames
         self._datas = []
-        for k in range(len(data_arrays)) :
-            self._datas.append( numpy.asmatrix(data_arrays[k]) )
-            if self._datas[-1].shape[0] == 1 and self._datas[-1].shape[1] == num_examples :
-                self._datas[-1] = self._datas[-1].transpose()
-        for k in range(len(self._datas)) :
-            assert self._datas[k].shape[0] == num_examples
+        for k in self.ndarrays :
+            assert k.shape[0] == num_examples
+        self._datas = ndarrays
+        # create dict 
+        self.map_field_idx = dict()
+        for k in range(len(fieldnames)):
+            self.map_field_idx[fieldnames[k]] = k
 
 
     def __len__(self) :
@@ -1086,6 +1091,12 @@
         """
         return self._fieldnames
 
+    def field_pos(self,fieldname) :
+        """
+        Returns the index of a given fieldname. Fieldname must exists! see fieldNames().
+        """
+        return self.map_field_idx[fieldname]
+
     def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset):
         cursor = Example(fieldnames,[0]*len(fieldnames))
         fieldnames = self.fieldNames() if fieldnames is None else fieldnames
@@ -1093,9 +1104,8 @@
             if offset == len(self):
                 break
             for f in range(len(cursor._names)) :
-                idx = self._fieldnames.index(cursor._names[f])
-                assert idx >= 0
-                sub_data = self._datas[f][offset : offset+minibatch_size]
+                idx = self.field_pos(cursor._names[f])
+                sub_data = self._datas[idx][offset : offset+minibatch_size]
                 cursor._values[f] = sub_data
             offset += len(sub_data) #can be less than minibatch_size at end
             yield cursor