comparison dataset.py @ 242:ef70a665aaaf

Hmm... that was committed by Fred I think, I got lost by Mercurial I think
author delallea@opale.iro.umontreal.ca
date Fri, 30 May 2008 10:19:16 -0400
parents ddb88a8e9fd2 ae1d85aca858
children c8f19a9eb10f
comparison
equal deleted inserted replaced
241:ddb88a8e9fd2 242:ef70a665aaaf
985 985
986 # check consistency and complete slices definitions 986 # check consistency and complete slices definitions
987 for fieldname, fieldcolumns in self.fields_columns.items(): 987 for fieldname, fieldcolumns in self.fields_columns.items():
988 if type(fieldcolumns) is int: 988 if type(fieldcolumns) is int:
989 assert fieldcolumns>=0 and fieldcolumns<data_array.shape[1] 989 assert fieldcolumns>=0 and fieldcolumns<data_array.shape[1]
990 if 0: 990 if 1:
991 #I changed this because it didn't make sense to me, 991 #I changed this because it didn't make sense to me,
992 # and it made it more difficult to write my learner. 992 # and it made it more difficult to write my learner.
993 # If it breaks stuff, let's talk about it. 993 # If it breaks stuff, let's talk about it.
994 # - James 22/05/2008 994 # - James 22/05/2008
995 self.fields_columns[fieldname]=[fieldcolumns] 995 self.fields_columns[fieldname]=[fieldcolumns]
1052 self.dataset=dataset 1052 self.dataset=dataset
1053 self.minibatch_size=minibatch_size 1053 self.minibatch_size=minibatch_size
1054 assert offset>=0 and offset<len(dataset.data) 1054 assert offset>=0 and offset<len(dataset.data)
1055 assert offset+minibatch_size<=len(dataset.data) 1055 assert offset+minibatch_size<=len(dataset.data)
1056 self.current=offset 1056 self.current=offset
1057 self.columns = [self.dataset.fields_columns[f]
1058 for f in self.minibatch._names]
1057 def __iter__(self): 1059 def __iter__(self):
1058 return self 1060 return self
1059 def next(self): 1061 def next(self):
1060 #@todo: we suppose that we need to stop only when minibatch_size == 1. 1062 #@todo: we suppose that we need to stop only when minibatch_size == 1.
1061 # Otherwise, MinibatchWrapAroundIterator do it. 1063 # Otherwise, MinibatchWrapAroundIterator do it.
1062 if self.current>=self.dataset.data.shape[0]: 1064 if self.current>=self.dataset.data.shape[0]:
1063 raise StopIteration 1065 raise StopIteration
1064 sub_data = self.dataset.data[self.current] 1066 sub_data = self.dataset.data[self.current]
1065 self.minibatch._values = [sub_data[self.dataset.fields_columns[f]] for f in self.minibatch._names] 1067 self.minibatch._values = [sub_data[c] for c in self.columns]
1068
1066 self.current+=self.minibatch_size 1069 self.current+=self.minibatch_size
1067 return self.minibatch 1070 return self.minibatch
1068 1071
1069 return ArrayDataSetIterator2(self,self.fieldNames(),1,0,0) 1072 return ArrayDataSetIterator2(self,self.fieldNames(),1,0,0)
1070 1073