view pylearn/dataset_ops/majorminer.py @ 998:8ba8b08e0442

added the image_patches dataset used in RanzatoHinton2010 modified mcRBM to use it.
author James Bergstra <bergstrj@iro.umontreal.ca>
date Tue, 24 Aug 2010 16:51:53 -0400
parents bb8ef344d0a9
children
line wrap: on
line source

from __future__ import absolute_import

import os
import numpy

import theano
import theano.sparse
import scipy.sparse

from ..datasets.majorminer import Meta

_meta = None

class MajorMiner(theano.Op):
    """Meta-information of major-miner dataset"""

    def __init__(self, meta=None):
        global _meta
        # on construction we make sure a *global* configuration is set
        # this is done because self.* might get pickled and we don't want to pickle
        # the whole dataset
        if _meta is None:
            if meta is None: _meta = Meta()
            else: _meta = meta
        else:
            if meta is None: pass # no problem, we use global _meta
            else: raise NotImplementedError('global MajorMiner meta-information already set')

    def __eq__(self, other):
        return type(self) == type(other)
    def __hash__(self):
        return hash(type(self))

    def make_node(self, idx):
        _idx = theano.tensor.as_tensor_variable(idx, ndim=0)
        return theano.Apply(self, 
                [_idx], 
                [theano.sparse.csr_matrix('MajorMiner.tag_counts'),
                 theano.generic('MajorMiner.track_path')])
    def perform(self, node, (idx,), out_storage):
        global _meta
        lil = scipy.sparse.lil_matrix((1, len(_meta.tags)), dtype='int8')

        for tag_id, count in _meta.track_tags[idx]:
            lil[0,tag_id] = count

        out_storage[0][0] = lil.tocsr()
        out_storage[1][0] = _meta.tracks[idx]

    def grad(self, inputs, output):
        return [None for i in inputs]


def test_basic():
    a = theano.tensor.lvector()
    f = theano.function([a], MajorMiner()(a))
    print 'f([0]):', f([0])
    rval_0_1 = f([0,1])
    rval_0_8 = f([0,8])

    assert rval_0_1[1][0] == rval_0_8[1][0] #compare strings
    assert rval_0_1[1][1] != rval_0_8[1][1] #track 1 != track 8