Mercurial > pylearn
comparison dataset.py @ 273:fa8abc813bd2
Automated merge with ssh://projects@lgcm.iro.umontreal.ca/hg/pylearn
author | Frederic Bastien <bastienf@iro.umontreal.ca> |
---|---|
date | Thu, 05 Jun 2008 11:47:44 -0400 |
parents | fdce496c3b56 6226ebafefc3 |
children | ed70580f2324 |
comparison
equal
deleted
inserted
replaced
269:fdce496c3b56 | 273:fa8abc813bd2 |
---|---|
1049 # else we are trying to access a property of the dataset | 1049 # else we are trying to access a property of the dataset |
1050 assert key in self.__dict__ # else it means we are trying to access a non-existing property | 1050 assert key in self.__dict__ # else it means we are trying to access a non-existing property |
1051 return self.__dict__[key] | 1051 return self.__dict__[key] |
1052 | 1052 |
1053 def __iter__(self): | 1053 def __iter__(self): |
1054 class ArrayDataSetIterator2(object): | 1054 class ArrayDataSetIteratorIter(object): |
1055 def __init__(self,dataset,fieldnames,minibatch_size,n_batches,offset): | 1055 def __init__(self,dataset,fieldnames): |
1056 if fieldnames is None: fieldnames = dataset.fieldNames() | 1056 if fieldnames is None: fieldnames = dataset.fieldNames() |
1057 # store the resulting minibatch in a lookup-list of values | 1057 # store the resulting minibatch in a lookup-list of values |
1058 self.minibatch = LookupList(fieldnames,[0]*len(fieldnames)) | 1058 self.minibatch = LookupList(fieldnames,[0]*len(fieldnames)) |
1059 self.dataset=dataset | 1059 self.dataset=dataset |
1060 self.minibatch_size=minibatch_size | 1060 self.current=0 |
1061 assert offset>=0 and offset<len(dataset.data) | |
1062 assert offset+minibatch_size<=len(dataset.data) | |
1063 self.current=offset | |
1064 self.columns = [self.dataset.fields_columns[f] | 1061 self.columns = [self.dataset.fields_columns[f] |
1065 for f in self.minibatch._names] | 1062 for f in self.minibatch._names] |
1063 self.l = self.dataset.data.shape[0] | |
1066 def __iter__(self): | 1064 def __iter__(self): |
1067 return self | 1065 return self |
1068 def next(self): | 1066 def next(self): |
1069 #@todo: we suppose that we need to stop only when minibatch_size == 1. | 1067 #@todo: we suppose that we need to stop only when minibatch_size == 1. |
1070 # Otherwise, MinibatchWrapAroundIterator do it. | 1068 # Otherwise, MinibatchWrapAroundIterator do it. |
1071 if self.current>=self.dataset.data.shape[0]: | 1069 if self.current>=self.l: |
1072 raise StopIteration | 1070 raise StopIteration |
1073 sub_data = self.dataset.data[self.current] | 1071 sub_data = self.dataset.data[self.current] |
1074 self.minibatch._values = [sub_data[c] for c in self.columns] | 1072 self.minibatch._values = [sub_data[c] for c in self.columns] |
1075 | 1073 |
1076 self.current+=self.minibatch_size | 1074 self.current+=1 |
1077 return self.minibatch | 1075 return self.minibatch |
1078 | 1076 |
1079 return ArrayDataSetIterator2(self,self.fieldNames(),1,0,0) | 1077 return ArrayDataSetIteratorIter(self,self.fieldNames()) |
1080 | 1078 |
1081 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): | 1079 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): |
1082 class ArrayDataSetIterator(object): | 1080 class ArrayDataSetIterator(object): |
1083 def __init__(self,dataset,fieldnames,minibatch_size,n_batches,offset): | 1081 def __init__(self,dataset,fieldnames,minibatch_size,n_batches,offset): |
1084 if fieldnames is None: fieldnames = dataset.fieldNames() | 1082 if fieldnames is None: fieldnames = dataset.fieldNames() |