Mercurial > pylearn
changeset 509:6fe692b93b69
added shapeset1
author | James Bergstra <bergstrj@iro.umontreal.ca> |
---|---|
date | Thu, 30 Oct 2008 11:27:04 -0400 |
parents | 60b7dd5be860 |
children | 919125098a3b da916044454c |
files | datasets/shapeset1.py |
diffstat | 1 files changed, 61 insertions(+), 0 deletions(-) [+] |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/datasets/shapeset1.py Thu Oct 30 11:27:04 2008 -0400 @@ -0,0 +1,61 @@ +""" +Routines to load/access Shapeset1 +""" + +from __future__ import absolute_import + +import os +import numpy + +from ..amat import AMat +from .config import data_root + +def _head(path, n): + dat = AMat(path=path, head=n) + + try: + assert dat.input.shape[0] == n + assert dat.target.shape[0] == n + except Exception , e: + raise Exception("failed to read %i lines from file %s" % (n, path)) + + return dat.input, numpy.asarray(dat.target, dtype='int64').reshape(dat.target.shape[0]) + + +def head_train(n=10000): + """Load the first Shapeset1 training examples. + + Returns two matrices: x, y. x has N rows of 1024 columns. Each row of x represents the + 32x32 grey-scale pixels in raster order. y is a vector of N integers. Each element y[i] + is the label of the i'th row of x. + + """ + path = os.path.join(data_root(), 'shapeset1','shapeset1_1cspo_2_3.10000.train.shape.amat') + return _head(path, n) + +def head_valid(n=5000): + """Load the first Shapeset1 validation examples. + + Returns two matrices: x, y. x has N rows of 1024 columns. Each row of x represents the + 32x32 grey-scale pixels in raster order. y is a vector of N integers. Each element y[i] + is the label of the i'th row of x. + + """ + path = os.path.join(data_root(), 'shapeset1','shapeset1_1cspo_2_3.5000.valid.shape.amat') + return _head(path, n) + +def head_test(n=5000): + """Load the first Shapeset1 testing examples. + + Returns two matrices: x, y. x has N rows of 1024 columns. Each row of x represents the + 32x32 grey-scale pixels in raster order. y is a vector of N integers. Each element y[i] + is the label of the i'th row of x. + + """ + path = os.path.join(data_root(), 'shapeset1','shapeset1_1cspo_2_3.5000.test.shape.amat') + return _head(path, n) + +def train_valid_test(ntrain=10000, nvalid=5000, ntest=5000): + return head_train(n=ntrain), head_valid(n=nvalid), head_test(n=ntest) + +