comparison dataset.py @ 272:6226ebafefc3

Automated merge with ssh://projects@lgcm.iro.umontreal.ca/hg/pylearn
author Frederic Bastien <bastienf@iro.umontreal.ca>
date Tue, 03 Jun 2008 16:13:42 -0400
parents 19b14afe04b7 38e7d90a1218
children fa8abc813bd2
comparison
equal deleted inserted replaced
258:19b14afe04b7 272:6226ebafefc3
1053 # else we are trying to access a property of the dataset 1053 # else we are trying to access a property of the dataset
1054 assert key in self.__dict__ # else it means we are trying to access a non-existing property 1054 assert key in self.__dict__ # else it means we are trying to access a non-existing property
1055 return self.__dict__[key] 1055 return self.__dict__[key]
1056 1056
1057 def __iter__(self): 1057 def __iter__(self):
1058 class ArrayDataSetIterator2(object): 1058 class ArrayDataSetIteratorIter(object):
1059 def __init__(self,dataset,fieldnames,minibatch_size,n_batches,offset): 1059 def __init__(self,dataset,fieldnames):
1060 if fieldnames is None: fieldnames = dataset.fieldNames() 1060 if fieldnames is None: fieldnames = dataset.fieldNames()
1061 # store the resulting minibatch in a lookup-list of values 1061 # store the resulting minibatch in a lookup-list of values
1062 self.minibatch = LookupList(fieldnames,[0]*len(fieldnames)) 1062 self.minibatch = LookupList(fieldnames,[0]*len(fieldnames))
1063 self.dataset=dataset 1063 self.dataset=dataset
1064 self.minibatch_size=minibatch_size 1064 self.current=0
1065 assert offset>=0 and offset<len(dataset.data)
1066 assert offset+minibatch_size<=len(dataset.data)
1067 self.current=offset
1068 self.columns = [self.dataset.fields_columns[f] 1065 self.columns = [self.dataset.fields_columns[f]
1069 for f in self.minibatch._names] 1066 for f in self.minibatch._names]
1067 self.l = self.dataset.data.shape[0]
1070 def __iter__(self): 1068 def __iter__(self):
1071 return self 1069 return self
1072 def next(self): 1070 def next(self):
1073 #@todo: we suppose that we need to stop only when minibatch_size == 1. 1071 #@todo: we suppose that we need to stop only when minibatch_size == 1.
1074 # Otherwise, MinibatchWrapAroundIterator do it. 1072 # Otherwise, MinibatchWrapAroundIterator do it.
1075 if self.current>=self.dataset.data.shape[0]: 1073 if self.current>=self.l:
1076 raise StopIteration 1074 raise StopIteration
1077 sub_data = self.dataset.data[self.current] 1075 sub_data = self.dataset.data[self.current]
1078 self.minibatch._values = [sub_data[c] for c in self.columns] 1076 self.minibatch._values = [sub_data[c] for c in self.columns]
1079 1077
1080 self.current+=self.minibatch_size 1078 self.current+=1
1081 return self.minibatch 1079 return self.minibatch
1082 1080
1083 return ArrayDataSetIterator2(self,self.fieldNames(),1,0,0) 1081 return ArrayDataSetIteratorIter(self,self.fieldNames())
1084 1082
1085 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): 1083 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset):
1086 class ArrayDataSetIterator(object): 1084 class ArrayDataSetIterator(object):
1087 def __init__(self,dataset,fieldnames,minibatch_size,n_batches,offset): 1085 def __init__(self,dataset,fieldnames,minibatch_size,n_batches,offset):
1088 if fieldnames is None: fieldnames = dataset.fieldNames() 1086 if fieldnames is None: fieldnames = dataset.fieldNames()