Mercurial > pylearn
comparison dataset.py @ 98:7186e4f502d1
bugfix in DataSet.minibatch to correctly wrap in all corner case.
author | Frederic Bastien <bastienf@iro.umontreal.ca> |
---|---|
date | Tue, 06 May 2008 13:50:54 -0400 |
parents | 6fe972a7393c |
children | a8da709eb6a9 |
comparison
equal
deleted
inserted
replaced
97:05cfe011ca20 | 98:7186e4f502d1 |
---|---|
207 self.fieldnames=fieldnames | 207 self.fieldnames=fieldnames |
208 self.minibatch_size=minibatch_size | 208 self.minibatch_size=minibatch_size |
209 self.n_batches=n_batches | 209 self.n_batches=n_batches |
210 self.n_batches_done=0 | 210 self.n_batches_done=0 |
211 self.next_row=offset | 211 self.next_row=offset |
212 self.offset=offset | |
212 self.L=len(dataset) | 213 self.L=len(dataset) |
213 assert offset+minibatch_size<=self.L | 214 assert offset+minibatch_size<=self.L |
214 ds_nbatches = (self.L-offset)/minibatch_size | 215 ds_nbatches = (self.L-self.next_row)/self.minibatch_size |
215 if n_batches is not None: | 216 if n_batches is not None: |
216 ds_nbatches = max(n_batches,ds_nbatches) | 217 ds_nbatches = min(n_batches,ds_nbatches) |
217 if fieldnames: | 218 if fieldnames: |
218 assert dataset.hasFields(*fieldnames) | 219 assert dataset.hasFields(*fieldnames) |
219 else: | 220 else: |
220 fieldnames=dataset.fieldNames() | 221 self.fieldnames=dataset.fieldNames() |
221 self.iterator = dataset.minibatches_nowrap(fieldnames,minibatch_size,ds_nbatches,offset) | 222 self.iterator = self.dataset.minibatches_nowrap(self.fieldnames,self.minibatch_size, |
223 ds_nbatches,self.next_row) | |
222 | 224 |
223 def __iter__(self): | 225 def __iter__(self): |
224 return self | 226 return self |
225 | 227 |
226 def next_index(self): | 228 def next_index(self): |
235 else: | 237 else: |
236 if not self.n_batches: | 238 if not self.n_batches: |
237 raise StopIteration | 239 raise StopIteration |
238 # we must concatenate (vstack) the bottom and top parts of our minibatch | 240 # we must concatenate (vstack) the bottom and top parts of our minibatch |
239 # first get the beginning of our minibatch (top of dataset) | 241 # first get the beginning of our minibatch (top of dataset) |
240 first_part = self.dataset.minibatches_nowrap(fieldnames,self.L-self.next_row,1,self.next_row).next() | 242 first_part = self.dataset.minibatches_nowrap(self.fieldnames,self.L-self.next_row,1,self.next_row).next() |
241 second_part = self.dataset.minibatches_nowrap(fieldnames,upper-self.L,1,0).next() | 243 second_part = self.dataset.minibatches_nowrap(self.fieldnames,upper-self.L,1,0).next() |
242 minibatch = Example(self.fieldnames, | 244 minibatch = Example(self.fieldnames, |
243 [self.dataset.valuesVStack(name,[first_part[name],second_part[name]]) | 245 [self.dataset.valuesVStack(name,[first_part[name],second_part[name]]) |
244 for name in self.fieldnames]) | 246 for name in self.fieldnames]) |
245 self.next_row=upper | 247 self.next_row=upper |
246 self.n_batches_done+=1 | 248 self.n_batches_done+=1 |
247 if upper >= self.L and self.n_batches: | 249 if upper >= self.L and self.n_batches: |
248 self.next_row -= self.L | 250 self.next_row -= self.L |
251 ds_nbatches = (self.L-self.next_row)/self.minibatch_size | |
252 if self.n_batches is not None: | |
253 ds_nbatches = min(self.n_batches,ds_nbatches) | |
254 self.iterator = self.dataset.minibatches_nowrap(self.fieldnames,self.minibatch_size, | |
255 ds_nbatches,self.next_row) | |
249 return DataSetFields(MinibatchDataSet(minibatch,self.dataset.valuesVStack, | 256 return DataSetFields(MinibatchDataSet(minibatch,self.dataset.valuesVStack, |
250 self.dataset.valuesHStack), | 257 self.dataset.valuesHStack), |
251 minibatch.keys()) | 258 minibatch.keys()) |
252 | 259 |
253 | 260 |