changeset 293:4bfdda107a17

still merging
author James Bergstra <bergstrj@iro.umontreal.ca>
date Fri, 06 Jun 2008 16:13:17 -0400
parents 174374d59405
children f7924e13e426
files _test_dataset.py dataset.py
diffstat 2 files changed, 18 insertions(+), 12 deletions(-) [+]
line wrap: on
line diff
--- a/_test_dataset.py	Fri Jun 06 15:56:18 2008 -0400
+++ b/_test_dataset.py	Fri Jun 06 16:13:17 2008 -0400
@@ -47,6 +47,11 @@
 #not in doc!!!
     i=0
     for example in range(len(ds)):
+        wanted = array[example][:3]
+        returned = ds[example]['x']
+        if (wanted != returned).all():
+            print 'returned:', returned
+            print 'wanted:', wanted
         assert (ds[example]['x']==array[example][:3]).all()
         assert ds[example]['y']==array[example][3]
         assert (ds[example]['z']==array[example][[0,2]]).all()
@@ -226,8 +231,7 @@
     assert i==m.n_batches*m.minibatch_size
     del x,y,i,id
 
-    #@todo: we can't do minibatch bigger then the size of the dataset???
-    assert have_raised2(ds.minibatches,['x','y'],n_batches=1,minibatch_size=len(array)+1,offset=0)
+    assert not have_raised2(ds.minibatches,['x','y'],n_batches=1,minibatch_size=len(array)+1,offset=0)
     assert not have_raised2(ds.minibatches,['x','y'],n_batches=1,minibatch_size=len(array),offset=0)
 
 def test_ds_iterator(array,iterator1,iterator2,iterator3):
--- a/dataset.py	Fri Jun 06 15:56:18 2008 -0400
+++ b/dataset.py	Fri Jun 06 16:13:17 2008 -0400
@@ -456,7 +456,7 @@
         if type(i) is int:
             #TODO: consider asserting that i >= 0
             i_batch = self.minibatches_nowrap(self.fieldNames(),
-                    minibatch_size=1, n_batches=1, offset=i % len(self))
+                    minibatch_size=1, n_batches=1, offset=i)
             return DataSet.MinibatchToSingleExampleIterator(i_batch).next()
 
         #if i is a contiguous slice
@@ -483,7 +483,7 @@
                     raise TypeError(idx)
             # call back into self.__getitem__
             examples = [self.minibatches_nowrap(self.fieldNames(),
-                    minibatch_size=1, n_batches=1, offset=ii%len(self)).next()
+                    minibatch_size=1, n_batches=1, offset=ii).next()
                     for ii in i]
             # re-index the fields in each example by field instead of by example
             field_values = [[] for blah in  self.fieldNames()]
@@ -1253,26 +1253,28 @@
         return self.output_names
 
     def minibatches_nowrap(self, fieldnames, *args, **kwargs):
-        for fields in self.input_dataset.minibatches_nowrap(fieldnames, *args, **kwargs):
+        for input_fields in self.input_dataset.minibatches_nowrap(fieldnames, *args, **kwargs):
 
             #function_inputs = self.input_iterator.next()
             if self.minibatch_mode:
-                function_outputs = self.function(*fields)
+                function_outputs = self.function(*input_fields)
             else:
-                input_examples = zip(*fields)
+                input_examples = zip(*input_fields)
                 output_examples = [self.function(*input_example)
                                     for input_example in input_examples]
                 function_outputs = [self.valuesVStack(name,values)
                                     for name,values in zip(self.output_names,
                                                            zip(*output_examples))]
             all_outputs = Example(self.output_names, function_outputs)
-            print fields
-            print all_outputs
-            print '--------'
+            print 'input_fields', input_fields
+            print 'all_outputs', all_outputs
             if fieldnames==self.output_names:
-                yield all_outputs
+                rval = all_outputs
             else:
-                yield Example(fieldnames,[all_outputs[name] for name in fieldnames])
+                rval = Example(fieldnames,[all_outputs[name] for name in fieldnames])
+            print 'rval', rval
+            print '--------'
+            yield rval
 
     def untested__iter__(self): # only implemented for increased efficiency
         class ApplyFunctionSingleExampleIterator(object):