Mercurial > pylearn
view doc/v2_planning/shared_dataset.py @ 1183:bc1b445e22fa
API_coding_style: Added code example to explain the point about the number of spaces after a period
author | Olivier Delalleau <delallea@iro> |
---|---|
date | Fri, 17 Sep 2010 16:51:09 -0400 |
parents | c1943feada10 |
children |
line wrap: on
line source
import theano # This is not final and may not even run for now. It is just to give # a feeling of what the interface could look like. def shared_dataset(dataset, mem_size): if dataset.total_size > mem_size: return OnlineDataset(dataset) else: return MemoryDataset(dataset) class MemoryDataset(theano.Op): def __init__(self, dataset): self.input = theano.shared(dataset.input) self.output = theano.shared(dataset.output) self.batch_size = dataset.batch_size def make_node(self, idx): idx_ = theano.as_tensor_variable(idx) return theano.Apply(self, inputs = [idx_], outputs = [self.input.type(), self.output.type()]) def preform(self, node, inputs, output_storage): idx, = inputs self.output_storage[0][0] = self.input[idx*self.batch_size:(idx+1)*self.batch_size] self.output_storage[1][0] = self.output[idx*self.batch_size:(idx+1)*self.batch_size] class OnlineDataset(theano.Op): def __init__(self, dataset): self.dataset = dataset def make_node(self, idx): idx_ = theano.as_tensor_variable(idx) return theano.Apply(self, inputs = [idx_], outputs = [theano.tensor.fmatrix(), theano.tensor.fmatrix()]) # fix this so its not fmatrix(), # but whatever the dataset outputs def perform(self, node, inputs, output_storage): idx, = inputs b = self.dataset.get_batch(idx.value) output_storage[0][0] = b.input output_storage[1][0] = b.output