# HG changeset patch # User Guillaume Sicard # Date 1271908070 14400 # Node ID b0741ea3ff6fdde2ced4ee8760f82c14224c6b18 # Parent 76b7182dd32e33f2f6ea3c5f13d33a517a0d86c5 Extension du choix de la classe principale pour les batches d'entrainement diff -r 76b7182dd32e -r b0741ea3ff6f scripts/setup_batches.py --- a/scripts/setup_batches.py Wed Apr 21 15:07:09 2010 -0400 +++ b/scripts/setup_batches.py Wed Apr 21 23:47:50 2010 -0400 @@ -15,8 +15,14 @@ lower_train_data = 'lower/lower_train_data.ft' lower_train_labels = 'lower/lower_train_labels.ft' + lower_test_data = 'lower/lower_test_data.ft' + lower_test_labels = 'lower/lower_test_labels.ft' + upper_train_data = 'upper/upper_train_data.ft' upper_train_labels = 'upper/upper_train_labels.ft' + upper_test_data = 'upper/upper_test_data.ft' + upper_test_labels = 'upper/upper_test_labels.ft' + test_data = 'all/all_test_data.ft' test_labels = 'all/all_test_labels.ft' @@ -29,11 +35,16 @@ f_lower_train_data = open(data_path + lower_train_data) f_lower_train_labels = open(data_path + lower_train_labels) + f_lower_test_data = open(data_path + lower_test_data) + f_lower_test_labels = open(data_path + lower_test_labels) + f_upper_train_data = open(data_path + upper_train_data) f_upper_train_labels = open(data_path + upper_train_labels) + f_upper_test_data = open(data_path + upper_test_data) + f_upper_test_labels = open(data_path + upper_test_labels) - f_test_data = open(data_path + test_data) - f_test_labels = open(data_path + test_labels) + #f_test_data = open(data_path + test_data) + #f_test_labels = open(data_path + test_labels) self.raw_digits_train_data = ft.read(f_digits_train_data) self.raw_digits_train_labels = ft.read(f_digits_train_labels) @@ -42,11 +53,16 @@ self.raw_lower_train_data = ft.read(f_lower_train_data) self.raw_lower_train_labels = ft.read(f_lower_train_labels) + self.raw_lower_test_data = ft.read(f_lower_test_data) + self.raw_lower_test_labels = ft.read(f_lower_test_labels) + self.raw_upper_train_data = ft.read(f_upper_train_data) self.raw_upper_train_labels = ft.read(f_upper_train_labels) + self.raw_upper_test_data = ft.read(f_upper_test_data) + self.raw_upper_test_labels = ft.read(f_upper_test_labels) - self.raw_test_data = ft.read(f_test_data) - self.raw_test_labels = ft.read(f_test_labels) + #self.raw_test_data = ft.read(f_test_data) + #self.raw_test_labels = ft.read(f_test_labels) f_digits_train_data.close() f_digits_train_labels.close() @@ -55,41 +71,73 @@ f_lower_train_data.close() f_lower_train_labels.close() + f_lower_test_data.close() + f_lower_test_labels.close() + f_upper_train_data.close() f_upper_train_labels.close() + f_upper_test_data.close() + f_upper_test_labels.close() - f_test_data.close() - f_test_labels.close() + #f_test_data.close() + #f_test_labels.close() print 'Data opened' - def set_batches(self, start_ratio = -1, end_ratio = -1, batch_size = 20, verbose = False): + def set_batches(self, main_class = "d", start_ratio = -1, end_ratio = -1, batch_size = 20, verbose = False): self.batch_size = batch_size digits_train_size = len(self.raw_digits_train_labels) digits_test_size = len(self.raw_digits_test_labels) lower_train_size = len(self.raw_lower_train_labels) + upper_train_size = len(self.raw_upper_train_labels) + upper_test_size = len(self.raw_upper_test_labels) if verbose == True: print 'digits_train_size = %d' %digits_train_size print 'digits_test_size = %d' %digits_test_size print 'lower_train_size = %d' %lower_train_size print 'upper_train_size = %d' %upper_train_size + print 'upper_test_size = %d' %upper_test_size - # define main and other datasets - raw_main_train_data = self.raw_digits_train_data - raw_other_train_data1 = self.raw_lower_train_labels - raw_other_train_data2 = self.raw_upper_train_labels - raw_test_data = self.raw_digits_test_data - #raw_test_data = self.raw_test_data + if main_class == "u": + # define main and other datasets + raw_main_train_data = self.raw_upper_train_data + raw_other_train_data1 = self.raw_lower_train_labels + raw_other_train_data2 = self.raw_digits_train_labels + raw_test_data = self.raw_upper_test_data + + raw_main_train_labels = self.raw_upper_train_labels + raw_other_train_labels1 = self.raw_lower_train_labels + raw_other_train_labels2 = self.raw_digits_train_labels + raw_test_labels = self.raw_upper_test_labels - raw_main_train_labels = self.raw_digits_train_labels - raw_other_train_labels1 = self.raw_lower_train_labels - raw_other_train_labels2 = self.raw_upper_train_labels - raw_test_labels = self.raw_digits_test_labels - #raw_test_labels = self.raw_test_labels + elif main_class == "l": + # define main and other datasets + raw_main_train_data = self.raw_lower_train_data + raw_other_train_data1 = self.raw_upper_train_labels + raw_other_train_data2 = self.raw_digits_train_labels + raw_test_data = self.raw_lower_test_data + + raw_main_train_labels = self.raw_lower_train_labels + raw_other_train_labels1 = self.raw_upper_train_labels + raw_other_train_labels2 = self.raw_digits_train_labels + raw_test_labels = self.raw_lower_test_labels + + else: + main_class = "d" + # define main and other datasets + raw_main_train_data = self.raw_digits_train_data + raw_other_train_data1 = self.raw_lower_train_labels + raw_other_train_data2 = self.raw_upper_train_labels + raw_test_data = self.raw_digits_test_data + + raw_main_train_labels = self.raw_digits_train_labels + raw_other_train_labels1 = self.raw_lower_train_labels + raw_other_train_labels2 = self.raw_upper_train_labels + raw_test_labels = self.raw_digits_test_labels main_train_size = len(raw_main_train_labels) other_train_size1 = len(raw_other_train_labels1) @@ -113,6 +161,7 @@ self.end_ratio = end_ratio if verbose == True: + print 'main class : %s' %main_class print 'start_ratio = %f' %self.start_ratio print 'end_ratio = %f' %self.end_ratio