comparison dataset.py @ 322:ad8be93b3c55

small bugs fixed with NArrayDataSet
author Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
date Thu, 12 Jun 2008 12:45:42 -0400
parents f03ae06fadc8
children 09140ba68e17 9ce791fb2cbf
comparison
equal deleted inserted replaced
321:f03ae06fadc8 322:ad8be93b3c55
190 190
191 self.description = default_desc() if description is None \ 191 self.description = default_desc() if description is None \
192 else description 192 else description
193 self._attribute_names = ["description"] 193 self._attribute_names = ["description"]
194 194
195 # create dictionnary of fieldnames index
196 self.map_field_idx = dict()
197 for k in len(range(self.fieldNames())):
198 map_field_idx[ self.fieldNames[k] ] = k
199
200 195
201 attributeNames = property(lambda self: copy.copy(self._attribute_names)) 196 attributeNames = property(lambda self: copy.copy(self._attribute_names))
202 197
203 def __contains__(self, fieldname): 198 def __contains__(self, fieldname):
204 return (fieldname in self.fieldNames()) \ 199 return (fieldname in self.fieldNames()) \
1063 Every tensor is treated as a numpy array (using numpy.asarray) 1058 Every tensor is treated as a numpy array (using numpy.asarray)
1064 """ 1059 """
1065 ArrayFieldsDataSet.__init__(self,**kwargs) 1060 ArrayFieldsDataSet.__init__(self,**kwargs)
1066 assert len(data_arrays) == len(fieldnames) 1061 assert len(data_arrays) == len(fieldnames)
1067 assert len(fieldnames) > 0 1062 assert len(fieldnames) > 0
1068 ndarrays = [numpy.ndarray(a) for a in data_arrays] 1063 ndarrays = [numpy.asarray(a) for a in data_arrays]
1069 lens = [a.shape[0] for a in ndarrays] 1064 lens = [a.shape[0] for a in ndarrays]
1070 num_examples = lens[0] #they must all be equal anyway 1065 num_examples = lens[0] #they must all be equal anyway
1071 self._fieldnames = fieldnames 1066 self._fieldnames = fieldnames
1072 self._datas = [] 1067 for k in ndarrays :
1073 for k in self.ndarrays :
1074 assert k.shape[0] == num_examples 1068 assert k.shape[0] == num_examples
1075 self._datas = ndarrays 1069 self._datas = ndarrays
1076 # create dict 1070 # create dict
1077 self.map_field_idx = dict() 1071 self.map_field_idx = dict()
1078 for k in range(len(fieldnames)): 1072 for k in range(len(fieldnames)):