view doc/v2_planning/shared_dataset.py @ 1435:3dd64c115657

revised version of pkldu that is a bit more structured code wise and outputs in human readable units
author Razvan Pascanu <r.pascanu@gmail.com>
date Tue, 22 Feb 2011 11:23:32 -0500
parents c1943feada10
children
line wrap: on
line source

import theano

# This is not final and may not even run for now.  It is just to give
# a feeling of what the interface could look like.

def shared_dataset(dataset, mem_size):
    if dataset.total_size > mem_size:
        return OnlineDataset(dataset)
    else:
        return MemoryDataset(dataset)

class MemoryDataset(theano.Op):
    def __init__(self, dataset):
        self.input = theano.shared(dataset.input)
        self.output = theano.shared(dataset.output)
        self.batch_size = dataset.batch_size

    def make_node(self, idx):
        idx_ = theano.as_tensor_variable(idx)
        return theano.Apply(self,
                            inputs = [idx_],
                            outputs = [self.input.type(), 
                                       self.output.type()])

    def preform(self, node, inputs, output_storage):
        idx, = inputs
        self.output_storage[0][0] = self.input[idx*self.batch_size:(idx+1)*self.batch_size]
        self.output_storage[1][0] = self.output[idx*self.batch_size:(idx+1)*self.batch_size]

class OnlineDataset(theano.Op):
    def __init__(self, dataset):
        self.dataset = dataset

    def make_node(self, idx):
        idx_ = theano.as_tensor_variable(idx)
        return theano.Apply(self,
                            inputs = [idx_],
                            outputs = [theano.tensor.fmatrix(), 
                                       theano.tensor.fmatrix()])
                            # fix this so its not fmatrix(), 
                            # but whatever the dataset outputs

    def perform(self, node, inputs, output_storage):
        idx, = inputs
        b = self.dataset.get_batch(idx.value)
        output_storage[0][0] = b.input
        output_storage[1][0] = b.output