changeset 671:9e62fd6b6677

adding wavread and tzanetakis dataset
author James Bergstra <bergstrj@iro.umontreal.ca>
date Mon, 30 Mar 2009 19:51:13 -0400
parents 63bcc7024378
children 8fff4bc26f4c f3b7d6956209
files pylearn/datasets/test_tzanetakis.py pylearn/datasets/tzanetakis.py pylearn/io/wavread.py
diffstat 3 files changed, 118 insertions(+), 0 deletions(-) [+]
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/pylearn/datasets/test_tzanetakis.py	Mon Mar 30 19:51:13 2009 -0400
@@ -0,0 +1,22 @@
+import theano
+
+from pylearn.io import wavread
+from pylearn.datasets import tzanetakis
+
+def test_tzanetakis():
+    idx = theano.tensor.lscalar()
+
+    path, label = tzanetakis.tzanetakis_example(idx)
+    print path, label
+
+    f = theano.function([idx], [path, label])
+
+    for i in xrange(len(tzanetakis.tzanetakis_example)):
+        print i, f(i)
+
+    wav,sr = wavread.wav_read_int16(path)
+
+    f = theano.function([idx], wav)
+    for i in xrange(len(tzanetakis.tzanetakis_example)):
+        print i, f(i).shape
+
--- a/pylearn/datasets/tzanetakis.py	Mon Mar 30 16:15:24 2009 -0400
+++ b/pylearn/datasets/tzanetakis.py	Mon Mar 30 19:51:13 2009 -0400
@@ -58,4 +58,43 @@
 
     return rval
 
+import theano
 
+class TzanetakisExample(theano.Op):
+    @staticmethod
+    def read_tzanetakis_file():
+        tracklist = open(data_root() + '/tzanetakis/tracklist.txt')
+        path = []
+        label = []
+        for line in tracklist:
+            toks = line.split()
+            try:
+                path.append(toks[0])
+                label.append(toks[1])
+            except:
+                print 'BAD LINE IN TZANETAKIS TRACKLIST'
+                print line, toks
+                raise
+        assert len(path) == 1000
+        return path, label
+            
+    def __init__(self):
+        self.path, self.label = self.read_tzanetakis_file()
+
+    def __len__(self):
+        return len(self.path)
+
+    def make_node(self, idx):
+        return theano.Apply(self, 
+                [theano.tensor.as_tensor_variable(idx)],
+                [theano.generic(), theano.generic()])
+
+    def perform(self, node, (idx,), outputs):
+        assert len(outputs) == 2
+        outputs[0][0] = self.path[idx]
+        outputs[1][0] = self.label[idx]
+
+    def grad(self, inputs, g_output):
+        return [None for i in inputs]
+tzanetakis_example = TzanetakisExample()
+
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/pylearn/io/wavread.py	Mon Mar 30 19:51:13 2009 -0400
@@ -0,0 +1,57 @@
+"""`WavRead` Op"""
+__docformat__ = "restructuredtext en"
+
+import numpy
+import theano
+import wave
+
+class WavRead(theano.Op):
+    #TODO: add the samplerate as an output
+    """Read a wave file
+
+    input - the path to a wave file
+    output - the contents of the wave file in pcm format, and the samplerate
+    
+    """
+
+    out_type = None
+    """The type for the output of this op. 
+
+    Currently only wvector (aka int16) and dvector (aka double) are supported
+    """
+
+    def __init__(self, out_type):
+        self.out_type = out_type
+        if out_type not in [theano.tensor.dvector, theano.tensor.wvector]:
+            raise TypeError(out_type)
+    def __eq__(self, other):
+        return (type(self) == type(other)) and (self.out_type == other.out_type)
+    def __hash__(self):
+        return hash(type(self)) ^ hash(self.out_type)
+    def make_node(self, path):
+        return theano.Apply(self, [path], [self.out_type(), theano.tensor.dscalar()])
+    def perform(self, node, (path,), (out, sr)):
+        w = wave.open(path)
+
+        if w.getnchannels() != 1:
+            raise NotImplementedError()
+        if w.getsampwidth() != 2: #2 bytes means 16bit samples
+            raise NotImplementedError()
+
+        samples = numpy.frombuffer(w.readframes(w.getnframes()), dtype='int16')
+
+        if self.out_type == theano.tensor.wvector:
+            out[0] = samples
+        elif self.out_type == theano.tensor.dvector:
+            out[0] = samples * (1.0 / 2**15)
+        else:
+            raise NotImplementedError()
+
+        sr[0] = w.getframerate()
+
+    def grad(self, inputs, g_output):
+        return [None for i in inputs]
+
+wav_read_int16 = WavRead(theano.tensor.wvector)
+wav_read_double = WavRead(theano.tensor.dvector)
+