comparison scripts/setup_batches.py @ 356:b0741ea3ff6f

Extension du choix de la classe principale pour les batches d'entrainement
author Guillaume Sicard <guitch21@gmail.com>
date Wed, 21 Apr 2010 23:47:50 -0400
parents 5b260cc8f477
children 22919039f7ab
comparison
equal deleted inserted replaced
355:76b7182dd32e 356:b0741ea3ff6f
13 digits_test_data = 'digits/digits_test_data.ft' 13 digits_test_data = 'digits/digits_test_data.ft'
14 digits_test_labels = 'digits/digits_test_labels.ft' 14 digits_test_labels = 'digits/digits_test_labels.ft'
15 15
16 lower_train_data = 'lower/lower_train_data.ft' 16 lower_train_data = 'lower/lower_train_data.ft'
17 lower_train_labels = 'lower/lower_train_labels.ft' 17 lower_train_labels = 'lower/lower_train_labels.ft'
18 lower_test_data = 'lower/lower_test_data.ft'
19 lower_test_labels = 'lower/lower_test_labels.ft'
20
18 upper_train_data = 'upper/upper_train_data.ft' 21 upper_train_data = 'upper/upper_train_data.ft'
19 upper_train_labels = 'upper/upper_train_labels.ft' 22 upper_train_labels = 'upper/upper_train_labels.ft'
23 upper_test_data = 'upper/upper_test_data.ft'
24 upper_test_labels = 'upper/upper_test_labels.ft'
25
20 test_data = 'all/all_test_data.ft' 26 test_data = 'all/all_test_data.ft'
21 test_labels = 'all/all_test_labels.ft' 27 test_labels = 'all/all_test_labels.ft'
22 28
23 print 'Opening data...' 29 print 'Opening data...'
24 30
27 f_digits_test_data = open(data_path + digits_test_data) 33 f_digits_test_data = open(data_path + digits_test_data)
28 f_digits_test_labels = open(data_path + digits_test_labels) 34 f_digits_test_labels = open(data_path + digits_test_labels)
29 35
30 f_lower_train_data = open(data_path + lower_train_data) 36 f_lower_train_data = open(data_path + lower_train_data)
31 f_lower_train_labels = open(data_path + lower_train_labels) 37 f_lower_train_labels = open(data_path + lower_train_labels)
38 f_lower_test_data = open(data_path + lower_test_data)
39 f_lower_test_labels = open(data_path + lower_test_labels)
40
32 f_upper_train_data = open(data_path + upper_train_data) 41 f_upper_train_data = open(data_path + upper_train_data)
33 f_upper_train_labels = open(data_path + upper_train_labels) 42 f_upper_train_labels = open(data_path + upper_train_labels)
34 43 f_upper_test_data = open(data_path + upper_test_data)
35 f_test_data = open(data_path + test_data) 44 f_upper_test_labels = open(data_path + upper_test_labels)
36 f_test_labels = open(data_path + test_labels) 45
46 #f_test_data = open(data_path + test_data)
47 #f_test_labels = open(data_path + test_labels)
37 48
38 self.raw_digits_train_data = ft.read(f_digits_train_data) 49 self.raw_digits_train_data = ft.read(f_digits_train_data)
39 self.raw_digits_train_labels = ft.read(f_digits_train_labels) 50 self.raw_digits_train_labels = ft.read(f_digits_train_labels)
40 self.raw_digits_test_data = ft.read(f_digits_test_data) 51 self.raw_digits_test_data = ft.read(f_digits_test_data)
41 self.raw_digits_test_labels = ft.read(f_digits_test_labels) 52 self.raw_digits_test_labels = ft.read(f_digits_test_labels)
42 53
43 self.raw_lower_train_data = ft.read(f_lower_train_data) 54 self.raw_lower_train_data = ft.read(f_lower_train_data)
44 self.raw_lower_train_labels = ft.read(f_lower_train_labels) 55 self.raw_lower_train_labels = ft.read(f_lower_train_labels)
56 self.raw_lower_test_data = ft.read(f_lower_test_data)
57 self.raw_lower_test_labels = ft.read(f_lower_test_labels)
58
45 self.raw_upper_train_data = ft.read(f_upper_train_data) 59 self.raw_upper_train_data = ft.read(f_upper_train_data)
46 self.raw_upper_train_labels = ft.read(f_upper_train_labels) 60 self.raw_upper_train_labels = ft.read(f_upper_train_labels)
47 61 self.raw_upper_test_data = ft.read(f_upper_test_data)
48 self.raw_test_data = ft.read(f_test_data) 62 self.raw_upper_test_labels = ft.read(f_upper_test_labels)
49 self.raw_test_labels = ft.read(f_test_labels) 63
64 #self.raw_test_data = ft.read(f_test_data)
65 #self.raw_test_labels = ft.read(f_test_labels)
50 66
51 f_digits_train_data.close() 67 f_digits_train_data.close()
52 f_digits_train_labels.close() 68 f_digits_train_labels.close()
53 f_digits_test_data.close() 69 f_digits_test_data.close()
54 f_digits_test_labels.close() 70 f_digits_test_labels.close()
55 71
56 f_lower_train_data.close() 72 f_lower_train_data.close()
57 f_lower_train_labels.close() 73 f_lower_train_labels.close()
74 f_lower_test_data.close()
75 f_lower_test_labels.close()
76
58 f_upper_train_data.close() 77 f_upper_train_data.close()
59 f_upper_train_labels.close() 78 f_upper_train_labels.close()
60 79 f_upper_test_data.close()
61 f_test_data.close() 80 f_upper_test_labels.close()
62 f_test_labels.close() 81
82 #f_test_data.close()
83 #f_test_labels.close()
63 84
64 print 'Data opened' 85 print 'Data opened'
65 86
66 def set_batches(self, start_ratio = -1, end_ratio = -1, batch_size = 20, verbose = False): 87 def set_batches(self, main_class = "d", start_ratio = -1, end_ratio = -1, batch_size = 20, verbose = False):
67 self.batch_size = batch_size 88 self.batch_size = batch_size
68 89
69 digits_train_size = len(self.raw_digits_train_labels) 90 digits_train_size = len(self.raw_digits_train_labels)
70 digits_test_size = len(self.raw_digits_test_labels) 91 digits_test_size = len(self.raw_digits_test_labels)
71 92
72 lower_train_size = len(self.raw_lower_train_labels) 93 lower_train_size = len(self.raw_lower_train_labels)
94
73 upper_train_size = len(self.raw_upper_train_labels) 95 upper_train_size = len(self.raw_upper_train_labels)
96 upper_test_size = len(self.raw_upper_test_labels)
74 97
75 if verbose == True: 98 if verbose == True:
76 print 'digits_train_size = %d' %digits_train_size 99 print 'digits_train_size = %d' %digits_train_size
77 print 'digits_test_size = %d' %digits_test_size 100 print 'digits_test_size = %d' %digits_test_size
78 print 'lower_train_size = %d' %lower_train_size 101 print 'lower_train_size = %d' %lower_train_size
79 print 'upper_train_size = %d' %upper_train_size 102 print 'upper_train_size = %d' %upper_train_size
80 103 print 'upper_test_size = %d' %upper_test_size
81 # define main and other datasets 104
82 raw_main_train_data = self.raw_digits_train_data 105 if main_class == "u":
83 raw_other_train_data1 = self.raw_lower_train_labels 106 # define main and other datasets
84 raw_other_train_data2 = self.raw_upper_train_labels 107 raw_main_train_data = self.raw_upper_train_data
85 raw_test_data = self.raw_digits_test_data 108 raw_other_train_data1 = self.raw_lower_train_labels
86 #raw_test_data = self.raw_test_data 109 raw_other_train_data2 = self.raw_digits_train_labels
87 110 raw_test_data = self.raw_upper_test_data
88 raw_main_train_labels = self.raw_digits_train_labels 111
89 raw_other_train_labels1 = self.raw_lower_train_labels 112 raw_main_train_labels = self.raw_upper_train_labels
90 raw_other_train_labels2 = self.raw_upper_train_labels 113 raw_other_train_labels1 = self.raw_lower_train_labels
91 raw_test_labels = self.raw_digits_test_labels 114 raw_other_train_labels2 = self.raw_digits_train_labels
92 #raw_test_labels = self.raw_test_labels 115 raw_test_labels = self.raw_upper_test_labels
116
117 elif main_class == "l":
118 # define main and other datasets
119 raw_main_train_data = self.raw_lower_train_data
120 raw_other_train_data1 = self.raw_upper_train_labels
121 raw_other_train_data2 = self.raw_digits_train_labels
122 raw_test_data = self.raw_lower_test_data
123
124 raw_main_train_labels = self.raw_lower_train_labels
125 raw_other_train_labels1 = self.raw_upper_train_labels
126 raw_other_train_labels2 = self.raw_digits_train_labels
127 raw_test_labels = self.raw_lower_test_labels
128
129 else:
130 main_class = "d"
131 # define main and other datasets
132 raw_main_train_data = self.raw_digits_train_data
133 raw_other_train_data1 = self.raw_lower_train_labels
134 raw_other_train_data2 = self.raw_upper_train_labels
135 raw_test_data = self.raw_digits_test_data
136
137 raw_main_train_labels = self.raw_digits_train_labels
138 raw_other_train_labels1 = self.raw_lower_train_labels
139 raw_other_train_labels2 = self.raw_upper_train_labels
140 raw_test_labels = self.raw_digits_test_labels
93 141
94 main_train_size = len(raw_main_train_labels) 142 main_train_size = len(raw_main_train_labels)
95 other_train_size1 = len(raw_other_train_labels1) 143 other_train_size1 = len(raw_other_train_labels1)
96 other_train_size2 = len(raw_other_train_labels2) 144 other_train_size2 = len(raw_other_train_labels2)
97 other_train_size = other_train_size1 + other_train_size2 145 other_train_size = other_train_size1 + other_train_size2
111 self.end_ratio = float(main_train_size - test_size) / float(main_train_size + other_train_size) 159 self.end_ratio = float(main_train_size - test_size) / float(main_train_size + other_train_size)
112 else: 160 else:
113 self.end_ratio = end_ratio 161 self.end_ratio = end_ratio
114 162
115 if verbose == True: 163 if verbose == True:
164 print 'main class : %s' %main_class
116 print 'start_ratio = %f' %self.start_ratio 165 print 'start_ratio = %f' %self.start_ratio
117 print 'end_ratio = %f' %self.end_ratio 166 print 'end_ratio = %f' %self.end_ratio
118 167
119 i_main = 0 168 i_main = 0
120 i_other1 = 0 169 i_other1 = 0