changeset 55:66619ce44497

Efficient implementation of getitem for ArrayDataSet
author Yoshua Bengio <bengioy@iro.umontreal.ca>
date Tue, 29 Apr 2008 15:05:12 -0400
parents 718befdc8671
children 1729ad44f175
files dataset.py
diffstat 1 files changed, 23 insertions(+), 8 deletions(-) [+]
line wrap: on
line diff
--- a/dataset.py	Tue Apr 29 14:34:40 2008 -0400
+++ b/dataset.py	Tue Apr 29 15:05:12 2008 -0400
@@ -603,6 +603,7 @@
                 Example(self.fields.keys(),[field[i] for field in self.fields])),self.fields)
         if self.hasFields(i):
             return self.fields[i]
+        assert i in self.__dict__ # else it means we are trying to access a non-existing property
         return self.__dict__[i]
 
     def fieldNames(self):
@@ -874,13 +875,13 @@
     values are (N-2)-dimensional objects (i.e. ordinary numbers if N=2).
     """
 
-    """
-    Construct an ArrayDataSet from the underlying numpy array (data) and
-    a map (fields_columns) from fieldnames to field columns. The columns of a field are specified
-    using the standard arguments for indexing/slicing: integer for a column index,
-    slice for an interval of columns (with possible stride), or iterable of column indices.
-    """
     def __init__(self, data_array, fields_columns):
+        """
+        Construct an ArrayDataSet from the underlying numpy array (data) and
+        a map (fields_columns) from fieldnames to field columns. The columns of a field are specified
+        using the standard arguments for indexing/slicing: integer for a column index,
+        slice for an interval of columns (with possible stride), or iterable of column indices.
+        """
         self.data=data_array
         self.fields_columns=fields_columns
 
@@ -906,8 +907,22 @@
     def __len__(self):
         return len(self.data)
 
-    #def __getitem__(self,i):
-    #    """More efficient implementation than the default"""
+    def __getitem__(self,i):
+        """More efficient implementation than the default __getitem__"""
+        fieldnames=self.fields_columns.keys()
+        if type(i) is int:
+            return Example(fieldnames,
+                           [self.data[i,self.fields_columns[f]] for f in fieldnames])
+        if type(i) in (slice,list):
+            return MinibatchDataSet(Example(fieldnames,
+                                            [self.data[i,self.fields_columns[f]] for f in fieldnames]))
+        # else check for a fieldname
+        if self.hasFields(i):
+            return Example([i],[self.data[self.fields_columns[i],:]])
+        # else we are trying to access a property of the dataset
+        assert i in self.__dict__ # else it means we are trying to access a non-existing property
+        return self.__dict__[i]
+        
             
     def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset):
         class ArrayDataSetIterator(object):