Mercurial > pylearn
comparison dataset.py @ 272:6226ebafefc3
Automated merge with ssh://projects@lgcm.iro.umontreal.ca/hg/pylearn
author | Frederic Bastien <bastienf@iro.umontreal.ca> |
---|---|
date | Tue, 03 Jun 2008 16:13:42 -0400 |
parents | 19b14afe04b7 38e7d90a1218 |
children | fa8abc813bd2 |
comparison
equal
deleted
inserted
replaced
258:19b14afe04b7 | 272:6226ebafefc3 |
---|---|
1053 # else we are trying to access a property of the dataset | 1053 # else we are trying to access a property of the dataset |
1054 assert key in self.__dict__ # else it means we are trying to access a non-existing property | 1054 assert key in self.__dict__ # else it means we are trying to access a non-existing property |
1055 return self.__dict__[key] | 1055 return self.__dict__[key] |
1056 | 1056 |
1057 def __iter__(self): | 1057 def __iter__(self): |
1058 class ArrayDataSetIterator2(object): | 1058 class ArrayDataSetIteratorIter(object): |
1059 def __init__(self,dataset,fieldnames,minibatch_size,n_batches,offset): | 1059 def __init__(self,dataset,fieldnames): |
1060 if fieldnames is None: fieldnames = dataset.fieldNames() | 1060 if fieldnames is None: fieldnames = dataset.fieldNames() |
1061 # store the resulting minibatch in a lookup-list of values | 1061 # store the resulting minibatch in a lookup-list of values |
1062 self.minibatch = LookupList(fieldnames,[0]*len(fieldnames)) | 1062 self.minibatch = LookupList(fieldnames,[0]*len(fieldnames)) |
1063 self.dataset=dataset | 1063 self.dataset=dataset |
1064 self.minibatch_size=minibatch_size | 1064 self.current=0 |
1065 assert offset>=0 and offset<len(dataset.data) | |
1066 assert offset+minibatch_size<=len(dataset.data) | |
1067 self.current=offset | |
1068 self.columns = [self.dataset.fields_columns[f] | 1065 self.columns = [self.dataset.fields_columns[f] |
1069 for f in self.minibatch._names] | 1066 for f in self.minibatch._names] |
1067 self.l = self.dataset.data.shape[0] | |
1070 def __iter__(self): | 1068 def __iter__(self): |
1071 return self | 1069 return self |
1072 def next(self): | 1070 def next(self): |
1073 #@todo: we suppose that we need to stop only when minibatch_size == 1. | 1071 #@todo: we suppose that we need to stop only when minibatch_size == 1. |
1074 # Otherwise, MinibatchWrapAroundIterator do it. | 1072 # Otherwise, MinibatchWrapAroundIterator do it. |
1075 if self.current>=self.dataset.data.shape[0]: | 1073 if self.current>=self.l: |
1076 raise StopIteration | 1074 raise StopIteration |
1077 sub_data = self.dataset.data[self.current] | 1075 sub_data = self.dataset.data[self.current] |
1078 self.minibatch._values = [sub_data[c] for c in self.columns] | 1076 self.minibatch._values = [sub_data[c] for c in self.columns] |
1079 | 1077 |
1080 self.current+=self.minibatch_size | 1078 self.current+=1 |
1081 return self.minibatch | 1079 return self.minibatch |
1082 | 1080 |
1083 return ArrayDataSetIterator2(self,self.fieldNames(),1,0,0) | 1081 return ArrayDataSetIteratorIter(self,self.fieldNames()) |
1084 | 1082 |
1085 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): | 1083 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): |
1086 class ArrayDataSetIterator(object): | 1084 class ArrayDataSetIterator(object): |
1087 def __init__(self,dataset,fieldnames,minibatch_size,n_batches,offset): | 1085 def __init__(self,dataset,fieldnames,minibatch_size,n_batches,offset): |
1088 if fieldnames is None: fieldnames = dataset.fieldNames() | 1086 if fieldnames is None: fieldnames = dataset.fieldNames() |