changeset 919:3901d06e2d96

using madplay in audio.AudioRead
author James Bergstra <bergstrj@iro.umontreal.ca>
date Sat, 20 Mar 2010 15:18:54 -0400
parents bb8ef344d0a9
children a5c33f01c9a4
files pylearn/io/audio.py
diffstat 1 files changed, 120 insertions(+), 116 deletions(-) [+]
line wrap: on
line diff
--- a/pylearn/io/audio.py	Fri Mar 19 23:32:23 2010 -0400
+++ b/pylearn/io/audio.py	Sat Mar 20 15:18:54 2010 -0400
@@ -1,4 +1,4 @@
-import subprocess 
+import subprocess, sys
 import numpy
 import theano
 
@@ -24,7 +24,125 @@
         assert not (n%2)
         yield scale*numpy.asarray(b, dtype=dtype).reshape((n/2, 2)) #cast and reshape
 
-try: #define audioread and company only if pygmy.audio can be imported
+class AudioRead(theano.Op):
+    #TODO: add the samplerate as an output
+    """Read an mp3 (other formats not implemented yet)
+
+    Depends on 'madplay' being on system path.
+
+    input - filename
+    output - the contents of the audiofile in pcm format
+    
+    """
+    def __init__(self, channels=2, sr=22050, dtype=theano.config.floatX):
+        """
+        :param channels: output this many channels
+        :param sr: output will be encoded at this samplerate
+        :param dtype: output will have this dtype
+        """
+        self.dtype = dtype
+        if dtype not in ('float32', 'float64', 'int16'):
+            raise NotImplementedError('dtype', dtype)
+        self.channels = channels
+        self.sr = sr
+
+    def __eq__(self, other):
+        return (type(self) == type(other)) and self.dtype == other.dtype \
+                and self.channels == other.channels and self.sr == other.sr
+
+    def __hash__(self):
+        return hash(type(self)) ^ hash(self.dtype) ^ hash(self.channels) ^ hash(self.sr)
+
+    def make_node(self, path):
+        bcast = (False,) *self.channels
+        otype = theano.tensor.TensorType(broadcastable=bcast, dtype=self.dtype)
+        return theano.Apply(self, [path], [otype(),])
+
+    def perform(self, node, (path,), (data_storage, )):
+        if path.upper().endswith('.MP3'):
+            cmd = ['madplay']
+            cmd.extend(['--sample-rate', str(self.sr)])
+            cmd.extend(['-o', 'raw:/dev/stdout'])
+            cmd.extend(['-d',])
+            if self.channels==1:
+                cmd.extend(['--mono'])
+            elif self.channels==2:
+                cmd.extend(['--stereo'])
+            else:
+                raise NotImplementedError("weird number of channels", self.channels)
+            cmd.append(path)
+
+            proc = subprocess.Popen(cmd, stderr=subprocess.PIPE, stdout=subprocess.PIPE)
+            proc_stdout, proc_stderr = proc.communicate()
+            assert proc.returncode is not None # process should be finished
+            if proc.returncode:
+                print >> sys.stderr, proc_stderr
+                raise Exception('cmd %s returned code %i'%(' '.join(cmd),proc.returncode))
+
+            int16samples= numpy.frombuffer(proc_stdout, dtype='int16')
+            if self.dtype == 'float32':
+                typedsamples = numpy.asarray(int16samples, dtype='float32') / numpy.float32(2**15)
+            elif self.dtype == 'float64':
+                typedsamples = int16samples * (1.0/2**15)
+            elif self.dtype == 'int16':
+                typedsamples = int16samples
+            else:
+                raise NotImplementedError()
+
+            if self.channels==2:
+                typedsamples = typedsamples.reshape((len(typedsamples)/2,2))
+        else: 
+            #TODO: if extension is .wav use the 'wave' module in the stdlib
+            #      see test_audioread below for usage
+            raise NotImplementedError()
+
+        assert typedsamples.dtype == self.dtype
+        assert len(typedsamples.shape) == self.channels, (typedsamples.shape, self.channels)
+        data_storage[0] = typedsamples
+
+    def grad(self, inputs, g_output):
+        return [None for i in inputs]
+
+
+def test_audioread():
+    #
+    # Not really a unit test because it depends on files that are probably not around anymore.
+    # Still, the basic idea is to decode externally, and compare with wavread.
+    #
+
+    mp3path = "/home/bergstra/data/majorminer/mp3/Mono/Formica Blues/03 Slimcea Girl_003.20_003.30.mp3"
+
+    dstorage = [None]
+    AudioRead(channels=1, dtype='float32', sr=44100).perform(None, (mp3path,), (dstorage, ))
+    mp3samples = dstorage[0]
+
+    wavpath = "/home/bergstra/tmp/blah2.wav"
+    import wave, numpy
+    wavfile = wave.open(wavpath)
+    assert wavfile.getsampwidth()==2 # bytes
+    wavsamples = numpy.frombuffer(
+            wavfile.readframes(wavfile.getnframes()),
+            dtype='int16')
+    wavsamples = wavsamples.reshape((wavfile.getnframes(), wavfile.getnchannels()))
+    wavsamples_as_float = numpy.asarray(wavsamples, dtype='float32') / 2**15
+
+    print 'wavsamples 1000:1020:', wavsamples[1000:1020].mean(axis=1)
+    print 'mp3samples 1000:1020:', mp3samples[1000:1020]*2**15
+    print 'wavsample range', wavsamples.min(), wavsamples.max()
+    print 'mp3sample range', mp3samples.min(), mp3samples.max()
+
+    print mp3samples.shape, mp3samples.dtype
+    print wavsamples.shape, wavsamples.dtype
+
+    #assert mp3samples.shape == wavsamples.shape
+    #assert mp3samples.dtype == wavsamples_as_float.dtype
+
+    #print wavsamples_as_float[:5]
+    #print mp3samples[:5]
+
+
+
+if 0: ### OLD CODE USING PYGMY
     import pygmy.audio
 
     class AudioRead(theano.Op):
