comparison dataset.py @ 228:6f55e301c687

optimisation of ArrayDataSet
author Frederic Bastien <bastienf@iro.umontreal.ca>
date Fri, 16 May 2008 16:38:07 -0400
parents 80731832c62b
children 38beb81f4e8b
comparison
equal deleted inserted replaced
203:80731832c62b 228:6f55e301c687
1013 return self.data[:,self.fields_columns[key]] 1013 return self.data[:,self.fields_columns[key]]
1014 # else we are trying to access a property of the dataset 1014 # else we are trying to access a property of the dataset
1015 assert key in self.__dict__ # else it means we are trying to access a non-existing property 1015 assert key in self.__dict__ # else it means we are trying to access a non-existing property
1016 return self.__dict__[key] 1016 return self.__dict__[key]
1017 1017
1018 1018 def __iter__(self):
1019 class ArrayDataSetIterator2(object):
1020 def __init__(self,dataset,fieldnames,minibatch_size,n_batches,offset):
1021 if fieldnames is None: fieldnames = dataset.fieldNames()
1022 # store the resulting minibatch in a lookup-list of values
1023 self.minibatch = LookupList(fieldnames,[0]*len(fieldnames))
1024 self.dataset=dataset
1025 self.minibatch_size=minibatch_size
1026 assert offset>=0 and offset<len(dataset.data)
1027 assert offset+minibatch_size<=len(dataset.data)
1028 self.current=offset
1029 def __iter__(self):
1030 return self
1031 def next(self):
1032 #@todo: we suppose that we need to stop only when minibatch_size == 1.
1033 # Otherwise, MinibatchWrapAroundIterator do it.
1034 if self.current>=self.dataset.data.shape[0]:
1035 raise StopIteration
1036 sub_data = self.dataset.data[self.current]
1037 self.minibatch._values = [sub_data[self.dataset.fields_columns[f]] for f in self.minibatch._names]
1038 self.current+=self.minibatch_size
1039 return self.minibatch
1040
1041 return ArrayDataSetIterator2(self,self.fieldNames(),1,0,0)
1042
1019 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): 1043 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset):
1020 class ArrayDataSetIterator(object): 1044 class ArrayDataSetIterator(object):
1021 def __init__(self,dataset,fieldnames,minibatch_size,n_batches,offset): 1045 def __init__(self,dataset,fieldnames,minibatch_size,n_batches,offset):
1022 if fieldnames is None: fieldnames = dataset.fieldNames() 1046 if fieldnames is None: fieldnames = dataset.fieldNames()
1023 # store the resulting minibatch in a lookup-list of values 1047 # store the resulting minibatch in a lookup-list of values
1028 assert offset+minibatch_size<=len(dataset.data) 1052 assert offset+minibatch_size<=len(dataset.data)
1029 self.current=offset 1053 self.current=offset
1030 def __iter__(self): 1054 def __iter__(self):
1031 return self 1055 return self
1032 def next(self): 1056 def next(self):
1057 #@todo: we suppose that MinibatchWrapAroundIterator stop the iterator
1033 sub_data = self.dataset.data[self.current:self.current+self.minibatch_size] 1058 sub_data = self.dataset.data[self.current:self.current+self.minibatch_size]
1034 self.minibatch._values = [sub_data[:,self.dataset.fields_columns[f]] for f in self.minibatch._names] 1059 self.minibatch._values = [sub_data[:,self.dataset.fields_columns[f]] for f in self.minibatch._names]
1035 self.current+=self.minibatch_size 1060 self.current+=self.minibatch_size
1036 return self.minibatch 1061 return self.minibatch
1037 1062