Mercurial > ift6266
comparison scripts/setup_batches.py @ 332:5b260cc8f477
Correction de bug numpy array et ajout d'une deuxième classe auxiliaire
author | Guillaume Sicard <guitch21@gmail.com> |
---|---|
date | Wed, 14 Apr 2010 11:51:18 -0400 |
parents | a6b6b1140de9 |
children | 7bc555cc9aab b0741ea3ff6f |
comparison
equal
deleted
inserted
replaced
331:c2331b8e4b89 | 332:5b260cc8f477 |
---|---|
13 digits_test_data = 'digits/digits_test_data.ft' | 13 digits_test_data = 'digits/digits_test_data.ft' |
14 digits_test_labels = 'digits/digits_test_labels.ft' | 14 digits_test_labels = 'digits/digits_test_labels.ft' |
15 | 15 |
16 lower_train_data = 'lower/lower_train_data.ft' | 16 lower_train_data = 'lower/lower_train_data.ft' |
17 lower_train_labels = 'lower/lower_train_labels.ft' | 17 lower_train_labels = 'lower/lower_train_labels.ft' |
18 #upper_train_data = 'upper/upper_train_data.ft' | 18 upper_train_data = 'upper/upper_train_data.ft' |
19 #upper_train_labels = 'upper/upper_train_labels.ft' | 19 upper_train_labels = 'upper/upper_train_labels.ft' |
20 test_data = 'all/all_test_data.ft' | |
21 test_labels = 'all/all_test_labels.ft' | |
20 | 22 |
21 print 'Opening data...' | 23 print 'Opening data...' |
22 | 24 |
23 f_digits_train_data = open(data_path + digits_train_data) | 25 f_digits_train_data = open(data_path + digits_train_data) |
24 f_digits_train_labels = open(data_path + digits_train_labels) | 26 f_digits_train_labels = open(data_path + digits_train_labels) |
25 f_digits_test_data = open(data_path + digits_test_data) | 27 f_digits_test_data = open(data_path + digits_test_data) |
26 f_digits_test_labels = open(data_path + digits_test_labels) | 28 f_digits_test_labels = open(data_path + digits_test_labels) |
27 | 29 |
28 f_lower_train_data = open(data_path + lower_train_data) | 30 f_lower_train_data = open(data_path + lower_train_data) |
29 f_lower_train_labels = open(data_path + lower_train_labels) | 31 f_lower_train_labels = open(data_path + lower_train_labels) |
30 #f_upper_train_data = open(data_path + upper_train_data) | 32 f_upper_train_data = open(data_path + upper_train_data) |
31 #f_upper_train_labels = open(data_path + upper_train_labels) | 33 f_upper_train_labels = open(data_path + upper_train_labels) |
34 | |
35 f_test_data = open(data_path + test_data) | |
36 f_test_labels = open(data_path + test_labels) | |
32 | 37 |
33 self.raw_digits_train_data = ft.read(f_digits_train_data) | 38 self.raw_digits_train_data = ft.read(f_digits_train_data) |
34 self.raw_digits_train_labels = ft.read(f_digits_train_labels) | 39 self.raw_digits_train_labels = ft.read(f_digits_train_labels) |
35 self.raw_digits_test_data = ft.read(f_digits_test_data) | 40 self.raw_digits_test_data = ft.read(f_digits_test_data) |
36 self.raw_digits_test_labels = ft.read(f_digits_test_labels) | 41 self.raw_digits_test_labels = ft.read(f_digits_test_labels) |
37 | 42 |
38 self.raw_lower_train_data = ft.read(f_lower_train_data) | 43 self.raw_lower_train_data = ft.read(f_lower_train_data) |
39 self.raw_lower_train_labels = ft.read(f_lower_train_labels) | 44 self.raw_lower_train_labels = ft.read(f_lower_train_labels) |
40 #self.raw_upper_train_data = ft.read(f_upper_train_data) | 45 self.raw_upper_train_data = ft.read(f_upper_train_data) |
41 #self.raw_upper_train_labels = ft.read(f_upper_train_labels) | 46 self.raw_upper_train_labels = ft.read(f_upper_train_labels) |
47 | |
48 self.raw_test_data = ft.read(f_test_data) | |
49 self.raw_test_labels = ft.read(f_test_labels) | |
42 | 50 |
43 f_digits_train_data.close() | 51 f_digits_train_data.close() |
44 f_digits_train_labels.close() | 52 f_digits_train_labels.close() |
45 f_digits_test_data.close() | 53 f_digits_test_data.close() |
46 f_digits_test_labels.close() | 54 f_digits_test_labels.close() |
47 | 55 |
48 f_lower_train_data.close() | 56 f_lower_train_data.close() |
49 f_lower_train_labels.close() | 57 f_lower_train_labels.close() |
50 #f_upper_train_data.close() | 58 f_upper_train_data.close() |
51 #f_upper_train_labels.close() | 59 f_upper_train_labels.close() |
60 | |
61 f_test_data.close() | |
62 f_test_labels.close() | |
52 | 63 |
53 print 'Data opened' | 64 print 'Data opened' |
54 | 65 |
55 def set_batches(self, start_ratio = -1, end_ratio = -1, batch_size = 20, verbose = False): | 66 def set_batches(self, start_ratio = -1, end_ratio = -1, batch_size = 20, verbose = False): |
56 self.batch_size = batch_size | 67 self.batch_size = batch_size |
57 | 68 |
58 digits_train_size = len(self.raw_digits_train_labels) | 69 digits_train_size = len(self.raw_digits_train_labels) |
59 digits_test_size = len(self.raw_digits_test_labels) | 70 digits_test_size = len(self.raw_digits_test_labels) |
60 | 71 |
61 lower_train_size = len(self.raw_lower_train_labels) | 72 lower_train_size = len(self.raw_lower_train_labels) |
62 #upper_train_size = len(self.raw_upper_train_labels) | 73 upper_train_size = len(self.raw_upper_train_labels) |
63 | 74 |
64 if verbose == True: | 75 if verbose == True: |
65 print 'digits_train_size = %d' %digits_train_size | 76 print 'digits_train_size = %d' %digits_train_size |
66 print 'digits_test_size = %d' %digits_test_size | 77 print 'digits_test_size = %d' %digits_test_size |
67 print 'lower_train_size = %d' %lower_train_size | 78 print 'lower_train_size = %d' %lower_train_size |
68 #print 'upper_train_size = %d' %upper_train_size | 79 print 'upper_train_size = %d' %upper_train_size |
69 | 80 |
70 # define main and other datasets | 81 # define main and other datasets |
71 raw_main_train_data = self.raw_digits_train_data | 82 raw_main_train_data = self.raw_digits_train_data |
72 raw_other_train_data = self.raw_lower_train_labels | 83 raw_other_train_data1 = self.raw_lower_train_labels |
84 raw_other_train_data2 = self.raw_upper_train_labels | |
73 raw_test_data = self.raw_digits_test_data | 85 raw_test_data = self.raw_digits_test_data |
86 #raw_test_data = self.raw_test_data | |
74 | 87 |
75 raw_main_train_labels = self.raw_digits_train_labels | 88 raw_main_train_labels = self.raw_digits_train_labels |
76 raw_other_train_labels = self.raw_lower_train_labels | 89 raw_other_train_labels1 = self.raw_lower_train_labels |
90 raw_other_train_labels2 = self.raw_upper_train_labels | |
77 raw_test_labels = self.raw_digits_test_labels | 91 raw_test_labels = self.raw_digits_test_labels |
78 | 92 #raw_test_labels = self.raw_test_labels |
79 main_train_size = len(raw_main_train_data) | 93 |
80 other_train_size = len(raw_other_train_data) | 94 main_train_size = len(raw_main_train_labels) |
95 other_train_size1 = len(raw_other_train_labels1) | |
96 other_train_size2 = len(raw_other_train_labels2) | |
97 other_train_size = other_train_size1 + other_train_size2 | |
98 | |
81 test_size = len(raw_test_labels) | 99 test_size = len(raw_test_labels) |
82 test_size = int(test_size/batch_size) | 100 test_size = int(test_size/batch_size) |
83 test_size *= batch_size | 101 test_size *= batch_size |
84 validation_size = test_size | 102 validation_size = test_size |
85 | 103 |
86 # default ratio is actual ratio | 104 # default ratio is actual ratio |
87 if start_ratio == -1: | 105 if start_ratio == -1: |
88 self.start_ratio = float(main_train_size) / float(main_train_size + other_train_size) | 106 self.start_ratio = float(main_train_size - test_size) / float(main_train_size + other_train_size) |
89 else: | 107 else: |
90 self.start_ratio = start_ratio | 108 self.start_ratio = start_ratio |
91 | 109 |
92 if start_ratio == -1: | 110 if start_ratio == -1: |
93 self.end_ratio = float(main_train_size) / float(main_train_size + other_train_size) | 111 self.end_ratio = float(main_train_size - test_size) / float(main_train_size + other_train_size) |
94 else: | 112 else: |
95 self.end_ratio = end_ratio | 113 self.end_ratio = end_ratio |
96 | 114 |
97 if verbose == True: | 115 if verbose == True: |
98 print 'start_ratio = %f' %self.start_ratio | 116 print 'start_ratio = %f' %self.start_ratio |
99 print 'end_ratio = %f' %self.end_ratio | 117 print 'end_ratio = %f' %self.end_ratio |
100 | 118 |
101 i_main = 0 | 119 i_main = 0 |
102 i_other = 0 | 120 i_other1 = 0 |
121 i_other2 = 0 | |
103 i_batch = 0 | 122 i_batch = 0 |
104 | 123 |
105 # compute the number of batches given start and end ratios | 124 # compute the number of batches given start and end ratios |
106 n_main_batch = (main_train_size - batch_size * (self.end_ratio - self.start_ratio) / 2 ) / (batch_size * (self.start_ratio + (self.end_ratio - self.start_ratio) / 2)) | 125 n_main_batch = (main_train_size - test_size - batch_size * (self.end_ratio - self.start_ratio) / 2 ) / (batch_size * (self.start_ratio + (self.end_ratio - self.start_ratio) / 2)) |
107 n_other_batch = (other_train_size - batch_size * (self.end_ratio - self.start_ratio) / 2 ) / (batch_size - batch_size * (self.start_ratio + (self.end_ratio - self.start_ratio) / 2)) | 126 if (batch_size != batch_size * (self.start_ratio + (self.end_ratio - self.start_ratio) / 2)): |
127 n_other_batch = (other_train_size - batch_size * (self.end_ratio - self.start_ratio) / 2 ) / (batch_size - batch_size * (self.start_ratio + (self.end_ratio - self.start_ratio) / 2)) | |
128 else: | |
129 n_other_batch = n_main_batch | |
130 | |
108 n_batches = min([n_main_batch, n_other_batch]) | 131 n_batches = min([n_main_batch, n_other_batch]) |
109 | 132 |
110 # train batches | 133 # train batches |
111 self.train_batches = [] | 134 self.train_batches = [] |
112 | 135 |
113 # as long as we have data left in main and other, we create batches | 136 # as long as we have data left in main and other, we create batches |
114 while i_main < main_train_size - batch_size - test_size and i_other < other_train_size - batch_size: | 137 while i_main < main_train_size - batch_size - test_size and i_other1 < other_train_size1 - batch_size and i_other2 < other_train_size2 - batch_size: |
115 | |
116 ratio = self.start_ratio + i_batch * (self.end_ratio - self.start_ratio) / n_batches | 138 ratio = self.start_ratio + i_batch * (self.end_ratio - self.start_ratio) / n_batches |
117 batch_data = raw_main_train_data[0:self.batch_size] | 139 batch_data = copy(raw_main_train_data[0:self.batch_size]) |
118 batch_labels = raw_main_train_labels[0:self.batch_size] | 140 batch_labels = copy(raw_main_train_labels[0:self.batch_size]) |
119 | 141 |
120 for i in xrange(0, self.batch_size): # randomly choose between main and other, given the current ratio | 142 for i in xrange(0, self.batch_size): # randomly choose between main and other, given the current ratio |
121 rnd = random.randint(0, 100) | 143 rnd1 = random.randint(0, 100) |
122 | 144 |
123 if rnd < 100 * ratio: | 145 if rnd1 < 100 * ratio: |
124 batch_data[i] = raw_main_train_data[i_main] | 146 batch_data[i] = raw_main_train_data[i_main] |
125 batch_labels[i] = raw_main_train_labels[i_main] | 147 batch_labels[i] = raw_main_train_labels[i_main] |
126 i_main += 1 | 148 i_main += 1 |
127 else: | 149 else: |
128 batch_data[i] = raw_other_train_data[i_other] | 150 rnd2 = random.randint(0, 100) |
129 batch_labels[i] = raw_other_train_labels[i_other] - 26 #to put values between 10 and 35 for lower case | 151 |
130 i_other += 1 | 152 if rnd2 < 100 * float(other_train_size1) / float(other_train_size): |
153 batch_data[i] = raw_other_train_data1[i_other1] | |
154 batch_labels[i] = raw_other_train_labels1[i_other1] | |
155 i_other1 += 1 | |
156 else: | |
157 batch_data[i] = raw_other_train_data2[i_other2] | |
158 batch_labels[i] = raw_other_train_labels2[i_other2] | |
159 i_other2 += 1 | |
131 | 160 |
132 self.train_batches = self.train_batches + \ | 161 self.train_batches = self.train_batches + \ |
133 [(batch_data, batch_labels)] | 162 [(batch_data, batch_labels)] |
134 i_batch += 1 | 163 i_batch += 1 |
135 | 164 |
141 self.test_batches = self.test_batches + \ | 170 self.test_batches = self.test_batches + \ |
142 [(raw_test_data[i:i+batch_size], raw_test_labels[i:i+batch_size])] | 171 [(raw_test_data[i:i+batch_size], raw_test_labels[i:i+batch_size])] |
143 | 172 |
144 # validation batches | 173 # validation batches |
145 self.validation_batches = [] | 174 self.validation_batches = [] |
146 for i in xrange(0, test_size, batch_size): | 175 for i in xrange(0, validation_size, batch_size): |
147 self.validation_batches = self.validation_batches + \ | 176 self.validation_batches = self.validation_batches + \ |
148 [(raw_main_train_data[offset+i:offset+i+batch_size], raw_main_train_labels[offset+i:offset+i+batch_size])] | 177 [(raw_main_train_data[offset+i:offset+i+batch_size], raw_main_train_labels[offset+i:offset+i+batch_size])] |
149 | 178 |
150 if verbose == True: | 179 if verbose == True: |
151 print 'n_main = %d' %i_main | 180 print 'n_main = %d' %i_main |
152 print 'n_other = %d' %i_other | 181 print 'n_other1 = %d' %i_other1 |
182 print 'n_other2 = %d' %i_other2 | |
153 print 'nb_train_batches = %d / %d' %(i_batch,n_batches) | 183 print 'nb_train_batches = %d / %d' %(i_batch,n_batches) |
154 print 'offset = %d' %offset | 184 print 'offset = %d' %offset |
155 | 185 |
156 def get_train_batches(self): | 186 def get_train_batches(self): |
157 return self.train_batches | 187 return self.train_batches |