comparison test_dataset.py @ 89:05dc4804357b

more test and refactoring
author Frederic Bastien <bastienf@iro.umontreal.ca>
date Mon, 05 May 2008 16:54:16 -0400
parents 3fd6879e0f76
children eee739fefdff
comparison
equal deleted inserted replaced
88:6749d18e11c8 89:05dc4804357b
42 a = numpy.random.rand(10,4) 42 a = numpy.random.rand(10,4)
43 ds = ArrayDataSet(a,{'x':slice(3),'y':3,'z':[0,2]})###???tuple not tested 43 ds = ArrayDataSet(a,{'x':slice(3),'y':3,'z':[0,2]})###???tuple not tested
44 44
45 assert len(ds)==10 45 assert len(ds)==10
46 #assert ds==a? should this work? 46 #assert ds==a? should this work?
47 for i in range(len(ds)): 47
48 assert ds[i]['x'].all()==a[i][:2].all() 48 #not in doc!!!
49 assert ds[i]['y']==a[i][3] 49 for example in range(len(ds)):
50 assert ds[i]['z'].all()==a[i][0:3:2].all() 50 assert ds[example]['x'].all()==a[example][:2].all()
51 assert ds[example]['y']==a[example][3]
52 assert ds[example]['z'].all()==a[example][0:3:2].all()
53 # - for example in dataset::
51 i=0 54 i=0
52 for x in ds('x','y'): 55 for example in ds:
53 assert numpy.append(x['x'],x['y']).all()==a[i].all() 56 assert example['x'].all()==a[i][:2].all()
57 assert example['y']==a[i][3]
58 assert example['z'].all()==a[i][0:3:2].all()
59 assert numpy.append(example['x'],example['y']).all()==a[i].all()
54 i+=1 60 i+=1
55 61 assert i==len(ds)
62
63 def test_ds_iterator(iterator1,iterator2,iterator3):
64 i=0
65 for x,y in iterator1:
66 assert x.all()==a[i][:2].all()
67 assert y==a[i][3]
68 assert numpy.append(x,y).all()==a[i].all()
69 i+=1
70 assert i==len(ds)
71 i=0
72 for y,z in iterator2:
73 assert y==a[i][3]
74 assert z.all()==a[i][0:3:2].all()
75 i+=1
76 assert i==len(ds)
77 i=0
78 for x,y,z in iterator3:
79 assert x.all()==a[i][:2].all()
80 assert y==a[i][3]
81 assert z.all()==a[i][0:3:2].all()
82 assert numpy.append(x,y).all()==a[i].all()
83 i+=1
84 assert i==len(ds)
85
86 #not in doc!!! - for val1,val2,val3 in dataset(field1, field2,field3):
87 test_ds_iterator(ds('x','y'),ds('y','z'),ds('x','y','z'))
88
89 #not in doc!!! - for val1,val2,val3 in dataset((field1, field2,field3)):
90 test_ds_iterator(ds(('x','y')),ds(('y','z')),ds(('x','y','z')))
91
92 # - for val1,val2,val3 in dataset([field1, field2,field3]): #was bugged
93 test_ds_iterator(ds(['x','y']),ds(['y','z']),ds(['x','y','z']))
94
95 # - for minibatch in dataset.minibatches([field1, field2, ...],minibatch_size=N):
96 for minibatch in ds.minibatches(['x','z'], minibatch_size=3):
97 assert len(minibatch)==2
98 assert len(minibatch[0])==3
99 assert len(minibatch[1])==3
100 assert minibatch[0][:,0:3:2].all()==minibatch[1].all()
56 i=0 101 i=0
57 for x,y in ds('x','y'): 102 for minibatch in ds.minibatches(['x','y'], minibatch_size=3):
58 assert numpy.append(x,y).all()==a[i].all() 103 assert len(minibatch)==2
59 i+=1 104 assert len(minibatch[0])==3
60 for minibatch in ds.minibatches(['x','z'], minibatch_size=3): 105 assert len(minibatch[1])==3
61 assert minibatch[0][:,0:3:2].all()==minibatch[1].all() 106 for id in range(3):
107 assert numpy.append(minibatch[0][id],minibatch[1][id]).all()==a[i].all()
108 i+=1
109
110 # - for mini1,mini2,mini3 in dataset.minibatches([field1, field2, field3], minibatch_size=N):
62 for x,z in ds.minibatches(['x','z'], minibatch_size=3): 111 for x,z in ds.minibatches(['x','z'], minibatch_size=3):
112 assert len(x)==3
113 assert len(z)==3
63 assert x[:,0:3:2].all()==z.all() 114 assert x[:,0:3:2].all()==z.all()
115 i=0
116 for x,y in ds.minibatches(['x','y'], minibatch_size=3):
117 assert len(x)==3
118 assert len(y)==3
119 for id in range(3):
120 assert numpy.append(minibatch[0][id],minibatch[1][id]).all()==a[i].all()
121 i+=1
122 # - for x,y,z in dataset: # fail x,y,z order not fixed as it is a dict.
123 # for x,y,z in ds:
124 # assert x.all()==a[i][:2].all()
125 # assert y==a[i][3]
126 # assert z.all()==a[i][0:3:2].all()
127 # assert numpy.append(x,y).all()==a[i].all()
128 # i+=1
64 129
65 # for minibatch in ds.minibatches(['z','y'], minibatch_size=3): 130 # for minibatch in ds.minibatches(['z','y'], minibatch_size=3):
66 # print minibatch 131 # print minibatch
67 # minibatch_iterator = ds.minibatches(fieldnames=['z','y'],n_batches=1,minibatch_size=3,offset=4) 132 # minibatch_iterator = ds.minibatches(fieldnames=['z','y'],n_batches=1,minibatch_size=3,offset=4)
68 # minibatch = minibatch_iterator.__iter__().next() 133 # minibatch = minibatch_iterator.__iter__().next()
97 assert orig[index[i]]['y']==y 162 assert orig[index[i]]['y']==y
98 assert orig[index[i]]['z'].all()==a[index[i]][0:3:2].all() 163 assert orig[index[i]]['z'].all()==a[index[i]][0:3:2].all()
99 assert orig[index[i]]['z'].all()==z.all() 164 assert orig[index[i]]['z'].all()==z.all()
100 i+=1 165 i+=1
101 del i 166 del i
167 ds[0]
168 if len(ds)>2:
169 ds[:1]
170 ds[1:1]
171 ds[1:1:1]
172 if len(ds)>5:
173 ds[[1,2,3]]
174 for x in ds:
175 pass
102 176
103 #ds[:n] returns a dataset with the n first examples. 177 #ds[:n] returns a dataset with the n first examples.
104 ds2=ds[:3] 178 ds2=ds[:3]
105 test_ds(ds,ds2,index=[0,1,2]) 179 test_ds(ds,ds2,index=[0,1,2])
106 180
107 #ds[i1:i2:s]# returns a ds with the examples i1,i1+s,...i2-s. 181 #ds[i1:i2:s]# returns a ds with the examples i1,i1+s,...i2-s.
108 ds2=ds[1:7:2] 182 ds2=ds[1:7:2]
109 ds2[1]
110 test_ds(ds,ds2,[1,3,5]) 183 test_ds(ds,ds2,[1,3,5])
184
111 #ds[[i1,i2,...in]]# returns a ds with examples i1,i2,...in. 185 #ds[[i1,i2,...in]]# returns a ds with examples i1,i2,...in.
112 # ds2=ds[[4,7,2,8]]# fail??? 186 ds2=ds[[4,7,2,8]]
113 # assert len(ds2)==4 187 test_ds(ds,ds2,[4,7,2,8])
114 # i=0
115 # index=[4,7,2,8]
116 # for x in ds2:
117 # assert ds[index[i]]['x'].all()==a[index[i]][:3].all()
118 # assert ds[index[i]]['x'].all()==x.all()
119 # assert ds[index[i]]['y']==a[index[i]][3]
120 # assert ds[index[i]]['y']==y
121 # assert ds[index[i]]['z'].all()==a[index[i]][0:3:2].all()
122 # assert ds[index[i]]['z'].all()==z.all()
123 # i+=1
124 #ds[i1,i2,...]# should we accept???? 188 #ds[i1,i2,...]# should we accept????
189
125 #ds[fieldname]# an iterable over the values of the field fieldname across 190 #ds[fieldname]# an iterable over the values of the field fieldname across
126 #the ds (the iterable is obtained by default by calling valuesVStack 191 #the ds (the iterable is obtained by default by calling valuesVStack
127 #over the values for individual examples). 192 #over the values for individual examples).
128 193
129 #ds.<property># returns the value of a property associated with 194 #ds.<property># returns the value of a property associated with
181 246
182 test1() 247 test1()
183 test_LookupList() 248 test_LookupList()
184 test_ArrayDataSet() 249 test_ArrayDataSet()
185 250
186
187