Mercurial > pylearn
changeset 756:8447bc9bb2d4
merge
author | Pascal Lamblin <lamblinp@iro.umontreal.ca> |
---|---|
date | Tue, 02 Jun 2009 21:15:41 -0400 |
parents | 6a703c5f2391 (diff) 390d8c5a1fee (current diff) |
children | 61a3608d5767 1e0fa60bfacd |
files | |
diffstat | 1 files changed, 25 insertions(+), 7 deletions(-) [+] |
line wrap: on
line diff
--- a/pylearn/datasets/norb_small.py Tue Jun 02 20:21:35 2009 -0400 +++ b/pylearn/datasets/norb_small.py Tue Jun 02 21:15:41 2009 -0400 @@ -63,11 +63,14 @@ test = {} train['dat'] = os.path.join(dirpath, 'smallnorb-5x46789x9x18x6x2x96x96-training-dat.mat') train['cat'] = os.path.join(dirpath, 'smallnorb-5x46789x9x18x6x2x96x96-training-cat.mat') + train['info'] = os.path.join(dirpath, 'smallnorb-5x46789x9x18x6x2x96x96-training-info.mat') test['dat'] = os.path.join(dirpath, 'smallnorb-5x01235x9x18x6x2x96x96-testing-dat.mat') test['cat'] = os.path.join(dirpath, 'smallnorb-5x01235x9x18x6x2x96x96-testing-cat.mat') + test['info'] = os.path.join(dirpath, 'smallnorb-5x01235x9x18x6x2x96x96-testing-info.mat') path = Paths() - def __init__(self, ntrain=19440, nvalid=4860, ntest=24300, + def __init__(self, ntrain=19440, nvalid=4860, ntest=24300, + valid_variant=None, downsample_amt=1, seed=1, normalize=False, mode='stereo', dtype='int8'): @@ -83,11 +86,26 @@ self.dtype = dtype rng = numpy.random.RandomState(seed) - self.indices = rng.permutation(self.nsamples) - self.itr = self.indices[0:ntrain] - self.ival = self.indices[ntrain:ntrain+nvalid] + if valid_variant is None: + # The validation set is just a random subset of training + self.indices = rng.permutation(self.nsamples) + self.itr = self.indices[0:ntrain] + self.ival = self.indices[ntrain:ntrain+nvalid] + elif valid_variant in (4,6,7,8,9): + # The validation set consists in an instance of each category + # In order to know which indices correspond to which instance, + # we need to load the 'info' files. + train_info = read(open(train['info'])) + + ordered_itrain = numpy.nonzero(train_info[:,0] != valid_variant)[0] + ordered_ivalid = numpy.nonzero(train_info[:,0] == valid_variant)[0] + + # TODO: randomize + self.itr = ordered_itrain + self.ival = ordered_ivalid + self.current = None - + def load(self, dataset='train'): if dataset == 'train' or dataset=='valid': @@ -99,7 +117,7 @@ print 'need to reload from train file' dat, cat = load_file(self.path.train, self.normalize, self.downsample_amt, self.dtype) - + x = dat[self.itr,...].reshape(self.ntrain,-1) y = cat[self.itr] self.dat1 = Dataset.Obj(x=x, y=y) # training @@ -126,7 +144,7 @@ x = dat.reshape(self.nsamples,-1) y = cat self.dat1 = Dataset.Obj(x=x, y=y) - + del dat, cat, x, y rval = self.dat1