Mercurial > pylearn
comparison test_dataset.py @ 102:4537ac630348
modifed test to accomodate the last change in dataset.py.
i.e. minibatch without a fixed number of batch return an incomplete minibatch at the end to stop at the end of the dataset.
author | Frederic Bastien <bastienf@iro.umontreal.ca> |
---|---|
date | Tue, 06 May 2008 16:03:17 -0400 |
parents | 574f4db76022 |
children | a90d85fef3d4 |
comparison
equal
deleted
inserted
replaced
101:a1740a99b81f | 102:4537ac630348 |
---|---|
110 assert (numpy.append(x,y)==array[i]).all() | 110 assert (numpy.append(x,y)==array[i]).all() |
111 i+=1 | 111 i+=1 |
112 assert i==len(ds) | 112 assert i==len(ds) |
113 del x,y,i | 113 del x,y,i |
114 | 114 |
115 def test_minibatch_size(minibatch,minibatch_size,len_ds,nb_field,nb_iter_finished): | |
116 ##full minibatch or the last minibatch | |
117 for idx in range(nb_field): | |
118 test_minibatch_field_size(minibatch[idx],minibatch_size,len_ds,nb_iter_finished) | |
119 del idx | |
120 def test_minibatch_field_size(minibatch_field,minibatch_size,len_ds,nb_iter_finished): | |
121 assert len(minibatch_field)==minibatch_size or ((nb_iter_finished*minibatch_size+len(minibatch_field))==len_ds and len(minibatch_field)<minibatch_size) | |
115 | 122 |
116 # - for minibatch in dataset.minibatches([field1, field2, ...],minibatch_size=N): | 123 # - for minibatch in dataset.minibatches([field1, field2, ...],minibatch_size=N): |
117 i=0 | 124 i=0 |
118 for minibatch in ds.minibatches(['x','z'], minibatch_size=3): | 125 mi=0 |
126 m=ds.minibatches(['x','z'], minibatch_size=3) | |
127 for minibatch in m: | |
119 assert len(minibatch)==2 | 128 assert len(minibatch)==2 |
120 assert len(minibatch[0])==3 | 129 test_minibatch_size(minibatch,m.minibatch_size,len(ds),2,mi) |
121 assert len(minibatch[1])==3 | |
122 assert (minibatch[0][:,0:3:2]==minibatch[1]).all() | 130 assert (minibatch[0][:,0:3:2]==minibatch[1]).all() |
123 i+=1 | 131 mi+=1 |
124 #assert i==#??? What shoud be the value? #option for the rest. | 132 i+=len(minibatch[0]) |
125 print i | 133 assert i==len(ds) |
126 del minibatch,i | 134 assert mi==4 |
127 i=0 | 135 del minibatch,i,m,mi |
128 for minibatch in ds.minibatches(['x','y'], minibatch_size=3): | 136 |
137 i=0 | |
138 mi=0 | |
139 m=ds.minibatches(['x','y'], minibatch_size=3) | |
140 for minibatch in m: | |
129 assert len(minibatch)==2 | 141 assert len(minibatch)==2 |
130 assert len(minibatch[0])==3 | 142 test_minibatch_size(minibatch,m.minibatch_size,len(ds),2,mi) |
131 assert len(minibatch[1])==3 | 143 mi+=1 |
132 for id in range(3): | 144 for id in range(len(minibatch[0])): |
133 assert (numpy.append(minibatch[0][id],minibatch[1][id])==a[i]).all() | 145 assert (numpy.append(minibatch[0][id],minibatch[1][id])==a[i]).all() |
134 i+=1 | 146 i+=1 |
135 #assert i==#??? What shoud be the value? | 147 assert i==len(ds) |
136 print i | 148 assert mi==4 |
137 del minibatch,i,id | 149 del minibatch,i,id,m,mi |
138 | 150 |
139 # - for mini1,mini2,mini3 in dataset.minibatches([field1, field2, field3], minibatch_size=N): | 151 # - for mini1,mini2,mini3 in dataset.minibatches([field1, field2, field3], minibatch_size=N): |
140 i=0 | 152 i=0 |
141 for x,z in ds.minibatches(['x','z'], minibatch_size=3): | 153 mi=0 |
142 assert len(x)==3 | 154 m=ds.minibatches(['x','z'], minibatch_size=3) |
143 assert len(z)==3 | 155 for x,z in m: |
156 test_minibatch_field_size(x,m.minibatch_size,len(ds),mi) | |
157 test_minibatch_field_size(z,m.minibatch_size,len(ds),mi) | |
144 assert (x[:,0:3:2]==z).all() | 158 assert (x[:,0:3:2]==z).all() |
145 i+=1 | 159 i+=len(x) |
146 #assert i==#??? What shoud be the value? | 160 mi+=1 |
147 print i | 161 assert i==len(ds) |
148 del x,z,i | 162 assert mi==4 |
149 i=0 | 163 del x,z,i,m,mi |
150 for x,y in ds.minibatches(['x','y'], minibatch_size=3): | 164 i=0 |
151 assert len(x)==3 | 165 mi=0 |
152 assert len(y)==3 | 166 m=ds.minibatches(['x','y'], minibatch_size=3) |
153 for id in range(3): | 167 for x,y in m: |
168 test_minibatch_field_size(x,m.minibatch_size,len(ds),mi) | |
169 test_minibatch_field_size(y,m.minibatch_size,len(ds),mi) | |
170 mi+=1 | |
171 for id in range(len(x)): | |
154 assert (numpy.append(x[id],y[id])==a[i]).all() | 172 assert (numpy.append(x[id],y[id])==a[i]).all() |
155 i+=1 | 173 i+=1 |
156 #assert i==#??? What shoud be the value? | 174 assert i==len(ds) |
157 print i | 175 assert mi==4 |
158 del x,y,i,id | 176 del x,y,i,id,m,mi |
159 | 177 |
160 #not in doc | 178 #not in doc |
161 i=0 | 179 i=0 |
162 for x,y in ds.minibatches(['x','y'],n_batches=1,minibatch_size=3,offset=4): | 180 for x,y in ds.minibatches(['x','y'],n_batches=1,minibatch_size=3,offset=4): |
163 assert len(x)==3 | 181 assert len(x)==3 |