comparison datasets/shapeset1.py @ 509:6fe692b93b69

added shapeset1
author James Bergstra <bergstrj@iro.umontreal.ca>
date Thu, 30 Oct 2008 11:27:04 -0400
parents
children
comparison
equal deleted inserted replaced
508:60b7dd5be860 509:6fe692b93b69
1 """
2 Routines to load/access Shapeset1
3 """
4
5 from __future__ import absolute_import
6
7 import os
8 import numpy
9
10 from ..amat import AMat
11 from .config import data_root
12
13 def _head(path, n):
14 dat = AMat(path=path, head=n)
15
16 try:
17 assert dat.input.shape[0] == n
18 assert dat.target.shape[0] == n
19 except Exception , e:
20 raise Exception("failed to read %i lines from file %s" % (n, path))
21
22 return dat.input, numpy.asarray(dat.target, dtype='int64').reshape(dat.target.shape[0])
23
24
25 def head_train(n=10000):
26 """Load the first Shapeset1 training examples.
27
28 Returns two matrices: x, y. x has N rows of 1024 columns. Each row of x represents the
29 32x32 grey-scale pixels in raster order. y is a vector of N integers. Each element y[i]
30 is the label of the i'th row of x.
31
32 """
33 path = os.path.join(data_root(), 'shapeset1','shapeset1_1cspo_2_3.10000.train.shape.amat')
34 return _head(path, n)
35
36 def head_valid(n=5000):
37 """Load the first Shapeset1 validation examples.
38
39 Returns two matrices: x, y. x has N rows of 1024 columns. Each row of x represents the
40 32x32 grey-scale pixels in raster order. y is a vector of N integers. Each element y[i]
41 is the label of the i'th row of x.
42
43 """
44 path = os.path.join(data_root(), 'shapeset1','shapeset1_1cspo_2_3.5000.valid.shape.amat')
45 return _head(path, n)
46
47 def head_test(n=5000):
48 """Load the first Shapeset1 testing examples.
49
50 Returns two matrices: x, y. x has N rows of 1024 columns. Each row of x represents the
51 32x32 grey-scale pixels in raster order. y is a vector of N integers. Each element y[i]
52 is the label of the i'th row of x.
53
54 """
55 path = os.path.join(data_root(), 'shapeset1','shapeset1_1cspo_2_3.5000.test.shape.amat')
56 return _head(path, n)
57
58 def train_valid_test(ntrain=10000, nvalid=5000, ntest=5000):
59 return head_train(n=ntrain), head_valid(n=nvalid), head_test(n=ntest)
60
61