comparison dataset.py @ 257:4ad6bc9b4f03

beginning to hack on #20, fixing for Thierry
author James Bergstra <bergstrj@iro.umontreal.ca>
date Tue, 03 Jun 2008 16:05:28 -0400
parents c8f19a9eb10f
children 19b14afe04b7
comparison
equal deleted inserted replaced
249:e93e511deb9a 257:4ad6bc9b4f03
276 # we must concatenate (vstack) the bottom and top parts of our minibatch 276 # we must concatenate (vstack) the bottom and top parts of our minibatch
277 # first get the beginning of our minibatch (top of dataset) 277 # first get the beginning of our minibatch (top of dataset)
278 first_part = self.dataset.minibatches_nowrap(self.fieldnames,self.L-self.next_row,1,self.next_row).next() 278 first_part = self.dataset.minibatches_nowrap(self.fieldnames,self.L-self.next_row,1,self.next_row).next()
279 second_part = self.dataset.minibatches_nowrap(self.fieldnames,upper-self.L,1,0).next() 279 second_part = self.dataset.minibatches_nowrap(self.fieldnames,upper-self.L,1,0).next()
280 minibatch = Example(self.fieldnames, 280 minibatch = Example(self.fieldnames,
281 [self.dataset.valuesVStack(name,[first_part[name],second_part[name]]) 281 [self.dataset.valuesAppend(name,[first_part[name],second_part[name]])
282 for name in self.fieldnames]) 282 for name in self.fieldnames])
283 self.next_row=upper 283 self.next_row=upper
284 self.n_batches_done+=1 284 self.n_batches_done+=1
285 if upper >= self.L and self.n_batches: 285 if upper >= self.L and self.n_batches:
286 self.next_row -= self.L 286 self.next_row -= self.L
951 class ArrayFieldsDataSet(DataSet): 951 class ArrayFieldsDataSet(DataSet):
952 """ 952 """
953 Virtual super-class of datasets whose field values are numpy array, 953 Virtual super-class of datasets whose field values are numpy array,
954 thus defining valuesHStack and valuesVStack for sub-classes. 954 thus defining valuesHStack and valuesVStack for sub-classes.
955 """ 955 """
956 def __init__(self,description=None,field_types=None): 956 def __init__(self, description=None, field_types=None):
957 DataSet.__init__(self,description,field_types) 957 DataSet.__init__(self, description, field_types)
958 def valuesHStack(self,fieldnames,fieldvalues): 958 def valuesHStack(self, fieldnames, fieldvalues):
959 """Concatenate field values horizontally, e.g. two vectors 959 """Concatenate field values horizontally, e.g. two vectors
960 become a longer vector, two matrices become a wider matrix, etc.""" 960 become a longer vector, two matrices become a wider matrix, etc."""
961 return numpy.hstack(fieldvalues) 961 return numpy.hstack(fieldvalues)
962 def valuesVStack(self,fieldname,values): 962 def valuesVStack(self, fieldname, values):
963 """Concatenate field values vertically, e.g. two vectors 963 """Concatenate field values vertically, e.g. two vectors
964 become a two-row matrix, two matrices become a longer matrix, etc.""" 964 become a two-row matrix, two matrices become a longer matrix, etc."""
965 return numpy.vstack(values) 965 return numpy.vstack(values)
966 def valuesAppend(self, fieldname, values):
967 s0 = sum([v.shape[0] for v in values])
968 #TODO: there's gotta be a better way to do this!
969 rval = numpy.ndarray([s0] + values[0].shape[1:],dtype=values[0].dtype)
970 cur_row = 0
971 for v in values:
972 rval[cur_row:cur_row+v.shape[0]] = v
973 cur_row += v.shape[0]
974 return rval
966 975
967 class ArrayDataSet(ArrayFieldsDataSet): 976 class ArrayDataSet(ArrayFieldsDataSet):
968 """ 977 """
969 An ArrayDataSet stores the fields as groups of columns in a numpy tensor, 978 An ArrayDataSet stores the fields as groups of columns in a numpy tensor,
970 whose first axis iterates over examples, second axis determines fields. 979 whose first axis iterates over examples, second axis determines fields.
985 994
986 # check consistency and complete slices definitions 995 # check consistency and complete slices definitions
987 for fieldname, fieldcolumns in self.fields_columns.items(): 996 for fieldname, fieldcolumns in self.fields_columns.items():
988 if type(fieldcolumns) is int: 997 if type(fieldcolumns) is int:
989 assert fieldcolumns>=0 and fieldcolumns<data_array.shape[1] 998 assert fieldcolumns>=0 and fieldcolumns<data_array.shape[1]
990 if 1: 999 if 0:
991 #I changed this because it didn't make sense to me, 1000 #I changed this because it didn't make sense to me,
992 # and it made it more difficult to write my learner. 1001 # and it made it more difficult to write my learner.
993 # If it breaks stuff, let's talk about it. 1002 # If it breaks stuff, let's talk about it.
994 # - James 22/05/2008 1003 # - James 22/05/2008
995 self.fields_columns[fieldname]=[fieldcolumns] 1004 self.fields_columns[fieldname]=[fieldcolumns]