diff dataset.py @ 293:4bfdda107a17

still merging
author James Bergstra <bergstrj@iro.umontreal.ca>
date Fri, 06 Jun 2008 16:13:17 -0400
parents 174374d59405
children f5d33f9c0b9c
line wrap: on
line diff
--- a/dataset.py	Fri Jun 06 15:56:18 2008 -0400
+++ b/dataset.py	Fri Jun 06 16:13:17 2008 -0400
@@ -456,7 +456,7 @@
         if type(i) is int:
             #TODO: consider asserting that i >= 0
             i_batch = self.minibatches_nowrap(self.fieldNames(),
-                    minibatch_size=1, n_batches=1, offset=i % len(self))
+                    minibatch_size=1, n_batches=1, offset=i)
             return DataSet.MinibatchToSingleExampleIterator(i_batch).next()
 
         #if i is a contiguous slice
@@ -483,7 +483,7 @@
                     raise TypeError(idx)
             # call back into self.__getitem__
             examples = [self.minibatches_nowrap(self.fieldNames(),
-                    minibatch_size=1, n_batches=1, offset=ii%len(self)).next()
+                    minibatch_size=1, n_batches=1, offset=ii).next()
                     for ii in i]
             # re-index the fields in each example by field instead of by example
             field_values = [[] for blah in  self.fieldNames()]
@@ -1253,26 +1253,28 @@
         return self.output_names
 
     def minibatches_nowrap(self, fieldnames, *args, **kwargs):
-        for fields in self.input_dataset.minibatches_nowrap(fieldnames, *args, **kwargs):
+        for input_fields in self.input_dataset.minibatches_nowrap(fieldnames, *args, **kwargs):
 
             #function_inputs = self.input_iterator.next()
             if self.minibatch_mode:
-                function_outputs = self.function(*fields)
+                function_outputs = self.function(*input_fields)
             else:
-                input_examples = zip(*fields)
+                input_examples = zip(*input_fields)
                 output_examples = [self.function(*input_example)
                                     for input_example in input_examples]
                 function_outputs = [self.valuesVStack(name,values)
                                     for name,values in zip(self.output_names,
                                                            zip(*output_examples))]
             all_outputs = Example(self.output_names, function_outputs)
-            print fields
-            print all_outputs
-            print '--------'
+            print 'input_fields', input_fields
+            print 'all_outputs', all_outputs
             if fieldnames==self.output_names:
-                yield all_outputs
+                rval = all_outputs
             else:
-                yield Example(fieldnames,[all_outputs[name] for name in fieldnames])
+                rval = Example(fieldnames,[all_outputs[name] for name in fieldnames])
+            print 'rval', rval
+            print '--------'
+            yield rval
 
     def untested__iter__(self): # only implemented for increased efficiency
         class ApplyFunctionSingleExampleIterator(object):