changeset 600:e56303df3c77

initial flickr
author James Bergstra <bergstrj@iro.umontreal.ca>
date Wed, 14 Jan 2009 15:54:39 -0500
parents bd777e960c7c
children fd95ff96dd47
files pylearn/datasets/flickr.py
diffstat 1 files changed, 61 insertions(+), 0 deletions(-) [+]
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/pylearn/datasets/flickr.py	Wed Jan 14 15:54:39 2009 -0500
@@ -0,0 +1,61 @@
+"""
+Routines to load variations on the Flickr image dataset.
+"""
+from __future__ import absolute_import
+
+import os
+import numpy
+
+from ..io import filetensor
+from .config import data_root
+from .dataset import Dataset
+
+
+def test_10class():
+    #TODO: make path an option,
+    #TODO: make default path relative to data_root()
+    f = open('flickr_10classes_test.ft')
+    data = filetensor.read(f)
+    return data.T.copy() #put in to one example per row, row major
+
+def train_10class():
+    #TODO: make path an option,
+    #TODO: make default path relative to data_root()
+    f = open('flickr_10classes_train.ft')
+    data = filetensor.read(f)
+    return data.T.copy() #put in to one example per row, row major
+
+def valid_10class():
+    #TODO: make path an option,
+    #TODO: make default path relative to data_root()
+    f = open('flickr_10classes_valid.ft')
+    data = filetensor.read(f)
+    return data.T.copy() #put in to one example per row, row major
+
+def basic_10class():
+    """Return the basic flickr image classification problem.
+    The images are 75x75, and there are 7500 training examples.
+    """
+    train = train_10class()
+    valid = valid_10class()
+    test = test_10class()
+
+    rval = Dataset()
+
+    rval.train = Dataset.Obj(
+            x=train[:, 0:-1],
+            y=numpy.asarray(train[:, -1], dtype='int64'))
+    rval.valid = Dataset.Obj(
+            x=valid[:, 0:-1],
+            y=numpy.asarray(valid[:, -1], dtype='int64'))
+    rval.test = Dataset.Obj(
+            x=test[:, 0:-1],
+            y=numpy.asarray(test[:, -1], dtype='int64'))
+
+    rval.n_classes = 10
+    rval.img_shape = (75,75)
+    return rval
+
+def translations_10class():
+    raise NotImplementedError('TODO')
+