comparison pylearn/datasets/flickr.py @ 602:28f7dc848efc

fixed flickr relpath mistake
author James Bergstra <bergstrj@iro.umontreal.ca>
date Wed, 14 Jan 2009 17:22:23 -0500
parents fd95ff96dd47
children f6c74f34cd35
comparison
equal deleted inserted replaced
601:fd95ff96dd47 602:28f7dc848efc
5 5
6 import os 6 import os
7 import numpy 7 import numpy
8 8
9 from ..io import filetensor 9 from ..io import filetensor
10 from .config import data_root 10 if 0:
11 from .config import data_root
12 else:
13 def data_root():
14 return '/u/lisa/db/flickr/filetensor'
11 from .dataset import Dataset 15 from .dataset import Dataset
12 16
13 17
14 def test_10class(): 18 path_test_10class ='flickr_10classes_test.ft'
15 #TODO: make path an option,
16 #TODO: make default path relative to data_root()
17 f = open('flickr_10classes_test.ft')
18 return filetensor.read(f)
19 19
20 def train_10class(): 20 path_train_10class = 'flickr_10classes_train.ft'
21 #TODO: make path an option,
22 #TODO: make default path relative to data_root()
23 f = open('flickr_10classes_train.ft')
24 return filetensor.read(f)
25 21
26 def valid_10class(): 22 path_valid_10class = 'flickr_10classes_valid.ft'
27 #TODO: make path an option,
28 #TODO: make default path relative to data_root()
29 f = open('flickr_10classes_valid.ft')
30 return filetensor.read(f)
31 23
32 def basic_10class(): 24 def basic_10class(folder = None):
33 """Return the basic flickr image classification problem. 25 """Return the basic flickr image classification problem.
34 The images are 75x75, and there are 7500 training examples. 26 The images are 75x75, and there are 7500 training examples.
35 """ 27 """
36 train = train_10class() 28 root = data_root() if folder is None else folder
37 valid = valid_10class() 29 train = filetensor.read(open(os.path.join(root, path_train_10class)))
38 test = test_10class() 30 valid = filetensor.read(open(os.path.join(root, path_valid_10class)))
31 test = filetensor.read(open(os.path.join(root, path_test_10class)))
39 32
40 rval = Dataset() 33 rval = Dataset()
41 34
42 rval.train = Dataset.Obj( 35 rval.train = Dataset.Obj(
43 x=train[:, 0:-1], 36 x=train[:, 0:-1],
49 x=test[:, 0:-1], 42 x=test[:, 0:-1],
50 y=numpy.asarray(test[:, -1], dtype='int64')) 43 y=numpy.asarray(test[:, -1], dtype='int64'))
51 44
52 rval.n_classes = 10 45 rval.n_classes = 10
53 rval.img_shape = (75,75) 46 rval.img_shape = (75,75)
47
54 return rval 48 return rval
55 49
56 def translations_10class(): 50 def translations_10class():
57 raise NotImplementedError('TODO') 51 raise NotImplementedError('TODO')
58 52