diff dataset.py @ 319:6441e568745e

bug fixed when one matrix is an array, a 1-d matrix
author Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
date Wed, 11 Jun 2008 17:05:56 -0400
parents e2eab74b6a28
children b1da46b9b901
line wrap: on
line diff
--- a/dataset.py	Wed Jun 11 16:59:03 2008 -0400
+++ b/dataset.py	Wed Jun 11 17:05:56 2008 -0400
@@ -1059,14 +1059,16 @@
         ArrayFieldsDataSet.__init__(self,**kwargs)
         assert len(data_arrays) == len(fieldnames)
         assert len(fieldnames) > 0
-        num_examples = numpy.asmatrix(data_arrays[0]).shape[0]
-        for k in range(len(data_arrays)) :
-            assert numpy.asmatrix(data_arrays[k]).shape[0] == num_examples
+        all_matrix_sizes = map(lambda x : numpy.asmatrix(x).shape[0] , data_arrays)
+        num_examples = max(all_matrix_sizes)
         self._fieldnames = fieldnames
         self._datas = []
         for k in range(len(data_arrays)) :
             self._datas.append( numpy.asmatrix(data_arrays[k]) )
-        #raise NotImplemented
+            if self._datas[-1].shape[0] == 1 and self._datas[-1].shape[1] == num_examples :
+                self._datas[-1] = self._datas[-1].transpose()
+        for k in range(len(self._datas)) :
+            assert self._datas[k].shape[0] == num_examples
 
 
     def __len__(self) :