Mercurial > ift6266
comparison scripts/setup_batches.py @ 295:a6b6b1140de9
modifié setup_batches.py pour compatibilité avec mlp_nist.py
author | Guillaume Sicard <guitch21@gmail.com> |
---|---|
date | Mon, 29 Mar 2010 09:18:54 -0400 |
parents | f6d9b6b89c2a |
children | 5b260cc8f477 |
comparison
equal
deleted
inserted
replaced
294:8babd43235dd | 295:a6b6b1140de9 |
---|---|
1 # -*- coding: utf-8 -*- | 1 # -*- coding: utf-8 -*- |
2 | 2 |
3 import random | 3 import random |
4 from numpy import * | |
4 from pylearn.io import filetensor as ft | 5 from pylearn.io import filetensor as ft |
5 | 6 |
6 class Batches(): | 7 class Batches(): |
7 def __init__(self): | 8 def __init__(self): |
8 data_path = '/data/lisa/data/nist/by_class/' | 9 data_path = '/data/lisa/data/nist/by_class/' |
14 | 15 |
15 lower_train_data = 'lower/lower_train_data.ft' | 16 lower_train_data = 'lower/lower_train_data.ft' |
16 lower_train_labels = 'lower/lower_train_labels.ft' | 17 lower_train_labels = 'lower/lower_train_labels.ft' |
17 #upper_train_data = 'upper/upper_train_data.ft' | 18 #upper_train_data = 'upper/upper_train_data.ft' |
18 #upper_train_labels = 'upper/upper_train_labels.ft' | 19 #upper_train_labels = 'upper/upper_train_labels.ft' |
20 | |
21 print 'Opening data...' | |
19 | 22 |
20 f_digits_train_data = open(data_path + digits_train_data) | 23 f_digits_train_data = open(data_path + digits_train_data) |
21 f_digits_train_labels = open(data_path + digits_train_labels) | 24 f_digits_train_labels = open(data_path + digits_train_labels) |
22 f_digits_test_data = open(data_path + digits_test_data) | 25 f_digits_test_data = open(data_path + digits_test_data) |
23 f_digits_test_labels = open(data_path + digits_test_labels) | 26 f_digits_test_labels = open(data_path + digits_test_labels) |
45 f_lower_train_data.close() | 48 f_lower_train_data.close() |
46 f_lower_train_labels.close() | 49 f_lower_train_labels.close() |
47 #f_upper_train_data.close() | 50 #f_upper_train_data.close() |
48 #f_upper_train_labels.close() | 51 #f_upper_train_labels.close() |
49 | 52 |
53 print 'Data opened' | |
54 | |
50 def set_batches(self, start_ratio = -1, end_ratio = -1, batch_size = 20, verbose = False): | 55 def set_batches(self, start_ratio = -1, end_ratio = -1, batch_size = 20, verbose = False): |
51 self.batch_size = batch_size | 56 self.batch_size = batch_size |
52 | 57 |
53 digits_train_size = len(self.raw_digits_train_labels) | 58 digits_train_size = len(self.raw_digits_train_labels) |
54 digits_test_size = len(self.raw_digits_test_labels) | 59 digits_test_size = len(self.raw_digits_test_labels) |
63 #print 'upper_train_size = %d' %upper_train_size | 68 #print 'upper_train_size = %d' %upper_train_size |
64 | 69 |
65 # define main and other datasets | 70 # define main and other datasets |
66 raw_main_train_data = self.raw_digits_train_data | 71 raw_main_train_data = self.raw_digits_train_data |
67 raw_other_train_data = self.raw_lower_train_labels | 72 raw_other_train_data = self.raw_lower_train_labels |
68 raw_test_data = self.raw_digits_test_labels | 73 raw_test_data = self.raw_digits_test_data |
69 | 74 |
70 raw_main_train_labels = self.raw_digits_train_labels | 75 raw_main_train_labels = self.raw_digits_train_labels |
71 raw_other_train_labels = self.raw_lower_train_labels | 76 raw_other_train_labels = self.raw_lower_train_labels |
72 raw_test_labels = self.raw_digits_test_labels | 77 raw_test_labels = self.raw_digits_test_labels |
73 | 78 |
74 main_train_size = len(raw_main_train_data) | 79 main_train_size = len(raw_main_train_data) |
75 other_train_size = len(raw_other_train_data) | 80 other_train_size = len(raw_other_train_data) |
76 test_size = len(raw_test_data) | 81 test_size = len(raw_test_labels) |
77 test_size = int(test_size/batch_size) | 82 test_size = int(test_size/batch_size) |
78 test_size *= batch_size | 83 test_size *= batch_size |
79 validation_size = test_size | 84 validation_size = test_size |
80 | 85 |
81 # default ratio is actual ratio | 86 # default ratio is actual ratio |
107 | 112 |
108 # as long as we have data left in main and other, we create batches | 113 # as long as we have data left in main and other, we create batches |
109 while i_main < main_train_size - batch_size - test_size and i_other < other_train_size - batch_size: | 114 while i_main < main_train_size - batch_size - test_size and i_other < other_train_size - batch_size: |
110 | 115 |
111 ratio = self.start_ratio + i_batch * (self.end_ratio - self.start_ratio) / n_batches | 116 ratio = self.start_ratio + i_batch * (self.end_ratio - self.start_ratio) / n_batches |
112 batch_data = [] | 117 batch_data = raw_main_train_data[0:self.batch_size] |
113 batch_labels = [] | 118 batch_labels = raw_main_train_labels[0:self.batch_size] |
114 | 119 |
115 for i in xrange(0, self.batch_size): # randomly choose between main and other, given the current ratio | 120 for i in xrange(0, self.batch_size): # randomly choose between main and other, given the current ratio |
116 rnd = random.randint(0, 100) | 121 rnd = random.randint(0, 100) |
117 | 122 |
118 if rnd < 100 * ratio: | 123 if rnd < 100 * ratio: |
119 batch_data = batch_data + \ | 124 batch_data[i] = raw_main_train_data[i_main] |
120 [raw_main_train_data[i_main]] | 125 batch_labels[i] = raw_main_train_labels[i_main] |
121 batch_labels = batch_labels + \ | |
122 [raw_main_train_labels[i_main]] | |
123 i_main += 1 | 126 i_main += 1 |
124 else: | 127 else: |
125 batch_data = batch_data + \ | 128 batch_data[i] = raw_other_train_data[i_other] |
126 [raw_other_train_data[i_other]] | 129 batch_labels[i] = raw_other_train_labels[i_other] - 26 #to put values between 10 and 35 for lower case |
127 batch_labels = batch_labels + \ | |
128 [raw_other_train_labels[i_other]] | |
129 i_other += 1 | 130 i_other += 1 |
130 | 131 |
131 self.train_batches = self.train_batches + \ | 132 self.train_batches = self.train_batches + \ |
132 [(batch_data,batch_labels)] | 133 [(batch_data, batch_labels)] |
133 i_batch += 1 | 134 i_batch += 1 |
134 | 135 |
135 offset = i_main | 136 offset = i_main |
136 | |
137 if verbose == True: | |
138 print 'n_main = %d' %i_main | |
139 print 'n_other = %d' %i_other | |
140 print 'nb_train_batches = %d / %d' %(i_batch,n_batches) | |
141 print 'offset = %d' %offset | |
142 | 137 |
143 # test batches | 138 # test batches |
144 self.test_batches = [] | 139 self.test_batches = [] |
145 for i in xrange(0, test_size, batch_size): | 140 for i in xrange(0, test_size, batch_size): |
146 self.test_batches = self.test_batches + \ | 141 self.test_batches = self.test_batches + \ |
149 # validation batches | 144 # validation batches |
150 self.validation_batches = [] | 145 self.validation_batches = [] |
151 for i in xrange(0, test_size, batch_size): | 146 for i in xrange(0, test_size, batch_size): |
152 self.validation_batches = self.validation_batches + \ | 147 self.validation_batches = self.validation_batches + \ |
153 [(raw_main_train_data[offset+i:offset+i+batch_size], raw_main_train_labels[offset+i:offset+i+batch_size])] | 148 [(raw_main_train_data[offset+i:offset+i+batch_size], raw_main_train_labels[offset+i:offset+i+batch_size])] |
149 | |
150 if verbose == True: | |
151 print 'n_main = %d' %i_main | |
152 print 'n_other = %d' %i_other | |
153 print 'nb_train_batches = %d / %d' %(i_batch,n_batches) | |
154 print 'offset = %d' %offset | |
154 | 155 |
155 def get_train_batches(self): | 156 def get_train_batches(self): |
156 return self.train_batches | 157 return self.train_batches |
157 | 158 |
158 def get_test_batches(self): | 159 def get_test_batches(self): |