diff deep/rbm/rbm.py @ 372:1e99dc965b5b

correcting some bugs
author goldfinger
date Sun, 25 Apr 2010 13:28:45 -0400
parents d81284e13d77
children e36ccffb3870
line wrap: on
line diff
--- a/deep/rbm/rbm.py	Sat Apr 24 11:32:26 2010 -0400
+++ b/deep/rbm/rbm.py	Sun Apr 25 13:28:45 2010 -0400
@@ -16,10 +16,13 @@
 import time 
 import theano.tensor.nnet
 import pylearn
-import ift6266
-import theano,pylearn.version,ift6266
+#import ift6266
+import theano,pylearn.version #,ift6266
 from pylearn.io import filetensor as ft
-from ift6266 import datasets
+#from ift6266 import datasets
+
+from jobman.tools import DD, flatten
+from jobman import sql
 
 from theano.tensor.shared_randomstreams import RandomStreams
 
@@ -240,8 +243,7 @@
 
 
 
-def test_rbm(b_size = 25, nhidden = 1000, kk = 1, persistance = 0,
-             dataset= 0):
+def test_rbm(b_size = 20, nhidden = 1000, kk = 1, persistance = 0):
     """
     Demonstrate ***
 
@@ -257,28 +259,47 @@
 
     learning_rate=0.1
 
-    if data_set==0:
-    	datasets=datasets.nist_all()
-    elif data_set==1:
-        datasets=datasets.nist_P07()
-    elif data_set==2:
-        datasets=datasets.PNIST07()
+#    if data_set==0:
+#   	datasets=datasets.nist_all()
+#    elif data_set==1:
+#        datasets=datasets.nist_P07()
+#    elif data_set==2:
+#        datasets=datasets.PNIST07()
 
 
+    data_path = '/data/lisa/data/nist/by_class/'
+    f = open(data_path+'all/all_train_data.ft')
+    g = open(data_path+'all/all_train_labels.ft')
+    h = open(data_path+'all/all_test_data.ft')
+    i = open(data_path+'all/all_test_labels.ft')
+    
+    train_set_x = theano.shared(ft.read(f))
+    train_set_y = ft.read(g)
+    test_set_x = ft.read(h)
+    test_set_y = ft.read(i)
+    
+    f.close()
+    g.close()
+    i.close()
+    h.close()
+
+    #t = len(train_set_x)
+    print  train_set_x.value.shape
+    
     # revoir la recuperation des donnees
 ##    dataset = load_data(dataset)
 ##
 ##    train_set_x, train_set_y = datasets[0]
 ##    test_set_x , test_set_y  = datasets[2]
-##    training_epochs = 10 # a determiner
+    training_epochs = 1 # a determiner
 
     batch_size = b_size    # size of the minibatch
 
     # compute number of minibatches for training, validation and testing
-    #n_train_batches = train_set_x.value.shape[0] / batch_size
+    n_train_batches = train_set_x.value.shape[0] / batch_size
 
     # allocate symbolic variables for the data
-    index = T.lscalar()    # index to a [mini]batch 
+    index = T.scalar()    # index to a [mini]batch 
     x     = T.matrix('x')  # the data is presented as rasterized images
 
     rng        = numpy.random.RandomState(123)
@@ -304,16 +325,18 @@
     #################################
     #     Training the RBM          #
     #################################
-    dirname = 'data=%i'%dataset + ' persistance=%i'%persistance + ' n_hidden=%i'%n_hidden + 'batch_size=i%'%b_size
+    #os.chdir('~')
+    dirname = str(persistance) + '_' + str(nhidden) + '_' + str(b_size) + '_'+ str(kk)
     os.makedirs(dirname)
     os.chdir(dirname)
-
+    print 'yes'
     # it is ok for a theano function to have no output
     # the purpose of train_rbm is solely to update the RBM parameters
-    train_rbm = theano.function([x], cost,
+    train_rbm = theano.function([index], cost,
            updates = updates, 
-           )
+           givens = { x: train_set_x[index*batch_size:(index+1)*batch_size]})
 
+    print 'yep'
     plotting_time = 0.0
     start_time = time.clock()  
     bufsize = 1000
@@ -324,8 +347,10 @@
 
         # go through the training set
         mean_cost = []
-        for mini_x, mini_y in datasets.train(b_size):
-           mean_cost += [train_rbm(mini_x)]
+        for batch_index in xrange(n_train_batches):
+            mean_cost += [train_rbm(batch_index)]
+#        for mini_x, mini_y in datasets.train(b_size):
+#           mean_cost += [train_rbm(mini_x)]
 ##           learning_rate = learning_rate - 0.0001
 ##           learning_rate = learning_rate/(tau+( epoch*batch_index*batch_size))
 
@@ -348,18 +373,16 @@
     pretraining_time = (end_time - start_time) - plotting_time
    
     
-
-    
-
   
     #################################
     #     Sampling from the RBM     #
     #################################
 
     # find out the number of test samples  
-    number_of_test_samples = 1000
+    #number_of_test_samples = 100
+    number_of_test_samples = test_set_x.value.shape[0]
 
-    test_set_x, test_y  = datasets.test(100*b_size)
+    #test_set_x, test_y  = datasets.test(100*b_size)
     # pick random test examples, with which to initialize the persistent chain
     test_idx = rng.randint(number_of_test_samples - b_size)
     persistent_vis_chain = theano.shared(test_set_x.value[test_idx:test_idx+b_size])
@@ -403,10 +426,10 @@
     #save the model
     model = [rbm.W, rbm.vbias, rbm.hbias]
     f = fopen('params.txt', 'w')
-    pickle.dump(model, f)
+    cPickle.dump(model, f, protocol = -1)
     f.close()
     #os.chdir('./..')
-    return numpy.mean(costs), pretraining_time/360
+    return numpy.mean(costs), pretraining_time*36
 
 
 def experiment(state, channel):
@@ -415,7 +438,7 @@
                                            nhidden = state.ndidden,\
                                            kk = state.kk,\
                                            persistance = state.persistance,\
-                                           dataset = state.dataset)
+                                           )
 
     state.mean_costs = mean_costs
     state.time_execution = time_execution
@@ -423,5 +446,23 @@
     return channel.COMPLETE
 
 if __name__ == '__main__':
+    
+    TABLE_NAME='RBM_tapha'
 
-    test_rbm()    
+    # DB path...
+    test_rbm()
+    #db = sql.db('postgres://ift6266h10:f0572cd63b@gershwin/ift6266h10_db/'+ TABLE_NAME)
+
+    #state = DD()
+    #for b_size in 50, 75, 100:
+    #    state.b_size = b_size
+    #    for nhidden in 1000,1250,1500:
+    #        state.nhidden = nhidden
+    #        for kk in 1,2,3,4:
+    #            state.kk = kk
+    #            for persistance in 0,1:
+    #                state.persistance = persistance
+    #                sql.insert_job(rbm.experiment, flatten(state), db)
+
+    
+    #db.createView(TABLE_NAME + 'view')