Mercurial > pylearn
comparison dataset.py @ 43:e92244f30116
Corrected iterator logic errors
author | bengioy@grenat.iro.umontreal.ca |
---|---|
date | Mon, 28 Apr 2008 11:41:28 -0400 |
parents | 9b68774fcc6b |
children | 5a85fda9b19b |
comparison
equal
deleted
inserted
replaced
42:9b68774fcc6b | 43:e92244f30116 |
---|---|
195 | 195 |
196 def next_index(self): | 196 def next_index(self): |
197 return self.next_row | 197 return self.next_row |
198 | 198 |
199 def next(self): | 199 def next(self): |
200 if self.n_batches and self.n_batches_done==self.n_batches | 200 if self.n_batches and self.n_batches_done==self.n_batches: |
201 raise StopIteration | 201 raise StopIteration |
202 upper = self.next_row+self.minibatch_size | 202 upper = self.next_row+self.minibatch_size |
203 if upper <=self.L: | 203 if upper <=self.L: |
204 minibatch = self.iterator.next() | 204 minibatch = self.iterator.next() |
205 else: | 205 else: |
212 minibatch = Example(self.fieldnames, | 212 minibatch = Example(self.fieldnames, |
213 [self.dataset.valuesVStack(name,[first_part[name],second_part[name]]) | 213 [self.dataset.valuesVStack(name,[first_part[name],second_part[name]]) |
214 for name in self.fieldnames]) | 214 for name in self.fieldnames]) |
215 self.next_row=upper | 215 self.next_row=upper |
216 self.n_batches_done+=1 | 216 self.n_batches_done+=1 |
217 if upper >= self.L: | 217 if upper >= self.L and self.n_batches: |
218 self.next_row -= self.L | 218 self.next_row -= self.L |
219 return minibatch | 219 return minibatch |
220 | 220 |
221 | 221 |
222 minibatches_fieldnames = None | 222 minibatches_fieldnames = None |
898 def __iter__(self): | 898 def __iter__(self): |
899 return self | 899 return self |
900 def next(self): | 900 def next(self): |
901 sub_data = self.dataset.data[self.current:self.current+self.minibatch_size] | 901 sub_data = self.dataset.data[self.current:self.current+self.minibatch_size] |
902 self._values = [sub_data[:,self.dataset.fields_columns[f]] for f in self._names] | 902 self._values = [sub_data[:,self.dataset.fields_columns[f]] for f in self._names] |
903 self.current+=self.minibatch_size | |
903 return self | 904 return self |
904 | 905 |
905 return Iterator(self,fieldnames,minibatch_size,n_batches,offset) | 906 return Iterator(self,fieldnames,minibatch_size,n_batches,offset) |
906 | 907 |
907 def supervised_learning_dataset(src_dataset,input_fields,target_fields,weight_field=None): | 908 def supervised_learning_dataset(src_dataset,input_fields,target_fields,weight_field=None): |