# HG changeset patch # User James Bergstra # Date 1212783197 14400 # Node ID 4bfdda107a1755373091046006c22fadebbeb2d7 # Parent 174374d5940555ad068fc9d4b13cf4dfa147b733 still merging diff -r 174374d59405 -r 4bfdda107a17 _test_dataset.py --- a/_test_dataset.py Fri Jun 06 15:56:18 2008 -0400 +++ b/_test_dataset.py Fri Jun 06 16:13:17 2008 -0400 @@ -47,6 +47,11 @@ #not in doc!!! i=0 for example in range(len(ds)): + wanted = array[example][:3] + returned = ds[example]['x'] + if (wanted != returned).all(): + print 'returned:', returned + print 'wanted:', wanted assert (ds[example]['x']==array[example][:3]).all() assert ds[example]['y']==array[example][3] assert (ds[example]['z']==array[example][[0,2]]).all() @@ -226,8 +231,7 @@ assert i==m.n_batches*m.minibatch_size del x,y,i,id - #@todo: we can't do minibatch bigger then the size of the dataset??? - assert have_raised2(ds.minibatches,['x','y'],n_batches=1,minibatch_size=len(array)+1,offset=0) + assert not have_raised2(ds.minibatches,['x','y'],n_batches=1,minibatch_size=len(array)+1,offset=0) assert not have_raised2(ds.minibatches,['x','y'],n_batches=1,minibatch_size=len(array),offset=0) def test_ds_iterator(array,iterator1,iterator2,iterator3): diff -r 174374d59405 -r 4bfdda107a17 dataset.py --- a/dataset.py Fri Jun 06 15:56:18 2008 -0400 +++ b/dataset.py Fri Jun 06 16:13:17 2008 -0400 @@ -456,7 +456,7 @@ if type(i) is int: #TODO: consider asserting that i >= 0 i_batch = self.minibatches_nowrap(self.fieldNames(), - minibatch_size=1, n_batches=1, offset=i % len(self)) + minibatch_size=1, n_batches=1, offset=i) return DataSet.MinibatchToSingleExampleIterator(i_batch).next() #if i is a contiguous slice @@ -483,7 +483,7 @@ raise TypeError(idx) # call back into self.__getitem__ examples = [self.minibatches_nowrap(self.fieldNames(), - minibatch_size=1, n_batches=1, offset=ii%len(self)).next() + minibatch_size=1, n_batches=1, offset=ii).next() for ii in i] # re-index the fields in each example by field instead of by example field_values = [[] for blah in self.fieldNames()] @@ -1253,26 +1253,28 @@ return self.output_names def minibatches_nowrap(self, fieldnames, *args, **kwargs): - for fields in self.input_dataset.minibatches_nowrap(fieldnames, *args, **kwargs): + for input_fields in self.input_dataset.minibatches_nowrap(fieldnames, *args, **kwargs): #function_inputs = self.input_iterator.next() if self.minibatch_mode: - function_outputs = self.function(*fields) + function_outputs = self.function(*input_fields) else: - input_examples = zip(*fields) + input_examples = zip(*input_fields) output_examples = [self.function(*input_example) for input_example in input_examples] function_outputs = [self.valuesVStack(name,values) for name,values in zip(self.output_names, zip(*output_examples))] all_outputs = Example(self.output_names, function_outputs) - print fields - print all_outputs - print '--------' + print 'input_fields', input_fields + print 'all_outputs', all_outputs if fieldnames==self.output_names: - yield all_outputs + rval = all_outputs else: - yield Example(fieldnames,[all_outputs[name] for name in fieldnames]) + rval = Example(fieldnames,[all_outputs[name] for name in fieldnames]) + print 'rval', rval + print '--------' + yield rval def untested__iter__(self): # only implemented for increased efficiency class ApplyFunctionSingleExampleIterator(object):