comparison dataset.py @ 101:a1740a99b81f

by default, in a minibatch without any fixed number of batchs, we need to finish at the end of the dataset. Now we return a minibatch at the end event if this minibacht size != the gived minibatch_size.
author Frederic Bastien <bastienf@iro.umontreal.ca>
date Tue, 06 May 2008 16:01:53 -0400
parents a8da709eb6a9
children 8c0a1b11b007
comparison
equal deleted inserted replaced
100:574f4db76022 101:a1740a99b81f
229 return self.next_row 229 return self.next_row
230 230
231 def next(self): 231 def next(self):
232 if self.n_batches and self.n_batches_done==self.n_batches: 232 if self.n_batches and self.n_batches_done==self.n_batches:
233 raise StopIteration 233 raise StopIteration
234 elif not self.n_batches and self.next_row ==self.L:
235 raise StopIteration
234 upper = self.next_row+self.minibatch_size 236 upper = self.next_row+self.minibatch_size
235 if upper <=self.L: 237 if upper <=self.L:
236 minibatch = self.iterator.next() 238 minibatch = self.iterator.next()
237 else: 239 else:
238 if not self.n_batches: 240 if not self.n_batches:
239 raise StopIteration 241 upper=min(upper, self.L)
240 # we must concatenate (vstack) the bottom and top parts of our minibatch 242 # if their is not a fixed number of batch, we continue to the end of the dataset.
241 # first get the beginning of our minibatch (top of dataset) 243 # this can create a minibatch that is smaller then the minibatch_size
242 first_part = self.dataset.minibatches_nowrap(self.fieldnames,self.L-self.next_row,1,self.next_row).next() 244 assert (self.L-self.next_row)<=self.minibatch_size
243 second_part = self.dataset.minibatches_nowrap(self.fieldnames,upper-self.L,1,0).next() 245 minibatch = self.dataset.minibatches_nowrap(self.fieldnames,self.L-self.next_row,1,self.next_row).next()
244 minibatch = Example(self.fieldnames, 246 else:
245 [self.dataset.valuesVStack(name,[first_part[name],second_part[name]]) 247 # we must concatenate (vstack) the bottom and top parts of our minibatch
246 for name in self.fieldnames]) 248 # first get the beginning of our minibatch (top of dataset)
249 first_part = self.dataset.minibatches_nowrap(self.fieldnames,self.L-self.next_row,1,self.next_row).next()
250 second_part = self.dataset.minibatches_nowrap(self.fieldnames,upper-self.L,1,0).next()
251 minibatch = Example(self.fieldnames,
252 [self.dataset.valuesVStack(name,[first_part[name],second_part[name]])
253 for name in self.fieldnames])
247 self.next_row=upper 254 self.next_row=upper
248 self.n_batches_done+=1 255 self.n_batches_done+=1
249 if upper >= self.L and self.n_batches: 256 if upper >= self.L and self.n_batches:
250 self.next_row -= self.L 257 self.next_row -= self.L
251 ds_nbatches = (self.L-self.next_row)/self.minibatch_size 258 ds_nbatches = (self.L-self.next_row)/self.minibatch_size