Mercurial > pylearn
comparison dataset.py @ 313:009ce84e9f52
behaviour is now the same as a list in pylearn, so if len(ds) = 10, ds[10] raise an IndexError, same thing for ds[[1,10]], and ds[0:14:1] returns 10 elements
author | Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca> |
---|---|
date | Wed, 11 Jun 2008 13:53:39 -0400 |
parents | ebccfd05ccd5 |
children | 105b54ac8260 |
comparison
equal
deleted
inserted
replaced
312:96cca78de3ed | 313:009ce84e9f52 |
---|---|
453 the data are actually stored in a memory array. | 453 the data are actually stored in a memory array. |
454 """ | 454 """ |
455 | 455 |
456 if type(i) is int: | 456 if type(i) is int: |
457 assert i >= 0 # TBM: see if someone complains and want negative i | 457 assert i >= 0 # TBM: see if someone complains and want negative i |
458 if i >= len(self) : | |
459 raise IndexError | |
458 i_batch = self.minibatches_nowrap(self.fieldNames(), | 460 i_batch = self.minibatches_nowrap(self.fieldNames(), |
459 minibatch_size=1, n_batches=1, offset=i) | 461 minibatch_size=1, n_batches=1, offset=i) |
460 return DataSet.MinibatchToSingleExampleIterator(i_batch).next() | 462 return DataSet.MinibatchToSingleExampleIterator(i_batch).next() |
461 | 463 |
462 #if i is a contiguous slice | 464 #if i is a contiguous slice |
463 if type(i) is slice and (i.step in (None, 1)): | 465 if type(i) is slice and (i.step in (None, 1)): |
464 offset = 0 if i.start is None else i.start | 466 offset = 0 if i.start is None else i.start |
465 upper_bound = len(self) if i.stop is None else i.stop | 467 upper_bound = len(self) if i.stop is None else i.stop |
468 upper_bound = min(len(self) , upper_bound) | |
466 #return MinibatchDataSet(self.minibatches_nowrap(self.fieldNames(), | 469 #return MinibatchDataSet(self.minibatches_nowrap(self.fieldNames(), |
467 # minibatch_size=upper_bound - offset, | 470 # minibatch_size=upper_bound - offset, |
468 # n_batches=1, | 471 # n_batches=1, |
469 # offset=offset).next()) | 472 # offset=offset).next()) |
470 # now returns a LookupList | 473 # now returns a LookupList |
484 if hasattr(i, '__getitem__'): | 487 if hasattr(i, '__getitem__'): |
485 for idx in i: | 488 for idx in i: |
486 #dis-allow nested slices | 489 #dis-allow nested slices |
487 if not isinstance(idx, int): | 490 if not isinstance(idx, int): |
488 raise TypeError(idx) | 491 raise TypeError(idx) |
492 if idx >= len(self) : | |
493 raise IndexError | |
489 # call back into self.__getitem__ | 494 # call back into self.__getitem__ |
490 examples = [self.minibatches_nowrap(self.fieldNames(), | 495 examples = [self.minibatches_nowrap(self.fieldNames(), |
491 minibatch_size=1, n_batches=1, offset=ii).next() | 496 minibatch_size=1, n_batches=1, offset=ii).next() |
492 for ii in i] | 497 for ii in i] |
493 # re-index the fields in each example by field instead of by example | 498 # re-index the fields in each example by field instead of by example |