Mercurial > pylearn
comparison dataset.py @ 228:6f55e301c687
optimisation of ArrayDataSet
author | Frederic Bastien <bastienf@iro.umontreal.ca> |
---|---|
date | Fri, 16 May 2008 16:38:07 -0400 |
parents | 80731832c62b |
children | 38beb81f4e8b |
comparison
equal
deleted
inserted
replaced
203:80731832c62b | 228:6f55e301c687 |
---|---|
1013 return self.data[:,self.fields_columns[key]] | 1013 return self.data[:,self.fields_columns[key]] |
1014 # else we are trying to access a property of the dataset | 1014 # else we are trying to access a property of the dataset |
1015 assert key in self.__dict__ # else it means we are trying to access a non-existing property | 1015 assert key in self.__dict__ # else it means we are trying to access a non-existing property |
1016 return self.__dict__[key] | 1016 return self.__dict__[key] |
1017 | 1017 |
1018 | 1018 def __iter__(self): |
1019 class ArrayDataSetIterator2(object): | |
1020 def __init__(self,dataset,fieldnames,minibatch_size,n_batches,offset): | |
1021 if fieldnames is None: fieldnames = dataset.fieldNames() | |
1022 # store the resulting minibatch in a lookup-list of values | |
1023 self.minibatch = LookupList(fieldnames,[0]*len(fieldnames)) | |
1024 self.dataset=dataset | |
1025 self.minibatch_size=minibatch_size | |
1026 assert offset>=0 and offset<len(dataset.data) | |
1027 assert offset+minibatch_size<=len(dataset.data) | |
1028 self.current=offset | |
1029 def __iter__(self): | |
1030 return self | |
1031 def next(self): | |
1032 #@todo: we suppose that we need to stop only when minibatch_size == 1. | |
1033 # Otherwise, MinibatchWrapAroundIterator do it. | |
1034 if self.current>=self.dataset.data.shape[0]: | |
1035 raise StopIteration | |
1036 sub_data = self.dataset.data[self.current] | |
1037 self.minibatch._values = [sub_data[self.dataset.fields_columns[f]] for f in self.minibatch._names] | |
1038 self.current+=self.minibatch_size | |
1039 return self.minibatch | |
1040 | |
1041 return ArrayDataSetIterator2(self,self.fieldNames(),1,0,0) | |
1042 | |
1019 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): | 1043 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): |
1020 class ArrayDataSetIterator(object): | 1044 class ArrayDataSetIterator(object): |
1021 def __init__(self,dataset,fieldnames,minibatch_size,n_batches,offset): | 1045 def __init__(self,dataset,fieldnames,minibatch_size,n_batches,offset): |
1022 if fieldnames is None: fieldnames = dataset.fieldNames() | 1046 if fieldnames is None: fieldnames = dataset.fieldNames() |
1023 # store the resulting minibatch in a lookup-list of values | 1047 # store the resulting minibatch in a lookup-list of values |
1028 assert offset+minibatch_size<=len(dataset.data) | 1052 assert offset+minibatch_size<=len(dataset.data) |
1029 self.current=offset | 1053 self.current=offset |
1030 def __iter__(self): | 1054 def __iter__(self): |
1031 return self | 1055 return self |
1032 def next(self): | 1056 def next(self): |
1057 #@todo: we suppose that MinibatchWrapAroundIterator stop the iterator | |
1033 sub_data = self.dataset.data[self.current:self.current+self.minibatch_size] | 1058 sub_data = self.dataset.data[self.current:self.current+self.minibatch_size] |
1034 self.minibatch._values = [sub_data[:,self.dataset.fields_columns[f]] for f in self.minibatch._names] | 1059 self.minibatch._values = [sub_data[:,self.dataset.fields_columns[f]] for f in self.minibatch._names] |
1035 self.current+=self.minibatch_size | 1060 self.current+=self.minibatch_size |
1036 return self.minibatch | 1061 return self.minibatch |
1037 | 1062 |