comparison _test_dataset.py @ 376:c9a89be5cb0a

Redesigning linear_regression
author Yoshua Bengio <bengioy@iro.umontreal.ca>
date Mon, 07 Jul 2008 10:08:35 -0400
parents 18702ceb2096
children 82da179d95b2
comparison
equal deleted inserted replaced
375:12ce29abf27d 376:c9a89be5cb0a
1 #!/bin/env python 1 #!/bin/env python
2 from dataset import * 2 from dataset import *
3 from math import * 3 from math import *
4 import numpy, unittest, sys 4 import numpy, unittest, sys
5 from misc import * 5 #from misc import *
6 from lookup_list import LookupList 6 from lookup_list import LookupList
7 7
8 def have_raised(to_eval, **var): 8 def have_raised(to_eval, **var):
9 have_thrown = False 9 have_thrown = False
10 try: 10 try:
132 assert len(minibatch_field)==minibatch_size or ((nb_iter_finished*minibatch_size+len(minibatch_field))==len_ds and len(minibatch_field)<minibatch_size) 132 assert len(minibatch_field)==minibatch_size or ((nb_iter_finished*minibatch_size+len(minibatch_field))==len_ds and len(minibatch_field)<minibatch_size)
133 133
134 # - for minibatch in dataset.minibatches([field1, field2, ...],minibatch_size=N): 134 # - for minibatch in dataset.minibatches([field1, field2, ...],minibatch_size=N):
135 i=0 135 i=0
136 mi=0 136 mi=0
137 m=ds.minibatches(['x','z'], minibatch_size=3) 137 size=3
138 assert isinstance(m,DataSet.MinibatchWrapAroundIterator) 138 m=ds.minibatches(['x','z'], minibatch_size=size)
139 assert hasattr(m,'__iter__')
139 for minibatch in m: 140 for minibatch in m:
140 assert isinstance(minibatch,DataSetFields) 141 assert isinstance(minibatch,LookupList)
141 assert len(minibatch)==2 142 assert len(minibatch)==2
142 test_minibatch_size(minibatch,m.minibatch_size,len(ds),2,mi) 143 test_minibatch_size(minibatch,size,len(ds),2,mi)
143 if type(ds)==ArrayDataSet: 144 if type(ds)==ArrayDataSet:
144 assert (minibatch[0][:,::2]==minibatch[1]).all() 145 assert (minibatch[0][:,::2]==minibatch[1]).all()
145 else: 146 else:
146 for j in xrange(len(minibatch[0])): 147 for j in xrange(len(minibatch[0])):
147 (minibatch[0][j][::2]==minibatch[1][j]).all() 148 (minibatch[0][j][::2]==minibatch[1][j]).all()
148 mi+=1 149 mi+=1
149 i+=len(minibatch[0]) 150 i+=len(minibatch[0])
150 assert i==len(ds) 151 assert i==(len(ds)/size)*size
151 assert mi==4 152 assert mi==(len(ds)/size)
152 del minibatch,i,m,mi 153 del minibatch,i,m,mi,size
153 154
154 i=0 155 i=0
155 mi=0 156 mi=0
156 m=ds.minibatches(['x','y'], minibatch_size=3) 157 size=3
157 assert isinstance(m,DataSet.MinibatchWrapAroundIterator) 158 m=ds.minibatches(['x','y'], minibatch_size=size)
159 assert hasattr(m,'__iter__')
158 for minibatch in m: 160 for minibatch in m:
161 assert isinstance(minibatch,LookupList)
159 assert len(minibatch)==2 162 assert len(minibatch)==2
160 test_minibatch_size(minibatch,m.minibatch_size,len(ds),2,mi) 163 test_minibatch_size(minibatch,size,len(ds),2,mi)
161 mi+=1 164 mi+=1
162 for id in range(len(minibatch[0])): 165 for id in range(len(minibatch[0])):
163 assert (numpy.append(minibatch[0][id],minibatch[1][id])==array[i]).all() 166 assert (numpy.append(minibatch[0][id],minibatch[1][id])==array[i]).all()
164 i+=1 167 i+=1
165 assert i==len(ds) 168 assert i==(len(ds)/size)*size
166 assert mi==4 169 assert mi==(len(ds)/size)
167 del minibatch,i,id,m,mi 170 del minibatch,i,id,m,mi,size
168 171
169 # - for mini1,mini2,mini3 in dataset.minibatches([field1, field2, field3], minibatch_size=N): 172 # - for mini1,mini2,mini3 in dataset.minibatches([field1, field2, field3], minibatch_size=N):
170 i=0 173 i=0
171 mi=0 174 mi=0
172 m=ds.minibatches(['x','z'], minibatch_size=3) 175 size=3
173 assert isinstance(m,DataSet.MinibatchWrapAroundIterator) 176 m=ds.minibatches(['x','z'], minibatch_size=size)
177 assert hasattr(m,'__iter__')
174 for x,z in m: 178 for x,z in m:
175 test_minibatch_field_size(x,m.minibatch_size,len(ds),mi) 179 test_minibatch_field_size(x,size,len(ds),mi)
176 test_minibatch_field_size(z,m.minibatch_size,len(ds),mi) 180 test_minibatch_field_size(z,size,len(ds),mi)
177 for id in range(len(x)): 181 for id in range(len(x)):
178 assert (x[id][::2]==z[id]).all() 182 assert (x[id][::2]==z[id]).all()
179 i+=1 183 i+=1
180 mi+=1 184 mi+=1
181 assert i==len(ds) 185 assert i==(len(ds)/size)*size
182 assert mi==4 186 assert mi==(len(ds)/size)
183 del x,z,i,m,mi 187 del x,z,i,m,mi,size
188
184 i=0 189 i=0
185 mi=0 190 mi=0
191 size=3
186 m=ds.minibatches(['x','y'], minibatch_size=3) 192 m=ds.minibatches(['x','y'], minibatch_size=3)
193 assert hasattr(m,'__iter__')
187 for x,y in m: 194 for x,y in m:
188 test_minibatch_field_size(x,m.minibatch_size,len(ds),mi) 195 assert len(x)==size
189 test_minibatch_field_size(y,m.minibatch_size,len(ds),mi) 196 assert len(y)==size
197 test_minibatch_field_size(x,size,len(ds),mi)
198 test_minibatch_field_size(y,size,len(ds),mi)
190 mi+=1 199 mi+=1
191 for id in range(len(x)): 200 for id in range(len(x)):
192 assert (numpy.append(x[id],y[id])==array[i]).all() 201 assert (numpy.append(x[id],y[id])==array[i]).all()
193 i+=1 202 i+=1
194 assert i==len(ds) 203 assert i==(len(ds)/size)*size
195 assert mi==4 204 assert mi==(len(ds)/size)
196 del x,y,i,id,m,mi 205 del x,y,i,id,m,mi,size
197 206
198 #not in doc 207 #not in doc
199 i=0 208 i=0
200 m=ds.minibatches(['x','y'],n_batches=1,minibatch_size=3,offset=4) 209 size=3
201 assert isinstance(m,DataSet.MinibatchWrapAroundIterator) 210 m=ds.minibatches(['x','y'],n_batches=1,minibatch_size=size,offset=4)
211 assert hasattr(m,'__iter__')
202 for x,y in m: 212 for x,y in m:
203 assert len(x)==m.minibatch_size 213 assert len(x)==size
204 assert len(y)==m.minibatch_size 214 assert len(y)==size
205 for id in range(m.minibatch_size): 215 for id in range(size):
206 assert (numpy.append(x[id],y[id])==array[i+4]).all() 216 assert (numpy.append(x[id],y[id])==array[i+4]).all()
207 i+=1 217 i+=1
208 assert i==m.n_batches*m.minibatch_size 218 assert i==size
209 del x,y,i,id,m 219 del x,y,i,id,m,size
210 220
211 i=0 221 i=0
212 m=ds.minibatches(['x','y'],n_batches=2,minibatch_size=3,offset=4) 222 size=3
213 assert isinstance(m,DataSet.MinibatchWrapAroundIterator) 223 m=ds.minibatches(['x','y'],n_batches=2,minibatch_size=size,offset=4)
224 assert hasattr(m,'__iter__')
214 for x,y in m: 225 for x,y in m:
215 assert len(x)==m.minibatch_size 226 assert len(x)==size
216 assert len(y)==m.minibatch_size 227 assert len(y)==size
217 for id in range(m.minibatch_size): 228 for id in range(size):
218 assert (numpy.append(x[id],y[id])==array[i+4]).all() 229 assert (numpy.append(x[id],y[id])==array[i+4]).all()
219 i+=1 230 i+=1
220 assert i==m.n_batches*m.minibatch_size 231 assert i==2*size
221 del x,y,i,id,m 232 del x,y,i,id,m,size
222 233
223 i=0 234 i=0
224 m=ds.minibatches(['x','y'],n_batches=20,minibatch_size=3,offset=4) 235 size=3
225 assert isinstance(m,DataSet.MinibatchWrapAroundIterator) 236 m=ds.minibatches(['x','y'],n_batches=20,minibatch_size=size,offset=4)
237 assert hasattr(m,'__iter__')
226 for x,y in m: 238 for x,y in m:
227 assert len(x)==m.minibatch_size 239 assert len(x)==size
228 assert len(y)==m.minibatch_size 240 assert len(y)==size
229 for id in range(m.minibatch_size): 241 for id in range(size):
230 assert (numpy.append(x[id],y[id])==array[(i+4)%array.shape[0]]).all() 242 assert (numpy.append(x[id],y[id])==array[(i+4)%array.shape[0]]).all()
231 i+=1 243 i+=1
232 assert i==m.n_batches*m.minibatch_size 244 assert i==2*size # should not wrap
233 del x,y,i,id 245 del x,y,i,id,size
234 246
235 assert not have_raised2(ds.minibatches,['x','y'],n_batches=1,minibatch_size=len(array)+1,offset=0) 247 assert have_raised2(ds.minibatches,['x','y'],n_batches=1,minibatch_size=len(array)+1,offset=0)
236 assert not have_raised2(ds.minibatches,['x','y'],n_batches=1,minibatch_size=len(array),offset=0) 248 assert not have_raised2(ds.minibatches,['x','y'],n_batches=1,minibatch_size=len(array),offset=0)
237 249
238 def test_ds_iterator(array,iterator1,iterator2,iterator3): 250 def test_ds_iterator(array,iterator1,iterator2,iterator3):
239 l=len(iterator1) 251 l=len(iterator1)
240 i=0 252 i=0
260 assert i==l 272 assert i==l
261 273
262 def test_getitem(array,ds): 274 def test_getitem(array,ds):
263 def test_ds(orig,ds,index): 275 def test_ds(orig,ds,index):
264 i=0 276 i=0
265 assert len(ds)==len(index) 277 assert isinstance(ds,LookupList)
266 for x,z,y in ds('x','z','y'): 278 assert len(ds)==3
267 assert (orig[index[i]]['x']==array[index[i]][:3]).all() 279 assert len(ds[0])==len(index)
268 assert (orig[index[i]]['x']==x).all() 280 # for x,z,y in ds('x','z','y'):
269 assert orig[index[i]]['y']==array[index[i]][3] 281 for idx in index:
270 assert (orig[index[i]]['y']==y).all() # why does it crash sometimes? 282 assert (orig[idx]['x']==array[idx][:3]).all()
271 assert (orig[index[i]]['z']==array[index[i]][0:3:2]).all() 283 assert (orig[idx]['x']==ds['x'][i]).all()
272 assert (orig[index[i]]['z']==z).all() 284 assert orig[idx]['y']==array[idx][3]
285 assert (orig[idx]['y']==ds['y'][i]).all() # why does it crash sometimes?
286 assert (orig[idx]['z']==array[idx][0:3:2]).all()
287 assert (orig[idx]['z']==ds['z'][i]).all()
273 i+=1 288 i+=1
274 del i 289 del i
275 ds[0] 290 ds[0]
276 if len(ds)>2: 291 if len(ds)>2:
277 ds[:1] 292 ds[:1]
280 if len(ds)>5: 295 if len(ds)>5:
281 ds[[1,2,3]] 296 ds[[1,2,3]]
282 for x in ds: 297 for x in ds:
283 pass 298 pass
284 299
285 #ds[:n] returns a dataset with the n first examples. 300 #ds[:n] returns a LookupList with the n first examples.
286 ds2=ds[:3] 301 ds2=ds[:3]
287 assert isinstance(ds2,LookupList)
288 test_ds(ds,ds2,index=[0,1,2]) 302 test_ds(ds,ds2,index=[0,1,2])
289 del ds2 303 del ds2
290 304
291 #ds[i1:i2:s]# returns a ds with the examples i1,i1+s,...i2-s. 305 #ds[i:j] returns a LookupList with examples i,i+1,...,j-1.
292 ds2=ds.subset[1:7:2] 306 ds2=ds[1:3]
293 assert isinstance(ds2,DataSet) 307 test_ds(ds,ds2,index=[1,2])
308 del ds2
309
310 #ds[i1:i2:s] returns a LookupList with the examples i1,i1+s,...i2-s.
311 ds2=ds[1:7:2]
294 test_ds(ds,ds2,[1,3,5]) 312 test_ds(ds,ds2,[1,3,5])
295 del ds2 313 del ds2
296 314
297 #ds[i] 315 #ds[i] returns the (i+1)-th example of the dataset.
298 ds2=ds[5] 316 ds2=ds[5]
299 assert isinstance(ds2,Example) 317 assert isinstance(ds2,Example)
300 assert have_raised("var['ds']["+str(len(ds))+"]",ds=ds) # index not defined 318 assert have_raised("var['ds']["+str(len(ds))+"]",ds=ds) # index not defined
301 assert not have_raised("var['ds']["+str(len(ds)-1)+"]",ds=ds) 319 assert not have_raised("var['ds']["+str(len(ds)-1)+"]",ds=ds)
302 del ds2 320 del ds2
303 321
304 #ds[[i1,i2,...in]]# returns a ds with examples i1,i2,...in. 322 #ds[[i1,i2,...in]]# returns a ds with examples i1,i2,...in.
305 ds2=ds.subset[[4,7,2,8]] 323 ds2=ds[[4,7,2,8]]
306 assert isinstance(ds2,DataSet) 324 # assert isinstance(ds2,DataSet)
307 test_ds(ds,ds2,[4,7,2,8]) 325 test_ds(ds,ds2,[4,7,2,8])
308 del ds2 326 del ds2
309 327
310 #ds.<property># returns the value of a property associated with 328 #ds.<property># returns the value of a property associated with
311 #the name <property>. The following properties should be supported: 329 #the name <property>. The following properties should be supported:
323 # for example in hstack([ds('x'),ds('y'),ds('z')]): 341 # for example in hstack([ds('x'),ds('y'),ds('z')]):
324 # example==ds[i] 342 # example==ds[i]
325 # i+=1 343 # i+=1
326 # del i,example 344 # del i,example
327 #* ds1 & ds2 & ds3 == ds.vstack([ds1,ds2,ds3])#???? 345 #* ds1 & ds2 & ds3 == ds.vstack([ds1,ds2,ds3])#????
346
347 def test_subset(array,ds):
348 def test_ds(orig,ds,index):
349 i=0
350 assert isinstance(ds2,DataSet)
351 assert len(ds)==len(index)
352 for x,z,y in ds('x','z','y'):
353 assert (orig[index[i]]['x']==array[index[i]][:3]).all()
354 assert (orig[index[i]]['x']==x).all()
355 assert orig[index[i]]['y']==array[index[i]][3]
356 assert orig[index[i]]['y']==y
357 assert (orig[index[i]]['z']==array[index[i]][0:3:2]).all()
358 assert (orig[index[i]]['z']==z).all()
359 i+=1
360 del i
361 ds[0]
362 if len(ds)>2:
363 ds[:1]
364 ds[1:1]
365 ds[1:1:1]
366 if len(ds)>5:
367 ds[[1,2,3]]
368 for x in ds:
369 pass
370
371 #ds[:n] returns a dataset with the n first examples.
372 ds2=ds.subset[:3]
373 test_ds(ds,ds2,index=[0,1,2])
374 # del ds2
375
376 #ds[i1:i2:s]# returns a ds with the examples i1,i1+s,...i2-s.
377 ds2=ds.subset[1:7:2]
378 test_ds(ds,ds2,[1,3,5])
379 # del ds2
380
381 # #ds[i]
382 # ds2=ds.subset[5]
383 # assert isinstance(ds2,Example)
384 # assert have_raised("var['ds']["+str(len(ds))+"]",ds=ds) # index not defined
385 # assert not have_raised("var['ds']["+str(len(ds)-1)+"]",ds=ds)
386 # del ds2
387
388 #ds[[i1,i2,...in]]# returns a ds with examples i1,i2,...in.
389 ds2=ds.subset[[4,7,2,8]]
390 test_ds(ds,ds2,[4,7,2,8])
391 # del ds2
392
393 #ds.<property># returns the value of a property associated with
394 #the name <property>. The following properties should be supported:
395 # - 'description': a textual description or name for the ds
396 # - 'fieldtypes': a list of types (one per field)
397
398 #* ds1 | ds2 | ds3 == ds.hstack([ds1,ds2,ds3])#????
399 #assert hstack([ds('x','y'),ds('z')])==ds
400 #hstack([ds('z','y'),ds('x')])==ds
401 assert have_raised2(hstack,[ds('x'),ds('x')])
402 assert have_raised2(hstack,[ds('y','x'),ds('x')])
403 assert not have_raised2(hstack,[ds('x'),ds('y')])
404
405 # i=0
406 # for example in hstack([ds('x'),ds('y'),ds('z')]):
407 # example==ds[i]
408 # i+=1
409 # del i,example
410 #* ds1 & ds2 & ds3 == ds.vstack([ds1,ds2,ds3])#????
328 411
329 def test_fields_fct(ds): 412 def test_fields_fct(ds):
330 #@todo, fill correctly 413 #@todo, fill correctly
331 assert len(ds.fields())==3 414 assert len(ds.fields())==3
332 i=0 415 i=0
453 def test_all(array,ds): 536 def test_all(array,ds):
454 assert len(ds)==10 537 assert len(ds)==10
455 test_iterate_over_examples(array, ds) 538 test_iterate_over_examples(array, ds)
456 test_overrides(ds) 539 test_overrides(ds)
457 test_getitem(array, ds) 540 test_getitem(array, ds)
541 test_subset(array, ds)
458 test_ds_iterator(array,ds('x','y'),ds('y','z'),ds('x','y','z')) 542 test_ds_iterator(array,ds('x','y'),ds('y','z'),ds('x','y','z'))
459 test_fields_fct(ds) 543 test_fields_fct(ds)
460 544
461 545
462 class T_DataSet(unittest.TestCase): 546 class T_DataSet(unittest.TestCase):
508 592
509 def test_FieldsSubsetDataSet(self): 593 def test_FieldsSubsetDataSet(self):
510 a = numpy.random.rand(10,4) 594 a = numpy.random.rand(10,4)
511 ds = ArrayDataSet(a,Example(['x','y','z','w'],[slice(3),3,[0,2],0])) 595 ds = ArrayDataSet(a,Example(['x','y','z','w'],[slice(3),3,[0,2],0]))
512 ds = FieldsSubsetDataSet(ds,['x','y','z']) 596 ds = FieldsSubsetDataSet(ds,['x','y','z'])
597
598 test_all(a,ds)
599
600 del a, ds
601
602 def test_RenamedFieldsDataSet(self):
603 a = numpy.random.rand(10,4)
604 ds = ArrayDataSet(a,Example(['x1','y1','z1','w1'],[slice(3),3,[0,2],0]))
605 ds = RenamedFieldsDataSet(ds,['x1','y1','z1'],['x','y','z'])
513 606
514 test_all(a,ds) 607 test_all(a,ds)
515 608
516 del a, ds 609 del a, ds
517 610
568 for k in range(len(dsc)) : 661 for k in range(len(dsc)) :
569 self.failUnless(numpy.all( dsc[k]('input')[0] == ds[k]('input')[0] ) , (dsc[k],ds[k]) ) 662 self.failUnless(numpy.all( dsc[k]('input')[0] == ds[k]('input')[0] ) , (dsc[k],ds[k]) )
570 res = dsc[:] 663 res = dsc[:]
571 664
572 if __name__=='__main__': 665 if __name__=='__main__':
573 if len(sys.argv)==2: 666 tests = []
574 if sys.argv[1]=="--debug": 667 debug=False
668 if len(sys.argv)==1:
669 unittest.main()
670 else:
671 assert sys.argv[1]=="--debug"
672 for arg in sys.argv[2:]:
673 tests.append(arg)
674 if tests:
675 unittest.TestSuite(map(T_DataSet, tests)).debug()
676 else:
575 module = __import__("_test_dataset") 677 module = __import__("_test_dataset")
576 tests = unittest.TestLoader().loadTestsFromModule(module) 678 tests = unittest.TestLoader().loadTestsFromModule(module)
577 tests.debug() 679 tests.debug()
578 print "bad argument: only --debug is accepted"
579 elif len(sys.argv)==1:
580 unittest.main()
581 else:
582 print "bad argument: only --debug is accepted"
583