changeset 602:28f7dc848efc

fixed flickr relpath mistake
author James Bergstra <bergstrj@iro.umontreal.ca>
date Wed, 14 Jan 2009 17:22:23 -0500
parents fd95ff96dd47
children 52a99d83f06d 20953adfdef8
files pylearn/datasets/flickr.py
diffstat 1 files changed, 15 insertions(+), 21 deletions(-) [+]
line wrap: on
line diff
--- a/pylearn/datasets/flickr.py	Wed Jan 14 17:00:57 2009 -0500
+++ b/pylearn/datasets/flickr.py	Wed Jan 14 17:22:23 2009 -0500
@@ -7,35 +7,28 @@
 import numpy
 
 from ..io import filetensor
-from .config import data_root
+if 0:
+    from .config import data_root
+else:
+    def data_root():
+        return '/u/lisa/db/flickr/filetensor'
 from .dataset import Dataset
 
 
-def test_10class():
-    #TODO: make path an option,
-    #TODO: make default path relative to data_root()
-    f = open('flickr_10classes_test.ft')
-    return filetensor.read(f)
+path_test_10class ='flickr_10classes_test.ft'
+
+path_train_10class = 'flickr_10classes_train.ft'
 
-def train_10class():
-    #TODO: make path an option,
-    #TODO: make default path relative to data_root()
-    f = open('flickr_10classes_train.ft')
-    return filetensor.read(f)
+path_valid_10class = 'flickr_10classes_valid.ft'
 
-def valid_10class():
-    #TODO: make path an option,
-    #TODO: make default path relative to data_root()
-    f = open('flickr_10classes_valid.ft')
-    return filetensor.read(f)
-
-def basic_10class():
+def basic_10class(folder = None):
     """Return the basic flickr image classification problem.
     The images are 75x75, and there are 7500 training examples.
     """
-    train = train_10class()
-    valid = valid_10class()
-    test = test_10class()
+    root = data_root() if folder is None else folder
+    train = filetensor.read(open(os.path.join(root, path_train_10class)))
+    valid = filetensor.read(open(os.path.join(root, path_valid_10class)))
+    test = filetensor.read(open(os.path.join(root, path_test_10class)))
 
     rval = Dataset()
 
@@ -51,6 +44,7 @@
 
     rval.n_classes = 10
     rval.img_shape = (75,75)
+
     return rval
 
 def translations_10class():