@@ -85,117 +203,3 @@
 
     audioread = AudioRead()
     audioread_mono = AudioRead(mono=True)
-except ImportError:
-
-    class AudioRead(theano.Op):
-        #TODO: add the samplerate as an output
-        """Read an mp3 (other formats not implemented yet)
-
-        Depends on 'madplay' being on system path.
-
-        input - filename
-        output - the contents of the audiofile in pcm format
-        
-        """
-        def __init__(self, channels=2, sr=22050, dtype=theano.config.floatX):
-            """
-            :param channels: output this many channels
-            :param sr: output will be encoded at this samplerate
-            :param dtype: output will have this dtype
-            """
-            self.dtype = dtype
-            if dtype not in ('float32', 'float64', 'int16'):
-                raise NotImplementedError('dtype', dtype)
-            self.channels = channels
-            self.sr = sr
-
-        def __eq__(self, other):
-            return (type(self) == type(other)) and self.dtype == other.dtype \
-                    and self.channels == other.channels and self.sr == other.sr
-
-        def __hash__(self):
-            return hash(type(self)) ^ hash(self.dtype) ^ hash(self.channels) ^ hash(self.sr)
-
-        def make_node(self, path):
-            bcast = (False,) *self.channels
-            otype = theano.tensor.TensorType(broadcastable=bcast, dtype=self.dtype)
-            return theano.Apply(self, [path], [otype(),])
-
-        def perform(self, node, (path,), (data_storage, )):
-            if path.upper().endswith('.MP3'):
-                cmd = ['madplay']
-                cmd.extend(['--sample-rate', str(self.sr)])
-                cmd.extend(['-o', 'raw:/dev/stdout'])
-                cmd.extend(['-d',])
-                if self.channels==1:
-                    cmd.extend(['--mono'])
-                elif self.channels==2:
-                    cmd.extend(['--stereo'])
-                else:
-                    raise NotImplementedError("weird number of channels", self.channels)
-                cmd.append(path)
-
-                rawdata = subprocess.Popen(cmd, stderr=subprocess.PIPE, stdout=subprocess.PIPE).communicate()[0]
-                int16samples= numpy.frombuffer(rawdata, dtype='int16')
-                if self.dtype == 'float32':
-                    typedsamples = numpy.asarray(int16samples, dtype='float32') / numpy.float32(2**15)
-                elif self.dtype == 'float64':
-                    typedsamples = int16samples * (1.0/2**15)
-                elif self.dtype == 'int16':
-                    typedsamples = int16samples
-                else:
-                    raise NotImplementedError()
-
-                if self.channels==2:
-                    typedsamples = typedsamples.reshape((len(typedsamples)/2,2))
-            else: 
-                #TODO: if extension is .wav use the 'wave' module in the stdlib
-                #      see test_audioread below for usage
-                raise NotImplementedError()
-
-            assert typedsamples.dtype == self.dtype
-            assert len(typedsamples.shape) == self.channels, (typedsamples.shape, self.channels)
-            data_storage[0] = typedsamples
-
-        def grad(self, inputs, g_output):
-            return [None for i in inputs]
-
-
-def test_audioread():
-    #
-    # Not really a unit test because it depends on files that are probably not around anymore.
-    # Still, the basic idea is to decode externally, and compare with wavread.
-    #
-
-    mp3path = "/home/bergstra/data/majorminer/mp3/Mono/Formica Blues/03 Slimcea Girl_003.20_003.30.mp3"
-
-    dstorage = [None]
-    AudioRead(channels=1, dtype='float32', sr=44100).perform(None, (mp3path,), (dstorage, ))
-    mp3samples = dstorage[0]
-
-    wavpath = "/home/bergstra/tmp/blah2.wav"
-    import wave, numpy
-    wavfile = wave.open(wavpath)
-    assert wavfile.getsampwidth()==2 # bytes
-    wavsamples = numpy.frombuffer(
-            wavfile.readframes(wavfile.getnframes()),
-            dtype='int16')
-    wavsamples = wavsamples.reshape((wavfile.getnframes(), wavfile.getnchannels()))
-    wavsamples_as_float = numpy.asarray(wavsamples, dtype='float32') / 2**15
-
-    print 'wavsamples 1000:1020:', wavsamples[1000:1020].mean(axis=1)
-    print 'mp3samples 1000:1020:', mp3samples[1000:1020]*2**15
-    print 'wavsample range', wavsamples.min(), wavsamples.max()
-    print 'mp3sample range', mp3samples.min(), mp3samples.max()
-
-    print mp3samples.shape, mp3samples.dtype
-    print wavsamples.shape, wavsamples.dtype
-
-    #assert mp3samples.shape == wavsamples.shape
-    #assert mp3samples.dtype == wavsamples_as_float.dtype
-
-    #print wavsamples_as_float[:5]
-    #print mp3samples[:5]
-
-
-