# HG changeset patch # User Pascal Lamblin # Date 1243991741 14400 # Node ID 8447bc9bb2d423a0d83605a4ba3ab8055d3f7539 # Parent 6a703c5f2391494953d5ac25e1d9486da125f438# Parent 390d8c5a1fee04b81bee86dd2e0f8255571fd268 merge diff -r 390d8c5a1fee -r 8447bc9bb2d4 pylearn/datasets/norb_small.py --- a/pylearn/datasets/norb_small.py Tue Jun 02 20:21:35 2009 -0400 +++ b/pylearn/datasets/norb_small.py Tue Jun 02 21:15:41 2009 -0400 @@ -63,11 +63,14 @@ test = {} train['dat'] = os.path.join(dirpath, 'smallnorb-5x46789x9x18x6x2x96x96-training-dat.mat') train['cat'] = os.path.join(dirpath, 'smallnorb-5x46789x9x18x6x2x96x96-training-cat.mat') + train['info'] = os.path.join(dirpath, 'smallnorb-5x46789x9x18x6x2x96x96-training-info.mat') test['dat'] = os.path.join(dirpath, 'smallnorb-5x01235x9x18x6x2x96x96-testing-dat.mat') test['cat'] = os.path.join(dirpath, 'smallnorb-5x01235x9x18x6x2x96x96-testing-cat.mat') + test['info'] = os.path.join(dirpath, 'smallnorb-5x01235x9x18x6x2x96x96-testing-info.mat') path = Paths() - def __init__(self, ntrain=19440, nvalid=4860, ntest=24300, + def __init__(self, ntrain=19440, nvalid=4860, ntest=24300, + valid_variant=None, downsample_amt=1, seed=1, normalize=False, mode='stereo', dtype='int8'): @@ -83,11 +86,26 @@ self.dtype = dtype rng = numpy.random.RandomState(seed) - self.indices = rng.permutation(self.nsamples) - self.itr = self.indices[0:ntrain] - self.ival = self.indices[ntrain:ntrain+nvalid] + if valid_variant is None: + # The validation set is just a random subset of training + self.indices = rng.permutation(self.nsamples) + self.itr = self.indices[0:ntrain] + self.ival = self.indices[ntrain:ntrain+nvalid] + elif valid_variant in (4,6,7,8,9): + # The validation set consists in an instance of each category + # In order to know which indices correspond to which instance, + # we need to load the 'info' files. + train_info = read(open(train['info'])) + + ordered_itrain = numpy.nonzero(train_info[:,0] != valid_variant)[0] + ordered_ivalid = numpy.nonzero(train_info[:,0] == valid_variant)[0] + + # TODO: randomize + self.itr = ordered_itrain + self.ival = ordered_ivalid + self.current = None - + def load(self, dataset='train'): if dataset == 'train' or dataset=='valid': @@ -99,7 +117,7 @@ print 'need to reload from train file' dat, cat = load_file(self.path.train, self.normalize, self.downsample_amt, self.dtype) - + x = dat[self.itr,...].reshape(self.ntrain,-1) y = cat[self.itr] self.dat1 = Dataset.Obj(x=x, y=y) # training @@ -126,7 +144,7 @@ x = dat.reshape(self.nsamples,-1) y = cat self.dat1 = Dataset.Obj(x=x, y=y) - + del dat, cat, x, y rval = self.dat1