diff dataset.py @ 257:4ad6bc9b4f03

beginning to hack on #20, fixing for Thierry
author James Bergstra <bergstrj@iro.umontreal.ca>
date Tue, 03 Jun 2008 16:05:28 -0400
parents c8f19a9eb10f
children 19b14afe04b7
line wrap: on
line diff
--- a/dataset.py	Tue Jun 03 13:18:33 2008 -0400
+++ b/dataset.py	Tue Jun 03 16:05:28 2008 -0400
@@ -278,7 +278,7 @@
                     first_part = self.dataset.minibatches_nowrap(self.fieldnames,self.L-self.next_row,1,self.next_row).next()
                     second_part = self.dataset.minibatches_nowrap(self.fieldnames,upper-self.L,1,0).next()
                     minibatch = Example(self.fieldnames,
-                                        [self.dataset.valuesVStack(name,[first_part[name],second_part[name]])
+                                        [self.dataset.valuesAppend(name,[first_part[name],second_part[name]])
                                          for name in self.fieldnames])
             self.next_row=upper
             self.n_batches_done+=1
@@ -953,16 +953,25 @@
     Virtual super-class of datasets whose field values are numpy array,
     thus defining valuesHStack and valuesVStack for sub-classes.
     """
-    def __init__(self,description=None,field_types=None):
-        DataSet.__init__(self,description,field_types)
-    def valuesHStack(self,fieldnames,fieldvalues):
+    def __init__(self, description=None, field_types=None):
+        DataSet.__init__(self, description, field_types)
+    def valuesHStack(self, fieldnames, fieldvalues):
         """Concatenate field values horizontally, e.g. two vectors
         become a longer vector, two matrices become a wider matrix, etc."""
         return numpy.hstack(fieldvalues)
-    def valuesVStack(self,fieldname,values):
+    def valuesVStack(self, fieldname, values):
         """Concatenate field values vertically, e.g. two vectors
         become a two-row matrix, two matrices become a longer matrix, etc."""
         return numpy.vstack(values)
+    def valuesAppend(self, fieldname, values):
+        s0 = sum([v.shape[0] for v in values])
+        #TODO: there's gotta be a better way to do this!
+        rval = numpy.ndarray([s0] + values[0].shape[1:],dtype=values[0].dtype)
+        cur_row = 0
+        for v in values:
+            rval[cur_row:cur_row+v.shape[0]] = v
+            cur_row += v.shape[0]
+        return rval
 
 class ArrayDataSet(ArrayFieldsDataSet):
     """
@@ -987,7 +996,7 @@
         for fieldname, fieldcolumns in self.fields_columns.items():
             if type(fieldcolumns) is int:
                 assert fieldcolumns>=0 and fieldcolumns<data_array.shape[1]
-                if 1:
+                if 0:
                     #I changed this because it didn't make sense to me,
                     # and it made it more difficult to write my learner.
                     # If it breaks stuff, let's talk about it.