Mercurial > pylearn
comparison dataset.py @ 322:ad8be93b3c55
small bugs fixed with NArrayDataSet
author | Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca> |
---|---|
date | Thu, 12 Jun 2008 12:45:42 -0400 |
parents | f03ae06fadc8 |
children | 09140ba68e17 9ce791fb2cbf |
comparison
equal
deleted
inserted
replaced
321:f03ae06fadc8 | 322:ad8be93b3c55 |
---|---|
190 | 190 |
191 self.description = default_desc() if description is None \ | 191 self.description = default_desc() if description is None \ |
192 else description | 192 else description |
193 self._attribute_names = ["description"] | 193 self._attribute_names = ["description"] |
194 | 194 |
195 # create dictionnary of fieldnames index | |
196 self.map_field_idx = dict() | |
197 for k in len(range(self.fieldNames())): | |
198 map_field_idx[ self.fieldNames[k] ] = k | |
199 | |
200 | 195 |
201 attributeNames = property(lambda self: copy.copy(self._attribute_names)) | 196 attributeNames = property(lambda self: copy.copy(self._attribute_names)) |
202 | 197 |
203 def __contains__(self, fieldname): | 198 def __contains__(self, fieldname): |
204 return (fieldname in self.fieldNames()) \ | 199 return (fieldname in self.fieldNames()) \ |
1063 Every tensor is treated as a numpy array (using numpy.asarray) | 1058 Every tensor is treated as a numpy array (using numpy.asarray) |
1064 """ | 1059 """ |
1065 ArrayFieldsDataSet.__init__(self,**kwargs) | 1060 ArrayFieldsDataSet.__init__(self,**kwargs) |
1066 assert len(data_arrays) == len(fieldnames) | 1061 assert len(data_arrays) == len(fieldnames) |
1067 assert len(fieldnames) > 0 | 1062 assert len(fieldnames) > 0 |
1068 ndarrays = [numpy.ndarray(a) for a in data_arrays] | 1063 ndarrays = [numpy.asarray(a) for a in data_arrays] |
1069 lens = [a.shape[0] for a in ndarrays] | 1064 lens = [a.shape[0] for a in ndarrays] |
1070 num_examples = lens[0] #they must all be equal anyway | 1065 num_examples = lens[0] #they must all be equal anyway |
1071 self._fieldnames = fieldnames | 1066 self._fieldnames = fieldnames |
1072 self._datas = [] | 1067 for k in ndarrays : |
1073 for k in self.ndarrays : | |
1074 assert k.shape[0] == num_examples | 1068 assert k.shape[0] == num_examples |
1075 self._datas = ndarrays | 1069 self._datas = ndarrays |
1076 # create dict | 1070 # create dict |
1077 self.map_field_idx = dict() | 1071 self.map_field_idx = dict() |
1078 for k in range(len(fieldnames)): | 1072 for k in range(len(fieldnames)): |