comparison doc/v2_planning/shared_dataset.py @ 1117:c1943feada10

Proposal for theano dataset wrapper. The details still have to be worked out.
author Arnaud Bergeron <abergeron@gmail.com>
date Tue, 14 Sep 2010 15:22:48 -0400
parents
children
comparison
equal deleted inserted replaced
1116:18a092001752 1117:c1943feada10
1 import theano
2
3 # This is not final and may not even run for now. It is just to give
4 # a feeling of what the interface could look like.
5
6 def shared_dataset(dataset, mem_size):
7 if dataset.total_size > mem_size:
8 return OnlineDataset(dataset)
9 else:
10 return MemoryDataset(dataset)
11
12 class MemoryDataset(theano.Op):
13 def __init__(self, dataset):
14 self.input = theano.shared(dataset.input)
15 self.output = theano.shared(dataset.output)
16 self.batch_size = dataset.batch_size
17
18 def make_node(self, idx):
19 idx_ = theano.as_tensor_variable(idx)
20 return theano.Apply(self,
21 inputs = [idx_],
22 outputs = [self.input.type(),
23 self.output.type()])
24
25 def preform(self, node, inputs, output_storage):
26 idx, = inputs
27 self.output_storage[0][0] = self.input[idx*self.batch_size:(idx+1)*self.batch_size]
28 self.output_storage[1][0] = self.output[idx*self.batch_size:(idx+1)*self.batch_size]
29
30 class OnlineDataset(theano.Op):
31 def __init__(self, dataset):
32 self.dataset = dataset
33
34 def make_node(self, idx):
35 idx_ = theano.as_tensor_variable(idx)
36 return theano.Apply(self,
37 inputs = [idx_],
38 outputs = [theano.tensor.fmatrix(),
39 theano.tensor.fmatrix()])
40 # fix this so its not fmatrix(),
41 # but whatever the dataset outputs
42
43 def perform(self, node, inputs, output_storage):
44 idx, = inputs
45 b = self.dataset.get_batch(idx.value)
46 output_storage[0][0] = b.input
47 output_storage[1][0] = b.output