# HG changeset patch # User Pascal Lamblin # Date 1243996189 14400 # Node ID 60394c46039033bd45da32785197422a7b1880b3 # Parent 1e0fa60bfacd07bb55cbff4c3f4e023e73dd5584# Parent 61a3608d5767232867dc64eba207b5396ce08311 merge diff -r 61a3608d5767 -r 60394c460390 pylearn/datasets/norb_small.py --- a/pylearn/datasets/norb_small.py Tue Jun 02 22:26:29 2009 -0400 +++ b/pylearn/datasets/norb_small.py Tue Jun 02 22:29:49 2009 -0400 @@ -98,11 +98,21 @@ train_info = read(open(train['info'])) ordered_itrain = numpy.nonzero(train_info[:,0] != valid_variant)[0] + max_ntrain = ordered_itrain.shape[0] ordered_ivalid = numpy.nonzero(train_info[:,0] == valid_variant)[0] + max_nvalid = ordered_ivalid.shape[0] + + if self.ntrain > max_ntrain: + print 'WARNING: ntrain is %i, but there are only %i training samples available' % (self.ntrain, max_ntrain) + self.ntrain = max_ntrain - # TODO: randomize - self.itr = ordered_itrain - self.ival = ordered_ivalid + if self.nvalid > max_nvalid: + print 'WARNING: nvalid is %i, but there are only %i validation samples available' % (self.nvalid, max_nvalid) + self.nvalid = max_nvalid + + # Randomize + self.itr = ordered_itrain[rng.permutation(self.max_ntrain)][0:self.ntrain] + self.ival = ordered_ivalid[rng.permutation(self.max_ntrain)][0:self.nvalid] self.current = None