# HG changeset patch # User James Bergstra # Date 1212523528 14400 # Node ID 4ad6bc9b4f039e474b1e823683c33f3bb8876829 # Parent e93e511deb9a7450f59b9b590b00446fc418e79a beginning to hack on #20, fixing for Thierry diff -r e93e511deb9a -r 4ad6bc9b4f03 dataset.py --- a/dataset.py Tue Jun 03 13:18:33 2008 -0400 +++ b/dataset.py Tue Jun 03 16:05:28 2008 -0400 @@ -278,7 +278,7 @@ first_part = self.dataset.minibatches_nowrap(self.fieldnames,self.L-self.next_row,1,self.next_row).next() second_part = self.dataset.minibatches_nowrap(self.fieldnames,upper-self.L,1,0).next() minibatch = Example(self.fieldnames, - [self.dataset.valuesVStack(name,[first_part[name],second_part[name]]) + [self.dataset.valuesAppend(name,[first_part[name],second_part[name]]) for name in self.fieldnames]) self.next_row=upper self.n_batches_done+=1 @@ -953,16 +953,25 @@ Virtual super-class of datasets whose field values are numpy array, thus defining valuesHStack and valuesVStack for sub-classes. """ - def __init__(self,description=None,field_types=None): - DataSet.__init__(self,description,field_types) - def valuesHStack(self,fieldnames,fieldvalues): + def __init__(self, description=None, field_types=None): + DataSet.__init__(self, description, field_types) + def valuesHStack(self, fieldnames, fieldvalues): """Concatenate field values horizontally, e.g. two vectors become a longer vector, two matrices become a wider matrix, etc.""" return numpy.hstack(fieldvalues) - def valuesVStack(self,fieldname,values): + def valuesVStack(self, fieldname, values): """Concatenate field values vertically, e.g. two vectors become a two-row matrix, two matrices become a longer matrix, etc.""" return numpy.vstack(values) + def valuesAppend(self, fieldname, values): + s0 = sum([v.shape[0] for v in values]) + #TODO: there's gotta be a better way to do this! + rval = numpy.ndarray([s0] + values[0].shape[1:],dtype=values[0].dtype) + cur_row = 0 + for v in values: + rval[cur_row:cur_row+v.shape[0]] = v + cur_row += v.shape[0] + return rval class ArrayDataSet(ArrayFieldsDataSet): """ @@ -987,7 +996,7 @@ for fieldname, fieldcolumns in self.fields_columns.items(): if type(fieldcolumns) is int: assert fieldcolumns>=0 and fieldcolumns