comparison test_dataset.py @ 102:4537ac630348

modifed test to accomodate the last change in dataset.py. i.e. minibatch without a fixed number of batch return an incomplete minibatch at the end to stop at the end of the dataset.
author Frederic Bastien <bastienf@iro.umontreal.ca>
date Tue, 06 May 2008 16:03:17 -0400
parents 574f4db76022
children a90d85fef3d4
comparison
equal deleted inserted replaced
101:a1740a99b81f 102:4537ac630348
110 assert (numpy.append(x,y)==array[i]).all() 110 assert (numpy.append(x,y)==array[i]).all()
111 i+=1 111 i+=1
112 assert i==len(ds) 112 assert i==len(ds)
113 del x,y,i 113 del x,y,i
114 114
115 def test_minibatch_size(minibatch,minibatch_size,len_ds,nb_field,nb_iter_finished):
116 ##full minibatch or the last minibatch
117 for idx in range(nb_field):
118 test_minibatch_field_size(minibatch[idx],minibatch_size,len_ds,nb_iter_finished)
119 del idx
120 def test_minibatch_field_size(minibatch_field,minibatch_size,len_ds,nb_iter_finished):
121 assert len(minibatch_field)==minibatch_size or ((nb_iter_finished*minibatch_size+len(minibatch_field))==len_ds and len(minibatch_field)<minibatch_size)
115 122
116 # - for minibatch in dataset.minibatches([field1, field2, ...],minibatch_size=N): 123 # - for minibatch in dataset.minibatches([field1, field2, ...],minibatch_size=N):
117 i=0 124 i=0
118 for minibatch in ds.minibatches(['x','z'], minibatch_size=3): 125 mi=0
126 m=ds.minibatches(['x','z'], minibatch_size=3)
127 for minibatch in m:
119 assert len(minibatch)==2 128 assert len(minibatch)==2
120 assert len(minibatch[0])==3 129 test_minibatch_size(minibatch,m.minibatch_size,len(ds),2,mi)
121 assert len(minibatch[1])==3
122 assert (minibatch[0][:,0:3:2]==minibatch[1]).all() 130 assert (minibatch[0][:,0:3:2]==minibatch[1]).all()
123 i+=1 131 mi+=1
124 #assert i==#??? What shoud be the value? #option for the rest. 132 i+=len(minibatch[0])
125 print i 133 assert i==len(ds)
126 del minibatch,i 134 assert mi==4
127 i=0 135 del minibatch,i,m,mi
128 for minibatch in ds.minibatches(['x','y'], minibatch_size=3): 136
137 i=0
138 mi=0
139 m=ds.minibatches(['x','y'], minibatch_size=3)
140 for minibatch in m:
129 assert len(minibatch)==2 141 assert len(minibatch)==2
130 assert len(minibatch[0])==3 142 test_minibatch_size(minibatch,m.minibatch_size,len(ds),2,mi)
131 assert len(minibatch[1])==3 143 mi+=1
132 for id in range(3): 144 for id in range(len(minibatch[0])):
133 assert (numpy.append(minibatch[0][id],minibatch[1][id])==a[i]).all() 145 assert (numpy.append(minibatch[0][id],minibatch[1][id])==a[i]).all()
134 i+=1 146 i+=1
135 #assert i==#??? What shoud be the value? 147 assert i==len(ds)
136 print i 148 assert mi==4
137 del minibatch,i,id 149 del minibatch,i,id,m,mi
138 150
139 # - for mini1,mini2,mini3 in dataset.minibatches([field1, field2, field3], minibatch_size=N): 151 # - for mini1,mini2,mini3 in dataset.minibatches([field1, field2, field3], minibatch_size=N):
140 i=0 152 i=0
141 for x,z in ds.minibatches(['x','z'], minibatch_size=3): 153 mi=0
142 assert len(x)==3 154 m=ds.minibatches(['x','z'], minibatch_size=3)
143 assert len(z)==3 155 for x,z in m:
156 test_minibatch_field_size(x,m.minibatch_size,len(ds),mi)
157 test_minibatch_field_size(z,m.minibatch_size,len(ds),mi)
144 assert (x[:,0:3:2]==z).all() 158 assert (x[:,0:3:2]==z).all()
145 i+=1 159 i+=len(x)
146 #assert i==#??? What shoud be the value? 160 mi+=1
147 print i 161 assert i==len(ds)
148 del x,z,i 162 assert mi==4
149 i=0 163 del x,z,i,m,mi
150 for x,y in ds.minibatches(['x','y'], minibatch_size=3): 164 i=0
151 assert len(x)==3 165 mi=0
152 assert len(y)==3 166 m=ds.minibatches(['x','y'], minibatch_size=3)
153 for id in range(3): 167 for x,y in m:
168 test_minibatch_field_size(x,m.minibatch_size,len(ds),mi)
169 test_minibatch_field_size(y,m.minibatch_size,len(ds),mi)
170 mi+=1
171 for id in range(len(x)):
154 assert (numpy.append(x[id],y[id])==a[i]).all() 172 assert (numpy.append(x[id],y[id])==a[i]).all()
155 i+=1 173 i+=1
156 #assert i==#??? What shoud be the value? 174 assert i==len(ds)
157 print i 175 assert mi==4
158 del x,y,i,id 176 del x,y,i,id,m,mi
159 177
160 #not in doc 178 #not in doc
161 i=0 179 i=0
162 for x,y in ds.minibatches(['x','y'],n_batches=1,minibatch_size=3,offset=4): 180 for x,y in ds.minibatches(['x','y'],n_batches=1,minibatch_size=3,offset=4):
163 assert len(x)==3 181 assert len(x)==3