comparison dataset.py @ 273:fa8abc813bd2

Automated merge with ssh://projects@lgcm.iro.umontreal.ca/hg/pylearn
author Frederic Bastien <bastienf@iro.umontreal.ca>
date Thu, 05 Jun 2008 11:47:44 -0400
parents fdce496c3b56 6226ebafefc3
children ed70580f2324
comparison
equal deleted inserted replaced
269:fdce496c3b56 273:fa8abc813bd2
1049 # else we are trying to access a property of the dataset 1049 # else we are trying to access a property of the dataset
1050 assert key in self.__dict__ # else it means we are trying to access a non-existing property 1050 assert key in self.__dict__ # else it means we are trying to access a non-existing property
1051 return self.__dict__[key] 1051 return self.__dict__[key]
1052 1052
1053 def __iter__(self): 1053 def __iter__(self):
1054 class ArrayDataSetIterator2(object): 1054 class ArrayDataSetIteratorIter(object):
1055 def __init__(self,dataset,fieldnames,minibatch_size,n_batches,offset): 1055 def __init__(self,dataset,fieldnames):
1056 if fieldnames is None: fieldnames = dataset.fieldNames() 1056 if fieldnames is None: fieldnames = dataset.fieldNames()
1057 # store the resulting minibatch in a lookup-list of values 1057 # store the resulting minibatch in a lookup-list of values
1058 self.minibatch = LookupList(fieldnames,[0]*len(fieldnames)) 1058 self.minibatch = LookupList(fieldnames,[0]*len(fieldnames))
1059 self.dataset=dataset 1059 self.dataset=dataset
1060 self.minibatch_size=minibatch_size 1060 self.current=0
1061 assert offset>=0 and offset<len(dataset.data)
1062 assert offset+minibatch_size<=len(dataset.data)
1063 self.current=offset
1064 self.columns = [self.dataset.fields_columns[f] 1061 self.columns = [self.dataset.fields_columns[f]
1065 for f in self.minibatch._names] 1062 for f in self.minibatch._names]
1063 self.l = self.dataset.data.shape[0]
1066 def __iter__(self): 1064 def __iter__(self):
1067 return self 1065 return self
1068 def next(self): 1066 def next(self):
1069 #@todo: we suppose that we need to stop only when minibatch_size == 1. 1067 #@todo: we suppose that we need to stop only when minibatch_size == 1.
1070 # Otherwise, MinibatchWrapAroundIterator do it. 1068 # Otherwise, MinibatchWrapAroundIterator do it.
1071 if self.current>=self.dataset.data.shape[0]: 1069 if self.current>=self.l:
1072 raise StopIteration 1070 raise StopIteration
1073 sub_data = self.dataset.data[self.current] 1071 sub_data = self.dataset.data[self.current]
1074 self.minibatch._values = [sub_data[c] for c in self.columns] 1072 self.minibatch._values = [sub_data[c] for c in self.columns]
1075 1073
1076 self.current+=self.minibatch_size 1074 self.current+=1
1077 return self.minibatch 1075 return self.minibatch
1078 1076
1079 return ArrayDataSetIterator2(self,self.fieldNames(),1,0,0) 1077 return ArrayDataSetIteratorIter(self,self.fieldNames())
1080 1078
1081 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): 1079 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset):
1082 class ArrayDataSetIterator(object): 1080 class ArrayDataSetIterator(object):
1083 def __init__(self,dataset,fieldnames,minibatch_size,n_batches,offset): 1081 def __init__(self,dataset,fieldnames,minibatch_size,n_batches,offset):
1084 if fieldnames is None: fieldnames = dataset.fieldNames() 1082 if fieldnames is None: fieldnames = dataset.fieldNames()