# HG changeset patch # User Frederic Bastien # Date 1210879260 14400 # Node ID 80731832c62bdb8772b5d2368ba8a196c5295df3 # Parent cb6b945acf5a1c9e06f36435a38705867dda7e70# Parent b9950ae5e54bc7909593be5efb31f9a1f220bcba Automated merge with ssh://p-omega1@lgcm.iro.umontreal.ca/tlearn diff -r cb6b945acf5a -r 80731832c62b dataset.py --- a/dataset.py Thu May 15 12:55:21 2008 -0400 +++ b/dataset.py Thu May 15 15:21:00 2008 -0400 @@ -1111,6 +1111,8 @@ given function example-wise or minibatch-wise to all the fields of an input dataset. The output of the function should be an iterable (e.g. a list or a LookupList) over the resulting values. + + The function take as input the fields of the dataset, not the examples. In minibatch mode, the function is expected to work on minibatches (takes a minibatch in input and returns a minibatch in output). More @@ -1170,7 +1172,7 @@ function_outputs = self.output_dataset.function(*function_inputs) else: input_examples = zip(*function_inputs) - output_examples = [self.output_dataset.function(input_example) + output_examples = [self.output_dataset.function(*input_example) for input_example in input_examples] function_outputs = [self.output_dataset.valuesVStack(name,values) for name,values in zip(all_output_names, @@ -1190,11 +1192,14 @@ self.input_iterator=output_dataset.input_dataset.__iter__() def __iter__(self): return self def next(self): - function_inputs = self.input_iterator.next() if self.output_dataset.minibatch_mode: - function_outputs = [output[0] for output in self.output_dataset.function(function_inputs)] + function_inputs = [[input] for input in self.input_iterator.next()] + outputs = self.output_dataset.function(*function_inputs) + assert all([hasattr(output,'__iter__') for output in outputs]) + function_outputs = [output[0] for output in outputs] else: - function_outputs = self.output_dataset.function(function_inputs) + function_inputs = self.input_iterator.next() + function_outputs = self.output_dataset.function(*function_inputs) return Example(self.output_dataset.output_names,function_outputs) return ApplyFunctionSingleExampleIterator(self) diff -r cb6b945acf5a -r 80731832c62b test_dataset.py --- a/test_dataset.py Thu May 15 12:55:21 2008 -0400 +++ b/test_dataset.py Thu May 15 15:21:00 2008 -0400 @@ -394,6 +394,13 @@ assert len(ds('y').fields()) == 1 del field +def test_all(array,ds): + assert len(ds)==10 + + test_iterate_over_examples(array, ds) + test_getitem(array, ds) + test_ds_iterator(array,ds('x','y'),ds('y','z'),ds('x','y','z')) + test_fields_fct(ds) def test_ArrayDataSet(): #don't test stream @@ -406,13 +413,9 @@ a2 = numpy.random.rand(10,4) ds = ArrayDataSet(a2,{'x':slice(3),'y':3,'z':[0,2]})###???tuple not tested ds = ArrayDataSet(a2,LookupList(['x','y','z'],[slice(3),3,[0,2]]))###???tuple not tested - assert len(ds)==10 #assert ds==a? should this work? - test_iterate_over_examples(a2, ds) - test_getitem(a2, ds) - test_ds_iterator(a2,ds('x','y'),ds('y','z'),ds('x','y','z')) - test_fields_fct(ds) + test_all(a2,ds) del a2, ds @@ -442,18 +445,9 @@ ds1 = ArrayDataSet(a,LookupList(['x','y','z'],[slice(3),3,[0,2]]))###???tuple not tested ds2 = CachedDataSet(ds1) ds3 = CachedDataSet(ds1,cache_all_upon_construction=True) - assert len(ds2)==10 - assert len(ds3)==10 - test_iterate_over_examples(a, ds2) - test_getitem(a, ds2) - test_ds_iterator(a,ds2('x','y'),ds2('y','z'),ds2('x','y','z')) - test_fields_fct(ds2) - - test_iterate_over_examples(a, ds3) - test_getitem(a, ds3) - test_ds_iterator(a,ds3('x','y'),ds3('y','z'),ds3('x','y','z')) - test_fields_fct(ds3) + test_all(a,ds2) + test_all(a,ds3) del a,ds1,ds2,ds3 @@ -464,7 +458,18 @@ def test_ApplyFunctionDataSet(): print "test_ApplyFunctionDataSet" - raise NotImplementedError() + a = numpy.random.rand(10,4) + a2 = a+1 + ds1 = ArrayDataSet(a,LookupList(['x','y','z'],[slice(3),3,[0,2]]))###???tuple not tested + + ds2 = ApplyFunctionDataSet(ds1,lambda x,y,z: (x+1,y+1,z+1), ['x','y','z'],minibatch_mode=False) + ds3 = ApplyFunctionDataSet(ds1,lambda x,y,z: (numpy.array(x)+1,numpy.array(y)+1,numpy.array(z)+1), ['x','y','z'],minibatch_mode=True) + + test_all(a2,ds2) + test_all(a2,ds3) + + del a,ds1,ds2,ds3 + def test_FieldsSubsetDataSet(): print "test_FieldsSubsetDataSet" raise NotImplementedError() @@ -485,4 +490,5 @@ test_LookupList() test_ArrayDataSet() test_CachedDataSet() + test_ApplyFunctionDataSet() #test pmat.py