Mercurial > pylearn
diff dataset.py @ 293:4bfdda107a17
still merging
author | James Bergstra <bergstrj@iro.umontreal.ca> |
---|---|
date | Fri, 06 Jun 2008 16:13:17 -0400 |
parents | 174374d59405 |
children | f5d33f9c0b9c |
line wrap: on
line diff
--- a/dataset.py Fri Jun 06 15:56:18 2008 -0400 +++ b/dataset.py Fri Jun 06 16:13:17 2008 -0400 @@ -456,7 +456,7 @@ if type(i) is int: #TODO: consider asserting that i >= 0 i_batch = self.minibatches_nowrap(self.fieldNames(), - minibatch_size=1, n_batches=1, offset=i % len(self)) + minibatch_size=1, n_batches=1, offset=i) return DataSet.MinibatchToSingleExampleIterator(i_batch).next() #if i is a contiguous slice @@ -483,7 +483,7 @@ raise TypeError(idx) # call back into self.__getitem__ examples = [self.minibatches_nowrap(self.fieldNames(), - minibatch_size=1, n_batches=1, offset=ii%len(self)).next() + minibatch_size=1, n_batches=1, offset=ii).next() for ii in i] # re-index the fields in each example by field instead of by example field_values = [[] for blah in self.fieldNames()] @@ -1253,26 +1253,28 @@ return self.output_names def minibatches_nowrap(self, fieldnames, *args, **kwargs): - for fields in self.input_dataset.minibatches_nowrap(fieldnames, *args, **kwargs): + for input_fields in self.input_dataset.minibatches_nowrap(fieldnames, *args, **kwargs): #function_inputs = self.input_iterator.next() if self.minibatch_mode: - function_outputs = self.function(*fields) + function_outputs = self.function(*input_fields) else: - input_examples = zip(*fields) + input_examples = zip(*input_fields) output_examples = [self.function(*input_example) for input_example in input_examples] function_outputs = [self.valuesVStack(name,values) for name,values in zip(self.output_names, zip(*output_examples))] all_outputs = Example(self.output_names, function_outputs) - print fields - print all_outputs - print '--------' + print 'input_fields', input_fields + print 'all_outputs', all_outputs if fieldnames==self.output_names: - yield all_outputs + rval = all_outputs else: - yield Example(fieldnames,[all_outputs[name] for name in fieldnames]) + rval = Example(fieldnames,[all_outputs[name] for name in fieldnames]) + print 'rval', rval + print '--------' + yield rval def untested__iter__(self): # only implemented for increased efficiency class ApplyFunctionSingleExampleIterator(object):