comparison dataset.py @ 231:38beb81f4e8b

Automated merge with ssh://projects@lgcm.iro.umontreal.ca/hg/pylearn
author Frederic Bastien <bastienf@iro.umontreal.ca>
date Tue, 27 May 2008 13:46:03 -0400
parents 17c5d080964b 6f55e301c687
children a70f2c973ea5 ddb88a8e9fd2
comparison
equal deleted inserted replaced
227:17c5d080964b 231:38beb81f4e8b
1041 return self.data[:,self.fields_columns[key]] 1041 return self.data[:,self.fields_columns[key]]
1042 # else we are trying to access a property of the dataset 1042 # else we are trying to access a property of the dataset
1043 assert key in self.__dict__ # else it means we are trying to access a non-existing property 1043 assert key in self.__dict__ # else it means we are trying to access a non-existing property
1044 return self.__dict__[key] 1044 return self.__dict__[key]
1045 1045
1046 1046 def __iter__(self):
1047 class ArrayDataSetIterator2(object):
1048 def __init__(self,dataset,fieldnames,minibatch_size,n_batches,offset):
1049 if fieldnames is None: fieldnames = dataset.fieldNames()
1050 # store the resulting minibatch in a lookup-list of values
1051 self.minibatch = LookupList(fieldnames,[0]*len(fieldnames))
1052 self.dataset=dataset
1053 self.minibatch_size=minibatch_size
1054 assert offset>=0 and offset<len(dataset.data)
1055 assert offset+minibatch_size<=len(dataset.data)
1056 self.current=offset
1057 def __iter__(self):
1058 return self
1059 def next(self):
1060 #@todo: we suppose that we need to stop only when minibatch_size == 1.
1061 # Otherwise, MinibatchWrapAroundIterator do it.
1062 if self.current>=self.dataset.data.shape[0]:
1063 raise StopIteration
1064 sub_data = self.dataset.data[self.current]
1065 self.minibatch._values = [sub_data[self.dataset.fields_columns[f]] for f in self.minibatch._names]
1066 self.current+=self.minibatch_size
1067 return self.minibatch
1068
1069 return ArrayDataSetIterator2(self,self.fieldNames(),1,0,0)
1070
1047 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): 1071 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset):
1048 class ArrayDataSetIterator(object): 1072 class ArrayDataSetIterator(object):
1049 def __init__(self,dataset,fieldnames,minibatch_size,n_batches,offset): 1073 def __init__(self,dataset,fieldnames,minibatch_size,n_batches,offset):
1050 if fieldnames is None: fieldnames = dataset.fieldNames() 1074 if fieldnames is None: fieldnames = dataset.fieldNames()
1051 # store the resulting minibatch in a lookup-list of values 1075 # store the resulting minibatch in a lookup-list of values
1056 assert offset+minibatch_size<=len(dataset.data) 1080 assert offset+minibatch_size<=len(dataset.data)
1057 self.current=offset 1081 self.current=offset
1058 def __iter__(self): 1082 def __iter__(self):
1059 return self 1083 return self
1060 def next(self): 1084 def next(self):
1085 #@todo: we suppose that MinibatchWrapAroundIterator stop the iterator
1061 sub_data = self.dataset.data[self.current:self.current+self.minibatch_size] 1086 sub_data = self.dataset.data[self.current:self.current+self.minibatch_size]
1062 self.minibatch._values = [sub_data[:,self.dataset.fields_columns[f]] for f in self.minibatch._names] 1087 self.minibatch._values = [sub_data[:,self.dataset.fields_columns[f]] for f in self.minibatch._names]
1063 self.current+=self.minibatch_size 1088 self.current+=self.minibatch_size
1064 return self.minibatch 1089 return self.minibatch
1065 1090