Mercurial > pylearn
diff dataset.py @ 257:4ad6bc9b4f03
beginning to hack on #20, fixing for Thierry
author | James Bergstra <bergstrj@iro.umontreal.ca> |
---|---|
date | Tue, 03 Jun 2008 16:05:28 -0400 |
parents | c8f19a9eb10f |
children | 19b14afe04b7 |
line wrap: on
line diff
--- a/dataset.py Tue Jun 03 13:18:33 2008 -0400 +++ b/dataset.py Tue Jun 03 16:05:28 2008 -0400 @@ -278,7 +278,7 @@ first_part = self.dataset.minibatches_nowrap(self.fieldnames,self.L-self.next_row,1,self.next_row).next() second_part = self.dataset.minibatches_nowrap(self.fieldnames,upper-self.L,1,0).next() minibatch = Example(self.fieldnames, - [self.dataset.valuesVStack(name,[first_part[name],second_part[name]]) + [self.dataset.valuesAppend(name,[first_part[name],second_part[name]]) for name in self.fieldnames]) self.next_row=upper self.n_batches_done+=1 @@ -953,16 +953,25 @@ Virtual super-class of datasets whose field values are numpy array, thus defining valuesHStack and valuesVStack for sub-classes. """ - def __init__(self,description=None,field_types=None): - DataSet.__init__(self,description,field_types) - def valuesHStack(self,fieldnames,fieldvalues): + def __init__(self, description=None, field_types=None): + DataSet.__init__(self, description, field_types) + def valuesHStack(self, fieldnames, fieldvalues): """Concatenate field values horizontally, e.g. two vectors become a longer vector, two matrices become a wider matrix, etc.""" return numpy.hstack(fieldvalues) - def valuesVStack(self,fieldname,values): + def valuesVStack(self, fieldname, values): """Concatenate field values vertically, e.g. two vectors become a two-row matrix, two matrices become a longer matrix, etc.""" return numpy.vstack(values) + def valuesAppend(self, fieldname, values): + s0 = sum([v.shape[0] for v in values]) + #TODO: there's gotta be a better way to do this! + rval = numpy.ndarray([s0] + values[0].shape[1:],dtype=values[0].dtype) + cur_row = 0 + for v in values: + rval[cur_row:cur_row+v.shape[0]] = v + cur_row += v.shape[0] + return rval class ArrayDataSet(ArrayFieldsDataSet): """ @@ -987,7 +996,7 @@ for fieldname, fieldcolumns in self.fields_columns.items(): if type(fieldcolumns) is int: assert fieldcolumns>=0 and fieldcolumns<data_array.shape[1] - if 1: + if 0: #I changed this because it didn't make sense to me, # and it made it more difficult to write my learner. # If it breaks stuff, let's talk about it.