diff dataset.py @ 80:40476a7746e8

bugfix
author Frederic Bastien <bastienf@iro.umontreal.ca>
date Mon, 05 May 2008 10:56:58 -0400
parents dde1fb1b63ba
children 158653a9bc7c
line wrap: on
line diff
--- a/dataset.py	Mon May 05 10:28:58 2008 -0400
+++ b/dataset.py	Mon May 05 10:56:58 2008 -0400
@@ -259,7 +259,7 @@
 
         The minibatches iterator is expected to return upon each call to next()
         a DataSetFields object, which is a LookupList (indexed by the field names) whose
-        elements are iterable over the minibatch examples, and which keeps a pointer to
+        elements are iterable and indexable over the minibatch examples, and which keeps a pointer to
         a sub-dataset that can be used to iterate over the individual examples
         in the minibatch. Hence a minibatch can be converted back to a regular
         dataset or its fields can be looked at individually (and possibly iterated over).
@@ -609,9 +609,13 @@
         return self.length
 
     def __getitem__(self,i):
-        if type(i) in (int,slice,list):
+        if type(i) in (slice,list):
             return DataSetFields(MinibatchDataSet(
-                Example(self._fields.keys(),[field[i] for field in self._fields])),self._fields)
+                Example(self._fields.keys(),[field[i] for field in self._fields])),self.fieldNames())
+        if type(i) is int:
+            return DataSetFields(MinibatchDataSet(
+                Example(self._fields.keys(),[[field[i]] for field in self._fields])),self.fieldNames())
+
         if self.hasFields(i):
             return self._fields[i]
         assert i in self.__dict__ # else it means we are trying to access a non-existing property
@@ -918,21 +922,28 @@
     def __len__(self):
         return len(self.data)
 
-    def __getitem__(self,i):
+    def __getitem__(self,key):
         """More efficient implementation than the default __getitem__"""
         fieldnames=self.fields_columns.keys()
-        if type(i) is int:
+        if type(key) is int:
             return Example(fieldnames,
-                           [self.data[i,self.fields_columns[f]] for f in fieldnames])
-        if type(i) in (slice,list):
+                           [self.data[key,self.fields_columns[f]] for f in fieldnames])
+        if type(key) is slice:
             return MinibatchDataSet(Example(fieldnames,
-                                            [self.data[i,self.fields_columns[f]] for f in fieldnames]))
+                                            [self.data[key,self.fields_columns[f]] for f in fieldnames]))
+        if type(key) is list:
+            for i in range(len(key)):
+                if self.hasFields(key[i]):
+                    key[i]=self.fields_columns[key[i]]
+            return MinibatchDataSet(Example(fieldnames,
+                                            [self.data[key,self.fields_columns[f]] for f in fieldnames]))
+
         # else check for a fieldname
-        if self.hasFields(i):
-            return Example([i],[self.data[self.fields_columns[i],:]])
+        if self.hasFields(key):
+            return self.data[self.fields_columns[key],:]
         # else we are trying to access a property of the dataset
-        assert i in self.__dict__ # else it means we are trying to access a non-existing property
-        return self.__dict__[i]
+        assert key in self.__dict__ # else it means we are trying to access a non-existing property
+        return self.__dict__[key]
         
             
     def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset):