Mercurial > pylearn
comparison test_dataset.py @ 96:352910e0dbf5
added test and some restructuring for futur use
author | Frederic Bastien <bastienf@iro.umontreal.ca> |
---|---|
date | Tue, 06 May 2008 10:53:21 -0400 |
parents | 9c8f3c9c247b |
children | 574f4db76022 |
comparison
equal
deleted
inserted
replaced
95:6fe972a7393c | 96:352910e0dbf5 |
---|---|
36 #don't test stream | 36 #don't test stream |
37 #tested only with float value | 37 #tested only with float value |
38 #test with y too | 38 #test with y too |
39 #test missing value | 39 #test missing value |
40 | 40 |
41 def test_iterate_over_examples(array,ds): | |
42 #not in doc!!! | |
43 i=0 | |
44 for example in range(len(ds)): | |
45 assert (ds[example]['x']==a[example][:3]).all() | |
46 assert ds[example]['y']==a[example][3] | |
47 assert (ds[example]['z']==a[example][[0,2]]).all() | |
48 i+=1 | |
49 assert i==len(ds) | |
50 del example,i | |
51 | |
52 # - for example in dataset: | |
53 i=0 | |
54 for example in ds: | |
55 assert len(example)==3 | |
56 assert (example['x']==array[i][:3]).all() | |
57 assert example['y']==array[i][3] | |
58 assert (example['z']==array[i][0:3:2]).all() | |
59 assert (numpy.append(example['x'],example['y'])==array[i]).all() | |
60 i+=1 | |
61 assert i==len(ds) | |
62 del example,i | |
63 | |
64 # - for val1,val2,... in dataset: | |
65 i=0 | |
66 for x,y,z in ds: | |
67 assert (x==array[i][:3]).all() | |
68 assert y==array[i][3] | |
69 assert (z==array[i][0:3:2]).all() | |
70 assert (numpy.append(x,y)==array[i]).all() | |
71 i+=1 | |
72 assert i==len(ds) | |
73 del x,y,z,i | |
74 | |
75 # - for example in dataset(field1, field2,field3, ...): | |
76 i=0 | |
77 for example in ds('x','y','z'): | |
78 assert len(example)==3 | |
79 assert (example['x']==array[i][:3]).all() | |
80 assert example['y']==array[i][3] | |
81 assert (example['z']==array[i][0:3:2]).all() | |
82 assert (numpy.append(example['x'],example['y'])==array[i]).all() | |
83 i+=1 | |
84 assert i==len(ds) | |
85 del example,i | |
86 i=0 | |
87 for example in ds('y','x'): | |
88 assert len(example)==2 | |
89 assert (example['x']==array[i][:3]).all() | |
90 assert example['y']==array[i][3] | |
91 assert (numpy.append(example['x'],example['y'])==array[i]).all() | |
92 i+=1 | |
93 assert i==len(ds) | |
94 del example,i | |
95 | |
96 # - for val1,val2,val3 in dataset(field1, field2,field3): | |
97 i=0 | |
98 for x,y,z in ds('x','y','z'): | |
99 assert (x==array[i][:3]).all() | |
100 assert y==array[i][3] | |
101 assert (z==array[i][0:3:2]).all() | |
102 assert (numpy.append(x,y)==array[i]).all() | |
103 i+=1 | |
104 assert i==len(ds) | |
105 del x,y,z,i | |
106 i=0 | |
107 for y,x in ds('y','x',): | |
108 assert (x==array[i][:3]).all() | |
109 assert y==array[i][3] | |
110 assert (numpy.append(x,y)==array[i]).all() | |
111 i+=1 | |
112 assert i==len(ds) | |
113 del x,y,i | |
114 | |
115 | |
116 # - for minibatch in dataset.minibatches([field1, field2, ...],minibatch_size=N): | |
117 i=0 | |
118 for minibatch in ds.minibatches(['x','z'], minibatch_size=3): | |
119 assert len(minibatch)==2 | |
120 assert len(minibatch[0])==3 | |
121 assert len(minibatch[1])==3 | |
122 assert (minibatch[0][:,0:3:2]==minibatch[1]).all() | |
123 i+=1 | |
124 #assert i==#??? What shoud be the value? | |
125 print i | |
126 del minibatch,i | |
127 i=0 | |
128 for minibatch in ds.minibatches(['x','y'], minibatch_size=3): | |
129 assert len(minibatch)==2 | |
130 assert len(minibatch[0])==3 | |
131 assert len(minibatch[1])==3 | |
132 for id in range(3): | |
133 assert (numpy.append(minibatch[0][id],minibatch[1][id])==a[i]).all() | |
134 i+=1 | |
135 #assert i==#??? What shoud be the value? | |
136 print i | |
137 del minibatch,i,id | |
138 | |
139 # - for mini1,mini2,mini3 in dataset.minibatches([field1, field2, field3], minibatch_size=N): | |
140 i=0 | |
141 for x,z in ds.minibatches(['x','z'], minibatch_size=3): | |
142 assert len(x)==3 | |
143 assert len(z)==3 | |
144 assert (x[:,0:3:2]==z).all() | |
145 i+=1 | |
146 #assert i==#??? What shoud be the value? | |
147 print i | |
148 del x,z,i | |
149 i=0 | |
150 for x,y in ds.minibatches(['x','y'], minibatch_size=3): | |
151 assert len(x)==3 | |
152 assert len(y)==3 | |
153 for id in range(3): | |
154 assert (numpy.append(x[id],y[id])==a[i]).all() | |
155 i+=1 | |
156 #assert i==#??? What shoud be the value? | |
157 print i | |
158 del x,y,i,id | |
159 | |
160 #not in doc | |
161 i=0 | |
162 for x,y in ds.minibatches(['x','y'],n_batches=1,minibatch_size=3,offset=4): | |
163 assert len(x)==3 | |
164 assert len(y)==3 | |
165 for id in range(3): | |
166 assert (numpy.append(x[id],y[id])==a[i+4]).all() | |
167 i+=1 | |
168 assert i==3 | |
169 del x,y,i,id | |
170 | |
171 i=0 | |
172 for x,y in ds.minibatches(['x','y'],n_batches=2,minibatch_size=3,offset=4): | |
173 assert len(x)==3 | |
174 assert len(y)==3 | |
175 for id in range(3): | |
176 assert (numpy.append(x[id],y[id])==a[i+4]).all() | |
177 i+=1 | |
178 assert i==6 | |
179 del x,y,i,id | |
180 | |
181 i=0 | |
182 for x,y in ds.minibatches(['x','y'],n_batches=10,minibatch_size=3,offset=4): | |
183 assert len(x)==3 | |
184 assert len(y)==3 | |
185 for id in range(3): | |
186 assert (numpy.append(x[id],y[id])==a[i+4]).all() | |
187 i+=1 | |
188 assert i==6 | |
189 del x,y,i,id | |
190 | |
191 | |
192 def test_ds_iterator(array,iterator1,iterator2,iterator3): | |
193 i=0 | |
194 for x,y in iterator1: | |
195 assert (x==array[i][:3]).all() | |
196 assert y==array[i][3] | |
197 assert (numpy.append(x,y)==array[i]).all() | |
198 i+=1 | |
199 assert i==len(ds) | |
200 i=0 | |
201 for y,z in iterator2: | |
202 assert y==array[i][3] | |
203 assert (z==array[i][0:3:2]).all() | |
204 i+=1 | |
205 assert i==len(ds) | |
206 i=0 | |
207 for x,y,z in iterator3: | |
208 assert (x==array[i][:3]).all() | |
209 assert y==array[i][3] | |
210 assert (z==array[i][0:3:2]).all() | |
211 assert (numpy.append(x,y)==array[i]).all() | |
212 i+=1 | |
213 assert i==len(ds) | |
214 | |
41 print "test_ArrayDataSet" | 215 print "test_ArrayDataSet" |
42 a = numpy.random.rand(10,4) | 216 a = numpy.random.rand(10,4) |
43 ds = ArrayDataSet(a,{'x':slice(3),'y':3,'z':[0,2]})###???tuple not tested | 217 ds = ArrayDataSet(a,{'x':slice(3),'y':3,'z':[0,2]})###???tuple not tested |
44 ds = ArrayDataSet(a,LookupList(['x','y','z'],[slice(3),3,[0,2]]))###???tuple not tested | 218 ds = ArrayDataSet(a,LookupList(['x','y','z'],[slice(3),3,[0,2]]))###???tuple not tested |
45 assert len(ds)==10 | 219 assert len(ds)==10 |
46 #assert ds==a? should this work? | 220 #assert ds==a? should this work? |
47 | 221 |
48 #not in doc!!! | 222 test_iterate_over_examples(a, ds) |
49 for example in range(len(ds)): | 223 |
50 assert (ds[example]['x']==a[example][:3]).all() | |
51 assert ds[example]['y']==a[example][3] | |
52 assert (ds[example]['z']==a[example][[0,2]]).all() | |
53 | |
54 # - for example in dataset: | |
55 i=0 | |
56 for example in ds: | |
57 assert (example['x']==a[i][:3]).all() | |
58 assert example['y']==a[i][3] | |
59 assert (example['z']==a[i][0:3:2]).all() | |
60 assert (numpy.append(example['x'],example['y'])==a[i]).all() | |
61 i+=1 | |
62 assert i==len(ds) | |
63 # - for val1,val2,... in dataset: | |
64 i=0 | |
65 for x,y,z in ds: | |
66 assert (x==a[i][:3]).all() | |
67 assert y==a[i][3] | |
68 assert (z==a[i][0:3:2]).all() | |
69 assert (numpy.append(x,y)==a[i]).all() | |
70 i+=1 | |
71 assert i==len(ds) | |
72 # - for example in dataset(field1, field2,field3, ...): | |
73 i=0 | |
74 for example in ds('x','y','z'): | |
75 assert (example['x']==a[i][:3]).all() | |
76 assert example['y']==a[i][3] | |
77 assert (example['z']==a[i][0:3:2]).all() | |
78 assert (numpy.append(example['x'],example['y'])==a[i]).all() | |
79 i+=1 | |
80 assert i==len(ds) | |
81 | 224 |
82 # - for val1,val2,val3 in dataset(field1, field2,field3): | 225 # - for val1,val2,val3 in dataset(field1, field2,field3): |
83 | 226 test_ds_iterator(a,ds('x','y'),ds('y','z'),ds('x','y','z')) |
84 # - for example in dataset(field1, field2,field3, ...): | 227 |
85 | |
86 def test_ds_iterator(iterator1,iterator2,iterator3): | |
87 i=0 | |
88 for x,y in iterator1: | |
89 assert (x==a[i][:3]).all() | |
90 assert y==a[i][3] | |
91 assert (numpy.append(x,y)==a[i]).all() | |
92 i+=1 | |
93 assert i==len(ds) | |
94 i=0 | |
95 for y,z in iterator2: | |
96 assert y==a[i][3] | |
97 assert (z==a[i][0:3:2]).all() | |
98 i+=1 | |
99 assert i==len(ds) | |
100 i=0 | |
101 for x,y,z in iterator3: | |
102 assert (x==a[i][:3]).all() | |
103 assert y==a[i][3] | |
104 assert (z==a[i][0:3:2]).all() | |
105 assert (numpy.append(x,y)==a[i]).all() | |
106 i+=1 | |
107 assert i==len(ds) | |
108 | |
109 #not in doc!!! - for val1,val2,val3 in dataset(field1, field2,field3): | |
110 test_ds_iterator(ds('x','y'),ds('y','z'),ds('x','y','z')) | |
111 | |
112 # - for minibatch in dataset.minibatches([field1, field2, ...],minibatch_size=N): | |
113 for minibatch in ds.minibatches(['x','z'], minibatch_size=3): | |
114 assert len(minibatch)==2 | |
115 assert len(minibatch[0])==3 | |
116 assert len(minibatch[1])==3 | |
117 assert (minibatch[0][:,0:3:2]==minibatch[1]).all() | |
118 i=0 | |
119 for minibatch in ds.minibatches(['x','y'], minibatch_size=3): | |
120 assert len(minibatch)==2 | |
121 assert len(minibatch[0])==3 | |
122 assert len(minibatch[1])==3 | |
123 for id in range(3): | |
124 assert (numpy.append(minibatch[0][id],minibatch[1][id])==a[i]).all() | |
125 i+=1 | |
126 | |
127 # - for mini1,mini2,mini3 in dataset.minibatches([field1, field2, field3], minibatch_size=N): | |
128 for x,z in ds.minibatches(['x','z'], minibatch_size=3): | |
129 assert len(x)==3 | |
130 assert len(z)==3 | |
131 assert (x[:,0:3:2]==z).all() | |
132 i=0 | |
133 for x,y in ds.minibatches(['x','y'], minibatch_size=3): | |
134 assert len(x)==3 | |
135 assert len(y)==3 | |
136 for id in range(3): | |
137 assert (numpy.append(x[id],y[id])==a[i]).all() | |
138 i+=1 | |
139 # - for x,y,z in dataset: # fail x,y,z order not fixed as it is a dict. | |
140 # for x,y,z in ds: | |
141 # assert (x==a[i][:2]).all() | |
142 # assert y==a[i][3] | |
143 # assert (z==a[i][0:3:2]).all() | |
144 # assert (numpy.append(x,y)==a[i]).all() | |
145 # i+=1 | |
146 | |
147 # for minibatch in ds.minibatches(['z','y'], minibatch_size=3): | |
148 # print minibatch | |
149 # minibatch_iterator = ds.minibatches(fieldnames=['z','y'],n_batches=1,minibatch_size=3,offset=4) | |
150 # minibatch = minibatch_iterator.__iter__().next() | |
151 # print "minibatch=",minibatch | |
152 # for var in minibatch: | |
153 # print "var=",var | |
154 # print "take a slice and look at field y",ds[1:6:2]["y"] | |
155 assert have_raised("ds['h']") # h is not defined... | 228 assert have_raised("ds['h']") # h is not defined... |
156 assert have_raised("ds[['h']]") # h is not defined... | 229 assert have_raised("ds[['h']]") # h is not defined... |
157 | 230 |
158 assert len(ds.fields())==3 | 231 assert len(ds.fields())==3 |
159 for field in ds.fields(): | 232 for field in ds.fields(): |
213 # - 'description': a textual description or name for the ds | 286 # - 'description': a textual description or name for the ds |
214 # - 'fieldtypes': a list of types (one per field) | 287 # - 'fieldtypes': a list of types (one per field) |
215 #* ds1 | ds2 | ds3 == ds.hstack([ds1,ds2,ds3]) | 288 #* ds1 | ds2 | ds3 == ds.hstack([ds1,ds2,ds3]) |
216 #* ds1 & ds2 & ds3 == ds.vstack([ds1,ds2,ds3]) | 289 #* ds1 & ds2 & ds3 == ds.vstack([ds1,ds2,ds3]) |
217 | 290 |
218 # for (x,y) in (ds('x','y'),a): #don't work # haven't found a variant that work. | 291 # for (x,y) in (ds('x','y'),a): #???don't work # haven't found a variant that work. |
219 # assert numpy.append(x,y)==z | 292 # assert numpy.append(x,y)==z |
220 | 293 |
221 def test_LookupList(): | 294 def test_LookupList(): |
222 #test only the example in the doc??? | 295 #test only the example in the doc??? |
223 print "test_LookupList" | 296 print "test_LookupList" |