Mercurial > pylearn
comparison test_dataset.py @ 89:05dc4804357b
more test and refactoring
author | Frederic Bastien <bastienf@iro.umontreal.ca> |
---|---|
date | Mon, 05 May 2008 16:54:16 -0400 |
parents | 3fd6879e0f76 |
children | eee739fefdff |
comparison
equal
deleted
inserted
replaced
88:6749d18e11c8 | 89:05dc4804357b |
---|---|
42 a = numpy.random.rand(10,4) | 42 a = numpy.random.rand(10,4) |
43 ds = ArrayDataSet(a,{'x':slice(3),'y':3,'z':[0,2]})###???tuple not tested | 43 ds = ArrayDataSet(a,{'x':slice(3),'y':3,'z':[0,2]})###???tuple not tested |
44 | 44 |
45 assert len(ds)==10 | 45 assert len(ds)==10 |
46 #assert ds==a? should this work? | 46 #assert ds==a? should this work? |
47 for i in range(len(ds)): | 47 |
48 assert ds[i]['x'].all()==a[i][:2].all() | 48 #not in doc!!! |
49 assert ds[i]['y']==a[i][3] | 49 for example in range(len(ds)): |
50 assert ds[i]['z'].all()==a[i][0:3:2].all() | 50 assert ds[example]['x'].all()==a[example][:2].all() |
51 assert ds[example]['y']==a[example][3] | |
52 assert ds[example]['z'].all()==a[example][0:3:2].all() | |
53 # - for example in dataset:: | |
51 i=0 | 54 i=0 |
52 for x in ds('x','y'): | 55 for example in ds: |
53 assert numpy.append(x['x'],x['y']).all()==a[i].all() | 56 assert example['x'].all()==a[i][:2].all() |
57 assert example['y']==a[i][3] | |
58 assert example['z'].all()==a[i][0:3:2].all() | |
59 assert numpy.append(example['x'],example['y']).all()==a[i].all() | |
54 i+=1 | 60 i+=1 |
55 | 61 assert i==len(ds) |
62 | |
63 def test_ds_iterator(iterator1,iterator2,iterator3): | |
64 i=0 | |
65 for x,y in iterator1: | |
66 assert x.all()==a[i][:2].all() | |
67 assert y==a[i][3] | |
68 assert numpy.append(x,y).all()==a[i].all() | |
69 i+=1 | |
70 assert i==len(ds) | |
71 i=0 | |
72 for y,z in iterator2: | |
73 assert y==a[i][3] | |
74 assert z.all()==a[i][0:3:2].all() | |
75 i+=1 | |
76 assert i==len(ds) | |
77 i=0 | |
78 for x,y,z in iterator3: | |
79 assert x.all()==a[i][:2].all() | |
80 assert y==a[i][3] | |
81 assert z.all()==a[i][0:3:2].all() | |
82 assert numpy.append(x,y).all()==a[i].all() | |
83 i+=1 | |
84 assert i==len(ds) | |
85 | |
86 #not in doc!!! - for val1,val2,val3 in dataset(field1, field2,field3): | |
87 test_ds_iterator(ds('x','y'),ds('y','z'),ds('x','y','z')) | |
88 | |
89 #not in doc!!! - for val1,val2,val3 in dataset((field1, field2,field3)): | |
90 test_ds_iterator(ds(('x','y')),ds(('y','z')),ds(('x','y','z'))) | |
91 | |
92 # - for val1,val2,val3 in dataset([field1, field2,field3]): #was bugged | |
93 test_ds_iterator(ds(['x','y']),ds(['y','z']),ds(['x','y','z'])) | |
94 | |
95 # - for minibatch in dataset.minibatches([field1, field2, ...],minibatch_size=N): | |
96 for minibatch in ds.minibatches(['x','z'], minibatch_size=3): | |
97 assert len(minibatch)==2 | |
98 assert len(minibatch[0])==3 | |
99 assert len(minibatch[1])==3 | |
100 assert minibatch[0][:,0:3:2].all()==minibatch[1].all() | |
56 i=0 | 101 i=0 |
57 for x,y in ds('x','y'): | 102 for minibatch in ds.minibatches(['x','y'], minibatch_size=3): |
58 assert numpy.append(x,y).all()==a[i].all() | 103 assert len(minibatch)==2 |
59 i+=1 | 104 assert len(minibatch[0])==3 |
60 for minibatch in ds.minibatches(['x','z'], minibatch_size=3): | 105 assert len(minibatch[1])==3 |
61 assert minibatch[0][:,0:3:2].all()==minibatch[1].all() | 106 for id in range(3): |
107 assert numpy.append(minibatch[0][id],minibatch[1][id]).all()==a[i].all() | |
108 i+=1 | |
109 | |
110 # - for mini1,mini2,mini3 in dataset.minibatches([field1, field2, field3], minibatch_size=N): | |
62 for x,z in ds.minibatches(['x','z'], minibatch_size=3): | 111 for x,z in ds.minibatches(['x','z'], minibatch_size=3): |
112 assert len(x)==3 | |
113 assert len(z)==3 | |
63 assert x[:,0:3:2].all()==z.all() | 114 assert x[:,0:3:2].all()==z.all() |
115 i=0 | |
116 for x,y in ds.minibatches(['x','y'], minibatch_size=3): | |
117 assert len(x)==3 | |
118 assert len(y)==3 | |
119 for id in range(3): | |
120 assert numpy.append(minibatch[0][id],minibatch[1][id]).all()==a[i].all() | |
121 i+=1 | |
122 # - for x,y,z in dataset: # fail x,y,z order not fixed as it is a dict. | |
123 # for x,y,z in ds: | |
124 # assert x.all()==a[i][:2].all() | |
125 # assert y==a[i][3] | |
126 # assert z.all()==a[i][0:3:2].all() | |
127 # assert numpy.append(x,y).all()==a[i].all() | |
128 # i+=1 | |
64 | 129 |
65 # for minibatch in ds.minibatches(['z','y'], minibatch_size=3): | 130 # for minibatch in ds.minibatches(['z','y'], minibatch_size=3): |
66 # print minibatch | 131 # print minibatch |
67 # minibatch_iterator = ds.minibatches(fieldnames=['z','y'],n_batches=1,minibatch_size=3,offset=4) | 132 # minibatch_iterator = ds.minibatches(fieldnames=['z','y'],n_batches=1,minibatch_size=3,offset=4) |
68 # minibatch = minibatch_iterator.__iter__().next() | 133 # minibatch = minibatch_iterator.__iter__().next() |
97 assert orig[index[i]]['y']==y | 162 assert orig[index[i]]['y']==y |
98 assert orig[index[i]]['z'].all()==a[index[i]][0:3:2].all() | 163 assert orig[index[i]]['z'].all()==a[index[i]][0:3:2].all() |
99 assert orig[index[i]]['z'].all()==z.all() | 164 assert orig[index[i]]['z'].all()==z.all() |
100 i+=1 | 165 i+=1 |
101 del i | 166 del i |
167 ds[0] | |
168 if len(ds)>2: | |
169 ds[:1] | |
170 ds[1:1] | |
171 ds[1:1:1] | |
172 if len(ds)>5: | |
173 ds[[1,2,3]] | |
174 for x in ds: | |
175 pass | |
102 | 176 |
103 #ds[:n] returns a dataset with the n first examples. | 177 #ds[:n] returns a dataset with the n first examples. |
104 ds2=ds[:3] | 178 ds2=ds[:3] |
105 test_ds(ds,ds2,index=[0,1,2]) | 179 test_ds(ds,ds2,index=[0,1,2]) |
106 | 180 |
107 #ds[i1:i2:s]# returns a ds with the examples i1,i1+s,...i2-s. | 181 #ds[i1:i2:s]# returns a ds with the examples i1,i1+s,...i2-s. |
108 ds2=ds[1:7:2] | 182 ds2=ds[1:7:2] |
109 ds2[1] | |
110 test_ds(ds,ds2,[1,3,5]) | 183 test_ds(ds,ds2,[1,3,5]) |
184 | |
111 #ds[[i1,i2,...in]]# returns a ds with examples i1,i2,...in. | 185 #ds[[i1,i2,...in]]# returns a ds with examples i1,i2,...in. |
112 # ds2=ds[[4,7,2,8]]# fail??? | 186 ds2=ds[[4,7,2,8]] |
113 # assert len(ds2)==4 | 187 test_ds(ds,ds2,[4,7,2,8]) |
114 # i=0 | |
115 # index=[4,7,2,8] | |
116 # for x in ds2: | |
117 # assert ds[index[i]]['x'].all()==a[index[i]][:3].all() | |
118 # assert ds[index[i]]['x'].all()==x.all() | |
119 # assert ds[index[i]]['y']==a[index[i]][3] | |
120 # assert ds[index[i]]['y']==y | |
121 # assert ds[index[i]]['z'].all()==a[index[i]][0:3:2].all() | |
122 # assert ds[index[i]]['z'].all()==z.all() | |
123 # i+=1 | |
124 #ds[i1,i2,...]# should we accept???? | 188 #ds[i1,i2,...]# should we accept???? |
189 | |
125 #ds[fieldname]# an iterable over the values of the field fieldname across | 190 #ds[fieldname]# an iterable over the values of the field fieldname across |
126 #the ds (the iterable is obtained by default by calling valuesVStack | 191 #the ds (the iterable is obtained by default by calling valuesVStack |
127 #over the values for individual examples). | 192 #over the values for individual examples). |
128 | 193 |
129 #ds.<property># returns the value of a property associated with | 194 #ds.<property># returns the value of a property associated with |
181 | 246 |
182 test1() | 247 test1() |
183 test_LookupList() | 248 test_LookupList() |
184 test_ArrayDataSet() | 249 test_ArrayDataSet() |
185 | 250 |
186 | |
187 |