diff dataset.py @ 216:4b7e89b75e2b

Modified ArrayDataSet's handling of column fields. Previously, if a fieldname were associated with an integer column index (by opposition to a column range or slice) then it would be returned as a Nx1 matrix. Now if a fieldname is associated with an integer column index, then it will make a field which is a vector of length N. The old behaviour can still be achieved by associating a fieldname with the slice(col, col+1).
author James Bergstra <bergstrj@iro.umontreal.ca>
date Thu, 22 May 2008 19:07:51 -0400
parents bd728c83faff
children df3fae88ab46
line wrap: on
line diff
--- a/dataset.py	Thu May 22 17:41:14 2008 -0400
+++ b/dataset.py	Thu May 22 19:07:51 2008 -0400
@@ -245,7 +245,8 @@
             if n_batches is not None:
                 ds_nbatches = min(n_batches,ds_nbatches)
             if fieldnames:
-                assert dataset.hasFields(*fieldnames)
+                if not dataset.hasFields(*fieldnames):
+                    raise ValueError('field not present', fieldnames)
             else:
                 self.fieldnames=dataset.fieldNames()
             self.iterator = self.dataset.minibatches_nowrap(self.fieldnames,self.minibatch_size,
@@ -969,7 +970,16 @@
         for fieldname, fieldcolumns in self.fields_columns.items():
             if type(fieldcolumns) is int:
                 assert fieldcolumns>=0 and fieldcolumns<data_array.shape[1]
-                self.fields_columns[fieldname]=[fieldcolumns]
+
+                if 0:
+                    #I changed this because it didn't make sense to me,
+                    # and it made it more difficult to write my learner.
+                    # If it breaks stuff, let's talk about it.
+                    # - James 22/05/2008
+                    self.fields_columns[fieldname]=[fieldcolumns]
+                else:
+                    self.fields_columns[fieldname]=fieldcolumns
+
             elif type(fieldcolumns) is slice:
                 start,step=None,None
                 if not fieldcolumns.start: