diff dataset.py @ 245:c702abb7f875

merged
author James Bergstra <bergstrj@iro.umontreal.ca>
date Mon, 02 Jun 2008 17:09:58 -0400
parents c8f19a9eb10f
children 7e6edee187e3 4ad6bc9b4f03
line wrap: on
line diff
--- a/dataset.py	Mon Jun 02 17:08:17 2008 -0400
+++ b/dataset.py	Mon Jun 02 17:09:58 2008 -0400
@@ -47,14 +47,14 @@
     columns/attributes are called fields. The field value for a particular example can be an arbitrary
     python object, which depends on the particular dataset.
     
-    We call a DataSet a 'stream' when its length is unbounded (otherwise its __len__ method
+    We call a DataSet a 'stream' when its length is unbounded (in which case its __len__ method
     should return sys.maxint).
 
     A DataSet is a generator of iterators; these iterators can run through the
     examples or the fields in a variety of ways.  A DataSet need not necessarily have a finite
     or known length, so this class can be used to interface to a 'stream' which
     feeds on-line learning (however, as noted below, some operations are not
-    feasible or not recommanded on streams).
+    feasible or not recommended on streams).
 
     To iterate over examples, there are several possibilities:
      - for example in dataset:
@@ -81,7 +81,7 @@
      - for field_examples in dataset.fields():
         for example_value in field_examples:
            ...
-    but when the dataset is a stream (unbounded length), it is not recommanded to do 
+    but when the dataset is a stream (unbounded length), it is not recommended to do 
     such things because the underlying dataset may refuse to access the different fields in
     an unsynchronized ways. Hence the fields() method is illegal for streams, by default.
     The result of fields() is a L{DataSetFields} object, which iterates over fields,
@@ -599,7 +599,7 @@
     * for field_examples in dataset.fields():
         for example_value in field_examples:
            ...
-    but when the dataset is a stream (unbounded length), it is not recommanded to do 
+    but when the dataset is a stream (unbounded length), it is not recommended to do 
     such things because the underlying dataset may refuse to access the different fields in
     an unsynchronized ways. Hence the fields() method is illegal for streams, by default.
     The result of fields() is a DataSetFields object, which iterates over fields,
@@ -1016,12 +1016,13 @@
     def __getitem__(self,key):
         """More efficient implementation than the default __getitem__"""
         fieldnames=self.fields_columns.keys()
+        values=self.fields_columns.values()
         if type(key) is int:
             return Example(fieldnames,
-                           [self.data[key,self.fields_columns[f]] for f in fieldnames])
+                           [self.data[key,col] for col in values])
         if type(key) is slice:
             return MinibatchDataSet(Example(fieldnames,
-                                            [self.data[key,self.fields_columns[f]] for f in fieldnames]))
+                                            [self.data[key,col] for col in values]))
         if type(key) is list:
             for i in range(len(key)):
                 if self.hasFields(key[i]):
@@ -1030,9 +1031,10 @@
                                             #we must separate differently for list as numpy
                                             # doesn't support self.data[[i1,...],[i2,...]]
                                             # when their is more then two i1 and i2
-                                            [self.data[key,:][:,self.fields_columns[f]]
-                                             if isinstance(self.fields_columns[f],list) else
-                                             self.data[key,self.fields_columns[f]] for f in fieldnames]),
+                                            [self.data[key,:][:,col]
+                                             if isinstance(col,list) else
+                                             self.data[key,col] for col in values]),
+
 
                                     self.valuesVStack,self.valuesHStack)
 
@@ -1054,6 +1056,8 @@
                 assert offset>=0 and offset<len(dataset.data)
                 assert offset+minibatch_size<=len(dataset.data)
                 self.current=offset
+                self.columns = [self.dataset.fields_columns[f] 
+                                for f in self.minibatch._names]
             def __iter__(self):
                 return self
             def next(self):
@@ -1062,7 +1066,8 @@
                 if self.current>=self.dataset.data.shape[0]:
                     raise StopIteration
                 sub_data =  self.dataset.data[self.current]
-                self.minibatch._values = [sub_data[self.dataset.fields_columns[f]] for f in self.minibatch._names]
+                self.minibatch._values = [sub_data[c] for c in self.columns]
+
                 self.current+=self.minibatch_size
                 return self.minibatch