changeset 756:8447bc9bb2d4

merge
author Pascal Lamblin <lamblinp@iro.umontreal.ca>
date Tue, 02 Jun 2009 21:15:41 -0400
parents 6a703c5f2391 (diff) 390d8c5a1fee (current diff)
children 61a3608d5767 1e0fa60bfacd
files
diffstat 1 files changed, 25 insertions(+), 7 deletions(-) [+]
line wrap: on
line diff
--- a/pylearn/datasets/norb_small.py	Tue Jun 02 20:21:35 2009 -0400
+++ b/pylearn/datasets/norb_small.py	Tue Jun 02 21:15:41 2009 -0400
@@ -63,11 +63,14 @@
         test = {}
         train['dat'] = os.path.join(dirpath, 'smallnorb-5x46789x9x18x6x2x96x96-training-dat.mat')
         train['cat'] = os.path.join(dirpath, 'smallnorb-5x46789x9x18x6x2x96x96-training-cat.mat')
+        train['info'] = os.path.join(dirpath, 'smallnorb-5x46789x9x18x6x2x96x96-training-info.mat')
         test['dat']  = os.path.join(dirpath, 'smallnorb-5x01235x9x18x6x2x96x96-testing-dat.mat')
         test['cat']  = os.path.join(dirpath, 'smallnorb-5x01235x9x18x6x2x96x96-testing-cat.mat')
+        test['info']  = os.path.join(dirpath, 'smallnorb-5x01235x9x18x6x2x96x96-testing-info.mat')
     path = Paths()
 
-    def __init__(self, ntrain=19440, nvalid=4860, ntest=24300, 
+    def __init__(self, ntrain=19440, nvalid=4860, ntest=24300,
+               valid_variant=None,
                downsample_amt=1, seed=1, normalize=False,
                mode='stereo', dtype='int8'):
 
@@ -83,11 +86,26 @@
         self.dtype = dtype
 
         rng = numpy.random.RandomState(seed)
-        self.indices = rng.permutation(self.nsamples)
-        self.itr  = self.indices[0:ntrain]
-        self.ival = self.indices[ntrain:ntrain+nvalid]
+        if valid_variant is None:
+            # The validation set is just a random subset of training
+            self.indices = rng.permutation(self.nsamples)
+            self.itr  = self.indices[0:ntrain]
+            self.ival = self.indices[ntrain:ntrain+nvalid]
+        elif valid_variant in (4,6,7,8,9):
+            # The validation set consists in an instance of each category
+            # In order to know which indices correspond to which instance,
+            # we need to load the 'info' files.
+            train_info = read(open(train['info']))
+
+            ordered_itrain = numpy.nonzero(train_info[:,0] != valid_variant)[0]
+            ordered_ivalid = numpy.nonzero(train_info[:,0] == valid_variant)[0]
+
+            # TODO: randomize
+            self.itr = ordered_itrain
+            self.ival = ordered_ivalid
+
         self.current = None
- 
+
     def load(self, dataset='train'):
 
         if dataset == 'train' or dataset=='valid':
@@ -99,7 +117,7 @@
                 print 'need to reload from train file'
                 dat, cat  = load_file(self.path.train, self.normalize,
                                       self.downsample_amt, self.dtype)
-                
+
                 x = dat[self.itr,...].reshape(self.ntrain,-1)
                 y = cat[self.itr]
                 self.dat1 = Dataset.Obj(x=x, y=y) # training
@@ -126,7 +144,7 @@
                 x = dat.reshape(self.nsamples,-1)
                 y = cat
                 self.dat1 = Dataset.Obj(x=x, y=y)
-                
+
                 del dat, cat, x, y
 
             rval = self.dat1