Mercurial > pylearn
comparison test_dataset.py @ 91:eee739fefdff
corrected test from discution about the syntax with Yoshua
author | Frederic Bastien <bastienf@iro.umontreal.ca> |
---|---|
date | Mon, 05 May 2008 17:13:53 -0400 |
parents | 05dc4804357b |
children | 9c8f3c9c247b |
comparison
equal
deleted
inserted
replaced
90:a289b8bed64c | 91:eee739fefdff |
---|---|
39 #test missing value | 39 #test missing value |
40 | 40 |
41 print "test_ArrayDataSet" | 41 print "test_ArrayDataSet" |
42 a = numpy.random.rand(10,4) | 42 a = numpy.random.rand(10,4) |
43 ds = ArrayDataSet(a,{'x':slice(3),'y':3,'z':[0,2]})###???tuple not tested | 43 ds = ArrayDataSet(a,{'x':slice(3),'y':3,'z':[0,2]})###???tuple not tested |
44 | 44 ds = ArrayDataSet(a,LookupList(['x','y','z'],[slice(3),3,[0,2]]))###???tuple not tested |
45 assert len(ds)==10 | 45 assert len(ds)==10 |
46 #assert ds==a? should this work? | 46 #assert ds==a? should this work? |
47 | 47 |
48 #not in doc!!! | 48 #not in doc!!! |
49 for example in range(len(ds)): | 49 for example in range(len(ds)): |
50 assert ds[example]['x'].all()==a[example][:2].all() | 50 assert ds[example]['x'].all()==a[example][:2].all() |
51 assert ds[example]['y']==a[example][3] | 51 assert ds[example]['y']==a[example][3] |
52 assert ds[example]['z'].all()==a[example][0:3:2].all() | 52 assert ds[example]['z'].all()==a[example][0:3:2].all() |
53 # - for example in dataset:: | 53 |
54 # - for example in dataset: | |
54 i=0 | 55 i=0 |
55 for example in ds: | 56 for example in ds: |
56 assert example['x'].all()==a[i][:2].all() | 57 assert example['x'].all()==a[i][:2].all() |
57 assert example['y']==a[i][3] | 58 assert example['y']==a[i][3] |
58 assert example['z'].all()==a[i][0:3:2].all() | 59 assert example['z'].all()==a[i][0:3:2].all() |
59 assert numpy.append(example['x'],example['y']).all()==a[i].all() | 60 assert numpy.append(example['x'],example['y']).all()==a[i].all() |
60 i+=1 | 61 i+=1 |
61 assert i==len(ds) | 62 assert i==len(ds) |
63 # - for val1,val2,... in dataset: | |
64 i=0 | |
65 for x,y,z in ds: | |
66 assert x.all()==a[i][:2].all() | |
67 assert y==a[i][3] | |
68 assert z.all()==a[i][0:3:2].all() | |
69 assert numpy.append(example['x'],example['y']).all()==a[i].all() | |
70 i+=1 | |
71 assert i==len(ds) | |
72 # - for example in dataset(field1, field2,field3, ...): | |
73 i=0 | |
74 for example in ds('x','y','z'): | |
75 assert example['x'].all()==a[i][:2].all() | |
76 assert example['y']==a[i][3] | |
77 assert example['z'].all()==a[i][0:3:2].all() | |
78 assert numpy.append(example['x'],example['y']).all()==a[i].all() | |
79 i+=1 | |
80 assert i==len(ds) | |
81 | |
82 # - for val1,val2,val3 in dataset(field1, field2,field3): | |
83 | |
84 # - for example in dataset(field1, field2,field3, ...): | |
62 | 85 |
63 def test_ds_iterator(iterator1,iterator2,iterator3): | 86 def test_ds_iterator(iterator1,iterator2,iterator3): |
64 i=0 | 87 i=0 |
65 for x,y in iterator1: | 88 for x,y in iterator1: |
66 assert x.all()==a[i][:2].all() | 89 assert x.all()==a[i][:2].all() |
84 assert i==len(ds) | 107 assert i==len(ds) |
85 | 108 |
86 #not in doc!!! - for val1,val2,val3 in dataset(field1, field2,field3): | 109 #not in doc!!! - for val1,val2,val3 in dataset(field1, field2,field3): |
87 test_ds_iterator(ds('x','y'),ds('y','z'),ds('x','y','z')) | 110 test_ds_iterator(ds('x','y'),ds('y','z'),ds('x','y','z')) |
88 | 111 |
89 #not in doc!!! - for val1,val2,val3 in dataset((field1, field2,field3)): | |
90 test_ds_iterator(ds(('x','y')),ds(('y','z')),ds(('x','y','z'))) | |
91 | |
92 # - for val1,val2,val3 in dataset([field1, field2,field3]): #was bugged | |
93 test_ds_iterator(ds(['x','y']),ds(['y','z']),ds(['x','y','z'])) | |
94 | |
95 # - for minibatch in dataset.minibatches([field1, field2, ...],minibatch_size=N): | 112 # - for minibatch in dataset.minibatches([field1, field2, ...],minibatch_size=N): |
96 for minibatch in ds.minibatches(['x','z'], minibatch_size=3): | 113 for minibatch in ds.minibatches(['x','z'], minibatch_size=3): |
97 assert len(minibatch)==2 | 114 assert len(minibatch)==2 |
98 assert len(minibatch[0])==3 | 115 assert len(minibatch[0])==3 |
99 assert len(minibatch[1])==3 | 116 assert len(minibatch[1])==3 |