Mercurial > pylearn
comparison dataset.py @ 101:a1740a99b81f
by default, in a minibatch without any fixed number of batchs, we need to finish at the end of the dataset. Now we return a minibatch at the end event if this minibacht size != the gived minibatch_size.
author | Frederic Bastien <bastienf@iro.umontreal.ca> |
---|---|
date | Tue, 06 May 2008 16:01:53 -0400 |
parents | a8da709eb6a9 |
children | 8c0a1b11b007 |
comparison
equal
deleted
inserted
replaced
100:574f4db76022 | 101:a1740a99b81f |
---|---|
229 return self.next_row | 229 return self.next_row |
230 | 230 |
231 def next(self): | 231 def next(self): |
232 if self.n_batches and self.n_batches_done==self.n_batches: | 232 if self.n_batches and self.n_batches_done==self.n_batches: |
233 raise StopIteration | 233 raise StopIteration |
234 elif not self.n_batches and self.next_row ==self.L: | |
235 raise StopIteration | |
234 upper = self.next_row+self.minibatch_size | 236 upper = self.next_row+self.minibatch_size |
235 if upper <=self.L: | 237 if upper <=self.L: |
236 minibatch = self.iterator.next() | 238 minibatch = self.iterator.next() |
237 else: | 239 else: |
238 if not self.n_batches: | 240 if not self.n_batches: |
239 raise StopIteration | 241 upper=min(upper, self.L) |
240 # we must concatenate (vstack) the bottom and top parts of our minibatch | 242 # if their is not a fixed number of batch, we continue to the end of the dataset. |
241 # first get the beginning of our minibatch (top of dataset) | 243 # this can create a minibatch that is smaller then the minibatch_size |
242 first_part = self.dataset.minibatches_nowrap(self.fieldnames,self.L-self.next_row,1,self.next_row).next() | 244 assert (self.L-self.next_row)<=self.minibatch_size |
243 second_part = self.dataset.minibatches_nowrap(self.fieldnames,upper-self.L,1,0).next() | 245 minibatch = self.dataset.minibatches_nowrap(self.fieldnames,self.L-self.next_row,1,self.next_row).next() |
244 minibatch = Example(self.fieldnames, | 246 else: |
245 [self.dataset.valuesVStack(name,[first_part[name],second_part[name]]) | 247 # we must concatenate (vstack) the bottom and top parts of our minibatch |
246 for name in self.fieldnames]) | 248 # first get the beginning of our minibatch (top of dataset) |
249 first_part = self.dataset.minibatches_nowrap(self.fieldnames,self.L-self.next_row,1,self.next_row).next() | |
250 second_part = self.dataset.minibatches_nowrap(self.fieldnames,upper-self.L,1,0).next() | |
251 minibatch = Example(self.fieldnames, | |
252 [self.dataset.valuesVStack(name,[first_part[name],second_part[name]]) | |
253 for name in self.fieldnames]) | |
247 self.next_row=upper | 254 self.next_row=upper |
248 self.n_batches_done+=1 | 255 self.n_batches_done+=1 |
249 if upper >= self.L and self.n_batches: | 256 if upper >= self.L and self.n_batches: |
250 self.next_row -= self.L | 257 self.next_row -= self.L |
251 ds_nbatches = (self.L-self.next_row)/self.minibatch_size | 258 ds_nbatches = (self.L-self.next_row)/self.minibatch_size |