# HG changeset patch # User James Bergstra # Date 1238629712 14400 # Node ID f3b7d69562092a4ff0fbacca594f0602168dd54c # Parent 9e62fd6b6677586a4a75f53f722196072fa891fa changes to tzanetakis and wavread diff -r 9e62fd6b6677 -r f3b7d6956209 pylearn/datasets/tzanetakis.py --- a/pylearn/datasets/tzanetakis.py Mon Mar 30 19:51:13 2009 -0400 +++ b/pylearn/datasets/tzanetakis.py Wed Apr 01 19:48:32 2009 -0400 @@ -61,8 +61,13 @@ import theano class TzanetakisExample(theano.Op): + """Return the i'th file, label pair from the Tzanetakis dataset.""" @staticmethod def read_tzanetakis_file(): + """Read the tzanetakis dataset file + :rtype: (list, list) + :returns: paths, labels + """ tracklist = open(data_root() + '/tzanetakis/tracklist.txt') path = [] label = [] @@ -77,24 +82,45 @@ raise assert len(path) == 1000 return path, label + + nclasses = 10 + + class_idx_dict = dict(blues=numpy.asarray(0), + classical=1, + country=2, + disco=3, + hiphop=4, + jazz=5, + metal=6, + pop=7, + reggae=8, + rock=9) def __init__(self): self.path, self.label = self.read_tzanetakis_file() + self.class_idx_dict = {} + classes = ('blues classical country disco hiphop jazz metal pop reggae rock').split() + for i, c in enumerate(classes): + self.class_idx_dict[c] = numpy.asarray(i, dtype='int64') def __len__(self): return len(self.path) def make_node(self, idx): + idx_ = theano.tensor.as_tensor_variable(idx) + if idx_.type not in theano.tensor.int_types: + raise TypeError(idx) return theano.Apply(self, - [theano.tensor.as_tensor_variable(idx)], - [theano.generic(), theano.generic()]) + [idx_], + [theano.generic('tzanetakis_path'), + theano.tensor.lscalar('tzanetakis_label')]) - def perform(self, node, (idx,), outputs): - assert len(outputs) == 2 - outputs[0][0] = self.path[idx] - outputs[1][0] = self.label[idx] + def perform(self, node, (idx,), (path, label)): + path[0] = self.path[idx] + label[0] = self.class_idx_dict[self.label[idx]] def grad(self, inputs, g_output): return [None for i in inputs] + tzanetakis_example = TzanetakisExample() diff -r 9e62fd6b6677 -r f3b7d6956209 pylearn/io/wavread.py --- a/pylearn/io/wavread.py Mon Mar 30 19:51:13 2009 -0400 +++ b/pylearn/io/wavread.py Wed Apr 01 19:48:32 2009 -0400 @@ -47,7 +47,7 @@ else: raise NotImplementedError() - sr[0] = w.getframerate() + sr[0] = numpy.asarray(w.getframerate(),dtype='float64') def grad(self, inputs, g_output): return [None for i in inputs]