comparison scripts/setup_batches.py @ 272:f6d9b6b89c2a

ajouté : module de préparation de batches en fonction d'un ratio de classes
author Guillaume Sicard <guitch21@gmail.com>
date Mon, 22 Mar 2010 08:34:48 -0400
parents
children a6b6b1140de9
comparison
equal deleted inserted replaced
271:a92ec9939e4f 272:f6d9b6b89c2a
1 # -*- coding: utf-8 -*-
2
3 import random
4 from pylearn.io import filetensor as ft
5
6 class Batches():
7 def __init__(self):
8 data_path = '/data/lisa/data/nist/by_class/'
9
10 digits_train_data = 'digits/digits_train_data.ft'
11 digits_train_labels = 'digits/digits_train_labels.ft'
12 digits_test_data = 'digits/digits_test_data.ft'
13 digits_test_labels = 'digits/digits_test_labels.ft'
14
15 lower_train_data = 'lower/lower_train_data.ft'
16 lower_train_labels = 'lower/lower_train_labels.ft'
17 #upper_train_data = 'upper/upper_train_data.ft'
18 #upper_train_labels = 'upper/upper_train_labels.ft'
19
20 f_digits_train_data = open(data_path + digits_train_data)
21 f_digits_train_labels = open(data_path + digits_train_labels)
22 f_digits_test_data = open(data_path + digits_test_data)
23 f_digits_test_labels = open(data_path + digits_test_labels)
24
25 f_lower_train_data = open(data_path + lower_train_data)
26 f_lower_train_labels = open(data_path + lower_train_labels)
27 #f_upper_train_data = open(data_path + upper_train_data)
28 #f_upper_train_labels = open(data_path + upper_train_labels)
29
30 self.raw_digits_train_data = ft.read(f_digits_train_data)
31 self.raw_digits_train_labels = ft.read(f_digits_train_labels)
32 self.raw_digits_test_data = ft.read(f_digits_test_data)
33 self.raw_digits_test_labels = ft.read(f_digits_test_labels)
34
35 self.raw_lower_train_data = ft.read(f_lower_train_data)
36 self.raw_lower_train_labels = ft.read(f_lower_train_labels)
37 #self.raw_upper_train_data = ft.read(f_upper_train_data)
38 #self.raw_upper_train_labels = ft.read(f_upper_train_labels)
39
40 f_digits_train_data.close()
41 f_digits_train_labels.close()
42 f_digits_test_data.close()
43 f_digits_test_labels.close()
44
45 f_lower_train_data.close()
46 f_lower_train_labels.close()
47 #f_upper_train_data.close()
48 #f_upper_train_labels.close()
49
50 def set_batches(self, start_ratio = -1, end_ratio = -1, batch_size = 20, verbose = False):
51 self.batch_size = batch_size
52
53 digits_train_size = len(self.raw_digits_train_labels)
54 digits_test_size = len(self.raw_digits_test_labels)
55
56 lower_train_size = len(self.raw_lower_train_labels)
57 #upper_train_size = len(self.raw_upper_train_labels)
58
59 if verbose == True:
60 print 'digits_train_size = %d' %digits_train_size
61 print 'digits_test_size = %d' %digits_test_size
62 print 'lower_train_size = %d' %lower_train_size
63 #print 'upper_train_size = %d' %upper_train_size
64
65 # define main and other datasets
66 raw_main_train_data = self.raw_digits_train_data
67 raw_other_train_data = self.raw_lower_train_labels
68 raw_test_data = self.raw_digits_test_labels
69
70 raw_main_train_labels = self.raw_digits_train_labels
71 raw_other_train_labels = self.raw_lower_train_labels
72 raw_test_labels = self.raw_digits_test_labels
73
74 main_train_size = len(raw_main_train_data)
75 other_train_size = len(raw_other_train_data)
76 test_size = len(raw_test_data)
77 test_size = int(test_size/batch_size)
78 test_size *= batch_size
79 validation_size = test_size
80
81 # default ratio is actual ratio
82 if start_ratio == -1:
83 self.start_ratio = float(main_train_size) / float(main_train_size + other_train_size)
84 else:
85 self.start_ratio = start_ratio
86
87 if start_ratio == -1:
88 self.end_ratio = float(main_train_size) / float(main_train_size + other_train_size)
89 else:
90 self.end_ratio = end_ratio
91
92 if verbose == True:
93 print 'start_ratio = %f' %self.start_ratio
94 print 'end_ratio = %f' %self.end_ratio
95
96 i_main = 0
97 i_other = 0
98 i_batch = 0
99
100 # compute the number of batches given start and end ratios
101 n_main_batch = (main_train_size - batch_size * (self.end_ratio - self.start_ratio) / 2 ) / (batch_size * (self.start_ratio + (self.end_ratio - self.start_ratio) / 2))
102 n_other_batch = (other_train_size - batch_size * (self.end_ratio - self.start_ratio) / 2 ) / (batch_size - batch_size * (self.start_ratio + (self.end_ratio - self.start_ratio) / 2))
103 n_batches = min([n_main_batch, n_other_batch])
104
105 # train batches
106 self.train_batches = []
107
108 # as long as we have data left in main and other, we create batches
109 while i_main < main_train_size - batch_size - test_size and i_other < other_train_size - batch_size:
110
111 ratio = self.start_ratio + i_batch * (self.end_ratio - self.start_ratio) / n_batches
112 batch_data = []
113 batch_labels = []
114
115 for i in xrange(0, self.batch_size): # randomly choose between main and other, given the current ratio
116 rnd = random.randint(0, 100)
117
118 if rnd < 100 * ratio:
119 batch_data = batch_data + \
120 [raw_main_train_data[i_main]]
121 batch_labels = batch_labels + \
122 [raw_main_train_labels[i_main]]
123 i_main += 1
124 else:
125 batch_data = batch_data + \
126 [raw_other_train_data[i_other]]
127 batch_labels = batch_labels + \
128 [raw_other_train_labels[i_other]]
129 i_other += 1
130
131 self.train_batches = self.train_batches + \
132 [(batch_data,batch_labels)]
133 i_batch += 1
134
135 offset = i_main
136
137 if verbose == True:
138 print 'n_main = %d' %i_main
139 print 'n_other = %d' %i_other
140 print 'nb_train_batches = %d / %d' %(i_batch,n_batches)
141 print 'offset = %d' %offset
142
143 # test batches
144 self.test_batches = []
145 for i in xrange(0, test_size, batch_size):
146 self.test_batches = self.test_batches + \
147 [(raw_test_data[i:i+batch_size], raw_test_labels[i:i+batch_size])]
148
149 # validation batches
150 self.validation_batches = []
151 for i in xrange(0, test_size, batch_size):
152 self.validation_batches = self.validation_batches + \
153 [(raw_main_train_data[offset+i:offset+i+batch_size], raw_main_train_labels[offset+i:offset+i+batch_size])]
154
155 def get_train_batches(self):
156 return self.train_batches
157
158 def get_test_batches(self):
159 return self.test_batches
160
161 def get_validation_batches(self):
162 return self.validation_batches
163
164 def test_set_batches(self, intervall = 1000):
165 for i in xrange(0, len(self.train_batches) - self.batch_size, intervall):
166 n_main = 0
167
168 for j in xrange(0, self.batch_size):
169 if self.train_batches[i][1][j] < 10:
170 n_main +=1
171 print 'ratio batch %d : %f' %(i,float(n_main) / float(self.batch_size))
172
173 if __name__ == '__main__':
174 batches = Batches()
175 batches.set_batches(0.5,1, 20, True)
176 batches.test_set_batches()