Mercurial > pylearn
diff dataset.py @ 319:6441e568745e
bug fixed when one matrix is an array, a 1-d matrix
author | Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca> |
---|---|
date | Wed, 11 Jun 2008 17:05:56 -0400 |
parents | e2eab74b6a28 |
children | b1da46b9b901 |
line wrap: on
line diff
--- a/dataset.py Wed Jun 11 16:59:03 2008 -0400 +++ b/dataset.py Wed Jun 11 17:05:56 2008 -0400 @@ -1059,14 +1059,16 @@ ArrayFieldsDataSet.__init__(self,**kwargs) assert len(data_arrays) == len(fieldnames) assert len(fieldnames) > 0 - num_examples = numpy.asmatrix(data_arrays[0]).shape[0] - for k in range(len(data_arrays)) : - assert numpy.asmatrix(data_arrays[k]).shape[0] == num_examples + all_matrix_sizes = map(lambda x : numpy.asmatrix(x).shape[0] , data_arrays) + num_examples = max(all_matrix_sizes) self._fieldnames = fieldnames self._datas = [] for k in range(len(data_arrays)) : self._datas.append( numpy.asmatrix(data_arrays[k]) ) - #raise NotImplemented + if self._datas[-1].shape[0] == 1 and self._datas[-1].shape[1] == num_examples : + self._datas[-1] = self._datas[-1].transpose() + for k in range(len(self._datas)) : + assert self._datas[k].shape[0] == num_examples def __len__(self) :