changeset 762:c3e1e495d689

BUG FIXES !!!!!!!
author desjagui@atchoum.iro.umontreal.ca
date Wed, 03 Jun 2009 04:07:30 -0400
parents 60394c460390
children f353c9a99f95
files pylearn/datasets/norb_small.py
diffstat 1 files changed, 10 insertions(+), 4 deletions(-) [+]
line wrap: on
line diff
--- a/pylearn/datasets/norb_small.py	Tue Jun 02 22:29:49 2009 -0400
+++ b/pylearn/datasets/norb_small.py	Wed Jun 03 04:07:30 2009 -0400
@@ -81,7 +81,7 @@
         self.ntrain = ntrain
         self.nvalid = nvalid
         self.ntest = ntest
-        self.downsample_amt = 1
+        self.downsample_amt = downsample_amt
         self.normalize = normalize
         self.dtype = dtype
 
@@ -95,7 +95,7 @@
             # The validation set consists in an instance of each category
             # In order to know which indices correspond to which instance,
             # we need to load the 'info' files.
-            train_info = read(open(train['info']))
+            train_info = read(open(self.path.train['info']))
 
             ordered_itrain = numpy.nonzero(train_info[:,0] != valid_variant)[0]
             max_ntrain = ordered_itrain.shape[0]
@@ -111,11 +111,17 @@
                 self.nvalid = max_nvalid
 
             # Randomize
-            self.itr = ordered_itrain[rng.permutation(self.max_ntrain)][0:self.ntrain]
-            self.ival = ordered_ivalid[rng.permutation(self.max_ntrain)][0:self.nvalid]
+            print 
+            self.itr = ordered_itrain[rng.permutation(max_ntrain)][0:self.ntrain]
+            self.ival = ordered_ivalid[rng.permutation(max_nvalid)][0:self.nvalid]
 
         self.current = None
 
+    def preprocess(self, x):
+        if not self.normalize:
+            return numpy.float64(x *1.0 / 255.0)
+        return x
+
     def load(self, dataset='train'):
 
         if dataset == 'train' or dataset=='valid':