changeset 916:a9b043c032ea

added new AudioRead op to io.audio
author James Bergstra <bergstrj@iro.umontreal.ca>
date Fri, 19 Mar 2010 23:30:50 -0400
parents 5cb947647432
children 09212b8a1edd
files pylearn/io/audio.py
diffstat 1 files changed, 134 insertions(+), 2 deletions(-) [+]
line wrap: on
line diff
--- a/pylearn/io/audio.py	Fri Mar 19 17:55:23 2010 -0400
+++ b/pylearn/io/audio.py	Fri Mar 19 23:30:50 2010 -0400
@@ -1,8 +1,28 @@
-
+import subprocess 
 import numpy
 import theano
 
 from wavread import WavRead, wav_read_int16, wav_read_double
+import mad
+
+def gen_mp3(madfile, dtype, scale):
+    printed = False
+
+    while True:
+        b = madfile.read()
+        if b is None:
+            break
+        b = numpy.frombuffer(b, dtype='int16')
+        #print len(b), b.min(), b.max()
+        if not printed:
+            bb = b.reshape((len(b)/2,2))
+            print bb[1000:1020]
+            #print 'first 10 mp3samples', b[:10]
+            #print b[:10] * (1.0 / 2**15)
+            printed = True
+        n = len(b)
+        assert not (n%2)
+        yield scale*numpy.asarray(b, dtype=dtype).reshape((n/2, 2)) #cast and reshape
 
 try: #define audioread and company only if pygmy.audio can be imported
     import pygmy.audio
@@ -42,6 +62,7 @@
         def make_node(self, path):
             out_type = theano.tensor.dvector if self.mono else theano.tensor.dmatrix
             return theano.Apply(self, [path], [out_type(), theano.tensor.dscalar()])
+
         def perform(self, node, (path,), (data_storage, sr_storage)):
             data, sr, dz = pygmy.audio.audioread(path, 
                     mono=self.mono, 
@@ -65,5 +86,116 @@
     audioread = AudioRead()
     audioread_mono = AudioRead(mono=True)
 except ImportError:
-    pass
+
+    class AudioRead(theano.Op):
+        #TODO: add the samplerate as an output
+        """Read an mp3 (other formats not implemented yet)
+
+        Depends on 'madplay' being on system path.
+
+        input - filename
+        output - the contents of the audiofile in pcm format
+        
+        """
+        def __init__(self, channels=2, sr=22050, dtype=theano.config.floatX):
+            """
+            :param channels: output this many channels
+            :param sr: output will be encoded at this samplerate
+            :param dtype: output will have this dtype
+            """
+            self.dtype = dtype
+            if dtype not in ('float32', 'float64', 'int16'):
+                raise NotImplementedError('dtype', dtype)
+            self.channels = channels
+            self.sr = sr
+
+        def __eq__(self, other):
+            return (type(self) == type(other)) and self.dtype == other.dtype \
+                    and self.channels == other.channels and self.sr == other.sr
+
+        def __hash__(self):
+            return hash(type(self)) ^ hash(self.dtype) ^ hash(self.channels) ^ hash(self.sr)
+
+        def make_node(self, path):
+            bcast = (False,) *self.channels
+            otype = theano.tensor.TensorType(broadcastable=bcast, dtype=self.dtype)
+            return theano.Apply(self, [path], [otype(),])
+
+        def perform(self, node, (path,), (data_storage, )):
+            if path.upper().endswith('.MP3'):
+                cmd = ['madplay']
+                cmd.extend(['--sample-rate', str(self.sr)])
+                cmd.extend(['-o', 'raw:/dev/stdout'])
+                cmd.extend(['-d',])
+                if self.channels==1:
+                    cmd.extend(['--mono'])
+                elif self.channels==2:
+                    cmd.extend(['--stereo'])
+                else:
+                    raise NotImplementedError("weird number of channels", self.channels)
+                cmd.append(path)
 
+                rawdata = subprocess.Popen(cmd, stderr=subprocess.PIPE, stdout=subprocess.PIPE).communicate()[0]
+                int16samples= numpy.frombuffer(rawdata, dtype='int16')
+                if self.dtype == 'float32':
+                    typedsamples = numpy.asarray(int16samples, dtype='float32') / numpy.float32(2**15)
+                elif self.dtype == 'float64':
+                    typedsamples = int16samples * (1.0/2**15)
+                elif self.dtype == 'int16':
+                    typedsamples = int16samples
+                else:
+                    raise NotImplementedError()
+
+                if self.channels==2:
+                    typedsamples = typedsamples.reshape((len(typedsamples)/2,2))
+            else: 
+                #TODO: if extension is .wav use the 'wave' module in the stdlib
+                #      see test_audioread below for usage
+                raise NotImplementedError()
+
+            assert typedsamples.dtype == self.dtype
+            assert len(typedsamples.shape) == self.channels, (typedsamples.shape, self.channels)
+            data_storage[0] = typedsamples
+
+        def grad(self, inputs, g_output):
+            return [None for i in inputs]
+
+
+def test_audioread():
+    #
+    # Not really a unit test because it depends on files that are probably not around anymore.
+    # Still, the basic idea is to decode externally, and compare with wavread.
+    #
+
+    mp3path = "/home/bergstra/data/majorminer/mp3/Mono/Formica Blues/03 Slimcea Girl_003.20_003.30.mp3"
+
+    dstorage = [None]
+    AudioRead(channels=1, dtype='float32', sr=44100).perform(None, (mp3path,), (dstorage, ))
+    mp3samples = dstorage[0]
+
+    wavpath = "/home/bergstra/tmp/blah2.wav"
+    import wave, numpy
+    wavfile = wave.open(wavpath)
+    assert wavfile.getsampwidth()==2 # bytes
+    wavsamples = numpy.frombuffer(
+            wavfile.readframes(wavfile.getnframes()),
+            dtype='int16')
+    wavsamples = wavsamples.reshape((wavfile.getnframes(), wavfile.getnchannels()))
+    wavsamples_as_float = numpy.asarray(wavsamples, dtype='float32') / 2**15
+
+    print 'wavsamples 1000:1020:', wavsamples[1000:1020].mean(axis=1)
+    print 'mp3samples 1000:1020:', mp3samples[1000:1020]*2**15
+    print 'wavsample range', wavsamples.min(), wavsamples.max()
+    print 'mp3sample range', mp3samples.min(), mp3samples.max()
+
+    print mp3samples.shape, mp3samples.dtype
+    print wavsamples.shape, wavsamples.dtype
+
+    #assert mp3samples.shape == wavsamples.shape
+    #assert mp3samples.dtype == wavsamples_as_float.dtype
+
+    #print wavsamples_as_float[:5]
+    #print mp3samples[:5]
+
+
+