# HG changeset patch # User Frederic Bastien # Date 1211910363 14400 # Node ID 38beb81f4e8be838a2ec7b36eeb4754ffb049951 # Parent 17c5d080964bc9505407b9da193e42a26d420cea# Parent 4d1bd2513e0625b661f483208b62f40f6035f870 Automated merge with ssh://projects@lgcm.iro.umontreal.ca/hg/pylearn diff -r 17c5d080964b -r 38beb81f4e8b dataset.py --- a/dataset.py Tue May 27 13:23:05 2008 -0400 +++ b/dataset.py Tue May 27 13:46:03 2008 -0400 @@ -1043,7 +1043,31 @@ assert key in self.__dict__ # else it means we are trying to access a non-existing property return self.__dict__[key] - + def __iter__(self): + class ArrayDataSetIterator2(object): + def __init__(self,dataset,fieldnames,minibatch_size,n_batches,offset): + if fieldnames is None: fieldnames = dataset.fieldNames() + # store the resulting minibatch in a lookup-list of values + self.minibatch = LookupList(fieldnames,[0]*len(fieldnames)) + self.dataset=dataset + self.minibatch_size=minibatch_size + assert offset>=0 and offset=self.dataset.data.shape[0]: + raise StopIteration + sub_data = self.dataset.data[self.current] + self.minibatch._values = [sub_data[self.dataset.fields_columns[f]] for f in self.minibatch._names] + self.current+=self.minibatch_size + return self.minibatch + + return ArrayDataSetIterator2(self,self.fieldNames(),1,0,0) + def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): class ArrayDataSetIterator(object): def __init__(self,dataset,fieldnames,minibatch_size,n_batches,offset): @@ -1058,6 +1082,7 @@ def __iter__(self): return self def next(self): + #@todo: we suppose that MinibatchWrapAroundIterator stop the iterator sub_data = self.dataset.data[self.current:self.current+self.minibatch_size] self.minibatch._values = [sub_data[:,self.dataset.fields_columns[f]] for f in self.minibatch._names] self.current+=self.minibatch_size diff -r 17c5d080964b -r 38beb81f4e8b misc.py --- a/misc.py Tue May 27 13:23:05 2008 -0400 +++ b/misc.py Tue May 27 13:46:03 2008 -0400 @@ -27,3 +27,13 @@ This should run in O(n1+n2) where n1=|list1|, n2=|list2|. """ return list(set.intersection(set(list1),set(list2))) +import time +#http://www.daniweb.com/code/snippet368.html +def print_timing(func): + def wrapper(*arg): + t1 = time.time() + res = func(*arg) + t2 = time.time() + print '%s took %0.3f ms' % (func.func_name, (t2-t1)*1000.0) + return res + return wrapper diff -r 17c5d080964b -r 38beb81f4e8b test_dataset.py --- a/test_dataset.py Tue May 27 13:23:05 2008 -0400 +++ b/test_dataset.py Tue May 27 13:46:03 2008 -0400 @@ -2,6 +2,7 @@ from dataset import * from math import * import numpy +from misc import * def have_raised(to_eval, **var): have_thrown = False @@ -463,7 +464,9 @@ ds1 = ArrayDataSet(a,LookupList(['x','y','z'],[slice(3),3,[0,2]]))###???tuple not tested ds2 = ApplyFunctionDataSet(ds1,lambda x,y,z: (x+1,y+1,z+1), ['x','y','z'],minibatch_mode=False) - ds3 = ApplyFunctionDataSet(ds1,lambda x,y,z: (numpy.array(x)+1,numpy.array(y)+1,numpy.array(z)+1), ['x','y','z'],minibatch_mode=True) + ds3 = ApplyFunctionDataSet(ds1,lambda x,y,z: (numpy.array(x)+1,numpy.array(y)+1,numpy.array(z)+1), + ['x','y','z'], + minibatch_mode=True) test_all(a2,ds2) test_all(a2,ds3) @@ -485,10 +488,68 @@ def test_ArrayFieldsDataSet(): print "test_ArrayFieldsDataSet" raise NotImplementedError() + + +def test_speed(): + print "test_speed" + import time + a2 = numpy.random.rand(100000,400) + ds = ArrayDataSet(a2,{'x':slice(3),'y':3,'z':[0,2]})###???tuple not tested + ds = ArrayDataSet(a2,LookupList(['x','y','z'],[slice(3),3,[0,2]]))###???tuple not tested + ds = ArrayDataSet(a2,{'all':slice(0,a2.shape[1],1)}) + #assert ds==a? should this work? + mat = numpy.random.rand(400,100) + @print_timing + def f_array1(a): + a+1 + @print_timing + def f_array2(a): + for id in range(a.shape[0]): +# pass + a[id]+1 +# a[id]*mat + @print_timing + def f_ds(ds): + for ex in ds: +# pass + ex[0]+1 +# a[id]*mat + @print_timing + def f_ds_mb1(ds,mb_size): + for exs in ds.minibatches(minibatch_size = mb_size): + for ex in exs: +# pass + ex[0]+1 +# ex[id]*mat + @print_timing + def f_ds_mb2(ds,mb_size): + for exs in ds.minibatches(minibatch_size = mb_size): +# pass + exs[0]+1 +# ex[id]*mat + + f_array1(a2) + f_array2(a2) + + f_ds(ds) + + f_ds_mb1(ds,10) + f_ds_mb1(ds,100) + f_ds_mb1(ds,1000) + f_ds_mb1(ds,10000) + f_ds_mb2(ds,10) + f_ds_mb2(ds,100) + f_ds_mb2(ds,1000) + f_ds_mb2(ds,10000) + + del a2, ds + if __name__=='__main__': test1() test_LookupList() test_ArrayDataSet() test_CachedDataSet() test_ApplyFunctionDataSet() + #test_speed() #test pmat.py +