changeset 245:c702abb7f875

merged
author James Bergstra <bergstrj@iro.umontreal.ca>
date Mon, 02 Jun 2008 17:09:58 -0400
parents 3156a9976183 (current diff) c8f19a9eb10f (diff)
children 9502f100eda5 82ba488b2c24
files
diffstat 2 files changed, 29 insertions(+), 24 deletions(-) [+]
line wrap: on
line diff
--- a/dataset.py	Mon Jun 02 17:08:17 2008 -0400
+++ b/dataset.py	Mon Jun 02 17:09:58 2008 -0400
@@ -47,14 +47,14 @@
     columns/attributes are called fields. The field value for a particular example can be an arbitrary
     python object, which depends on the particular dataset.
     
-    We call a DataSet a 'stream' when its length is unbounded (otherwise its __len__ method
+    We call a DataSet a 'stream' when its length is unbounded (in which case its __len__ method
     should return sys.maxint).
 
     A DataSet is a generator of iterators; these iterators can run through the
     examples or the fields in a variety of ways.  A DataSet need not necessarily have a finite
     or known length, so this class can be used to interface to a 'stream' which
     feeds on-line learning (however, as noted below, some operations are not
-    feasible or not recommanded on streams).
+    feasible or not recommended on streams).
 
     To iterate over examples, there are several possibilities:
      - for example in dataset:
@@ -81,7 +81,7 @@
      - for field_examples in dataset.fields():
         for example_value in field_examples:
            ...
-    but when the dataset is a stream (unbounded length), it is not recommanded to do 
+    but when the dataset is a stream (unbounded length), it is not recommended to do 
     such things because the underlying dataset may refuse to access the different fields in
     an unsynchronized ways. Hence the fields() method is illegal for streams, by default.
     The result of fields() is a L{DataSetFields} object, which iterates over fields,
@@ -599,7 +599,7 @@
     * for field_examples in dataset.fields():
         for example_value in field_examples:
            ...
-    but when the dataset is a stream (unbounded length), it is not recommanded to do 
+    but when the dataset is a stream (unbounded length), it is not recommended to do 
     such things because the underlying dataset may refuse to access the different fields in
     an unsynchronized ways. Hence the fields() method is illegal for streams, by default.
     The result of fields() is a DataSetFields object, which iterates over fields,
@@ -1016,12 +1016,13 @@
     def __getitem__(self,key):
         """More efficient implementation than the default __getitem__"""
         fieldnames=self.fields_columns.keys()
+        values=self.fields_columns.values()
         if type(key) is int:
             return Example(fieldnames,
-                           [self.data[key,self.fields_columns[f]] for f in fieldnames])
+                           [self.data[key,col] for col in values])
         if type(key) is slice:
             return MinibatchDataSet(Example(fieldnames,
-                                            [self.data[key,self.fields_columns[f]] for f in fieldnames]))
+                                            [self.data[key,col] for col in values]))
         if type(key) is list:
             for i in range(len(key)):
                 if self.hasFields(key[i]):
@@ -1030,9 +1031,10 @@
                                             #we must separate differently for list as numpy
                                             # doesn't support self.data[[i1,...],[i2,...]]
                                             # when their is more then two i1 and i2
-                                            [self.data[key,:][:,self.fields_columns[f]]
-                                             if isinstance(self.fields_columns[f],list) else
-                                             self.data[key,self.fields_columns[f]] for f in fieldnames]),
+                                            [self.data[key,:][:,col]
+                                             if isinstance(col,list) else
+                                             self.data[key,col] for col in values]),
+
 
                                     self.valuesVStack,self.valuesHStack)
 
@@ -1054,6 +1056,8 @@
                 assert offset>=0 and offset<len(dataset.data)
                 assert offset+minibatch_size<=len(dataset.data)
                 self.current=offset
+                self.columns = [self.dataset.fields_columns[f] 
+                                for f in self.minibatch._names]
             def __iter__(self):
                 return self
             def next(self):
@@ -1062,7 +1066,8 @@
                 if self.current>=self.dataset.data.shape[0]:
                     raise StopIteration
                 sub_data =  self.dataset.data[self.current]
-                self.minibatch._values = [sub_data[self.dataset.fields_columns[f]] for f in self.minibatch._names]
+                self.minibatch._values = [sub_data[c] for c in self.columns]
+
                 self.current+=self.minibatch_size
                 return self.minibatch
 
--- a/test_dataset.py	Mon Jun 02 17:08:17 2008 -0400
+++ b/test_dataset.py	Mon Jun 02 17:09:58 2008 -0400
@@ -194,38 +194,41 @@
     m=ds.minibatches(['x','y'],n_batches=1,minibatch_size=3,offset=4)
     assert isinstance(m,DataSet.MinibatchWrapAroundIterator)
     for x,y in m:
-        assert len(x)==3
-        assert len(y)==3
-        for id in range(3):
+        assert len(x)==m.minibatch_size
+        assert len(y)==m.minibatch_size
+        for id in range(m.minibatch_size):
             assert (numpy.append(x[id],y[id])==array[i+4]).all()
             i+=1
-    assert i==3
+    assert i==m.n_batches*m.minibatch_size
     del x,y,i,id,m
 
     i=0
     m=ds.minibatches(['x','y'],n_batches=2,minibatch_size=3,offset=4)
     assert isinstance(m,DataSet.MinibatchWrapAroundIterator)
     for x,y in m:
-        assert len(x)==3
-        assert len(y)==3
-        for id in range(3):
+        assert len(x)==m.minibatch_size
+        assert len(y)==m.minibatch_size
+        for id in range(m.minibatch_size):
             assert (numpy.append(x[id],y[id])==array[i+4]).all()
             i+=1
-    assert i==6
+    assert i==m.n_batches*m.minibatch_size
     del x,y,i,id,m
 
     i=0
     m=ds.minibatches(['x','y'],n_batches=20,minibatch_size=3,offset=4)
     assert isinstance(m,DataSet.MinibatchWrapAroundIterator)
     for x,y in m:
-        assert len(x)==3
-        assert len(y)==3
-        for id in range(3):
+        assert len(x)==m.minibatch_size
+        assert len(y)==m.minibatch_size
+        for id in range(m.minibatch_size):
             assert (numpy.append(x[id],y[id])==array[(i+4)%array.shape[0]]).all()
             i+=1
     assert i==m.n_batches*m.minibatch_size
     del x,y,i,id
 
+    #@todo: we can't do minibatch bigger then the size of the dataset???
+    assert have_raised2(ds.minibatches,['x','y'],n_batches=1,minibatch_size=len(array)+1,offset=0)
+    assert not have_raised2(ds.minibatches,['x','y'],n_batches=1,minibatch_size=len(array),offset=0)
 
 def test_ds_iterator(array,iterator1,iterator2,iterator3):
     l=len(iterator1)
@@ -494,10 +497,7 @@
     print "test_speed"
     import time
     a2 = numpy.random.rand(100000,400)
-    ds = ArrayDataSet(a2,{'x':slice(3),'y':3,'z':[0,2]})###???tuple not tested
-    ds = ArrayDataSet(a2,LookupList(['x','y','z'],[slice(3),3,[0,2]]))###???tuple not tested
     ds = ArrayDataSet(a2,{'all':slice(0,a2.shape[1],1)})
-    #assert ds==a? should this work?
     mat = numpy.random.rand(400,100)
     @print_timing
     def f_array1(a):