comparison dataset.py @ 318:e2eab74b6a28

NArraysDataSet, a generalization ArrayDataSet where every field is a ndarray, is implemented. Not really tested aside basic stuff...
author Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
date Wed, 11 Jun 2008 16:59:03 -0400
parents 14081904d8f3
children 6441e568745e
comparison
equal deleted inserted replaced
317:14081904d8f3 318:e2eab74b6a28
1054 fieldnames. Each set of numpy tensor must have the same first dimension (first 1054 fieldnames. Each set of numpy tensor must have the same first dimension (first
1055 axis) corresponding to the number of examples. 1055 axis) corresponding to the number of examples.
1056 1056
1057 Every tensor is treated as a numpy matrix (using numpy.asmatrix) 1057 Every tensor is treated as a numpy matrix (using numpy.asmatrix)
1058 """ 1058 """
1059 ArrayFieldsDataSet.__init(self,**kwargs) 1059 ArrayFieldsDataSet.__init__(self,**kwargs)
1060 assert len(data_arrays) == len(fieldnames) 1060 assert len(data_arrays) == len(fieldnames)
1061 assert len(fieldnames) > 0 1061 assert len(fieldnames) > 0
1062 num_examples = numpy.asmatrix(data_arrays[0]).shape[0] 1062 num_examples = numpy.asmatrix(data_arrays[0]).shape[0]
1063 for k in range(data_arrays) : 1063 for k in range(len(data_arrays)) :
1064 assert numpy.asmatrix(data_arrays[k]).shape[0] == num_examples 1064 assert numpy.asmatrix(data_arrays[k]).shape[0] == num_examples
1065 self._fieldnames = fieldnames 1065 self._fieldnames = fieldnames
1066 self._datas = [] 1066 self._datas = []
1067 for k in range(data_arrays) : 1067 for k in range(len(data_arrays)) :
1068 self._datas.append( numpy.asmatrix(data_arrays[k]) ) 1068 self._datas.append( numpy.asmatrix(data_arrays[k]) )
1069 raise NotImplemented 1069 #raise NotImplemented
1070 1070
1071 1071
1072 def __len__(self) : 1072 def __len__(self) :
1073 """ 1073 """
1074 Length of the dataset is based on the first array = data_arrays[0], using its shape 1074 Length of the dataset is based on the first array = data_arrays[0], using its shape
1075 """ 1075 """
1076 return self.datas[0].shape[0] 1076 return self._datas[0].shape[0]
1077 1077
1078 def fieldNames(self) : 1078 def fieldNames(self) :
1079 """ 1079 """
1080 Returns the fieldnames as set in self.__init__ 1080 Returns the fieldnames as set in self.__init__
1081 """ 1081 """
1082 return self._fieldnames 1082 return self._fieldnames
1083
1084 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset):
1085 cursor = Example(fieldnames,[0]*len(fieldnames))
1086 fieldnames = self.fieldNames() if fieldnames is None else fieldnames
1087 for n in xrange(n_batches):
1088 if offset == len(self):
1089 break
1090 for f in range(len(cursor._names)) :
1091 idx = self._fieldnames.index(cursor._names[f])
1092 assert idx >= 0
1093 sub_data = self._datas[f][offset : offset+minibatch_size]
1094 cursor._values[f] = sub_data
1095 offset += len(sub_data) #can be less than minibatch_size at end
1096 yield cursor
1097
1098 #return ArrayDataSetIterator(self,fieldnames,minibatch_size,n_batches,offset)
1099
1083 1100
1084 1101
1085 1102
1086 class ArrayDataSet(ArrayFieldsDataSet): 1103 class ArrayDataSet(ArrayFieldsDataSet):
1087 """ 1104 """