comparison dataset.py @ 270:1cafd495098c

code cleanup and small optimisation
author Frederic Bastien <bastienf@iro.umontreal.ca>
date Tue, 03 Jun 2008 16:11:59 -0400
parents 8ec867d12428
children 38e7d90a1218
comparison
equal deleted inserted replaced
256:aef979d5bad9 270:1cafd495098c
1044 # else we are trying to access a property of the dataset 1044 # else we are trying to access a property of the dataset
1045 assert key in self.__dict__ # else it means we are trying to access a non-existing property 1045 assert key in self.__dict__ # else it means we are trying to access a non-existing property
1046 return self.__dict__[key] 1046 return self.__dict__[key]
1047 1047
1048 def __iter__(self): 1048 def __iter__(self):
1049 class ArrayDataSetIterator2(object): 1049 class ArrayDataSetIteratorIter(object):
1050 def __init__(self,dataset,fieldnames,minibatch_size,n_batches,offset): 1050 def __init__(self,dataset,fieldnames):
1051 if fieldnames is None: fieldnames = dataset.fieldNames() 1051 if fieldnames is None: fieldnames = dataset.fieldNames()
1052 # store the resulting minibatch in a lookup-list of values 1052 # store the resulting minibatch in a lookup-list of values
1053 self.minibatch = LookupList(fieldnames,[0]*len(fieldnames)) 1053 self.minibatch = LookupList(fieldnames,[0]*len(fieldnames))
1054 self.dataset=dataset 1054 self.dataset=dataset
1055 self.minibatch_size=minibatch_size 1055 assert 1<=len(dataset.data)
1056 assert offset>=0 and offset<len(dataset.data) 1056 self.current=0
1057 assert offset+minibatch_size<=len(dataset.data)
1058 self.current=offset
1059 self.columns = [self.dataset.fields_columns[f] 1057 self.columns = [self.dataset.fields_columns[f]
1060 for f in self.minibatch._names] 1058 for f in self.minibatch._names]
1059 self.l = self.dataset.data.shape[0]
1061 def __iter__(self): 1060 def __iter__(self):
1062 return self 1061 return self
1063 def next(self): 1062 def next(self):
1064 #@todo: we suppose that we need to stop only when minibatch_size == 1. 1063 #@todo: we suppose that we need to stop only when minibatch_size == 1.
1065 # Otherwise, MinibatchWrapAroundIterator do it. 1064 # Otherwise, MinibatchWrapAroundIterator do it.
1066 if self.current>=self.dataset.data.shape[0]: 1065 if self.current>=self.l:
1067 raise StopIteration 1066 raise StopIteration
1068 sub_data = self.dataset.data[self.current] 1067 sub_data = self.dataset.data[self.current]
1069 self.minibatch._values = [sub_data[c] for c in self.columns] 1068 self.minibatch._values = [sub_data[c] for c in self.columns]
1070 1069
1071 self.current+=self.minibatch_size 1070 self.current+=1
1072 return self.minibatch 1071 return self.minibatch
1073 1072
1074 return ArrayDataSetIterator2(self,self.fieldNames(),1,0,0) 1073 return ArrayDataSetIteratorIter(self,self.fieldNames())
1075 1074
1076 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): 1075 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset):
1077 class ArrayDataSetIterator(object): 1076 class ArrayDataSetIterator(object):
1078 def __init__(self,dataset,fieldnames,minibatch_size,n_batches,offset): 1077 def __init__(self,dataset,fieldnames,minibatch_size,n_batches,offset):
1079 if fieldnames is None: fieldnames = dataset.fieldNames() 1078 if fieldnames is None: fieldnames = dataset.fieldNames()