Mercurial > ift6266
comparison scripts/setup_batches.py @ 275:7b4507295eba
merge
author | Xavier Glorot <glorotxa@iro.umontreal.ca> |
---|---|
date | Mon, 22 Mar 2010 10:20:10 -0400 |
parents | f6d9b6b89c2a |
children | a6b6b1140de9 |
comparison
equal
deleted
inserted
replaced
274:44409b6652aa | 275:7b4507295eba |
---|---|
1 # -*- coding: utf-8 -*- | |
2 | |
3 import random | |
4 from pylearn.io import filetensor as ft | |
5 | |
6 class Batches(): | |
7 def __init__(self): | |
8 data_path = '/data/lisa/data/nist/by_class/' | |
9 | |
10 digits_train_data = 'digits/digits_train_data.ft' | |
11 digits_train_labels = 'digits/digits_train_labels.ft' | |
12 digits_test_data = 'digits/digits_test_data.ft' | |
13 digits_test_labels = 'digits/digits_test_labels.ft' | |
14 | |
15 lower_train_data = 'lower/lower_train_data.ft' | |
16 lower_train_labels = 'lower/lower_train_labels.ft' | |
17 #upper_train_data = 'upper/upper_train_data.ft' | |
18 #upper_train_labels = 'upper/upper_train_labels.ft' | |
19 | |
20 f_digits_train_data = open(data_path + digits_train_data) | |
21 f_digits_train_labels = open(data_path + digits_train_labels) | |
22 f_digits_test_data = open(data_path + digits_test_data) | |
23 f_digits_test_labels = open(data_path + digits_test_labels) | |
24 | |
25 f_lower_train_data = open(data_path + lower_train_data) | |
26 f_lower_train_labels = open(data_path + lower_train_labels) | |
27 #f_upper_train_data = open(data_path + upper_train_data) | |
28 #f_upper_train_labels = open(data_path + upper_train_labels) | |
29 | |
30 self.raw_digits_train_data = ft.read(f_digits_train_data) | |
31 self.raw_digits_train_labels = ft.read(f_digits_train_labels) | |
32 self.raw_digits_test_data = ft.read(f_digits_test_data) | |
33 self.raw_digits_test_labels = ft.read(f_digits_test_labels) | |
34 | |
35 self.raw_lower_train_data = ft.read(f_lower_train_data) | |
36 self.raw_lower_train_labels = ft.read(f_lower_train_labels) | |
37 #self.raw_upper_train_data = ft.read(f_upper_train_data) | |
38 #self.raw_upper_train_labels = ft.read(f_upper_train_labels) | |
39 | |
40 f_digits_train_data.close() | |
41 f_digits_train_labels.close() | |
42 f_digits_test_data.close() | |
43 f_digits_test_labels.close() | |
44 | |
45 f_lower_train_data.close() | |
46 f_lower_train_labels.close() | |
47 #f_upper_train_data.close() | |
48 #f_upper_train_labels.close() | |
49 | |
50 def set_batches(self, start_ratio = -1, end_ratio = -1, batch_size = 20, verbose = False): | |
51 self.batch_size = batch_size | |
52 | |
53 digits_train_size = len(self.raw_digits_train_labels) | |
54 digits_test_size = len(self.raw_digits_test_labels) | |
55 | |
56 lower_train_size = len(self.raw_lower_train_labels) | |
57 #upper_train_size = len(self.raw_upper_train_labels) | |
58 | |
59 if verbose == True: | |
60 print 'digits_train_size = %d' %digits_train_size | |
61 print 'digits_test_size = %d' %digits_test_size | |
62 print 'lower_train_size = %d' %lower_train_size | |
63 #print 'upper_train_size = %d' %upper_train_size | |
64 | |
65 # define main and other datasets | |
66 raw_main_train_data = self.raw_digits_train_data | |
67 raw_other_train_data = self.raw_lower_train_labels | |
68 raw_test_data = self.raw_digits_test_labels | |
69 | |
70 raw_main_train_labels = self.raw_digits_train_labels | |
71 raw_other_train_labels = self.raw_lower_train_labels | |
72 raw_test_labels = self.raw_digits_test_labels | |
73 | |
74 main_train_size = len(raw_main_train_data) | |
75 other_train_size = len(raw_other_train_data) | |
76 test_size = len(raw_test_data) | |
77 test_size = int(test_size/batch_size) | |
78 test_size *= batch_size | |
79 validation_size = test_size | |
80 | |
81 # default ratio is actual ratio | |
82 if start_ratio == -1: | |
83 self.start_ratio = float(main_train_size) / float(main_train_size + other_train_size) | |
84 else: | |
85 self.start_ratio = start_ratio | |
86 | |
87 if start_ratio == -1: | |
88 self.end_ratio = float(main_train_size) / float(main_train_size + other_train_size) | |
89 else: | |
90 self.end_ratio = end_ratio | |
91 | |
92 if verbose == True: | |
93 print 'start_ratio = %f' %self.start_ratio | |
94 print 'end_ratio = %f' %self.end_ratio | |
95 | |
96 i_main = 0 | |
97 i_other = 0 | |
98 i_batch = 0 | |
99 | |
100 # compute the number of batches given start and end ratios | |
101 n_main_batch = (main_train_size - batch_size * (self.end_ratio - self.start_ratio) / 2 ) / (batch_size * (self.start_ratio + (self.end_ratio - self.start_ratio) / 2)) | |
102 n_other_batch = (other_train_size - batch_size * (self.end_ratio - self.start_ratio) / 2 ) / (batch_size - batch_size * (self.start_ratio + (self.end_ratio - self.start_ratio) / 2)) | |
103 n_batches = min([n_main_batch, n_other_batch]) | |
104 | |
105 # train batches | |
106 self.train_batches = [] | |
107 | |
108 # as long as we have data left in main and other, we create batches | |
109 while i_main < main_train_size - batch_size - test_size and i_other < other_train_size - batch_size: | |
110 | |
111 ratio = self.start_ratio + i_batch * (self.end_ratio - self.start_ratio) / n_batches | |
112 batch_data = [] | |
113 batch_labels = [] | |
114 | |
115 for i in xrange(0, self.batch_size): # randomly choose between main and other, given the current ratio | |
116 rnd = random.randint(0, 100) | |
117 | |
118 if rnd < 100 * ratio: | |
119 batch_data = batch_data + \ | |
120 [raw_main_train_data[i_main]] | |
121 batch_labels = batch_labels + \ | |
122 [raw_main_train_labels[i_main]] | |
123 i_main += 1 | |
124 else: | |
125 batch_data = batch_data + \ | |
126 [raw_other_train_data[i_other]] | |
127 batch_labels = batch_labels + \ | |
128 [raw_other_train_labels[i_other]] | |
129 i_other += 1 | |
130 | |
131 self.train_batches = self.train_batches + \ | |
132 [(batch_data,batch_labels)] | |
133 i_batch += 1 | |
134 | |
135 offset = i_main | |
136 | |
137 if verbose == True: | |
138 print 'n_main = %d' %i_main | |
139 print 'n_other = %d' %i_other | |
140 print 'nb_train_batches = %d / %d' %(i_batch,n_batches) | |
141 print 'offset = %d' %offset | |
142 | |
143 # test batches | |
144 self.test_batches = [] | |
145 for i in xrange(0, test_size, batch_size): | |
146 self.test_batches = self.test_batches + \ | |
147 [(raw_test_data[i:i+batch_size], raw_test_labels[i:i+batch_size])] | |
148 | |
149 # validation batches | |
150 self.validation_batches = [] | |
151 for i in xrange(0, test_size, batch_size): | |
152 self.validation_batches = self.validation_batches + \ | |
153 [(raw_main_train_data[offset+i:offset+i+batch_size], raw_main_train_labels[offset+i:offset+i+batch_size])] | |
154 | |
155 def get_train_batches(self): | |
156 return self.train_batches | |
157 | |
158 def get_test_batches(self): | |
159 return self.test_batches | |
160 | |
161 def get_validation_batches(self): | |
162 return self.validation_batches | |
163 | |
164 def test_set_batches(self, intervall = 1000): | |
165 for i in xrange(0, len(self.train_batches) - self.batch_size, intervall): | |
166 n_main = 0 | |
167 | |
168 for j in xrange(0, self.batch_size): | |
169 if self.train_batches[i][1][j] < 10: | |
170 n_main +=1 | |
171 print 'ratio batch %d : %f' %(i,float(n_main) / float(self.batch_size)) | |
172 | |
173 if __name__ == '__main__': | |
174 batches = Batches() | |
175 batches.set_batches(0.5,1, 20, True) | |
176 batches.test_set_batches() |