comparison scripts/setup_batches.py @ 332:5b260cc8f477

Correction de bug numpy array et ajout d'une deuxième classe auxiliaire
author Guillaume Sicard <guitch21@gmail.com>
date Wed, 14 Apr 2010 11:51:18 -0400
parents a6b6b1140de9
children 7bc555cc9aab b0741ea3ff6f
comparison
equal deleted inserted replaced
331:c2331b8e4b89 332:5b260cc8f477
13 digits_test_data = 'digits/digits_test_data.ft' 13 digits_test_data = 'digits/digits_test_data.ft'
14 digits_test_labels = 'digits/digits_test_labels.ft' 14 digits_test_labels = 'digits/digits_test_labels.ft'
15 15
16 lower_train_data = 'lower/lower_train_data.ft' 16 lower_train_data = 'lower/lower_train_data.ft'
17 lower_train_labels = 'lower/lower_train_labels.ft' 17 lower_train_labels = 'lower/lower_train_labels.ft'
18 #upper_train_data = 'upper/upper_train_data.ft' 18 upper_train_data = 'upper/upper_train_data.ft'
19 #upper_train_labels = 'upper/upper_train_labels.ft' 19 upper_train_labels = 'upper/upper_train_labels.ft'
20 test_data = 'all/all_test_data.ft'
21 test_labels = 'all/all_test_labels.ft'
20 22
21 print 'Opening data...' 23 print 'Opening data...'
22 24
23 f_digits_train_data = open(data_path + digits_train_data) 25 f_digits_train_data = open(data_path + digits_train_data)
24 f_digits_train_labels = open(data_path + digits_train_labels) 26 f_digits_train_labels = open(data_path + digits_train_labels)
25 f_digits_test_data = open(data_path + digits_test_data) 27 f_digits_test_data = open(data_path + digits_test_data)
26 f_digits_test_labels = open(data_path + digits_test_labels) 28 f_digits_test_labels = open(data_path + digits_test_labels)
27 29
28 f_lower_train_data = open(data_path + lower_train_data) 30 f_lower_train_data = open(data_path + lower_train_data)
29 f_lower_train_labels = open(data_path + lower_train_labels) 31 f_lower_train_labels = open(data_path + lower_train_labels)
30 #f_upper_train_data = open(data_path + upper_train_data) 32 f_upper_train_data = open(data_path + upper_train_data)
31 #f_upper_train_labels = open(data_path + upper_train_labels) 33 f_upper_train_labels = open(data_path + upper_train_labels)
34
35 f_test_data = open(data_path + test_data)
36 f_test_labels = open(data_path + test_labels)
32 37
33 self.raw_digits_train_data = ft.read(f_digits_train_data) 38 self.raw_digits_train_data = ft.read(f_digits_train_data)
34 self.raw_digits_train_labels = ft.read(f_digits_train_labels) 39 self.raw_digits_train_labels = ft.read(f_digits_train_labels)
35 self.raw_digits_test_data = ft.read(f_digits_test_data) 40 self.raw_digits_test_data = ft.read(f_digits_test_data)
36 self.raw_digits_test_labels = ft.read(f_digits_test_labels) 41 self.raw_digits_test_labels = ft.read(f_digits_test_labels)
37 42
38 self.raw_lower_train_data = ft.read(f_lower_train_data) 43 self.raw_lower_train_data = ft.read(f_lower_train_data)
39 self.raw_lower_train_labels = ft.read(f_lower_train_labels) 44 self.raw_lower_train_labels = ft.read(f_lower_train_labels)
40 #self.raw_upper_train_data = ft.read(f_upper_train_data) 45 self.raw_upper_train_data = ft.read(f_upper_train_data)
41 #self.raw_upper_train_labels = ft.read(f_upper_train_labels) 46 self.raw_upper_train_labels = ft.read(f_upper_train_labels)
47
48 self.raw_test_data = ft.read(f_test_data)
49 self.raw_test_labels = ft.read(f_test_labels)
42 50
43 f_digits_train_data.close() 51 f_digits_train_data.close()
44 f_digits_train_labels.close() 52 f_digits_train_labels.close()
45 f_digits_test_data.close() 53 f_digits_test_data.close()
46 f_digits_test_labels.close() 54 f_digits_test_labels.close()
47 55
48 f_lower_train_data.close() 56 f_lower_train_data.close()
49 f_lower_train_labels.close() 57 f_lower_train_labels.close()
50 #f_upper_train_data.close() 58 f_upper_train_data.close()
51 #f_upper_train_labels.close() 59 f_upper_train_labels.close()
60
61 f_test_data.close()
62 f_test_labels.close()
52 63
53 print 'Data opened' 64 print 'Data opened'
54 65
55 def set_batches(self, start_ratio = -1, end_ratio = -1, batch_size = 20, verbose = False): 66 def set_batches(self, start_ratio = -1, end_ratio = -1, batch_size = 20, verbose = False):
56 self.batch_size = batch_size 67 self.batch_size = batch_size
57 68
58 digits_train_size = len(self.raw_digits_train_labels) 69 digits_train_size = len(self.raw_digits_train_labels)
59 digits_test_size = len(self.raw_digits_test_labels) 70 digits_test_size = len(self.raw_digits_test_labels)
60 71
61 lower_train_size = len(self.raw_lower_train_labels) 72 lower_train_size = len(self.raw_lower_train_labels)
62 #upper_train_size = len(self.raw_upper_train_labels) 73 upper_train_size = len(self.raw_upper_train_labels)
63 74
64 if verbose == True: 75 if verbose == True:
65 print 'digits_train_size = %d' %digits_train_size 76 print 'digits_train_size = %d' %digits_train_size
66 print 'digits_test_size = %d' %digits_test_size 77 print 'digits_test_size = %d' %digits_test_size
67 print 'lower_train_size = %d' %lower_train_size 78 print 'lower_train_size = %d' %lower_train_size
68 #print 'upper_train_size = %d' %upper_train_size 79 print 'upper_train_size = %d' %upper_train_size
69 80
70 # define main and other datasets 81 # define main and other datasets
71 raw_main_train_data = self.raw_digits_train_data 82 raw_main_train_data = self.raw_digits_train_data
72 raw_other_train_data = self.raw_lower_train_labels 83 raw_other_train_data1 = self.raw_lower_train_labels
84 raw_other_train_data2 = self.raw_upper_train_labels
73 raw_test_data = self.raw_digits_test_data 85 raw_test_data = self.raw_digits_test_data
86 #raw_test_data = self.raw_test_data
74 87
75 raw_main_train_labels = self.raw_digits_train_labels 88 raw_main_train_labels = self.raw_digits_train_labels
76 raw_other_train_labels = self.raw_lower_train_labels 89 raw_other_train_labels1 = self.raw_lower_train_labels
90 raw_other_train_labels2 = self.raw_upper_train_labels
77 raw_test_labels = self.raw_digits_test_labels 91 raw_test_labels = self.raw_digits_test_labels
78 92 #raw_test_labels = self.raw_test_labels
79 main_train_size = len(raw_main_train_data) 93
80 other_train_size = len(raw_other_train_data) 94 main_train_size = len(raw_main_train_labels)
95 other_train_size1 = len(raw_other_train_labels1)
96 other_train_size2 = len(raw_other_train_labels2)
97 other_train_size = other_train_size1 + other_train_size2
98
81 test_size = len(raw_test_labels) 99 test_size = len(raw_test_labels)
82 test_size = int(test_size/batch_size) 100 test_size = int(test_size/batch_size)
83 test_size *= batch_size 101 test_size *= batch_size
84 validation_size = test_size 102 validation_size = test_size
85 103
86 # default ratio is actual ratio 104 # default ratio is actual ratio
87 if start_ratio == -1: 105 if start_ratio == -1:
88 self.start_ratio = float(main_train_size) / float(main_train_size + other_train_size) 106 self.start_ratio = float(main_train_size - test_size) / float(main_train_size + other_train_size)
89 else: 107 else:
90 self.start_ratio = start_ratio 108 self.start_ratio = start_ratio
91 109
92 if start_ratio == -1: 110 if start_ratio == -1:
93 self.end_ratio = float(main_train_size) / float(main_train_size + other_train_size) 111 self.end_ratio = float(main_train_size - test_size) / float(main_train_size + other_train_size)
94 else: 112 else:
95 self.end_ratio = end_ratio 113 self.end_ratio = end_ratio
96 114
97 if verbose == True: 115 if verbose == True:
98 print 'start_ratio = %f' %self.start_ratio 116 print 'start_ratio = %f' %self.start_ratio
99 print 'end_ratio = %f' %self.end_ratio 117 print 'end_ratio = %f' %self.end_ratio
100 118
101 i_main = 0 119 i_main = 0
102 i_other = 0 120 i_other1 = 0
121 i_other2 = 0
103 i_batch = 0 122 i_batch = 0
104 123
105 # compute the number of batches given start and end ratios 124 # compute the number of batches given start and end ratios
106 n_main_batch = (main_train_size - batch_size * (self.end_ratio - self.start_ratio) / 2 ) / (batch_size * (self.start_ratio + (self.end_ratio - self.start_ratio) / 2)) 125 n_main_batch = (main_train_size - test_size - batch_size * (self.end_ratio - self.start_ratio) / 2 ) / (batch_size * (self.start_ratio + (self.end_ratio - self.start_ratio) / 2))
107 n_other_batch = (other_train_size - batch_size * (self.end_ratio - self.start_ratio) / 2 ) / (batch_size - batch_size * (self.start_ratio + (self.end_ratio - self.start_ratio) / 2)) 126 if (batch_size != batch_size * (self.start_ratio + (self.end_ratio - self.start_ratio) / 2)):
127 n_other_batch = (other_train_size - batch_size * (self.end_ratio - self.start_ratio) / 2 ) / (batch_size - batch_size * (self.start_ratio + (self.end_ratio - self.start_ratio) / 2))
128 else:
129 n_other_batch = n_main_batch
130
108 n_batches = min([n_main_batch, n_other_batch]) 131 n_batches = min([n_main_batch, n_other_batch])
109 132
110 # train batches 133 # train batches
111 self.train_batches = [] 134 self.train_batches = []
112 135
113 # as long as we have data left in main and other, we create batches 136 # as long as we have data left in main and other, we create batches
114 while i_main < main_train_size - batch_size - test_size and i_other < other_train_size - batch_size: 137 while i_main < main_train_size - batch_size - test_size and i_other1 < other_train_size1 - batch_size and i_other2 < other_train_size2 - batch_size:
115
116 ratio = self.start_ratio + i_batch * (self.end_ratio - self.start_ratio) / n_batches 138 ratio = self.start_ratio + i_batch * (self.end_ratio - self.start_ratio) / n_batches
117 batch_data = raw_main_train_data[0:self.batch_size] 139 batch_data = copy(raw_main_train_data[0:self.batch_size])
118 batch_labels = raw_main_train_labels[0:self.batch_size] 140 batch_labels = copy(raw_main_train_labels[0:self.batch_size])
119 141
120 for i in xrange(0, self.batch_size): # randomly choose between main and other, given the current ratio 142 for i in xrange(0, self.batch_size): # randomly choose between main and other, given the current ratio
121 rnd = random.randint(0, 100) 143 rnd1 = random.randint(0, 100)
122 144
123 if rnd < 100 * ratio: 145 if rnd1 < 100 * ratio:
124 batch_data[i] = raw_main_train_data[i_main] 146 batch_data[i] = raw_main_train_data[i_main]
125 batch_labels[i] = raw_main_train_labels[i_main] 147 batch_labels[i] = raw_main_train_labels[i_main]
126 i_main += 1 148 i_main += 1
127 else: 149 else:
128 batch_data[i] = raw_other_train_data[i_other] 150 rnd2 = random.randint(0, 100)
129 batch_labels[i] = raw_other_train_labels[i_other] - 26 #to put values between 10 and 35 for lower case 151
130 i_other += 1 152 if rnd2 < 100 * float(other_train_size1) / float(other_train_size):
153 batch_data[i] = raw_other_train_data1[i_other1]
154 batch_labels[i] = raw_other_train_labels1[i_other1]
155 i_other1 += 1
156 else:
157 batch_data[i] = raw_other_train_data2[i_other2]
158 batch_labels[i] = raw_other_train_labels2[i_other2]
159 i_other2 += 1
131 160
132 self.train_batches = self.train_batches + \ 161 self.train_batches = self.train_batches + \
133 [(batch_data, batch_labels)] 162 [(batch_data, batch_labels)]
134 i_batch += 1 163 i_batch += 1
135 164
141 self.test_batches = self.test_batches + \ 170 self.test_batches = self.test_batches + \
142 [(raw_test_data[i:i+batch_size], raw_test_labels[i:i+batch_size])] 171 [(raw_test_data[i:i+batch_size], raw_test_labels[i:i+batch_size])]
143 172
144 # validation batches 173 # validation batches
145 self.validation_batches = [] 174 self.validation_batches = []
146 for i in xrange(0, test_size, batch_size): 175 for i in xrange(0, validation_size, batch_size):
147 self.validation_batches = self.validation_batches + \ 176 self.validation_batches = self.validation_batches + \
148 [(raw_main_train_data[offset+i:offset+i+batch_size], raw_main_train_labels[offset+i:offset+i+batch_size])] 177 [(raw_main_train_data[offset+i:offset+i+batch_size], raw_main_train_labels[offset+i:offset+i+batch_size])]
149 178
150 if verbose == True: 179 if verbose == True:
151 print 'n_main = %d' %i_main 180 print 'n_main = %d' %i_main
152 print 'n_other = %d' %i_other 181 print 'n_other1 = %d' %i_other1
182 print 'n_other2 = %d' %i_other2
153 print 'nb_train_batches = %d / %d' %(i_batch,n_batches) 183 print 'nb_train_batches = %d / %d' %(i_batch,n_batches)
154 print 'offset = %d' %offset 184 print 'offset = %d' %offset
155 185
156 def get_train_batches(self): 186 def get_train_batches(self):
157 return self.train_batches 187 return self.train_batches