comparison dataset.py @ 98:7186e4f502d1

bugfix in DataSet.minibatch to correctly wrap in all corner case.
author Frederic Bastien <bastienf@iro.umontreal.ca>
date Tue, 06 May 2008 13:50:54 -0400
parents 6fe972a7393c
children a8da709eb6a9
comparison
equal deleted inserted replaced
97:05cfe011ca20 98:7186e4f502d1
207 self.fieldnames=fieldnames 207 self.fieldnames=fieldnames
208 self.minibatch_size=minibatch_size 208 self.minibatch_size=minibatch_size
209 self.n_batches=n_batches 209 self.n_batches=n_batches
210 self.n_batches_done=0 210 self.n_batches_done=0
211 self.next_row=offset 211 self.next_row=offset
212 self.offset=offset
212 self.L=len(dataset) 213 self.L=len(dataset)
213 assert offset+minibatch_size<=self.L 214 assert offset+minibatch_size<=self.L
214 ds_nbatches = (self.L-offset)/minibatch_size 215 ds_nbatches = (self.L-self.next_row)/self.minibatch_size
215 if n_batches is not None: 216 if n_batches is not None:
216 ds_nbatches = max(n_batches,ds_nbatches) 217 ds_nbatches = min(n_batches,ds_nbatches)
217 if fieldnames: 218 if fieldnames:
218 assert dataset.hasFields(*fieldnames) 219 assert dataset.hasFields(*fieldnames)
219 else: 220 else:
220 fieldnames=dataset.fieldNames() 221 self.fieldnames=dataset.fieldNames()
221 self.iterator = dataset.minibatches_nowrap(fieldnames,minibatch_size,ds_nbatches,offset) 222 self.iterator = self.dataset.minibatches_nowrap(self.fieldnames,self.minibatch_size,
223 ds_nbatches,self.next_row)
222 224
223 def __iter__(self): 225 def __iter__(self):
224 return self 226 return self
225 227
226 def next_index(self): 228 def next_index(self):
235 else: 237 else:
236 if not self.n_batches: 238 if not self.n_batches:
237 raise StopIteration 239 raise StopIteration
238 # we must concatenate (vstack) the bottom and top parts of our minibatch 240 # we must concatenate (vstack) the bottom and top parts of our minibatch
239 # first get the beginning of our minibatch (top of dataset) 241 # first get the beginning of our minibatch (top of dataset)
240 first_part = self.dataset.minibatches_nowrap(fieldnames,self.L-self.next_row,1,self.next_row).next() 242 first_part = self.dataset.minibatches_nowrap(self.fieldnames,self.L-self.next_row,1,self.next_row).next()
241 second_part = self.dataset.minibatches_nowrap(fieldnames,upper-self.L,1,0).next() 243 second_part = self.dataset.minibatches_nowrap(self.fieldnames,upper-self.L,1,0).next()
242 minibatch = Example(self.fieldnames, 244 minibatch = Example(self.fieldnames,
243 [self.dataset.valuesVStack(name,[first_part[name],second_part[name]]) 245 [self.dataset.valuesVStack(name,[first_part[name],second_part[name]])
244 for name in self.fieldnames]) 246 for name in self.fieldnames])
245 self.next_row=upper 247 self.next_row=upper
246 self.n_batches_done+=1 248 self.n_batches_done+=1
247 if upper >= self.L and self.n_batches: 249 if upper >= self.L and self.n_batches:
248 self.next_row -= self.L 250 self.next_row -= self.L
251 ds_nbatches = (self.L-self.next_row)/self.minibatch_size
252 if self.n_batches is not None:
253 ds_nbatches = min(self.n_batches,ds_nbatches)
254 self.iterator = self.dataset.minibatches_nowrap(self.fieldnames,self.minibatch_size,
255 ds_nbatches,self.next_row)
249 return DataSetFields(MinibatchDataSet(minibatch,self.dataset.valuesVStack, 256 return DataSetFields(MinibatchDataSet(minibatch,self.dataset.valuesVStack,
250 self.dataset.valuesHStack), 257 self.dataset.valuesHStack),
251 minibatch.keys()) 258 minibatch.keys())
252 259
253 260