Mercurial > pylearn
comparison dataset.py @ 270:1cafd495098c
code cleanup and small optimisation
author | Frederic Bastien <bastienf@iro.umontreal.ca> |
---|---|
date | Tue, 03 Jun 2008 16:11:59 -0400 |
parents | 8ec867d12428 |
children | 38e7d90a1218 |
comparison
equal
deleted
inserted
replaced
256:aef979d5bad9 | 270:1cafd495098c |
---|---|
1044 # else we are trying to access a property of the dataset | 1044 # else we are trying to access a property of the dataset |
1045 assert key in self.__dict__ # else it means we are trying to access a non-existing property | 1045 assert key in self.__dict__ # else it means we are trying to access a non-existing property |
1046 return self.__dict__[key] | 1046 return self.__dict__[key] |
1047 | 1047 |
1048 def __iter__(self): | 1048 def __iter__(self): |
1049 class ArrayDataSetIterator2(object): | 1049 class ArrayDataSetIteratorIter(object): |
1050 def __init__(self,dataset,fieldnames,minibatch_size,n_batches,offset): | 1050 def __init__(self,dataset,fieldnames): |
1051 if fieldnames is None: fieldnames = dataset.fieldNames() | 1051 if fieldnames is None: fieldnames = dataset.fieldNames() |
1052 # store the resulting minibatch in a lookup-list of values | 1052 # store the resulting minibatch in a lookup-list of values |
1053 self.minibatch = LookupList(fieldnames,[0]*len(fieldnames)) | 1053 self.minibatch = LookupList(fieldnames,[0]*len(fieldnames)) |
1054 self.dataset=dataset | 1054 self.dataset=dataset |
1055 self.minibatch_size=minibatch_size | 1055 assert 1<=len(dataset.data) |
1056 assert offset>=0 and offset<len(dataset.data) | 1056 self.current=0 |
1057 assert offset+minibatch_size<=len(dataset.data) | |
1058 self.current=offset | |
1059 self.columns = [self.dataset.fields_columns[f] | 1057 self.columns = [self.dataset.fields_columns[f] |
1060 for f in self.minibatch._names] | 1058 for f in self.minibatch._names] |
1059 self.l = self.dataset.data.shape[0] | |
1061 def __iter__(self): | 1060 def __iter__(self): |
1062 return self | 1061 return self |
1063 def next(self): | 1062 def next(self): |
1064 #@todo: we suppose that we need to stop only when minibatch_size == 1. | 1063 #@todo: we suppose that we need to stop only when minibatch_size == 1. |
1065 # Otherwise, MinibatchWrapAroundIterator do it. | 1064 # Otherwise, MinibatchWrapAroundIterator do it. |
1066 if self.current>=self.dataset.data.shape[0]: | 1065 if self.current>=self.l: |
1067 raise StopIteration | 1066 raise StopIteration |
1068 sub_data = self.dataset.data[self.current] | 1067 sub_data = self.dataset.data[self.current] |
1069 self.minibatch._values = [sub_data[c] for c in self.columns] | 1068 self.minibatch._values = [sub_data[c] for c in self.columns] |
1070 | 1069 |
1071 self.current+=self.minibatch_size | 1070 self.current+=1 |
1072 return self.minibatch | 1071 return self.minibatch |
1073 | 1072 |
1074 return ArrayDataSetIterator2(self,self.fieldNames(),1,0,0) | 1073 return ArrayDataSetIteratorIter(self,self.fieldNames()) |
1075 | 1074 |
1076 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): | 1075 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): |
1077 class ArrayDataSetIterator(object): | 1076 class ArrayDataSetIterator(object): |
1078 def __init__(self,dataset,fieldnames,minibatch_size,n_batches,offset): | 1077 def __init__(self,dataset,fieldnames,minibatch_size,n_batches,offset): |
1079 if fieldnames is None: fieldnames = dataset.fieldNames() | 1078 if fieldnames is None: fieldnames = dataset.fieldNames() |