diff dataset.py @ 82:158653a9bc7c

Automated merge with ssh://p-omega1@lgcm.iro.umontreal.ca/tlearn
author Frederic Bastien <bastienf@iro.umontreal.ca>
date Mon, 05 May 2008 11:02:03 -0400
parents 3499918faa9d 40476a7746e8
children c0f211213a58
line wrap: on
line diff
--- a/dataset.py	Mon May 05 09:35:30 2008 -0400
+++ b/dataset.py	Mon May 05 11:02:03 2008 -0400
@@ -281,7 +281,7 @@
 
         The minibatches iterator is expected to return upon each call to next()
         a DataSetFields object, which is a LookupList (indexed by the field names) whose
-        elements are iterable over the minibatch examples, and which keeps a pointer to
+        elements are iterable and indexable over the minibatch examples, and which keeps a pointer to
         a sub-dataset that can be used to iterate over the individual examples
         in the minibatch. Hence a minibatch can be converted back to a regular
         dataset or its fields can be looked at individually (and possibly iterated over).
@@ -632,12 +632,13 @@
         return self.length
 
     def __getitem__(self,i):
+        if type(i) in (slice,list):
+            return DataSetFields(MinibatchDataSet(
+                Example(self._fields.keys(),[field[i] for field in self._fields])),self.fieldNames())
         if type(i) is int:
-            return Example(self._fields.keys(),[field[i] for field in self._fields])
-        if type(i) in (slice,list):
-            return MinibatchDataSet(Example(self._fields.keys(),
-                                            [field[i] for field in self._fields]),
-                                    self.valuesVStack,self.valuesHStack)
+            return DataSetFields(MinibatchDataSet(
+                Example(self._fields.keys(),[[field[i]] for field in self._fields])),self.fieldNames())
+
         if self.hasFields(i):
             return self._fields[i]
         assert i in self.__dict__ # else it means we are trying to access a non-existing property
@@ -939,22 +940,29 @@
     def __len__(self):
         return len(self.data)
 
-    def __getitem__(self,i):
+    def __getitem__(self,key):
         """More efficient implementation than the default __getitem__"""
         fieldnames=self.fields_columns.keys()
-        if type(i) is int:
+        if type(key) is int:
             return Example(fieldnames,
-                           [self.data[i,self.fields_columns[f]] for f in fieldnames])
-        if type(i) in (slice,list):
+                           [self.data[key,self.fields_columns[f]] for f in fieldnames])
+        if type(key) is slice:
             return MinibatchDataSet(Example(fieldnames,
-                                            [self.data[i,self.fields_columns[f]] for f in fieldnames]),
+                                            [self.data[key,self.fields_columns[f]] for f in fieldnames]))
+        if type(key) is list:
+            for i in range(len(key)):
+                if self.hasFields(key[i]):
+                    key[i]=self.fields_columns[key[i]]
+            return MinibatchDataSet(Example(fieldnames,
+                                            [self.data[key,self.fields_columns[f]] for f in fieldnames]),
                                     self.valuesVStack,self.valuesHStack)
+
         # else check for a fieldname
-        if self.hasFields(i):
-            return Example([i],[self.data[self.fields_columns[i],:]])
+        if self.hasFields(key):
+            return self.data[self.fields_columns[key],:]
         # else we are trying to access a property of the dataset
-        assert i in self.__dict__ # else it means we are trying to access a non-existing property
-        return self.__dict__[i]
+        assert key in self.__dict__ # else it means we are trying to access a non-existing property
+        return self.__dict__[key]
         
             
     def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset):