diff scripts/setup_batches.py @ 295:a6b6b1140de9

modifié setup_batches.py pour compatibilité avec mlp_nist.py
author Guillaume Sicard <guitch21@gmail.com>
date Mon, 29 Mar 2010 09:18:54 -0400
parents f6d9b6b89c2a
children 5b260cc8f477
line wrap: on
line diff
--- a/scripts/setup_batches.py	Sat Mar 27 13:39:48 2010 -0400
+++ b/scripts/setup_batches.py	Mon Mar 29 09:18:54 2010 -0400
@@ -1,6 +1,7 @@
 # -*- coding: utf-8 -*-
 
 import random
+from numpy import *
 from pylearn.io import filetensor as ft
 
 class Batches():
@@ -17,6 +18,8 @@
     #upper_train_data = 'upper/upper_train_data.ft'
     #upper_train_labels = 'upper/upper_train_labels.ft'
 
+    print 'Opening data...'
+
     f_digits_train_data = open(data_path + digits_train_data)
     f_digits_train_labels = open(data_path + digits_train_labels)
     f_digits_test_data = open(data_path + digits_test_data)
@@ -47,6 +50,8 @@
     #f_upper_train_data.close()
     #f_upper_train_labels.close()
 
+    print 'Data opened'
+
   def set_batches(self, start_ratio = -1, end_ratio = -1, batch_size = 20, verbose = False):
     self.batch_size = batch_size
 
@@ -65,7 +70,7 @@
     # define main and other datasets
     raw_main_train_data = self.raw_digits_train_data
     raw_other_train_data = self.raw_lower_train_labels
-    raw_test_data = self.raw_digits_test_labels
+    raw_test_data = self.raw_digits_test_data
 
     raw_main_train_labels = self.raw_digits_train_labels
     raw_other_train_labels = self.raw_lower_train_labels
@@ -73,7 +78,7 @@
 
     main_train_size = len(raw_main_train_data)
     other_train_size = len(raw_other_train_data)
-    test_size = len(raw_test_data)
+    test_size = len(raw_test_labels)
     test_size = int(test_size/batch_size)
     test_size *= batch_size
     validation_size = test_size 
@@ -109,37 +114,27 @@
     while i_main < main_train_size - batch_size - test_size  and i_other < other_train_size - batch_size:
 
       ratio = self.start_ratio + i_batch * (self.end_ratio - self.start_ratio) / n_batches
-      batch_data = []
-      batch_labels = []
+      batch_data = raw_main_train_data[0:self.batch_size]
+      batch_labels = raw_main_train_labels[0:self.batch_size]
 
       for i in xrange(0, self.batch_size): # randomly choose between main and other, given the current ratio
 	rnd = random.randint(0, 100)
 
 	if rnd < 100 * ratio:
-	  batch_data = batch_data + \
-		[raw_main_train_data[i_main]]
-	  batch_labels = batch_labels + \
-		[raw_main_train_labels[i_main]]
+	  batch_data[i] = raw_main_train_data[i_main]
+	  batch_labels[i] = raw_main_train_labels[i_main]
 	  i_main += 1
 	else:
-	  batch_data = batch_data + \
-		[raw_other_train_data[i_other]]
-	  batch_labels = batch_labels + \
-		[raw_other_train_labels[i_other]]
+	  batch_data[i] = raw_other_train_data[i_other]
+	  batch_labels[i] = raw_other_train_labels[i_other] - 26 #to put values between 10 and 35 for lower case
 	  i_other += 1
 
       self.train_batches = self.train_batches + \
-	      [(batch_data,batch_labels)]
+	      [(batch_data, batch_labels)]
       i_batch += 1
 
     offset = i_main
 
-    if verbose == True:
-      print 'n_main = %d' %i_main
-      print 'n_other = %d' %i_other
-      print 'nb_train_batches = %d / %d' %(i_batch,n_batches)
-      print 'offset = %d' %offset
-
     # test batches
     self.test_batches = []
     for i in xrange(0, test_size, batch_size):
@@ -152,6 +147,12 @@
         self.validation_batches = self.validation_batches + \
             [(raw_main_train_data[offset+i:offset+i+batch_size], raw_main_train_labels[offset+i:offset+i+batch_size])]
 
+    if verbose == True:
+      print 'n_main = %d' %i_main
+      print 'n_other = %d' %i_other
+      print 'nb_train_batches = %d / %d' %(i_batch,n_batches)
+      print 'offset = %d' %offset
+
   def get_train_batches(self):
     return self.train_batches