comparison dataset.py @ 43:e92244f30116

Corrected iterator logic errors
author bengioy@grenat.iro.umontreal.ca
date Mon, 28 Apr 2008 11:41:28 -0400
parents 9b68774fcc6b
children 5a85fda9b19b
comparison
equal deleted inserted replaced
42:9b68774fcc6b 43:e92244f30116
195 195
196 def next_index(self): 196 def next_index(self):
197 return self.next_row 197 return self.next_row
198 198
199 def next(self): 199 def next(self):
200 if self.n_batches and self.n_batches_done==self.n_batches 200 if self.n_batches and self.n_batches_done==self.n_batches:
201 raise StopIteration 201 raise StopIteration
202 upper = self.next_row+self.minibatch_size 202 upper = self.next_row+self.minibatch_size
203 if upper <=self.L: 203 if upper <=self.L:
204 minibatch = self.iterator.next() 204 minibatch = self.iterator.next()
205 else: 205 else:
212 minibatch = Example(self.fieldnames, 212 minibatch = Example(self.fieldnames,
213 [self.dataset.valuesVStack(name,[first_part[name],second_part[name]]) 213 [self.dataset.valuesVStack(name,[first_part[name],second_part[name]])
214 for name in self.fieldnames]) 214 for name in self.fieldnames])
215 self.next_row=upper 215 self.next_row=upper
216 self.n_batches_done+=1 216 self.n_batches_done+=1
217 if upper >= self.L: 217 if upper >= self.L and self.n_batches:
218 self.next_row -= self.L 218 self.next_row -= self.L
219 return minibatch 219 return minibatch
220 220
221 221
222 minibatches_fieldnames = None 222 minibatches_fieldnames = None
898 def __iter__(self): 898 def __iter__(self):
899 return self 899 return self
900 def next(self): 900 def next(self):
901 sub_data = self.dataset.data[self.current:self.current+self.minibatch_size] 901 sub_data = self.dataset.data[self.current:self.current+self.minibatch_size]
902 self._values = [sub_data[:,self.dataset.fields_columns[f]] for f in self._names] 902 self._values = [sub_data[:,self.dataset.fields_columns[f]] for f in self._names]
903 self.current+=self.minibatch_size
903 return self 904 return self
904 905
905 return Iterator(self,fieldnames,minibatch_size,n_batches,offset) 906 return Iterator(self,fieldnames,minibatch_size,n_batches,offset)
906 907
907 def supervised_learning_dataset(src_dataset,input_fields,target_fields,weight_field=None): 908 def supervised_learning_dataset(src_dataset,input_fields,target_fields,weight_field=None):