Mercurial > pylearn
comparison pylearn/datasets/flickr.py @ 602:28f7dc848efc
fixed flickr relpath mistake
author | James Bergstra <bergstrj@iro.umontreal.ca> |
---|---|
date | Wed, 14 Jan 2009 17:22:23 -0500 |
parents | fd95ff96dd47 |
children | f6c74f34cd35 |
comparison
equal
deleted
inserted
replaced
601:fd95ff96dd47 | 602:28f7dc848efc |
---|---|
5 | 5 |
6 import os | 6 import os |
7 import numpy | 7 import numpy |
8 | 8 |
9 from ..io import filetensor | 9 from ..io import filetensor |
10 from .config import data_root | 10 if 0: |
11 from .config import data_root | |
12 else: | |
13 def data_root(): | |
14 return '/u/lisa/db/flickr/filetensor' | |
11 from .dataset import Dataset | 15 from .dataset import Dataset |
12 | 16 |
13 | 17 |
14 def test_10class(): | 18 path_test_10class ='flickr_10classes_test.ft' |
15 #TODO: make path an option, | |
16 #TODO: make default path relative to data_root() | |
17 f = open('flickr_10classes_test.ft') | |
18 return filetensor.read(f) | |
19 | 19 |
20 def train_10class(): | 20 path_train_10class = 'flickr_10classes_train.ft' |
21 #TODO: make path an option, | |
22 #TODO: make default path relative to data_root() | |
23 f = open('flickr_10classes_train.ft') | |
24 return filetensor.read(f) | |
25 | 21 |
26 def valid_10class(): | 22 path_valid_10class = 'flickr_10classes_valid.ft' |
27 #TODO: make path an option, | |
28 #TODO: make default path relative to data_root() | |
29 f = open('flickr_10classes_valid.ft') | |
30 return filetensor.read(f) | |
31 | 23 |
32 def basic_10class(): | 24 def basic_10class(folder = None): |
33 """Return the basic flickr image classification problem. | 25 """Return the basic flickr image classification problem. |
34 The images are 75x75, and there are 7500 training examples. | 26 The images are 75x75, and there are 7500 training examples. |
35 """ | 27 """ |
36 train = train_10class() | 28 root = data_root() if folder is None else folder |
37 valid = valid_10class() | 29 train = filetensor.read(open(os.path.join(root, path_train_10class))) |
38 test = test_10class() | 30 valid = filetensor.read(open(os.path.join(root, path_valid_10class))) |
31 test = filetensor.read(open(os.path.join(root, path_test_10class))) | |
39 | 32 |
40 rval = Dataset() | 33 rval = Dataset() |
41 | 34 |
42 rval.train = Dataset.Obj( | 35 rval.train = Dataset.Obj( |
43 x=train[:, 0:-1], | 36 x=train[:, 0:-1], |
49 x=test[:, 0:-1], | 42 x=test[:, 0:-1], |
50 y=numpy.asarray(test[:, -1], dtype='int64')) | 43 y=numpy.asarray(test[:, -1], dtype='int64')) |
51 | 44 |
52 rval.n_classes = 10 | 45 rval.n_classes = 10 |
53 rval.img_shape = (75,75) | 46 rval.img_shape = (75,75) |
47 | |
54 return rval | 48 return rval |
55 | 49 |
56 def translations_10class(): | 50 def translations_10class(): |
57 raise NotImplementedError('TODO') | 51 raise NotImplementedError('TODO') |
58 | 52 |