Mercurial > pylearn
changeset 318:e2eab74b6a28
NArraysDataSet, a generalization ArrayDataSet where every field is a ndarray, is implemented. Not really tested aside basic stuff...
author | Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca> |
---|---|
date | Wed, 11 Jun 2008 16:59:03 -0400 |
parents | 14081904d8f3 |
children | 6441e568745e |
files | dataset.py |
diffstat | 1 files changed, 22 insertions(+), 5 deletions(-) [+] |
line wrap: on
line diff
--- a/dataset.py Wed Jun 11 16:40:47 2008 -0400 +++ b/dataset.py Wed Jun 11 16:59:03 2008 -0400 @@ -1056,24 +1056,24 @@ Every tensor is treated as a numpy matrix (using numpy.asmatrix) """ - ArrayFieldsDataSet.__init(self,**kwargs) + ArrayFieldsDataSet.__init__(self,**kwargs) assert len(data_arrays) == len(fieldnames) assert len(fieldnames) > 0 num_examples = numpy.asmatrix(data_arrays[0]).shape[0] - for k in range(data_arrays) : + for k in range(len(data_arrays)) : assert numpy.asmatrix(data_arrays[k]).shape[0] == num_examples self._fieldnames = fieldnames self._datas = [] - for k in range(data_arrays) : + for k in range(len(data_arrays)) : self._datas.append( numpy.asmatrix(data_arrays[k]) ) - raise NotImplemented + #raise NotImplemented def __len__(self) : """ Length of the dataset is based on the first array = data_arrays[0], using its shape """ - return self.datas[0].shape[0] + return self._datas[0].shape[0] def fieldNames(self) : """ @@ -1081,6 +1081,23 @@ """ return self._fieldnames + def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): + cursor = Example(fieldnames,[0]*len(fieldnames)) + fieldnames = self.fieldNames() if fieldnames is None else fieldnames + for n in xrange(n_batches): + if offset == len(self): + break + for f in range(len(cursor._names)) : + idx = self._fieldnames.index(cursor._names[f]) + assert idx >= 0 + sub_data = self._datas[f][offset : offset+minibatch_size] + cursor._values[f] = sub_data + offset += len(sub_data) #can be less than minibatch_size at end + yield cursor + + #return ArrayDataSetIterator(self,fieldnames,minibatch_size,n_batches,offset) + + class ArrayDataSet(ArrayFieldsDataSet):