changeset 690:7d8bb6d087bc

additions to datasets/tagatune
author James Bergstra <bergstrj@iro.umontreal.ca>
date Thu, 14 May 2009 16:59:20 -0400
parents 651eb6506d91
children e69249897f89
files pylearn/datasets/tagatune.py
diffstat 1 files changed, 43 insertions(+), 8 deletions(-) [+]
line wrap: on
line diff
--- a/pylearn/datasets/tagatune.py	Thu May 14 16:58:14 2009 -0400
+++ b/pylearn/datasets/tagatune.py	Thu May 14 16:59:20 2009 -0400
@@ -9,7 +9,9 @@
 import os
 import numpy
 
-from config import data_root
+import theano
+
+from .config import data_root
 
 def read_annotations_final(path):
     """Return a parsed (column-wise) representation of the tagatune/annotations_final.csv file
@@ -35,7 +37,7 @@
             #strip the leading and trailing '"' symbol from each token
             column_values = [tok[1:-1] for tok in line[:-2].split('\t')]
             assert len(column_values) == 190
-            clip_ids.append(column_values[0])
+            clip_ids.append(int(column_values[0]))
             mp3_paths.append(column_values[-1])
             # assert we didn't chop off too many chars
             assert column_values[-1].endswith('.mp3')
@@ -43,7 +45,8 @@
 
             # assert that the data is binary
             assert all(c in '01' for c in attributes_this_line)
-            attributes.append(attributes_this_line)
+            attributes.append(numpy.asarray([int(c) for c in attributes_this_line],
+            dtype='int8'))
 
     # assert that we read all the lines of the file
     assert len(clip_ids) == 25863
@@ -53,10 +56,42 @@
     attribute_names = column_names[1:-1] #all but clip_id and mp3_path
     return clip_ids, attributes, mp3_paths, attribute_names
 
+def cached_read_annotations_final(path):
+    if not hasattr(cached_read_annotations_final, 'rval'):
+        cached_read_annotations_final.rval = {}
+    if not path in cached_read_annotations_final.rval:
+        cached_read_annotations_final.rval[path] = read_annotations_final(path)
+    return cached_read_annotations_final.rval[path]
+
 def test_read_annotations_final():
-    return read_annotations_final(data_root() +'/tagatune/annotations_final.csv')
+    return read_annotations_final(data_root() + '/tagatune/annotations_final.csv')
 
-if __name__ == '__main__':
-    print 'starting'
-    test_read_annotations_final()
-    print 'done'
+class TagatuneExample(theano.Op):
+    """
+    input - index into tagatune database (not clip_id)
+    output - clip_id, attributes, path to clip's mp3 file
+    """
+    def __init__(self, music_dbs='/data/gamme/data/music_dbs'):
+        self.music_dbs = music_dbs
+        annotations_path = music_dbs + '/tagatune/annotations_final.csv'
+        self.clip_ids, self.attributes, self.mp3_paths, self.attribute_names =\
+                cached_read_annotations_final(annotations_path)
+
+    n_examples = property(lambda self: len(self.clip_ids))
+
+    def make_node(self, idx):
+        _idx = theano.tensor.as_tensor_variable(idx, ndim=0)
+        return theano.Apply(self, 
+                [_idx], 
+                [theano.tensor.lscalar('clip_id'),
+                    theano.tensor.bvector('clip_attributes'),
+                    theano.generic('clip_path')])
+    def perform(self, node, (idx,), out_storage):
+        out_storage[0][0] = self.clip_ids[idx]
+        out_storage[1][0] = self.attributes[idx]
+        out_storage[2][0] = self.music_dbs + '/tagatune/clips/mp3/' + self.mp3_paths[idx]
+
+    def grad(self, inputs, output):
+        return [None for i in inputs]
+
+#tagatune_example = TagatuneExample() #requires reading a big data file