Mercurial > ift6266
comparison scripts/setup_batches.py @ 356:b0741ea3ff6f
Extension du choix de la classe principale pour les batches d'entrainement
author | Guillaume Sicard <guitch21@gmail.com> |
---|---|
date | Wed, 21 Apr 2010 23:47:50 -0400 |
parents | 5b260cc8f477 |
children | 22919039f7ab |
comparison
equal
deleted
inserted
replaced
355:76b7182dd32e | 356:b0741ea3ff6f |
---|---|
13 digits_test_data = 'digits/digits_test_data.ft' | 13 digits_test_data = 'digits/digits_test_data.ft' |
14 digits_test_labels = 'digits/digits_test_labels.ft' | 14 digits_test_labels = 'digits/digits_test_labels.ft' |
15 | 15 |
16 lower_train_data = 'lower/lower_train_data.ft' | 16 lower_train_data = 'lower/lower_train_data.ft' |
17 lower_train_labels = 'lower/lower_train_labels.ft' | 17 lower_train_labels = 'lower/lower_train_labels.ft' |
18 lower_test_data = 'lower/lower_test_data.ft' | |
19 lower_test_labels = 'lower/lower_test_labels.ft' | |
20 | |
18 upper_train_data = 'upper/upper_train_data.ft' | 21 upper_train_data = 'upper/upper_train_data.ft' |
19 upper_train_labels = 'upper/upper_train_labels.ft' | 22 upper_train_labels = 'upper/upper_train_labels.ft' |
23 upper_test_data = 'upper/upper_test_data.ft' | |
24 upper_test_labels = 'upper/upper_test_labels.ft' | |
25 | |
20 test_data = 'all/all_test_data.ft' | 26 test_data = 'all/all_test_data.ft' |
21 test_labels = 'all/all_test_labels.ft' | 27 test_labels = 'all/all_test_labels.ft' |
22 | 28 |
23 print 'Opening data...' | 29 print 'Opening data...' |
24 | 30 |
27 f_digits_test_data = open(data_path + digits_test_data) | 33 f_digits_test_data = open(data_path + digits_test_data) |
28 f_digits_test_labels = open(data_path + digits_test_labels) | 34 f_digits_test_labels = open(data_path + digits_test_labels) |
29 | 35 |
30 f_lower_train_data = open(data_path + lower_train_data) | 36 f_lower_train_data = open(data_path + lower_train_data) |
31 f_lower_train_labels = open(data_path + lower_train_labels) | 37 f_lower_train_labels = open(data_path + lower_train_labels) |
38 f_lower_test_data = open(data_path + lower_test_data) | |
39 f_lower_test_labels = open(data_path + lower_test_labels) | |
40 | |
32 f_upper_train_data = open(data_path + upper_train_data) | 41 f_upper_train_data = open(data_path + upper_train_data) |
33 f_upper_train_labels = open(data_path + upper_train_labels) | 42 f_upper_train_labels = open(data_path + upper_train_labels) |
34 | 43 f_upper_test_data = open(data_path + upper_test_data) |
35 f_test_data = open(data_path + test_data) | 44 f_upper_test_labels = open(data_path + upper_test_labels) |
36 f_test_labels = open(data_path + test_labels) | 45 |
46 #f_test_data = open(data_path + test_data) | |
47 #f_test_labels = open(data_path + test_labels) | |
37 | 48 |
38 self.raw_digits_train_data = ft.read(f_digits_train_data) | 49 self.raw_digits_train_data = ft.read(f_digits_train_data) |
39 self.raw_digits_train_labels = ft.read(f_digits_train_labels) | 50 self.raw_digits_train_labels = ft.read(f_digits_train_labels) |
40 self.raw_digits_test_data = ft.read(f_digits_test_data) | 51 self.raw_digits_test_data = ft.read(f_digits_test_data) |
41 self.raw_digits_test_labels = ft.read(f_digits_test_labels) | 52 self.raw_digits_test_labels = ft.read(f_digits_test_labels) |
42 | 53 |
43 self.raw_lower_train_data = ft.read(f_lower_train_data) | 54 self.raw_lower_train_data = ft.read(f_lower_train_data) |
44 self.raw_lower_train_labels = ft.read(f_lower_train_labels) | 55 self.raw_lower_train_labels = ft.read(f_lower_train_labels) |
56 self.raw_lower_test_data = ft.read(f_lower_test_data) | |
57 self.raw_lower_test_labels = ft.read(f_lower_test_labels) | |
58 | |
45 self.raw_upper_train_data = ft.read(f_upper_train_data) | 59 self.raw_upper_train_data = ft.read(f_upper_train_data) |
46 self.raw_upper_train_labels = ft.read(f_upper_train_labels) | 60 self.raw_upper_train_labels = ft.read(f_upper_train_labels) |
47 | 61 self.raw_upper_test_data = ft.read(f_upper_test_data) |
48 self.raw_test_data = ft.read(f_test_data) | 62 self.raw_upper_test_labels = ft.read(f_upper_test_labels) |
49 self.raw_test_labels = ft.read(f_test_labels) | 63 |
64 #self.raw_test_data = ft.read(f_test_data) | |
65 #self.raw_test_labels = ft.read(f_test_labels) | |
50 | 66 |
51 f_digits_train_data.close() | 67 f_digits_train_data.close() |
52 f_digits_train_labels.close() | 68 f_digits_train_labels.close() |
53 f_digits_test_data.close() | 69 f_digits_test_data.close() |
54 f_digits_test_labels.close() | 70 f_digits_test_labels.close() |
55 | 71 |
56 f_lower_train_data.close() | 72 f_lower_train_data.close() |
57 f_lower_train_labels.close() | 73 f_lower_train_labels.close() |
74 f_lower_test_data.close() | |
75 f_lower_test_labels.close() | |
76 | |
58 f_upper_train_data.close() | 77 f_upper_train_data.close() |
59 f_upper_train_labels.close() | 78 f_upper_train_labels.close() |
60 | 79 f_upper_test_data.close() |
61 f_test_data.close() | 80 f_upper_test_labels.close() |
62 f_test_labels.close() | 81 |
82 #f_test_data.close() | |
83 #f_test_labels.close() | |
63 | 84 |
64 print 'Data opened' | 85 print 'Data opened' |
65 | 86 |
66 def set_batches(self, start_ratio = -1, end_ratio = -1, batch_size = 20, verbose = False): | 87 def set_batches(self, main_class = "d", start_ratio = -1, end_ratio = -1, batch_size = 20, verbose = False): |
67 self.batch_size = batch_size | 88 self.batch_size = batch_size |
68 | 89 |
69 digits_train_size = len(self.raw_digits_train_labels) | 90 digits_train_size = len(self.raw_digits_train_labels) |
70 digits_test_size = len(self.raw_digits_test_labels) | 91 digits_test_size = len(self.raw_digits_test_labels) |
71 | 92 |
72 lower_train_size = len(self.raw_lower_train_labels) | 93 lower_train_size = len(self.raw_lower_train_labels) |
94 | |
73 upper_train_size = len(self.raw_upper_train_labels) | 95 upper_train_size = len(self.raw_upper_train_labels) |
96 upper_test_size = len(self.raw_upper_test_labels) | |
74 | 97 |
75 if verbose == True: | 98 if verbose == True: |
76 print 'digits_train_size = %d' %digits_train_size | 99 print 'digits_train_size = %d' %digits_train_size |
77 print 'digits_test_size = %d' %digits_test_size | 100 print 'digits_test_size = %d' %digits_test_size |
78 print 'lower_train_size = %d' %lower_train_size | 101 print 'lower_train_size = %d' %lower_train_size |
79 print 'upper_train_size = %d' %upper_train_size | 102 print 'upper_train_size = %d' %upper_train_size |
80 | 103 print 'upper_test_size = %d' %upper_test_size |
81 # define main and other datasets | 104 |
82 raw_main_train_data = self.raw_digits_train_data | 105 if main_class == "u": |
83 raw_other_train_data1 = self.raw_lower_train_labels | 106 # define main and other datasets |
84 raw_other_train_data2 = self.raw_upper_train_labels | 107 raw_main_train_data = self.raw_upper_train_data |
85 raw_test_data = self.raw_digits_test_data | 108 raw_other_train_data1 = self.raw_lower_train_labels |
86 #raw_test_data = self.raw_test_data | 109 raw_other_train_data2 = self.raw_digits_train_labels |
87 | 110 raw_test_data = self.raw_upper_test_data |
88 raw_main_train_labels = self.raw_digits_train_labels | 111 |
89 raw_other_train_labels1 = self.raw_lower_train_labels | 112 raw_main_train_labels = self.raw_upper_train_labels |
90 raw_other_train_labels2 = self.raw_upper_train_labels | 113 raw_other_train_labels1 = self.raw_lower_train_labels |
91 raw_test_labels = self.raw_digits_test_labels | 114 raw_other_train_labels2 = self.raw_digits_train_labels |
92 #raw_test_labels = self.raw_test_labels | 115 raw_test_labels = self.raw_upper_test_labels |
116 | |
117 elif main_class == "l": | |
118 # define main and other datasets | |
119 raw_main_train_data = self.raw_lower_train_data | |
120 raw_other_train_data1 = self.raw_upper_train_labels | |
121 raw_other_train_data2 = self.raw_digits_train_labels | |
122 raw_test_data = self.raw_lower_test_data | |
123 | |
124 raw_main_train_labels = self.raw_lower_train_labels | |
125 raw_other_train_labels1 = self.raw_upper_train_labels | |
126 raw_other_train_labels2 = self.raw_digits_train_labels | |
127 raw_test_labels = self.raw_lower_test_labels | |
128 | |
129 else: | |
130 main_class = "d" | |
131 # define main and other datasets | |
132 raw_main_train_data = self.raw_digits_train_data | |
133 raw_other_train_data1 = self.raw_lower_train_labels | |
134 raw_other_train_data2 = self.raw_upper_train_labels | |
135 raw_test_data = self.raw_digits_test_data | |
136 | |
137 raw_main_train_labels = self.raw_digits_train_labels | |
138 raw_other_train_labels1 = self.raw_lower_train_labels | |
139 raw_other_train_labels2 = self.raw_upper_train_labels | |
140 raw_test_labels = self.raw_digits_test_labels | |
93 | 141 |
94 main_train_size = len(raw_main_train_labels) | 142 main_train_size = len(raw_main_train_labels) |
95 other_train_size1 = len(raw_other_train_labels1) | 143 other_train_size1 = len(raw_other_train_labels1) |
96 other_train_size2 = len(raw_other_train_labels2) | 144 other_train_size2 = len(raw_other_train_labels2) |
97 other_train_size = other_train_size1 + other_train_size2 | 145 other_train_size = other_train_size1 + other_train_size2 |
111 self.end_ratio = float(main_train_size - test_size) / float(main_train_size + other_train_size) | 159 self.end_ratio = float(main_train_size - test_size) / float(main_train_size + other_train_size) |
112 else: | 160 else: |
113 self.end_ratio = end_ratio | 161 self.end_ratio = end_ratio |
114 | 162 |
115 if verbose == True: | 163 if verbose == True: |
164 print 'main class : %s' %main_class | |
116 print 'start_ratio = %f' %self.start_ratio | 165 print 'start_ratio = %f' %self.start_ratio |
117 print 'end_ratio = %f' %self.end_ratio | 166 print 'end_ratio = %f' %self.end_ratio |
118 | 167 |
119 i_main = 0 | 168 i_main = 0 |
120 i_other1 = 0 | 169 i_other1 = 0 |