Mercurial > pylearn
changeset 316:5fe6d0c93109
getitem in ArrayDataSet is set up again, supposed to be faster than default one, has been tested agains the default behaviour. In particular, now always return a LookupList
author | Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca> |
---|---|
date | Wed, 11 Jun 2008 16:28:09 -0400 |
parents | b48cf8dce2bf |
children | 14081904d8f3 |
files | dataset.py |
diffstat | 1 files changed, 53 insertions(+), 13 deletions(-) [+] |
line wrap: on
line diff
--- a/dataset.py Wed Jun 11 16:26:41 2008 -0400 +++ b/dataset.py Wed Jun 11 16:28:09 2008 -0400 @@ -1028,6 +1028,50 @@ become a two-row matrix, two matrices become a longer matrix, etc.""" return numpy.vstack(values) + + +class NArraysDataSet(ArrayFieldsDataSet) : + """ + An NArraysDataSet stores fields that are numpy tensor, whose first axis + iterates over examples. It's a generalization of ArrayDataSet. + """ + #@TODO not completely implemented yet + def __init__(self, data_arrays, fieldnames, **kwargs) : + """ + Construct an NArraysDataSet from a list of numpy tensor (data_arrays) and a list + of fieldnames. The number of arrays must be the same as the number of + fieldnames. Each set of numpy tensor must have the same first dimension (first + axis) corresponding to the number of examples. + + Every tensor is treated as a numpy matrix (using numpy.asmatrix) + """ + ArrayFieldsDataSet.__init(self,**kwargs) + assert len(data_arrays) == len(fieldnames) + assert len(fieldnames) > 0 + num_examples = numpy.asmatrix(data_arrays[0]).shape[0] + for k in range(data_arrays) : + assert numpy.asmatrix(data_arrays[k]).shape[0] == num_examples + self._fieldnames = fieldnames + self._datas = [] + for k in range(data_arrays) : + self._datas.append( numpy.asmatrix(data_arrays[k]) ) + raise NotImplemented + + + def __len__(self) : + """ + Length of the dataset is based on the first array = data_arrays[0], using its shape + """ + return self.datas[0].shape[0] + + def fieldNames(self) : + """ + Returns the fieldnames as set in self.__init__ + """ + return self._fieldnames + + + class ArrayDataSet(ArrayFieldsDataSet): """ An ArrayDataSet stores the fields as groups of columns in a numpy tensor, @@ -1077,7 +1121,7 @@ def __len__(self): return len(self.data) - def dontuse__getitem__(self,key): + def __getitem__(self,key): """More efficient implementation than the default __getitem__""" fieldnames=self.fields_columns.keys() values=self.fields_columns.values() @@ -1085,22 +1129,18 @@ return Example(fieldnames, [self.data[key,col] for col in values]) if type(key) is slice: - return MinibatchDataSet(Example(fieldnames, - [self.data[key,col] for col in values])) + return Example(fieldnames,[self.data[key,col] for col in values]) if type(key) is list: for i in range(len(key)): if self.hasFields(key[i]): key[i]=self.fields_columns[key[i]] - return MinibatchDataSet(Example(fieldnames, - #we must separate differently for list as numpy - # doesn't support self.data[[i1,...],[i2,...]] - # when their is more then two i1 and i2 - [self.data[key,:][:,col] - if isinstance(col,list) else - self.data[key,col] for col in values]), - - - self.valuesVStack,self.valuesHStack) + return Example(fieldnames, + #we must separate differently for list as numpy + # doesn't support self.data[[i1,...],[i2,...]] + # when their is more then two i1 and i2 + [self.data[key,:][:,col] + if isinstance(col,list) else + self.data[key,col] for col in values]) # else check for a fieldname if self.hasFields(key):