comparison dataset.py @ 293:4bfdda107a17

still merging
author James Bergstra <bergstrj@iro.umontreal.ca>
date Fri, 06 Jun 2008 16:13:17 -0400
parents 174374d59405
children f5d33f9c0b9c
comparison
equal deleted inserted replaced
292:174374d59405 293:4bfdda107a17
454 """ 454 """
455 455
456 if type(i) is int: 456 if type(i) is int:
457 #TODO: consider asserting that i >= 0 457 #TODO: consider asserting that i >= 0
458 i_batch = self.minibatches_nowrap(self.fieldNames(), 458 i_batch = self.minibatches_nowrap(self.fieldNames(),
459 minibatch_size=1, n_batches=1, offset=i % len(self)) 459 minibatch_size=1, n_batches=1, offset=i)
460 return DataSet.MinibatchToSingleExampleIterator(i_batch).next() 460 return DataSet.MinibatchToSingleExampleIterator(i_batch).next()
461 461
462 #if i is a contiguous slice 462 #if i is a contiguous slice
463 if type(i) is slice and (i.step in (None, 1)): 463 if type(i) is slice and (i.step in (None, 1)):
464 offset = 0 if i.start is None else i.start 464 offset = 0 if i.start is None else i.start
481 #dis-allow nested slices 481 #dis-allow nested slices
482 if not isinstance(idx, int): 482 if not isinstance(idx, int):
483 raise TypeError(idx) 483 raise TypeError(idx)
484 # call back into self.__getitem__ 484 # call back into self.__getitem__
485 examples = [self.minibatches_nowrap(self.fieldNames(), 485 examples = [self.minibatches_nowrap(self.fieldNames(),
486 minibatch_size=1, n_batches=1, offset=ii%len(self)).next() 486 minibatch_size=1, n_batches=1, offset=ii).next()
487 for ii in i] 487 for ii in i]
488 # re-index the fields in each example by field instead of by example 488 # re-index the fields in each example by field instead of by example
489 field_values = [[] for blah in self.fieldNames()] 489 field_values = [[] for blah in self.fieldNames()]
490 for e in examples: 490 for e in examples:
491 for f,v in zip(field_values, e): 491 for f,v in zip(field_values, e):
1251 1251
1252 def fieldNames(self): 1252 def fieldNames(self):
1253 return self.output_names 1253 return self.output_names
1254 1254
1255 def minibatches_nowrap(self, fieldnames, *args, **kwargs): 1255 def minibatches_nowrap(self, fieldnames, *args, **kwargs):
1256 for fields in self.input_dataset.minibatches_nowrap(fieldnames, *args, **kwargs): 1256 for input_fields in self.input_dataset.minibatches_nowrap(fieldnames, *args, **kwargs):
1257 1257
1258 #function_inputs = self.input_iterator.next() 1258 #function_inputs = self.input_iterator.next()
1259 if self.minibatch_mode: 1259 if self.minibatch_mode:
1260 function_outputs = self.function(*fields) 1260 function_outputs = self.function(*input_fields)
1261 else: 1261 else:
1262 input_examples = zip(*fields) 1262 input_examples = zip(*input_fields)
1263 output_examples = [self.function(*input_example) 1263 output_examples = [self.function(*input_example)
1264 for input_example in input_examples] 1264 for input_example in input_examples]
1265 function_outputs = [self.valuesVStack(name,values) 1265 function_outputs = [self.valuesVStack(name,values)
1266 for name,values in zip(self.output_names, 1266 for name,values in zip(self.output_names,
1267 zip(*output_examples))] 1267 zip(*output_examples))]
1268 all_outputs = Example(self.output_names, function_outputs) 1268 all_outputs = Example(self.output_names, function_outputs)
1269 print fields 1269 print 'input_fields', input_fields
1270 print all_outputs 1270 print 'all_outputs', all_outputs
1271 if fieldnames==self.output_names:
1272 rval = all_outputs
1273 else:
1274 rval = Example(fieldnames,[all_outputs[name] for name in fieldnames])
1275 print 'rval', rval
1271 print '--------' 1276 print '--------'
1272 if fieldnames==self.output_names: 1277 yield rval
1273 yield all_outputs
1274 else:
1275 yield Example(fieldnames,[all_outputs[name] for name in fieldnames])
1276 1278
1277 def untested__iter__(self): # only implemented for increased efficiency 1279 def untested__iter__(self): # only implemented for increased efficiency
1278 class ApplyFunctionSingleExampleIterator(object): 1280 class ApplyFunctionSingleExampleIterator(object):
1279 def __init__(self,output_dataset): 1281 def __init__(self,output_dataset):
1280 self.current=0 1282 self.current=0