diff doc/v2_planning/shared_dataset.py @ 1117:c1943feada10

Proposal for theano dataset wrapper. The details still have to be worked out.
author Arnaud Bergeron <abergeron@gmail.com>
date Tue, 14 Sep 2010 15:22:48 -0400
parents
children
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/doc/v2_planning/shared_dataset.py	Tue Sep 14 15:22:48 2010 -0400
@@ -0,0 +1,47 @@
+import theano
+
+# This is not final and may not even run for now.  It is just to give
+# a feeling of what the interface could look like.
+
+def shared_dataset(dataset, mem_size):
+    if dataset.total_size > mem_size:
+        return OnlineDataset(dataset)
+    else:
+        return MemoryDataset(dataset)
+
+class MemoryDataset(theano.Op):
+    def __init__(self, dataset):
+        self.input = theano.shared(dataset.input)
+        self.output = theano.shared(dataset.output)
+        self.batch_size = dataset.batch_size
+
+    def make_node(self, idx):
+        idx_ = theano.as_tensor_variable(idx)
+        return theano.Apply(self,
+                            inputs = [idx_],
+                            outputs = [self.input.type(), 
+                                       self.output.type()])
+
+    def preform(self, node, inputs, output_storage):
+        idx, = inputs
+        self.output_storage[0][0] = self.input[idx*self.batch_size:(idx+1)*self.batch_size]
+        self.output_storage[1][0] = self.output[idx*self.batch_size:(idx+1)*self.batch_size]
+
+class OnlineDataset(theano.Op):
+    def __init__(self, dataset):
+        self.dataset = dataset
+
+    def make_node(self, idx):
+        idx_ = theano.as_tensor_variable(idx)
+        return theano.Apply(self,
+                            inputs = [idx_],
+                            outputs = [theano.tensor.fmatrix(), 
+                                       theano.tensor.fmatrix()])
+                            # fix this so its not fmatrix(), 
+                            # but whatever the dataset outputs
+
+    def perform(self, node, inputs, output_storage):
+        idx, = inputs
+        b = self.dataset.get_batch(idx.value)
+        output_storage[0][0] = b.input
+        output_storage[1][0] = b.output