diff scripts/setup_batches.py @ 356:b0741ea3ff6f

Extension du choix de la classe principale pour les batches d'entrainement
author Guillaume Sicard <guitch21@gmail.com>
date Wed, 21 Apr 2010 23:47:50 -0400
parents 5b260cc8f477
children 22919039f7ab
line wrap: on
line diff
--- a/scripts/setup_batches.py	Wed Apr 21 15:07:09 2010 -0400
+++ b/scripts/setup_batches.py	Wed Apr 21 23:47:50 2010 -0400
@@ -15,8 +15,14 @@
 
     lower_train_data = 'lower/lower_train_data.ft'
     lower_train_labels = 'lower/lower_train_labels.ft'
+    lower_test_data = 'lower/lower_test_data.ft'
+    lower_test_labels = 'lower/lower_test_labels.ft'
+
     upper_train_data = 'upper/upper_train_data.ft'
     upper_train_labels = 'upper/upper_train_labels.ft'
+    upper_test_data = 'upper/upper_test_data.ft'
+    upper_test_labels = 'upper/upper_test_labels.ft'
+
     test_data = 'all/all_test_data.ft'
     test_labels = 'all/all_test_labels.ft'
 
@@ -29,11 +35,16 @@
 
     f_lower_train_data = open(data_path + lower_train_data)
     f_lower_train_labels = open(data_path + lower_train_labels)
+    f_lower_test_data = open(data_path + lower_test_data)
+    f_lower_test_labels = open(data_path + lower_test_labels)
+
     f_upper_train_data = open(data_path + upper_train_data)
     f_upper_train_labels = open(data_path + upper_train_labels)
+    f_upper_test_data = open(data_path + upper_test_data)
+    f_upper_test_labels = open(data_path + upper_test_labels)
 
-    f_test_data = open(data_path + test_data)
-    f_test_labels = open(data_path + test_labels)
+    #f_test_data = open(data_path + test_data)
+    #f_test_labels = open(data_path + test_labels)
 
     self.raw_digits_train_data = ft.read(f_digits_train_data)
     self.raw_digits_train_labels = ft.read(f_digits_train_labels)
@@ -42,11 +53,16 @@
 
     self.raw_lower_train_data = ft.read(f_lower_train_data)
     self.raw_lower_train_labels = ft.read(f_lower_train_labels)
+    self.raw_lower_test_data = ft.read(f_lower_test_data)
+    self.raw_lower_test_labels = ft.read(f_lower_test_labels)
+
     self.raw_upper_train_data = ft.read(f_upper_train_data)
     self.raw_upper_train_labels = ft.read(f_upper_train_labels)
+    self.raw_upper_test_data = ft.read(f_upper_test_data)
+    self.raw_upper_test_labels = ft.read(f_upper_test_labels)
 
-    self.raw_test_data = ft.read(f_test_data)
-    self.raw_test_labels = ft.read(f_test_labels)
+    #self.raw_test_data = ft.read(f_test_data)
+    #self.raw_test_labels = ft.read(f_test_labels)
 
     f_digits_train_data.close()
     f_digits_train_labels.close()
@@ -55,41 +71,73 @@
 
     f_lower_train_data.close()
     f_lower_train_labels.close()
+    f_lower_test_data.close()
+    f_lower_test_labels.close()
+
     f_upper_train_data.close()
     f_upper_train_labels.close()
+    f_upper_test_data.close()
+    f_upper_test_labels.close()
 
-    f_test_data.close()
-    f_test_labels.close()
+    #f_test_data.close()
+    #f_test_labels.close()
 
     print 'Data opened'
 
-  def set_batches(self, start_ratio = -1, end_ratio = -1, batch_size = 20, verbose = False):
+  def set_batches(self, main_class = "d", start_ratio = -1, end_ratio = -1, batch_size = 20, verbose = False):
     self.batch_size = batch_size
 
     digits_train_size = len(self.raw_digits_train_labels)
     digits_test_size = len(self.raw_digits_test_labels)
 
     lower_train_size = len(self.raw_lower_train_labels)
+
     upper_train_size = len(self.raw_upper_train_labels)
+    upper_test_size = len(self.raw_upper_test_labels)
 
     if verbose == True:
       print 'digits_train_size = %d' %digits_train_size
       print 'digits_test_size = %d' %digits_test_size
       print 'lower_train_size = %d' %lower_train_size
       print 'upper_train_size = %d' %upper_train_size
+      print 'upper_test_size = %d' %upper_test_size
 
-    # define main and other datasets
-    raw_main_train_data = self.raw_digits_train_data
-    raw_other_train_data1 = self.raw_lower_train_labels
-    raw_other_train_data2 = self.raw_upper_train_labels
-    raw_test_data = self.raw_digits_test_data
-    #raw_test_data = self.raw_test_data
+    if main_class == "u":
+	# define main and other datasets
+	raw_main_train_data = self.raw_upper_train_data
+	raw_other_train_data1 = self.raw_lower_train_labels
+	raw_other_train_data2 = self.raw_digits_train_labels
+	raw_test_data = self.raw_upper_test_data
+
+	raw_main_train_labels = self.raw_upper_train_labels
+	raw_other_train_labels1 = self.raw_lower_train_labels
+	raw_other_train_labels2 = self.raw_digits_train_labels
+	raw_test_labels = self.raw_upper_test_labels
 
-    raw_main_train_labels = self.raw_digits_train_labels
-    raw_other_train_labels1 = self.raw_lower_train_labels
-    raw_other_train_labels2 = self.raw_upper_train_labels
-    raw_test_labels = self.raw_digits_test_labels
-    #raw_test_labels = self.raw_test_labels
+    elif main_class == "l":
+	# define main and other datasets
+	raw_main_train_data = self.raw_lower_train_data
+	raw_other_train_data1 = self.raw_upper_train_labels
+	raw_other_train_data2 = self.raw_digits_train_labels
+	raw_test_data = self.raw_lower_test_data
+
+	raw_main_train_labels = self.raw_lower_train_labels
+	raw_other_train_labels1 = self.raw_upper_train_labels
+	raw_other_train_labels2 = self.raw_digits_train_labels
+	raw_test_labels = self.raw_lower_test_labels
+
+    else:
+	main_class = "d"
+	# define main and other datasets
+	raw_main_train_data = self.raw_digits_train_data
+	raw_other_train_data1 = self.raw_lower_train_labels
+	raw_other_train_data2 = self.raw_upper_train_labels
+	raw_test_data = self.raw_digits_test_data
+
+	raw_main_train_labels = self.raw_digits_train_labels
+	raw_other_train_labels1 = self.raw_lower_train_labels
+	raw_other_train_labels2 = self.raw_upper_train_labels
+	raw_test_labels = self.raw_digits_test_labels
 
     main_train_size = len(raw_main_train_labels)
     other_train_size1 = len(raw_other_train_labels1)
@@ -113,6 +161,7 @@
       self.end_ratio = end_ratio
 
     if verbose == True:
+      print 'main class : %s' %main_class
       print 'start_ratio = %f' %self.start_ratio
       print 'end_ratio = %f' %self.end_ratio