Mercurial > pylearn
changeset 761:60394c460390
merge
author | Pascal Lamblin <lamblinp@iro.umontreal.ca> |
---|---|
date | Tue, 02 Jun 2009 22:29:49 -0400 |
parents | 1e0fa60bfacd (diff) 61a3608d5767 (current diff) |
children | c3e1e495d689 |
files | |
diffstat | 1 files changed, 13 insertions(+), 3 deletions(-) [+] |
line wrap: on
line diff
--- a/pylearn/datasets/norb_small.py Tue Jun 02 22:26:29 2009 -0400 +++ b/pylearn/datasets/norb_small.py Tue Jun 02 22:29:49 2009 -0400 @@ -98,11 +98,21 @@ train_info = read(open(train['info'])) ordered_itrain = numpy.nonzero(train_info[:,0] != valid_variant)[0] + max_ntrain = ordered_itrain.shape[0] ordered_ivalid = numpy.nonzero(train_info[:,0] == valid_variant)[0] + max_nvalid = ordered_ivalid.shape[0] + + if self.ntrain > max_ntrain: + print 'WARNING: ntrain is %i, but there are only %i training samples available' % (self.ntrain, max_ntrain) + self.ntrain = max_ntrain - # TODO: randomize - self.itr = ordered_itrain - self.ival = ordered_ivalid + if self.nvalid > max_nvalid: + print 'WARNING: nvalid is %i, but there are only %i validation samples available' % (self.nvalid, max_nvalid) + self.nvalid = max_nvalid + + # Randomize + self.itr = ordered_itrain[rng.permutation(self.max_ntrain)][0:self.ntrain] + self.ival = ordered_ivalid[rng.permutation(self.max_ntrain)][0:self.nvalid] self.current = None