Mercurial > pylearn
comparison dataset.py @ 257:4ad6bc9b4f03
beginning to hack on #20, fixing for Thierry
author | James Bergstra <bergstrj@iro.umontreal.ca> |
---|---|
date | Tue, 03 Jun 2008 16:05:28 -0400 |
parents | c8f19a9eb10f |
children | 19b14afe04b7 |
comparison
equal
deleted
inserted
replaced
249:e93e511deb9a | 257:4ad6bc9b4f03 |
---|---|
276 # we must concatenate (vstack) the bottom and top parts of our minibatch | 276 # we must concatenate (vstack) the bottom and top parts of our minibatch |
277 # first get the beginning of our minibatch (top of dataset) | 277 # first get the beginning of our minibatch (top of dataset) |
278 first_part = self.dataset.minibatches_nowrap(self.fieldnames,self.L-self.next_row,1,self.next_row).next() | 278 first_part = self.dataset.minibatches_nowrap(self.fieldnames,self.L-self.next_row,1,self.next_row).next() |
279 second_part = self.dataset.minibatches_nowrap(self.fieldnames,upper-self.L,1,0).next() | 279 second_part = self.dataset.minibatches_nowrap(self.fieldnames,upper-self.L,1,0).next() |
280 minibatch = Example(self.fieldnames, | 280 minibatch = Example(self.fieldnames, |
281 [self.dataset.valuesVStack(name,[first_part[name],second_part[name]]) | 281 [self.dataset.valuesAppend(name,[first_part[name],second_part[name]]) |
282 for name in self.fieldnames]) | 282 for name in self.fieldnames]) |
283 self.next_row=upper | 283 self.next_row=upper |
284 self.n_batches_done+=1 | 284 self.n_batches_done+=1 |
285 if upper >= self.L and self.n_batches: | 285 if upper >= self.L and self.n_batches: |
286 self.next_row -= self.L | 286 self.next_row -= self.L |
951 class ArrayFieldsDataSet(DataSet): | 951 class ArrayFieldsDataSet(DataSet): |
952 """ | 952 """ |
953 Virtual super-class of datasets whose field values are numpy array, | 953 Virtual super-class of datasets whose field values are numpy array, |
954 thus defining valuesHStack and valuesVStack for sub-classes. | 954 thus defining valuesHStack and valuesVStack for sub-classes. |
955 """ | 955 """ |
956 def __init__(self,description=None,field_types=None): | 956 def __init__(self, description=None, field_types=None): |
957 DataSet.__init__(self,description,field_types) | 957 DataSet.__init__(self, description, field_types) |
958 def valuesHStack(self,fieldnames,fieldvalues): | 958 def valuesHStack(self, fieldnames, fieldvalues): |
959 """Concatenate field values horizontally, e.g. two vectors | 959 """Concatenate field values horizontally, e.g. two vectors |
960 become a longer vector, two matrices become a wider matrix, etc.""" | 960 become a longer vector, two matrices become a wider matrix, etc.""" |
961 return numpy.hstack(fieldvalues) | 961 return numpy.hstack(fieldvalues) |
962 def valuesVStack(self,fieldname,values): | 962 def valuesVStack(self, fieldname, values): |
963 """Concatenate field values vertically, e.g. two vectors | 963 """Concatenate field values vertically, e.g. two vectors |
964 become a two-row matrix, two matrices become a longer matrix, etc.""" | 964 become a two-row matrix, two matrices become a longer matrix, etc.""" |
965 return numpy.vstack(values) | 965 return numpy.vstack(values) |
966 def valuesAppend(self, fieldname, values): | |
967 s0 = sum([v.shape[0] for v in values]) | |
968 #TODO: there's gotta be a better way to do this! | |
969 rval = numpy.ndarray([s0] + values[0].shape[1:],dtype=values[0].dtype) | |
970 cur_row = 0 | |
971 for v in values: | |
972 rval[cur_row:cur_row+v.shape[0]] = v | |
973 cur_row += v.shape[0] | |
974 return rval | |
966 | 975 |
967 class ArrayDataSet(ArrayFieldsDataSet): | 976 class ArrayDataSet(ArrayFieldsDataSet): |
968 """ | 977 """ |
969 An ArrayDataSet stores the fields as groups of columns in a numpy tensor, | 978 An ArrayDataSet stores the fields as groups of columns in a numpy tensor, |
970 whose first axis iterates over examples, second axis determines fields. | 979 whose first axis iterates over examples, second axis determines fields. |
985 | 994 |
986 # check consistency and complete slices definitions | 995 # check consistency and complete slices definitions |
987 for fieldname, fieldcolumns in self.fields_columns.items(): | 996 for fieldname, fieldcolumns in self.fields_columns.items(): |
988 if type(fieldcolumns) is int: | 997 if type(fieldcolumns) is int: |
989 assert fieldcolumns>=0 and fieldcolumns<data_array.shape[1] | 998 assert fieldcolumns>=0 and fieldcolumns<data_array.shape[1] |
990 if 1: | 999 if 0: |
991 #I changed this because it didn't make sense to me, | 1000 #I changed this because it didn't make sense to me, |
992 # and it made it more difficult to write my learner. | 1001 # and it made it more difficult to write my learner. |
993 # If it breaks stuff, let's talk about it. | 1002 # If it breaks stuff, let's talk about it. |
994 # - James 22/05/2008 | 1003 # - James 22/05/2008 |
995 self.fields_columns[fieldname]=[fieldcolumns] | 1004 self.fields_columns[fieldname]=[fieldcolumns] |