Mercurial > pylearn
comparison dataset.py @ 293:4bfdda107a17
still merging
author | James Bergstra <bergstrj@iro.umontreal.ca> |
---|---|
date | Fri, 06 Jun 2008 16:13:17 -0400 |
parents | 174374d59405 |
children | f5d33f9c0b9c |
comparison
equal
deleted
inserted
replaced
292:174374d59405 | 293:4bfdda107a17 |
---|---|
454 """ | 454 """ |
455 | 455 |
456 if type(i) is int: | 456 if type(i) is int: |
457 #TODO: consider asserting that i >= 0 | 457 #TODO: consider asserting that i >= 0 |
458 i_batch = self.minibatches_nowrap(self.fieldNames(), | 458 i_batch = self.minibatches_nowrap(self.fieldNames(), |
459 minibatch_size=1, n_batches=1, offset=i % len(self)) | 459 minibatch_size=1, n_batches=1, offset=i) |
460 return DataSet.MinibatchToSingleExampleIterator(i_batch).next() | 460 return DataSet.MinibatchToSingleExampleIterator(i_batch).next() |
461 | 461 |
462 #if i is a contiguous slice | 462 #if i is a contiguous slice |
463 if type(i) is slice and (i.step in (None, 1)): | 463 if type(i) is slice and (i.step in (None, 1)): |
464 offset = 0 if i.start is None else i.start | 464 offset = 0 if i.start is None else i.start |
481 #dis-allow nested slices | 481 #dis-allow nested slices |
482 if not isinstance(idx, int): | 482 if not isinstance(idx, int): |
483 raise TypeError(idx) | 483 raise TypeError(idx) |
484 # call back into self.__getitem__ | 484 # call back into self.__getitem__ |
485 examples = [self.minibatches_nowrap(self.fieldNames(), | 485 examples = [self.minibatches_nowrap(self.fieldNames(), |
486 minibatch_size=1, n_batches=1, offset=ii%len(self)).next() | 486 minibatch_size=1, n_batches=1, offset=ii).next() |
487 for ii in i] | 487 for ii in i] |
488 # re-index the fields in each example by field instead of by example | 488 # re-index the fields in each example by field instead of by example |
489 field_values = [[] for blah in self.fieldNames()] | 489 field_values = [[] for blah in self.fieldNames()] |
490 for e in examples: | 490 for e in examples: |
491 for f,v in zip(field_values, e): | 491 for f,v in zip(field_values, e): |
1251 | 1251 |
1252 def fieldNames(self): | 1252 def fieldNames(self): |
1253 return self.output_names | 1253 return self.output_names |
1254 | 1254 |
1255 def minibatches_nowrap(self, fieldnames, *args, **kwargs): | 1255 def minibatches_nowrap(self, fieldnames, *args, **kwargs): |
1256 for fields in self.input_dataset.minibatches_nowrap(fieldnames, *args, **kwargs): | 1256 for input_fields in self.input_dataset.minibatches_nowrap(fieldnames, *args, **kwargs): |
1257 | 1257 |
1258 #function_inputs = self.input_iterator.next() | 1258 #function_inputs = self.input_iterator.next() |
1259 if self.minibatch_mode: | 1259 if self.minibatch_mode: |
1260 function_outputs = self.function(*fields) | 1260 function_outputs = self.function(*input_fields) |
1261 else: | 1261 else: |
1262 input_examples = zip(*fields) | 1262 input_examples = zip(*input_fields) |
1263 output_examples = [self.function(*input_example) | 1263 output_examples = [self.function(*input_example) |
1264 for input_example in input_examples] | 1264 for input_example in input_examples] |
1265 function_outputs = [self.valuesVStack(name,values) | 1265 function_outputs = [self.valuesVStack(name,values) |
1266 for name,values in zip(self.output_names, | 1266 for name,values in zip(self.output_names, |
1267 zip(*output_examples))] | 1267 zip(*output_examples))] |
1268 all_outputs = Example(self.output_names, function_outputs) | 1268 all_outputs = Example(self.output_names, function_outputs) |
1269 print fields | 1269 print 'input_fields', input_fields |
1270 print all_outputs | 1270 print 'all_outputs', all_outputs |
1271 if fieldnames==self.output_names: | |
1272 rval = all_outputs | |
1273 else: | |
1274 rval = Example(fieldnames,[all_outputs[name] for name in fieldnames]) | |
1275 print 'rval', rval | |
1271 print '--------' | 1276 print '--------' |
1272 if fieldnames==self.output_names: | 1277 yield rval |
1273 yield all_outputs | |
1274 else: | |
1275 yield Example(fieldnames,[all_outputs[name] for name in fieldnames]) | |
1276 | 1278 |
1277 def untested__iter__(self): # only implemented for increased efficiency | 1279 def untested__iter__(self): # only implemented for increased efficiency |
1278 class ApplyFunctionSingleExampleIterator(object): | 1280 class ApplyFunctionSingleExampleIterator(object): |
1279 def __init__(self,output_dataset): | 1281 def __init__(self,output_dataset): |
1280 self.current=0 | 1282 self.current=0 |