# HG changeset patch
# User James Bergstra <bergstrj@iro.umontreal.ca>
# Date 1212440998 14400
# Node ID c702abb7f87557ae97a5576f46d6850623a57d72
# Parent  3156a9976183b20c5faa6e4a092fdbc3a573193f# Parent  c8f19a9eb10fac6eec3cf7c77f4e2476a1c25c24
merged

diff -r 3156a9976183 -r c702abb7f875 dataset.py
--- a/dataset.py	Mon Jun 02 17:08:17 2008 -0400
+++ b/dataset.py	Mon Jun 02 17:09:58 2008 -0400
@@ -47,14 +47,14 @@
     columns/attributes are called fields. The field value for a particular example can be an arbitrary
     python object, which depends on the particular dataset.
     
-    We call a DataSet a 'stream' when its length is unbounded (otherwise its __len__ method
+    We call a DataSet a 'stream' when its length is unbounded (in which case its __len__ method
     should return sys.maxint).
 
     A DataSet is a generator of iterators; these iterators can run through the
     examples or the fields in a variety of ways.  A DataSet need not necessarily have a finite
     or known length, so this class can be used to interface to a 'stream' which
     feeds on-line learning (however, as noted below, some operations are not
-    feasible or not recommanded on streams).
+    feasible or not recommended on streams).
 
     To iterate over examples, there are several possibilities:
      - for example in dataset:
@@ -81,7 +81,7 @@
      - for field_examples in dataset.fields():
         for example_value in field_examples:
            ...
-    but when the dataset is a stream (unbounded length), it is not recommanded to do 
+    but when the dataset is a stream (unbounded length), it is not recommended to do 
     such things because the underlying dataset may refuse to access the different fields in
     an unsynchronized ways. Hence the fields() method is illegal for streams, by default.
     The result of fields() is a L{DataSetFields} object, which iterates over fields,
@@ -599,7 +599,7 @@
     * for field_examples in dataset.fields():
         for example_value in field_examples:
            ...
-    but when the dataset is a stream (unbounded length), it is not recommanded to do 
+    but when the dataset is a stream (unbounded length), it is not recommended to do 
     such things because the underlying dataset may refuse to access the different fields in
     an unsynchronized ways. Hence the fields() method is illegal for streams, by default.
     The result of fields() is a DataSetFields object, which iterates over fields,
@@ -1016,12 +1016,13 @@
     def __getitem__(self,key):
         """More efficient implementation than the default __getitem__"""
         fieldnames=self.fields_columns.keys()
+        values=self.fields_columns.values()
         if type(key) is int:
             return Example(fieldnames,
-                           [self.data[key,self.fields_columns[f]] for f in fieldnames])
+                           [self.data[key,col] for col in values])
         if type(key) is slice:
             return MinibatchDataSet(Example(fieldnames,
-                                            [self.data[key,self.fields_columns[f]] for f in fieldnames]))
+                                            [self.data[key,col] for col in values]))
         if type(key) is list:
             for i in range(len(key)):
                 if self.hasFields(key[i]):
@@ -1030,9 +1031,10 @@
                                             #we must separate differently for list as numpy
                                             # doesn't support self.data[[i1,...],[i2,...]]
                                             # when their is more then two i1 and i2
-                                            [self.data[key,:][:,self.fields_columns[f]]
-                                             if isinstance(self.fields_columns[f],list) else
-                                             self.data[key,self.fields_columns[f]] for f in fieldnames]),
+                                            [self.data[key,:][:,col]
+                                             if isinstance(col,list) else
+                                             self.data[key,col] for col in values]),
+
 
                                     self.valuesVStack,self.valuesHStack)
 
@@ -1054,6 +1056,8 @@
                 assert offset>=0 and offset<len(dataset.data)
                 assert offset+minibatch_size<=len(dataset.data)
                 self.current=offset
+                self.columns = [self.dataset.fields_columns[f] 
+                                for f in self.minibatch._names]
             def __iter__(self):
                 return self
             def next(self):
@@ -1062,7 +1066,8 @@
                 if self.current>=self.dataset.data.shape[0]:
                     raise StopIteration
                 sub_data =  self.dataset.data[self.current]
-                self.minibatch._values = [sub_data[self.dataset.fields_columns[f]] for f in self.minibatch._names]
+                self.minibatch._values = [sub_data[c] for c in self.columns]
+
                 self.current+=self.minibatch_size
                 return self.minibatch
 
diff -r 3156a9976183 -r c702abb7f875 test_dataset.py
--- a/test_dataset.py	Mon Jun 02 17:08:17 2008 -0400
+++ b/test_dataset.py	Mon Jun 02 17:09:58 2008 -0400
@@ -194,38 +194,41 @@
     m=ds.minibatches(['x','y'],n_batches=1,minibatch_size=3,offset=4)
     assert isinstance(m,DataSet.MinibatchWrapAroundIterator)
     for x,y in m:
-        assert len(x)==3
-        assert len(y)==3
-        for id in range(3):
+        assert len(x)==m.minibatch_size
+        assert len(y)==m.minibatch_size
+        for id in range(m.minibatch_size):
             assert (numpy.append(x[id],y[id])==array[i+4]).all()
             i+=1
-    assert i==3
+    assert i==m.n_batches*m.minibatch_size
     del x,y,i,id,m
 
     i=0
     m=ds.minibatches(['x','y'],n_batches=2,minibatch_size=3,offset=4)
     assert isinstance(m,DataSet.MinibatchWrapAroundIterator)
     for x,y in m:
-        assert len(x)==3
-        assert len(y)==3
-        for id in range(3):
+        assert len(x)==m.minibatch_size
+        assert len(y)==m.minibatch_size
+        for id in range(m.minibatch_size):
             assert (numpy.append(x[id],y[id])==array[i+4]).all()
             i+=1
-    assert i==6
+    assert i==m.n_batches*m.minibatch_size
     del x,y,i,id,m
 
     i=0
     m=ds.minibatches(['x','y'],n_batches=20,minibatch_size=3,offset=4)
     assert isinstance(m,DataSet.MinibatchWrapAroundIterator)
     for x,y in m:
-        assert len(x)==3
-        assert len(y)==3
-        for id in range(3):
+        assert len(x)==m.minibatch_size
+        assert len(y)==m.minibatch_size
+        for id in range(m.minibatch_size):
             assert (numpy.append(x[id],y[id])==array[(i+4)%array.shape[0]]).all()
             i+=1
     assert i==m.n_batches*m.minibatch_size
     del x,y,i,id
 
+    #@todo: we can't do minibatch bigger then the size of the dataset???
+    assert have_raised2(ds.minibatches,['x','y'],n_batches=1,minibatch_size=len(array)+1,offset=0)
+    assert not have_raised2(ds.minibatches,['x','y'],n_batches=1,minibatch_size=len(array),offset=0)
 
 def test_ds_iterator(array,iterator1,iterator2,iterator3):
     l=len(iterator1)
@@ -494,10 +497,7 @@
     print "test_speed"
     import time
     a2 = numpy.random.rand(100000,400)
-    ds = ArrayDataSet(a2,{'x':slice(3),'y':3,'z':[0,2]})###???tuple not tested
-    ds = ArrayDataSet(a2,LookupList(['x','y','z'],[slice(3),3,[0,2]]))###???tuple not tested
     ds = ArrayDataSet(a2,{'all':slice(0,a2.shape[1],1)})
-    #assert ds==a? should this work?
     mat = numpy.random.rand(400,100)
     @print_timing
     def f_array1(a):