# HG changeset patch # User James Bergstra # Date 1225380424 14400 # Node ID 6fe692b93b691a37f01bcd5c7fa6263edbd74363 # Parent 60b7dd5be8608e61891379d5a3a1b9eda099d7f0 added shapeset1 diff -r 60b7dd5be860 -r 6fe692b93b69 datasets/shapeset1.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/datasets/shapeset1.py Thu Oct 30 11:27:04 2008 -0400 @@ -0,0 +1,61 @@ +""" +Routines to load/access Shapeset1 +""" + +from __future__ import absolute_import + +import os +import numpy + +from ..amat import AMat +from .config import data_root + +def _head(path, n): + dat = AMat(path=path, head=n) + + try: + assert dat.input.shape[0] == n + assert dat.target.shape[0] == n + except Exception , e: + raise Exception("failed to read %i lines from file %s" % (n, path)) + + return dat.input, numpy.asarray(dat.target, dtype='int64').reshape(dat.target.shape[0]) + + +def head_train(n=10000): + """Load the first Shapeset1 training examples. + + Returns two matrices: x, y. x has N rows of 1024 columns. Each row of x represents the + 32x32 grey-scale pixels in raster order. y is a vector of N integers. Each element y[i] + is the label of the i'th row of x. + + """ + path = os.path.join(data_root(), 'shapeset1','shapeset1_1cspo_2_3.10000.train.shape.amat') + return _head(path, n) + +def head_valid(n=5000): + """Load the first Shapeset1 validation examples. + + Returns two matrices: x, y. x has N rows of 1024 columns. Each row of x represents the + 32x32 grey-scale pixels in raster order. y is a vector of N integers. Each element y[i] + is the label of the i'th row of x. + + """ + path = os.path.join(data_root(), 'shapeset1','shapeset1_1cspo_2_3.5000.valid.shape.amat') + return _head(path, n) + +def head_test(n=5000): + """Load the first Shapeset1 testing examples. + + Returns two matrices: x, y. x has N rows of 1024 columns. Each row of x represents the + 32x32 grey-scale pixels in raster order. y is a vector of N integers. Each element y[i] + is the label of the i'th row of x. + + """ + path = os.path.join(data_root(), 'shapeset1','shapeset1_1cspo_2_3.5000.test.shape.amat') + return _head(path, n) + +def train_valid_test(ntrain=10000, nvalid=5000, ntest=5000): + return head_train(n=ntrain), head_valid(n=nvalid), head_test(n=ntest) + +