changeset 375:e36ccffb3870

Changes to cast NIST data to floatX for rbm code
author Xavier Glorot <glorotxa@iro.umontreal.ca>
date Sun, 25 Apr 2010 14:53:10 -0400
parents 846f0678ffe8
children 0b7e64e8e93f
files deep/rbm/rbm.py
diffstat 1 files changed, 10 insertions(+), 5 deletions(-) [+]
line wrap: on
line diff
--- a/deep/rbm/rbm.py	Sun Apr 25 13:55:07 2010 -0400
+++ b/deep/rbm/rbm.py	Sun Apr 25 14:53:10 2010 -0400
@@ -273,9 +273,13 @@
     h = open(data_path+'all/all_test_data.ft')
     i = open(data_path+'all/all_test_labels.ft')
     
-    train_set_x = theano.shared(ft.read(f))
+    train_set_x_uint8 = theano.shared(ft.read(f))
+    test_set_x_uint8 = theano.shared(ft.read(h))
+
+
+    train_set_x = T.cast(train_set_x_uint8/255.,theano.config.floatX)
     train_set_y = ft.read(g)
-    test_set_x = ft.read(h)
+    test_set_x = T.cast(test_set_x_uint8/255.,theano.config.floatX)
     test_set_y = ft.read(i)
     
     f.close()
@@ -284,7 +288,6 @@
     h.close()
 
     #t = len(train_set_x)
-    print  train_set_x.value.shape
     
     # revoir la recuperation des donnees
 ##    dataset = load_data(dataset)
@@ -296,10 +299,10 @@
     batch_size = b_size    # size of the minibatch
 
     # compute number of minibatches for training, validation and testing
-    n_train_batches = train_set_x.value.shape[0] / batch_size
+    n_train_batches = train_set_x_uint8.value.shape[0] / batch_size
 
     # allocate symbolic variables for the data
-    index = T.scalar()    # index to a [mini]batch 
+    index = T.lscalar()    # index to a [mini]batch 
     x     = T.matrix('x')  # the data is presented as rasterized images
 
     rng        = numpy.random.RandomState(123)
@@ -332,6 +335,8 @@
     print 'yes'
     # it is ok for a theano function to have no output
     # the purpose of train_rbm is solely to update the RBM parameters
+    print type(batch_size)
+    print index.dtype
     train_rbm = theano.function([index], cost,
            updates = updates, 
            givens = { x: train_set_x[index*batch_size:(index+1)*batch_size]})