Mercurial > pylearn
view datasets/shapeset1.py @ 528:cfe3f62a08cb
bugfix, outputs needed to be in a list in the past, not anymore.
author | Frederic Bastien <bastienf@iro.umontreal.ca> |
---|---|
date | Mon, 17 Nov 2008 13:17:00 -0500 |
parents | 6fe692b93b69 |
children |
line wrap: on
line source
""" Routines to load/access Shapeset1 """ from __future__ import absolute_import import os import numpy from ..amat import AMat from .config import data_root def _head(path, n): dat = AMat(path=path, head=n) try: assert dat.input.shape[0] == n assert dat.target.shape[0] == n except Exception , e: raise Exception("failed to read %i lines from file %s" % (n, path)) return dat.input, numpy.asarray(dat.target, dtype='int64').reshape(dat.target.shape[0]) def head_train(n=10000): """Load the first Shapeset1 training examples. Returns two matrices: x, y. x has N rows of 1024 columns. Each row of x represents the 32x32 grey-scale pixels in raster order. y is a vector of N integers. Each element y[i] is the label of the i'th row of x. """ path = os.path.join(data_root(), 'shapeset1','shapeset1_1cspo_2_3.10000.train.shape.amat') return _head(path, n) def head_valid(n=5000): """Load the first Shapeset1 validation examples. Returns two matrices: x, y. x has N rows of 1024 columns. Each row of x represents the 32x32 grey-scale pixels in raster order. y is a vector of N integers. Each element y[i] is the label of the i'th row of x. """ path = os.path.join(data_root(), 'shapeset1','shapeset1_1cspo_2_3.5000.valid.shape.amat') return _head(path, n) def head_test(n=5000): """Load the first Shapeset1 testing examples. Returns two matrices: x, y. x has N rows of 1024 columns. Each row of x represents the 32x32 grey-scale pixels in raster order. y is a vector of N integers. Each element y[i] is the label of the i'th row of x. """ path = os.path.join(data_root(), 'shapeset1','shapeset1_1cspo_2_3.5000.test.shape.amat') return _head(path, n) def train_valid_test(ntrain=10000, nvalid=5000, ntest=5000): return head_train(n=ntrain), head_valid(n=nvalid), head_test(n=ntest)