diff dataset.py @ 376:c9a89be5cb0a

Redesigning linear_regression
author Yoshua Bengio <bengioy@iro.umontreal.ca>
date Mon, 07 Jul 2008 10:08:35 -0400
parents 18702ceb2096
children 835830e52b42
line wrap: on
line diff
--- a/dataset.py	Mon Jun 16 17:47:36 2008 -0400
+++ b/dataset.py	Mon Jul 07 10:08:35 2008 -0400
@@ -1,6 +1,6 @@
 
 from lookup_list import LookupList as Example
-from misc import unique_elements_list_intersection
+from common.misc import unique_elements_list_intersection
 from string import join
 from sys import maxint
 import numpy, copy
@@ -381,7 +381,8 @@
         any other object that supports integer indexing and slicing.
 
         @ATTENTION: now minibatches returns minibatches_nowrap, which is supposed to return complete
-        batches only, raise StopIteration
+        batches only, raise StopIteration.
+        @ATTENTION: minibatches returns a LookupList, we can't iterate over examples on it.
 
         """
         #return DataSet.MinibatchWrapAroundIterator(self,fieldnames,minibatch_size,n_batches,offset)\
@@ -435,6 +436,16 @@
         Return a dataset that sees only the fields whose name are specified.
         """
         assert self.hasFields(*fieldnames)
+        #return self.fields(*fieldnames).examples()
+        fieldnames_list = list(fieldnames)
+        return FieldsSubsetDataSet(self,fieldnames_list)
+
+    def cached_fields_subset(self,*fieldnames) :
+        """
+        Behaviour is supposed to be the same as __call__(*fieldnames), but the dataset returned is cached.
+        @see : dataset.__call__
+        """
+        assert self.hasFields(*fieldnames)
         return self.fields(*fieldnames).examples()
 
     def fields(self,*fieldnames):
@@ -692,6 +703,7 @@
         assert len(src_fieldnames)==len(new_fieldnames)
         self.valuesHStack = src.valuesHStack
         self.valuesVStack = src.valuesVStack
+        self.lookup_fields = Example(new_fieldnames,src_fieldnames)
 
     def __len__(self): return len(self.src)
     
@@ -719,9 +731,18 @@
 
     def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset):
         assert self.hasFields(*fieldnames)
-        return self.src.minibatches_nowrap(fieldnames,minibatch_size,n_batches,offset)
+        cursor = Example(fieldnames,[0]*len(fieldnames))
+        for batch in self.src.minibatches_nowrap([self.lookup_fields[f] for f in fieldnames],minibatch_size,n_batches,offset):
+            cursor._values=batch._values
+            yield cursor
+    
     def __getitem__(self,i):
-        return FieldsSubsetDataSet(self.src[i],self.new_fieldnames)
+#        return FieldsSubsetDataSet(self.src[i],self.new_fieldnames)
+        complete_example = self.src[i]
+        return Example(self.new_fieldnames,
+                             [complete_example[field]
+                              for field in self.src_fieldnames])
+
 
 
 class DataSetFields(Example):
@@ -859,7 +880,9 @@
                 return self
             def next(self):
                 upper = self.next_example+minibatch_size
-                assert upper<=self.ds.length
+                if upper > len(self.ds) :
+                    raise StopIteration()
+                assert upper<=len(self.ds) # instead of self.ds.length
                 #minibatch = Example(self.ds._fields.keys(),
                 #                    [field[self.next_example:upper]
                 #                     for field in self.ds._fields])
@@ -1314,7 +1337,10 @@
           # into memory at once, which may be too much
           # the work could possibly be done by minibatches
           # that are as large as possible but no more than what memory allows.
-          fields_values = source_dataset.minibatches(minibatch_size=len(source_dataset)).__iter__().next()
+          #
+          # field_values is supposed to be an DataSetFields, that inherits from LookupList
+          #fields_values = source_dataset.minibatches(minibatch_size=len(source_dataset)).__iter__().next()
+          fields_values = DataSetFields(source_dataset,None)
           assert all([len(self)==len(field_values) for field_values in fields_values])
           for example in fields_values.examples():
               self.cached_examples.append(copy.copy(example))
@@ -1333,16 +1359,25 @@
               self.dataset=dataset
               self.current=offset
               self.all_fields = self.dataset.fieldNames()==fieldnames
+              self.n_batches = n_batches
+              self.batch_counter = 0
           def __iter__(self): return self
           def next(self):
+              self.batch_counter += 1
+              if self.n_batches and self.batch_counter > self.n_batches :
+                  raise StopIteration()
               upper = self.current+minibatch_size
+              if upper > len(self.dataset.source_dataset):
+                  raise StopIteration()
               cache_len = len(self.dataset.cached_examples)
               if upper>cache_len: # whole minibatch is not already in cache
                   # cache everything from current length to upper
-                  for example in self.dataset.source_dataset[cache_len:upper]:
+                  #for example in self.dataset.source_dataset[cache_len:upper]:
+                  for example in self.dataset.source_dataset.subset[cache_len:upper]:
                       self.dataset.cached_examples.append(example)
               all_fields_minibatch = Example(self.dataset.fieldNames(),
                                              zip(*self.dataset.cached_examples[self.current:self.current+minibatch_size]))
+
               self.current+=minibatch_size
               if self.all_fields:
                   return all_fields_minibatch