Mercurial > pylearn
changeset 762:c3e1e495d689
BUG FIXES !!!!!!!
author | desjagui@atchoum.iro.umontreal.ca |
---|---|
date | Wed, 03 Jun 2009 04:07:30 -0400 |
parents | 60394c460390 |
children | f353c9a99f95 |
files | pylearn/datasets/norb_small.py |
diffstat | 1 files changed, 10 insertions(+), 4 deletions(-) [+] |
line wrap: on
line diff
--- a/pylearn/datasets/norb_small.py Tue Jun 02 22:29:49 2009 -0400 +++ b/pylearn/datasets/norb_small.py Wed Jun 03 04:07:30 2009 -0400 @@ -81,7 +81,7 @@ self.ntrain = ntrain self.nvalid = nvalid self.ntest = ntest - self.downsample_amt = 1 + self.downsample_amt = downsample_amt self.normalize = normalize self.dtype = dtype @@ -95,7 +95,7 @@ # The validation set consists in an instance of each category # In order to know which indices correspond to which instance, # we need to load the 'info' files. - train_info = read(open(train['info'])) + train_info = read(open(self.path.train['info'])) ordered_itrain = numpy.nonzero(train_info[:,0] != valid_variant)[0] max_ntrain = ordered_itrain.shape[0] @@ -111,11 +111,17 @@ self.nvalid = max_nvalid # Randomize - self.itr = ordered_itrain[rng.permutation(self.max_ntrain)][0:self.ntrain] - self.ival = ordered_ivalid[rng.permutation(self.max_ntrain)][0:self.nvalid] + print + self.itr = ordered_itrain[rng.permutation(max_ntrain)][0:self.ntrain] + self.ival = ordered_ivalid[rng.permutation(max_nvalid)][0:self.nvalid] self.current = None + def preprocess(self, x): + if not self.normalize: + return numpy.float64(x *1.0 / 255.0) + return x + def load(self, dataset='train'): if dataset == 'train' or dataset=='valid':