changeset 318:e2eab74b6a28

NArraysDataSet, a generalization ArrayDataSet where every field is a ndarray, is implemented. Not really tested aside basic stuff...
author Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
date Wed, 11 Jun 2008 16:59:03 -0400
parents 14081904d8f3
children 6441e568745e
files dataset.py
diffstat 1 files changed, 22 insertions(+), 5 deletions(-) [+]
line wrap: on
line diff
--- a/dataset.py	Wed Jun 11 16:40:47 2008 -0400
+++ b/dataset.py	Wed Jun 11 16:59:03 2008 -0400
@@ -1056,24 +1056,24 @@
 
         Every tensor is treated as a numpy matrix (using numpy.asmatrix)
         """
-        ArrayFieldsDataSet.__init(self,**kwargs)
+        ArrayFieldsDataSet.__init__(self,**kwargs)
         assert len(data_arrays) == len(fieldnames)
         assert len(fieldnames) > 0
         num_examples = numpy.asmatrix(data_arrays[0]).shape[0]
-        for k in range(data_arrays) :
+        for k in range(len(data_arrays)) :
             assert numpy.asmatrix(data_arrays[k]).shape[0] == num_examples
         self._fieldnames = fieldnames
         self._datas = []
-        for k in range(data_arrays) :
+        for k in range(len(data_arrays)) :
             self._datas.append( numpy.asmatrix(data_arrays[k]) )
-        raise NotImplemented
+        #raise NotImplemented
 
 
     def __len__(self) :
         """
         Length of the dataset is based on the first array = data_arrays[0], using its shape
         """
-        return self.datas[0].shape[0]
+        return self._datas[0].shape[0]
 
     def fieldNames(self) :
         """
@@ -1081,6 +1081,23 @@
         """
         return self._fieldnames
 
+    def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset):
+        cursor = Example(fieldnames,[0]*len(fieldnames))
+        fieldnames = self.fieldNames() if fieldnames is None else fieldnames
+        for n in xrange(n_batches):
+            if offset == len(self):
+                break
+            for f in range(len(cursor._names)) :
+                idx = self._fieldnames.index(cursor._names[f])
+                assert idx >= 0
+                sub_data = self._datas[f][offset : offset+minibatch_size]
+                cursor._values[f] = sub_data
+            offset += len(sub_data) #can be less than minibatch_size at end
+            yield cursor
+
+        #return ArrayDataSetIterator(self,fieldnames,minibatch_size,n_batches,offset)
+
+
 
 
 class ArrayDataSet(ArrayFieldsDataSet):