comparison test_dataset.py @ 104:e1a004b21daa

more test
author Frederic Bastien <bastienf@iro.umontreal.ca>
date Tue, 06 May 2008 16:12:37 -0400
parents a90d85fef3d4
children cf9bdb1d9656
comparison
equal deleted inserted replaced
103:a90d85fef3d4 104:e1a004b21daa
122 122
123 # - for minibatch in dataset.minibatches([field1, field2, ...],minibatch_size=N): 123 # - for minibatch in dataset.minibatches([field1, field2, ...],minibatch_size=N):
124 i=0 124 i=0
125 mi=0 125 mi=0
126 m=ds.minibatches(['x','z'], minibatch_size=3) 126 m=ds.minibatches(['x','z'], minibatch_size=3)
127 assert isinstance(m,DataSet.MinibatchWrapAroundIterator)
127 for minibatch in m: 128 for minibatch in m:
128 assert len(minibatch)==2 129 assert len(minibatch)==2
129 test_minibatch_size(minibatch,m.minibatch_size,len(ds),2,mi) 130 test_minibatch_size(minibatch,m.minibatch_size,len(ds),2,mi)
130 assert (minibatch[0][:,0:3:2]==minibatch[1]).all() 131 assert (minibatch[0][:,0:3:2]==minibatch[1]).all()
131 mi+=1 132 mi+=1
135 del minibatch,i,m,mi 136 del minibatch,i,m,mi
136 137
137 i=0 138 i=0
138 mi=0 139 mi=0
139 m=ds.minibatches(['x','y'], minibatch_size=3) 140 m=ds.minibatches(['x','y'], minibatch_size=3)
141 assert isinstance(m,DataSet.MinibatchWrapAroundIterator)
140 for minibatch in m: 142 for minibatch in m:
141 assert len(minibatch)==2 143 assert len(minibatch)==2
142 test_minibatch_size(minibatch,m.minibatch_size,len(ds),2,mi) 144 test_minibatch_size(minibatch,m.minibatch_size,len(ds),2,mi)
143 mi+=1 145 mi+=1
144 for id in range(len(minibatch[0])): 146 for id in range(len(minibatch[0])):
150 152
151 # - for mini1,mini2,mini3 in dataset.minibatches([field1, field2, field3], minibatch_size=N): 153 # - for mini1,mini2,mini3 in dataset.minibatches([field1, field2, field3], minibatch_size=N):
152 i=0 154 i=0
153 mi=0 155 mi=0
154 m=ds.minibatches(['x','z'], minibatch_size=3) 156 m=ds.minibatches(['x','z'], minibatch_size=3)
157 assert isinstance(m,DataSet.MinibatchWrapAroundIterator)
155 for x,z in m: 158 for x,z in m:
156 test_minibatch_field_size(x,m.minibatch_size,len(ds),mi) 159 test_minibatch_field_size(x,m.minibatch_size,len(ds),mi)
157 test_minibatch_field_size(z,m.minibatch_size,len(ds),mi) 160 test_minibatch_field_size(z,m.minibatch_size,len(ds),mi)
158 assert (x[:,0:3:2]==z).all() 161 assert (x[:,0:3:2]==z).all()
159 i+=len(x) 162 i+=len(x)
175 assert mi==4 178 assert mi==4
176 del x,y,i,id,m,mi 179 del x,y,i,id,m,mi
177 180
178 #not in doc 181 #not in doc
179 i=0 182 i=0
180 for x,y in ds.minibatches(['x','y'],n_batches=1,minibatch_size=3,offset=4): 183 m=ds.minibatches(['x','y'],n_batches=1,minibatch_size=3,offset=4)
184 assert isinstance(m,DataSet.MinibatchWrapAroundIterator)
185 for x,y in m:
181 assert len(x)==3 186 assert len(x)==3
182 assert len(y)==3 187 assert len(y)==3
183 for id in range(3): 188 for id in range(3):
184 assert (numpy.append(x[id],y[id])==a[i+4]).all() 189 assert (numpy.append(x[id],y[id])==a[i+4]).all()
185 i+=1 190 i+=1
186 assert i==3 191 assert i==3
187 del x,y,i,id 192 del x,y,i,id,m
188 193
189 i=0 194 i=0
190 for x,y in ds.minibatches(['x','y'],n_batches=2,minibatch_size=3,offset=4): 195 m=ds.minibatches(['x','y'],n_batches=2,minibatch_size=3,offset=4)
196 assert isinstance(m,DataSet.MinibatchWrapAroundIterator)
197 for x,y in m:
191 assert len(x)==3 198 assert len(x)==3
192 assert len(y)==3 199 assert len(y)==3
193 for id in range(3): 200 for id in range(3):
194 assert (numpy.append(x[id],y[id])==a[i+4]).all() 201 assert (numpy.append(x[id],y[id])==a[i+4]).all()
195 i+=1 202 i+=1
196 assert i==6 203 assert i==6
197 del x,y,i,id 204 del x,y,i,id,m
198 205
199 i=0 206 i=0
200 m=ds.minibatches(['x','y'],n_batches=20,minibatch_size=3,offset=4) 207 m=ds.minibatches(['x','y'],n_batches=20,minibatch_size=3,offset=4)
208 assert isinstance(m,DataSet.MinibatchWrapAroundIterator)
201 for x,y in m: 209 for x,y in m:
202 assert len(x)==3 210 assert len(x)==3
203 assert len(y)==3 211 assert len(y)==3
204 for id in range(3): 212 for id in range(3):
205 assert (numpy.append(x[id],y[id])==a[(i+4)%a.shape[0]]).all() 213 assert (numpy.append(x[id],y[id])==a[(i+4)%a.shape[0]]).all()