# HG changeset patch # User James Bergstra # Date 1269035723 14400 # Node ID 5cb947647432c49ccb8f306d53ff0ff90e286cc2 # Parent 519e82748a553b35de3078d3defe4e1b9e765de2 adding majorminer dataset diff -r 519e82748a55 -r 5cb947647432 pylearn/dataset_ops/majorminer.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/pylearn/dataset_ops/majorminer.py Fri Mar 19 17:55:23 2010 -0400 @@ -0,0 +1,65 @@ +from __future__ import absolute_import + +import os +import numpy + +import theano +import theano.sparse +import scipy.sparse + +from ..datasets.majorminer import Meta + +_meta = None + +class MajorMiner(theano.Op): + """Meta-information of major-miner dataset""" + + def __init__(self, meta=None): + global _meta + # on construction we make sure a *global* configuration is set + # this is done because self.* might get pickled and we don't want to pickle + # the whole dataset + if _meta is None: + if meta is None: _meta = Meta() + else: _meta = meta + else: + if meta is None: pass # no problem, we use global _meta + else: raise NotImplementedError('global MajorMiner meta-information already set') + + def __eq__(self, other): + return type(self) == type(other) + def __hash__(self): + return hash(type(self)) + + def make_node(self, idx): + _idx = theano.tensor.as_tensor_variable(idx, ndim=1) + return theano.Apply(self, + [_idx], + [theano.sparse.csr_matrix('MajorMiner.tag_counts'), + theano.generic('MajorMiner.:track_path')]) + def perform(self, node, (idx,), out_storage): + global _meta + lil = scipy.sparse.lil_matrix((len(idx), len(_meta.tags)), dtype='int8') + tracks = [] + for j,i in enumerate(idx): + for tag_id, count in _meta.track_tags[i]: + lil[j,tag_id] = count + tracks.append(_meta.tracks[i]) + + out_storage[0][0] = lil.tocsr() + out_storage[1][0] = tracks + + def grad(self, inputs, output): + return [None for i in inputs] + + +def test_basic(): + a = theano.tensor.lvector() + f = theano.function([a], MajorMiner()(a)) + print 'f([0]):', f([0]) + rval_0_1 = f([0,1]) + rval_0_8 = f([0,8]) + + assert rval_0_1[1][0] == rval_0_8[1][0] #compare strings + assert rval_0_1[1][1] != rval_0_8[1][1] #track 1 != track 8 + diff -r 519e82748a55 -r 5cb947647432 pylearn/datasets/majorminer.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/pylearn/datasets/majorminer.py Fri Mar 19 17:55:23 2010 -0400 @@ -0,0 +1,78 @@ +""" +Load the MajorMiner dataset +""" + +import logging, os +from .config import data_root +_logger = logging.getLogger('pylearn.datasets.majorminer') + +def three_column(tagfile=None, trackroot=None, expected_tagfile_len=51556): + """Load meta-information of major-miner dataset + + Data is stored as a three-column file: + + + + This function returns the parsed file as a list of 3-tuples. + + """ + if tagfile is None: + tagfile = os.path.join(data_root(), 'majorminer', 'three_column.txt') + _logger.info('Majorminer loading %s'%tagfile) + + if trackroot is None: + trackroot = os.path.join(data_root(), 'majorminer') + _logger.info('Majorminer using trackroot %s'%tagfile) + + tag_count_track = [] + + for line in open(tagfile): + if line: + tag, count, track = line.split('\t') + tag_count_track.append((tag, int(count), os.path.join(trackroot, track))) + + if expected_tagfile_len: + if len(tag_count_track) != expected_tagfile_len: + raise Exception('Wrong number of files listed') + + return tag_count_track + +def list_tracks(three_col): + tracks = list(set(tup[2] for tup in three_col)) + tracks.sort() + return tracks + +def list_tags(three_col): + tags = list(set(tup[0] for tup in three_col)) + tags.sort() + return tags + +def track_tags(three_col, tracks, tags): + """Return the count of each tag for each track + [ [(tag_id, count), (tag_id, count), ...], <---- for tracks[0] + [(tag_id, count), (tag_id, count), ...], <---- for tracks[1] + ... + ] + """ + tag_id = dict(((t,i) for i,t in enumerate(tags))) + track_id = dict(((t,i) for i,t in enumerate(tracks))) + rval = [[] for t in tracks] + for tag, count, track in three_col: + rval[track_id[track]].append((tag_id[tag], count)) + return rval + + + +class Meta(object): + def __init__(self, tagfile=None, trackroot=None, expected_tagfile_len=51556): + self.three_column = three_column(tagfile, trackroot, expected_tagfile_len) + self.tracks = list_tracks(self.three_column) + self.tags = list_tags(self.three_column) + self.track_tags = track_tags(self.three_column, self.tracks, self.tags) + + _logger.info('MajorMiner meta-information: %i tracks, %i tags' %( + len(self.tracks), len(self.tags))) + + #for tt in self.track_tags: + # print tt +