changeset 231:38beb81f4e8b

Automated merge with ssh://projects@lgcm.iro.umontreal.ca/hg/pylearn
author Frederic Bastien <bastienf@iro.umontreal.ca>
date Tue, 27 May 2008 13:46:03 -0400
parents 17c5d080964b (current diff) 4d1bd2513e06 (diff)
children c047238e5b3f 9e96fe8b955c
files dataset.py test_dataset.py
diffstat 3 files changed, 98 insertions(+), 2 deletions(-) [+]
line wrap: on
line diff
--- a/dataset.py	Tue May 27 13:23:05 2008 -0400
+++ b/dataset.py	Tue May 27 13:46:03 2008 -0400
@@ -1043,7 +1043,31 @@
         assert key in self.__dict__ # else it means we are trying to access a non-existing property
         return self.__dict__[key]
         
-            
+    def __iter__(self):
+        class ArrayDataSetIterator2(object):
+            def __init__(self,dataset,fieldnames,minibatch_size,n_batches,offset):
+                if fieldnames is None: fieldnames = dataset.fieldNames()
+                # store the resulting minibatch in a lookup-list of values
+                self.minibatch = LookupList(fieldnames,[0]*len(fieldnames))
+                self.dataset=dataset
+                self.minibatch_size=minibatch_size
+                assert offset>=0 and offset<len(dataset.data)
+                assert offset+minibatch_size<=len(dataset.data)
+                self.current=offset
+            def __iter__(self):
+                return self
+            def next(self):
+                #@todo: we suppose that we need to stop only when minibatch_size == 1.
+                # Otherwise, MinibatchWrapAroundIterator do it.
+                if self.current>=self.dataset.data.shape[0]:
+                    raise StopIteration
+                sub_data =  self.dataset.data[self.current]
+                self.minibatch._values = [sub_data[self.dataset.fields_columns[f]] for f in self.minibatch._names]
+                self.current+=self.minibatch_size
+                return self.minibatch
+
+        return ArrayDataSetIterator2(self,self.fieldNames(),1,0,0)
+
     def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset):
         class ArrayDataSetIterator(object):
             def __init__(self,dataset,fieldnames,minibatch_size,n_batches,offset):
@@ -1058,6 +1082,7 @@
             def __iter__(self):
                 return self
             def next(self):
+                #@todo: we suppose that MinibatchWrapAroundIterator stop the iterator
                 sub_data =  self.dataset.data[self.current:self.current+self.minibatch_size]
                 self.minibatch._values = [sub_data[:,self.dataset.fields_columns[f]] for f in self.minibatch._names]
                 self.current+=self.minibatch_size
--- a/misc.py	Tue May 27 13:23:05 2008 -0400
+++ b/misc.py	Tue May 27 13:46:03 2008 -0400
@@ -27,3 +27,13 @@
     This should run in O(n1+n2) where n1=|list1|, n2=|list2|.
     """
     return list(set.intersection(set(list1),set(list2)))
+import time
+#http://www.daniweb.com/code/snippet368.html
+def print_timing(func):
+    def wrapper(*arg):
+        t1 = time.time()
+        res = func(*arg)
+        t2 = time.time()
+        print '%s took %0.3f ms' % (func.func_name, (t2-t1)*1000.0)
+        return res
+    return wrapper
--- a/test_dataset.py	Tue May 27 13:23:05 2008 -0400
+++ b/test_dataset.py	Tue May 27 13:46:03 2008 -0400
@@ -2,6 +2,7 @@
 from dataset import *
 from math import *
 import numpy
+from misc import *
 
 def have_raised(to_eval, **var):
     have_thrown = False
@@ -463,7 +464,9 @@
     ds1 = ArrayDataSet(a,LookupList(['x','y','z'],[slice(3),3,[0,2]]))###???tuple not tested
 
     ds2 = ApplyFunctionDataSet(ds1,lambda x,y,z: (x+1,y+1,z+1), ['x','y','z'],minibatch_mode=False)
-    ds3 = ApplyFunctionDataSet(ds1,lambda x,y,z: (numpy.array(x)+1,numpy.array(y)+1,numpy.array(z)+1), ['x','y','z'],minibatch_mode=True)
+    ds3 = ApplyFunctionDataSet(ds1,lambda x,y,z: (numpy.array(x)+1,numpy.array(y)+1,numpy.array(z)+1),
+                               ['x','y','z'],
+                               minibatch_mode=True)
 
     test_all(a2,ds2)
     test_all(a2,ds3)
@@ -485,10 +488,68 @@
 def test_ArrayFieldsDataSet():
     print "test_ArrayFieldsDataSet"
     raise NotImplementedError()
+
+
+def test_speed():
+    print "test_speed"
+    import time
+    a2 = numpy.random.rand(100000,400)
+    ds = ArrayDataSet(a2,{'x':slice(3),'y':3,'z':[0,2]})###???tuple not tested
+    ds = ArrayDataSet(a2,LookupList(['x','y','z'],[slice(3),3,[0,2]]))###???tuple not tested
+    ds = ArrayDataSet(a2,{'all':slice(0,a2.shape[1],1)})
+    #assert ds==a? should this work?
+    mat = numpy.random.rand(400,100)
+    @print_timing
+    def f_array1(a):
+        a+1
+    @print_timing
+    def f_array2(a):
+        for id in range(a.shape[0]):
+#            pass
+            a[id]+1
+#            a[id]*mat
+    @print_timing
+    def f_ds(ds):
+        for ex in ds:
+#            pass
+            ex[0]+1
+#            a[id]*mat
+    @print_timing
+    def f_ds_mb1(ds,mb_size):
+        for exs in ds.minibatches(minibatch_size = mb_size):
+            for ex in exs:
+#                pass
+                ex[0]+1
+#                ex[id]*mat
+    @print_timing
+    def f_ds_mb2(ds,mb_size):
+        for exs in ds.minibatches(minibatch_size = mb_size):
+#            pass
+            exs[0]+1
+#            ex[id]*mat
+
+    f_array1(a2)
+    f_array2(a2)
+
+    f_ds(ds)
+
+    f_ds_mb1(ds,10)
+    f_ds_mb1(ds,100)
+    f_ds_mb1(ds,1000)
+    f_ds_mb1(ds,10000)
+    f_ds_mb2(ds,10)
+    f_ds_mb2(ds,100)
+    f_ds_mb2(ds,1000)
+    f_ds_mb2(ds,10000)
+
+    del a2, ds
+
 if __name__=='__main__':
     test1()
     test_LookupList()
     test_ArrayDataSet()
     test_CachedDataSet()
     test_ApplyFunctionDataSet()
+    #test_speed()
 #test pmat.py
+