# HG changeset patch # User Thierry Bertin-Mahieux # Date 1213288547 14400 # Node ID f03ae06fadc8d38f83ebbc1b902d9ab89ac23091 # Parent b1da46b9b901c0d49bb49e329fc18fdf41a84d07 NArraysDataSet improved, use arrays instead of matrix, also a dictionnary of field indexes diff -r b1da46b9b901 -r f03ae06fadc8 dataset.py --- a/dataset.py Wed Jun 11 17:11:37 2008 -0400 +++ b/dataset.py Thu Jun 12 12:35:47 2008 -0400 @@ -192,6 +192,12 @@ else description self._attribute_names = ["description"] + # create dictionnary of fieldnames index + self.map_field_idx = dict() + for k in len(range(self.fieldNames())): + map_field_idx[ self.fieldNames[k] ] = k + + attributeNames = property(lambda self: copy.copy(self._attribute_names)) def __contains__(self, fieldname): @@ -1054,24 +1060,23 @@ fieldnames. Each set of numpy tensor must have the same first dimension (first axis) corresponding to the number of examples. - Every tensor is treated as a numpy matrix (using numpy.asmatrix) + Every tensor is treated as a numpy array (using numpy.asarray) """ ArrayFieldsDataSet.__init__(self,**kwargs) assert len(data_arrays) == len(fieldnames) assert len(fieldnames) > 0 - all_matrix_sizes = map(lambda x : numpy.asmatrix(x).shape[0] , data_arrays) - num_examples = max(all_matrix_sizes) - if num_examples == 1 : - # problem, do we transpose all arrays? is there only one example? - raise Exception("wrong initialization, unknow behaviour with 1-d arrays") + ndarrays = [numpy.ndarray(a) for a in data_arrays] + lens = [a.shape[0] for a in ndarrays] + num_examples = lens[0] #they must all be equal anyway self._fieldnames = fieldnames self._datas = [] - for k in range(len(data_arrays)) : - self._datas.append( numpy.asmatrix(data_arrays[k]) ) - if self._datas[-1].shape[0] == 1 and self._datas[-1].shape[1] == num_examples : - self._datas[-1] = self._datas[-1].transpose() - for k in range(len(self._datas)) : - assert self._datas[k].shape[0] == num_examples + for k in self.ndarrays : + assert k.shape[0] == num_examples + self._datas = ndarrays + # create dict + self.map_field_idx = dict() + for k in range(len(fieldnames)): + self.map_field_idx[fieldnames[k]] = k def __len__(self) : @@ -1086,6 +1091,12 @@ """ return self._fieldnames + def field_pos(self,fieldname) : + """ + Returns the index of a given fieldname. Fieldname must exists! see fieldNames(). + """ + return self.map_field_idx[fieldname] + def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): cursor = Example(fieldnames,[0]*len(fieldnames)) fieldnames = self.fieldNames() if fieldnames is None else fieldnames @@ -1093,9 +1104,8 @@ if offset == len(self): break for f in range(len(cursor._names)) : - idx = self._fieldnames.index(cursor._names[f]) - assert idx >= 0 - sub_data = self._datas[f][offset : offset+minibatch_size] + idx = self.field_pos(cursor._names[f]) + sub_data = self._datas[idx][offset : offset+minibatch_size] cursor._values[f] = sub_data offset += len(sub_data) #can be less than minibatch_size at end yield cursor