Mercurial > pylearn
comparison pylearn/datasets/shapeset1.py @ 617:5120bf7c4694
More complete version of shapeset1 dataset.
author | lamblinp@ip03.m |
---|---|
date | Sat, 17 Jan 2009 19:06:21 -0500 |
parents | 16f91ca016b1 |
children | 8aef46b42cb5 |
comparison
equal
deleted
inserted
replaced
616:d0f7a6f87adc | 617:5120bf7c4694 |
---|---|
7 import os | 7 import os |
8 import numpy | 8 import numpy |
9 | 9 |
10 from ..io.amat import AMat | 10 from ..io.amat import AMat |
11 from .config import data_root | 11 from .config import data_root |
12 from .dataset import Dataset | |
12 | 13 |
13 def _head(path, n): | 14 def _head(path, n): |
14 dat = AMat(path=path, head=n) | 15 dat = AMat(path=path, head=n) |
15 | 16 |
16 try: | 17 try: |
23 | 24 |
24 | 25 |
25 def head_train(n=10000): | 26 def head_train(n=10000): |
26 """Load the first Shapeset1 training examples. | 27 """Load the first Shapeset1 training examples. |
27 | 28 |
28 Returns two matrices: x, y. x has N rows of 1024 columns. Each row of x represents the | 29 Returns two matrices: x, y. |
29 32x32 grey-scale pixels in raster order. y is a vector of N integers. Each element y[i] | 30 x has N rows of 1024 columns. |
30 is the label of the i'th row of x. | 31 Each row of x represents the 32x32 grey-scale pixels in raster order. |
31 | 32 y is a vector of N integers between 0 and 2. |
33 Each element y[i] is the label of the i'th row of x. | |
32 """ | 34 """ |
33 path = os.path.join(data_root(), 'shapeset1','shapeset1_1cspo_2_3.10000.train.shape.amat') | 35 path = os.path.join(data_root(), 'shapeset1','shapeset1_1cspo_2_3.10000.train.shape.amat') |
34 return _head(path, n) | 36 return _head(path, n) |
35 | 37 |
36 def head_valid(n=5000): | 38 def head_valid(n=5000): |
37 """Load the first Shapeset1 validation examples. | 39 """Load the first Shapeset1 validation examples. |
38 | 40 |
39 Returns two matrices: x, y. x has N rows of 1024 columns. Each row of x represents the | 41 Returns two matrices: x, y. |
40 32x32 grey-scale pixels in raster order. y is a vector of N integers. Each element y[i] | 42 x has N rows of 1024 columns. |
41 is the label of the i'th row of x. | 43 Each row of x represents the 32x32 grey-scale pixels in raster order. |
42 | 44 y is a vector of N integers between 0 and 2. |
45 Each element y[i] is the label of the i'th row of x. | |
43 """ | 46 """ |
44 path = os.path.join(data_root(), 'shapeset1','shapeset1_1cspo_2_3.5000.valid.shape.amat') | 47 path = os.path.join(data_root(), 'shapeset1','shapeset1_1cspo_2_3.5000.valid.shape.amat') |
45 return _head(path, n) | 48 return _head(path, n) |
46 | 49 |
47 def head_test(n=5000): | 50 def head_test(n=5000): |
48 """Load the first Shapeset1 testing examples. | 51 """Load the first Shapeset1 testing examples. |
49 | 52 |
50 Returns two matrices: x, y. x has N rows of 1024 columns. Each row of x represents the | 53 Returns two matrices: x, y. |
51 32x32 grey-scale pixels in raster order. y is a vector of N integers. Each element y[i] | 54 x has N rows of 1024 columns. |
52 is the label of the i'th row of x. | 55 Each row of x represents the 32x32 grey-scale pixels in raster order. |
53 | 56 y is a vector of N integers between 0 and 2. |
57 Each element y[i] is the label of the i'th row of x. | |
54 """ | 58 """ |
55 path = os.path.join(data_root(), 'shapeset1','shapeset1_1cspo_2_3.5000.test.shape.amat') | 59 path = os.path.join(data_root(), 'shapeset1','shapeset1_1cspo_2_3.5000.test.shape.amat') |
56 return _head(path, n) | 60 return _head(path, n) |
57 | 61 |
58 def train_valid_test(ntrain=10000, nvalid=5000, ntest=5000): | 62 def train_valid_test(ntrain=10000, nvalid=5000, ntest=5000): |
59 return head_train(n=ntrain), head_valid(n=nvalid), head_test(n=ntest) | 63 train_x, train_y = head_train(n=ntrain) |
64 valid_x, valid_y = head_valid(n=nvalid) | |
65 test_x, test_y = head_test(n=test) | |
66 | |
67 rval = Dataset() | |
68 rval.train = Dataset.Obj(x = train_x, y = train_y) | |
69 rval.valid = Dataset.Obj(x = valid_x, y = valid_y) | |
70 rval.test = Dataset.Obj(x = test_x, y = test_y) | |
71 | |
72 rval.n_classes = 3 | |
73 rval.img_shape = (32, 32) | |
74 | |
75 return rval | |
60 | 76 |
61 | 77 |