Mercurial > pylearn
comparison _test_dataset.py @ 341:9c08e3af975e
corrected test for dataset.minibatches()
author | Frederic Bastien <bastienf@iro.umontreal.ca> |
---|---|
date | Tue, 17 Jun 2008 13:33:17 -0400 |
parents | b48cf8dce2bf |
children | 2259f6fa4959 |
comparison
equal
deleted
inserted
replaced
340:d96be0eba3cc | 341:9c08e3af975e |
---|---|
132 assert len(minibatch_field)==minibatch_size or ((nb_iter_finished*minibatch_size+len(minibatch_field))==len_ds and len(minibatch_field)<minibatch_size) | 132 assert len(minibatch_field)==minibatch_size or ((nb_iter_finished*minibatch_size+len(minibatch_field))==len_ds and len(minibatch_field)<minibatch_size) |
133 | 133 |
134 # - for minibatch in dataset.minibatches([field1, field2, ...],minibatch_size=N): | 134 # - for minibatch in dataset.minibatches([field1, field2, ...],minibatch_size=N): |
135 i=0 | 135 i=0 |
136 mi=0 | 136 mi=0 |
137 m=ds.minibatches(['x','z'], minibatch_size=3) | 137 size=3 |
138 assert isinstance(m,DataSet.MinibatchWrapAroundIterator) | 138 m=ds.minibatches(['x','z'], minibatch_size=size) |
139 assert hasattr(m,'__iter__') | |
139 for minibatch in m: | 140 for minibatch in m: |
140 assert isinstance(minibatch,DataSetFields) | 141 assert isinstance(minibatch,LookupList) |
141 assert len(minibatch)==2 | 142 assert len(minibatch)==2 |
142 test_minibatch_size(minibatch,m.minibatch_size,len(ds),2,mi) | 143 test_minibatch_size(minibatch,size,len(ds),2,mi) |
143 if type(ds)==ArrayDataSet: | 144 if type(ds)==ArrayDataSet: |
144 assert (minibatch[0][:,::2]==minibatch[1]).all() | 145 assert (minibatch[0][:,::2]==minibatch[1]).all() |
145 else: | 146 else: |
146 for j in xrange(len(minibatch[0])): | 147 for j in xrange(len(minibatch[0])): |
147 (minibatch[0][j][::2]==minibatch[1][j]).all() | 148 (minibatch[0][j][::2]==minibatch[1][j]).all() |
148 mi+=1 | 149 mi+=1 |
149 i+=len(minibatch[0]) | 150 i+=len(minibatch[0]) |
150 assert i==len(ds) | 151 assert i==(len(ds)/size)*size |
151 assert mi==4 | 152 assert mi==(len(ds)/size) |
152 del minibatch,i,m,mi | 153 del minibatch,i,m,mi,size |
153 | 154 |
154 i=0 | 155 i=0 |
155 mi=0 | 156 mi=0 |
156 m=ds.minibatches(['x','y'], minibatch_size=3) | 157 size=3 |
157 assert isinstance(m,DataSet.MinibatchWrapAroundIterator) | 158 m=ds.minibatches(['x','y'], minibatch_size=size) |
159 assert hasattr(m,'__iter__') | |
158 for minibatch in m: | 160 for minibatch in m: |
161 assert isinstance(minibatch,LookupList) | |
159 assert len(minibatch)==2 | 162 assert len(minibatch)==2 |
160 test_minibatch_size(minibatch,m.minibatch_size,len(ds),2,mi) | 163 test_minibatch_size(minibatch,size,len(ds),2,mi) |
161 mi+=1 | 164 mi+=1 |
162 for id in range(len(minibatch[0])): | 165 for id in range(len(minibatch[0])): |
163 assert (numpy.append(minibatch[0][id],minibatch[1][id])==array[i]).all() | 166 assert (numpy.append(minibatch[0][id],minibatch[1][id])==array[i]).all() |
164 i+=1 | 167 i+=1 |
165 assert i==len(ds) | 168 assert i==(len(ds)/size)*size |
166 assert mi==4 | 169 assert mi==(len(ds)/size) |
167 del minibatch,i,id,m,mi | 170 del minibatch,i,id,m,mi,size |
168 | 171 |
169 # - for mini1,mini2,mini3 in dataset.minibatches([field1, field2, field3], minibatch_size=N): | 172 # - for mini1,mini2,mini3 in dataset.minibatches([field1, field2, field3], minibatch_size=N): |
170 i=0 | 173 i=0 |
171 mi=0 | 174 mi=0 |
172 m=ds.minibatches(['x','z'], minibatch_size=3) | 175 size=3 |
173 assert isinstance(m,DataSet.MinibatchWrapAroundIterator) | 176 m=ds.minibatches(['x','z'], minibatch_size=size) |
177 assert hasattr(m,'__iter__') | |
174 for x,z in m: | 178 for x,z in m: |
175 test_minibatch_field_size(x,m.minibatch_size,len(ds),mi) | 179 test_minibatch_field_size(x,size,len(ds),mi) |
176 test_minibatch_field_size(z,m.minibatch_size,len(ds),mi) | 180 test_minibatch_field_size(z,size,len(ds),mi) |
177 for id in range(len(x)): | 181 for id in range(len(x)): |
178 assert (x[id][::2]==z[id]).all() | 182 assert (x[id][::2]==z[id]).all() |
179 i+=1 | 183 i+=1 |
180 mi+=1 | 184 mi+=1 |
181 assert i==len(ds) | 185 assert i==(len(ds)/size)*size |
182 assert mi==4 | 186 assert mi==(len(ds)/size) |
183 del x,z,i,m,mi | 187 del x,z,i,m,mi,size |
188 | |
184 i=0 | 189 i=0 |
185 mi=0 | 190 mi=0 |
191 size=3 | |
186 m=ds.minibatches(['x','y'], minibatch_size=3) | 192 m=ds.minibatches(['x','y'], minibatch_size=3) |
193 assert hasattr(m,'__iter__') | |
187 for x,y in m: | 194 for x,y in m: |
188 test_minibatch_field_size(x,m.minibatch_size,len(ds),mi) | 195 assert len(x)==size |
189 test_minibatch_field_size(y,m.minibatch_size,len(ds),mi) | 196 assert len(y)==size |
197 test_minibatch_field_size(x,size,len(ds),mi) | |
198 test_minibatch_field_size(y,size,len(ds),mi) | |
190 mi+=1 | 199 mi+=1 |
191 for id in range(len(x)): | 200 for id in range(len(x)): |
192 assert (numpy.append(x[id],y[id])==array[i]).all() | 201 assert (numpy.append(x[id],y[id])==array[i]).all() |
193 i+=1 | 202 i+=1 |
194 assert i==len(ds) | 203 assert i==(len(ds)/size)*size |
195 assert mi==4 | 204 assert mi==(len(ds)/size) |
196 del x,y,i,id,m,mi | 205 del x,y,i,id,m,mi,size |
197 | 206 |
198 #not in doc | 207 #not in doc |
199 i=0 | 208 i=0 |
200 m=ds.minibatches(['x','y'],n_batches=1,minibatch_size=3,offset=4) | 209 size=3 |
201 assert isinstance(m,DataSet.MinibatchWrapAroundIterator) | 210 m=ds.minibatches(['x','y'],n_batches=1,minibatch_size=size,offset=4) |
211 assert hasattr(m,'__iter__') | |
202 for x,y in m: | 212 for x,y in m: |
203 assert len(x)==m.minibatch_size | 213 assert len(x)==size |
204 assert len(y)==m.minibatch_size | 214 assert len(y)==size |
205 for id in range(m.minibatch_size): | 215 for id in range(size): |
206 assert (numpy.append(x[id],y[id])==array[i+4]).all() | 216 assert (numpy.append(x[id],y[id])==array[i+4]).all() |
207 i+=1 | 217 i+=1 |
208 assert i==m.n_batches*m.minibatch_size | 218 assert i==size |
209 del x,y,i,id,m | 219 del x,y,i,id,m,size |
210 | 220 |
211 i=0 | 221 i=0 |
212 m=ds.minibatches(['x','y'],n_batches=2,minibatch_size=3,offset=4) | 222 size=3 |
213 assert isinstance(m,DataSet.MinibatchWrapAroundIterator) | 223 m=ds.minibatches(['x','y'],n_batches=2,minibatch_size=size,offset=4) |
224 assert hasattr(m,'__iter__') | |
214 for x,y in m: | 225 for x,y in m: |
215 assert len(x)==m.minibatch_size | 226 assert len(x)==size |
216 assert len(y)==m.minibatch_size | 227 assert len(y)==size |
217 for id in range(m.minibatch_size): | 228 for id in range(size): |
218 assert (numpy.append(x[id],y[id])==array[i+4]).all() | 229 assert (numpy.append(x[id],y[id])==array[i+4]).all() |
219 i+=1 | 230 i+=1 |
220 assert i==m.n_batches*m.minibatch_size | 231 assert i==2*size |
221 del x,y,i,id,m | 232 del x,y,i,id,m,size |
222 | 233 |
223 i=0 | 234 i=0 |
224 m=ds.minibatches(['x','y'],n_batches=20,minibatch_size=3,offset=4) | 235 size=3 |
225 assert isinstance(m,DataSet.MinibatchWrapAroundIterator) | 236 m=ds.minibatches(['x','y'],n_batches=20,minibatch_size=size,offset=4) |
237 assert hasattr(m,'__iter__') | |
226 for x,y in m: | 238 for x,y in m: |
227 assert len(x)==m.minibatch_size | 239 assert len(x)==size |
228 assert len(y)==m.minibatch_size | 240 assert len(y)==size |
229 for id in range(m.minibatch_size): | 241 for id in range(size): |
230 assert (numpy.append(x[id],y[id])==array[(i+4)%array.shape[0]]).all() | 242 assert (numpy.append(x[id],y[id])==array[(i+4)%array.shape[0]]).all() |
231 i+=1 | 243 i+=1 |
232 assert i==m.n_batches*m.minibatch_size | 244 assert i==2*size # should not wrap |
233 del x,y,i,id | 245 del x,y,i,id,size |
234 | 246 |
235 assert not have_raised2(ds.minibatches,['x','y'],n_batches=1,minibatch_size=len(array)+1,offset=0) | 247 assert have_raised2(ds.minibatches,['x','y'],n_batches=1,minibatch_size=len(array)+1,offset=0) |
236 assert not have_raised2(ds.minibatches,['x','y'],n_batches=1,minibatch_size=len(array),offset=0) | 248 assert not have_raised2(ds.minibatches,['x','y'],n_batches=1,minibatch_size=len(array),offset=0) |
237 | 249 |
238 def test_ds_iterator(array,iterator1,iterator2,iterator3): | 250 def test_ds_iterator(array,iterator1,iterator2,iterator3): |
239 l=len(iterator1) | 251 l=len(iterator1) |
240 i=0 | 252 i=0 |