# HG changeset patch # User Frederic Bastien # Date 1212524022 14400 # Node ID 6226ebafefc399b2520c9c58f6f2dac6a26dd3e6 # Parent 19b14afe04b72f919bb5f3112b05a4682a1515ce# Parent 38e7d90a12186ecf6798efcc58e8a4b49b69b7c3 Automated merge with ssh://projects@lgcm.iro.umontreal.ca/hg/pylearn diff -r 38e7d90a1218 -r 6226ebafefc3 dataset.py --- a/dataset.py Tue Jun 03 16:13:38 2008 -0400 +++ b/dataset.py Tue Jun 03 16:13:42 2008 -0400 @@ -278,7 +278,7 @@ first_part = self.dataset.minibatches_nowrap(self.fieldnames,self.L-self.next_row,1,self.next_row).next() second_part = self.dataset.minibatches_nowrap(self.fieldnames,upper-self.L,1,0).next() minibatch = Example(self.fieldnames, - [self.dataset.valuesVStack(name,[first_part[name],second_part[name]]) + [self.dataset.valuesAppend(name,[first_part[name],second_part[name]]) for name in self.fieldnames]) self.next_row=upper self.n_batches_done+=1 @@ -953,16 +953,25 @@ Virtual super-class of datasets whose field values are numpy array, thus defining valuesHStack and valuesVStack for sub-classes. """ - def __init__(self,description=None,field_types=None): - DataSet.__init__(self,description,field_types) - def valuesHStack(self,fieldnames,fieldvalues): + def __init__(self, description=None, field_types=None): + DataSet.__init__(self, description, field_types) + def valuesHStack(self, fieldnames, fieldvalues): """Concatenate field values horizontally, e.g. two vectors become a longer vector, two matrices become a wider matrix, etc.""" return numpy.hstack(fieldvalues) - def valuesVStack(self,fieldname,values): + def valuesVStack(self, fieldname, values): """Concatenate field values vertically, e.g. two vectors become a two-row matrix, two matrices become a longer matrix, etc.""" return numpy.vstack(values) + def valuesAppend(self, fieldname, values): + s0 = sum([v.shape[0] for v in values]) + #TODO: there's gotta be a better way to do this! + rval = numpy.ndarray([s0] + values[0].shape[1:],dtype=values[0].dtype) + cur_row = 0 + for v in values: + rval[cur_row:cur_row+v.shape[0]] = v + cur_row += v.shape[0] + return rval class ArrayDataSet(ArrayFieldsDataSet): """ @@ -987,7 +996,7 @@ for fieldname, fieldcolumns in self.fields_columns.items(): if type(fieldcolumns) is int: assert fieldcolumns>=0 and fieldcolumns