changeset 761:60394c460390

merge
author Pascal Lamblin <lamblinp@iro.umontreal.ca>
date Tue, 02 Jun 2009 22:29:49 -0400
parents 1e0fa60bfacd (diff) 61a3608d5767 (current diff)
children c3e1e495d689
files
diffstat 1 files changed, 13 insertions(+), 3 deletions(-) [+]
line wrap: on
line diff
--- a/pylearn/datasets/norb_small.py	Tue Jun 02 22:26:29 2009 -0400
+++ b/pylearn/datasets/norb_small.py	Tue Jun 02 22:29:49 2009 -0400
@@ -98,11 +98,21 @@
             train_info = read(open(train['info']))
 
             ordered_itrain = numpy.nonzero(train_info[:,0] != valid_variant)[0]
+            max_ntrain = ordered_itrain.shape[0]
             ordered_ivalid = numpy.nonzero(train_info[:,0] == valid_variant)[0]
+            max_nvalid = ordered_ivalid.shape[0]
+
+            if self.ntrain > max_ntrain:
+                print 'WARNING: ntrain is %i, but there are only %i training samples available' % (self.ntrain, max_ntrain)
+                self.ntrain = max_ntrain
 
-            # TODO: randomize
-            self.itr = ordered_itrain
-            self.ival = ordered_ivalid
+            if self.nvalid > max_nvalid:
+                print 'WARNING: nvalid is %i, but there are only %i validation samples available' % (self.nvalid, max_nvalid)
+                self.nvalid = max_nvalid
+
+            # Randomize
+            self.itr = ordered_itrain[rng.permutation(self.max_ntrain)][0:self.ntrain]
+            self.ival = ordered_ivalid[rng.permutation(self.max_ntrain)][0:self.nvalid]
 
         self.current = None