Mercurial > pylearn
comparison _test_dataset.py @ 221:58e17421c69c
tester on iterator consistency now triggers a bug in dataset, linked to the combination of minibatch and slicing
author | Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca> |
---|---|
date | Fri, 23 May 2008 14:07:53 -0400 |
parents | 1f527fe65e22 |
children | 174374d59405 |
comparison
equal
deleted
inserted
replaced
220:1f527fe65e22 | 221:58e17421c69c |
---|---|
110 self.ds = ds | 110 self.ds = ds |
111 | 111 |
112 if runall : | 112 if runall : |
113 self.test1_basicstats(ds) | 113 self.test1_basicstats(ds) |
114 self.test2_slicing(ds) | 114 self.test2_slicing(ds) |
115 self.test3_fields_iterator_consistency(ds) | |
115 | 116 |
116 def test1_basicstats(self,ds) : | 117 def test1_basicstats(self,ds) : |
117 """print basics stats on a dataset, like length""" | 118 """print basics stats on a dataset, like length""" |
118 | 119 |
119 print 'len(ds) = ',len(ds) | 120 print 'len(ds) = ',len(ds) |
137 if type(set1[middle-tenpercent+k](k2)[0]) == N.ndarray : | 138 if type(set1[middle-tenpercent+k](k2)[0]) == N.ndarray : |
138 for k3 in range(len(set1[middle-tenpercent+k](k2)[0])) : | 139 for k3 in range(len(set1[middle-tenpercent+k](k2)[0])) : |
139 assert set1[middle-tenpercent+k](k2)[0][k3] == set2[k](k2)[0][k3] | 140 assert set1[middle-tenpercent+k](k2)[0][k3] == set2[k](k2)[0][k3] |
140 else : | 141 else : |
141 assert set1[middle-tenpercent+k](k2)[0] == set2[k](k2)[0] | 142 assert set1[middle-tenpercent+k](k2)[0] == set2[k](k2)[0] |
143 assert tenpercent > 1 | |
144 set3 = ds[middle-tenpercent:middle+tenpercent:2] | |
145 for k2 in ds.fieldNames() : | |
146 if type(set2[2](k2)[0]) == N.ndarray : | |
147 for k3 in range(len(set2[2](k2)[0])) : | |
148 assert set2[2](k2)[0][k3] == set3[1](k2)[0][k3] | |
149 else : | |
150 assert set2[2](k2)[0] == set3[1](k2)[0] | |
142 | 151 |
143 print 'done' | 152 print 'done' |
153 | |
154 | |
155 def test3_fields_iterator_consistency(self,ds) : | |
156 """ check if the number of iterator corresponds to the number of fields""" | |
157 print 'testing fields/iterator consistency...', | |
158 sys.stdout.flush() | |
159 | |
160 # basic test | |
161 maxsize = min(len(ds)-1,100) | |
162 for iter in ds[:maxsize] : | |
163 assert len(iter) == len(ds.fieldNames()) | |
164 if len(ds.fieldNames()) == 1 : | |
165 print 'done' | |
166 return | |
167 | |
168 # with minibatches iterator | |
169 ds2 = ds.minibatches[:maxsize]([ds.fieldNames()[0],ds.fieldNames()[1]],minibatch_size=2) | |
170 for iter in ds2 : | |
171 assert len(iter) == 2 | |
172 | |
173 print 'done' | |
174 | |
175 | |
176 | |
144 | 177 |
145 | 178 |
146 ################################################################### | 179 ################################################################### |
147 # main | 180 # main |
148 if __name__ == '__main__': | 181 if __name__ == '__main__': |