Mercurial > pylearn
changeset 192:f62a03c9d485
Redesign of StatsCollector
author | Yoshua Bengio <bengioy@iro.umontreal.ca> |
---|---|
date | Thu, 15 May 2008 12:46:21 -0400 |
parents | e816821c1e50 |
children | cb6b945acf5a |
files | statscollector.py |
diffstat | 1 files changed, 109 insertions(+), 25 deletions(-) [+] |
line wrap: on
line diff
--- a/statscollector.py Wed May 14 20:04:44 2008 -0400 +++ b/statscollector.py Thu May 15 12:46:21 2008 -0400 @@ -1,34 +1,118 @@ -from numpy import * +# Here is how I see stats collectors: -class StatsCollector(object): - """A StatsCollector object is used to record performance statistics during training - or testing of a learner. It can be configured to measure different things and - accumulate the appropriate statistics. From these statistics it can be interrogated - to obtain performance measures of interest (such as maxima, minima, mean, standard - deviation, standard error, etc.). Optionally, the observations can be weighted - (yielded weighted mean, weighted variance, etc., where applicable). The statistics - that are desired can be specified among a list supported by the StatsCollector - class or subclass. When some statistics are requested, others become automatically - available (e.g., sum or mean).""" +# def my_stats((residue,nll),(regularizer)): +# mse=examplewise_mean(square_norm(residue)) +# training_loss=regularizer+examplewise_sum(nll) +# set_names(locals()) +# return ((residue,nll),(regularizer),(),(mse,training_loss)) +# my_stats_collector = make_stats_collector(my_stats) +# +# where make_stats_collector calls my_stats(examplewise_fields, attributes) to +# construct its update function, and figure out what are the input fields (here "residue" +# and "nll") and input attributes (here "regularizer") it needs, and the output +# attributes that it computes (here "mse" and "training_loss"). Remember that +# fields are examplewise quantities, but attributes are not, in my jargon. +# In the above example, I am highlighting that some operations done in my_stats +# are examplewise and some are not. I am hoping that theano Ops can do these +# kinds of internal side-effect operations (and proper initialization of these hidden +# variables). I expect that a StatsCollector (returned by make_stats_collector) +# knows the following methods: +# stats_collector.input_fieldnames +# stats_collector.input_attribute_names +# stats_collector.output_attribute_names +# stats_collector.update(mini_dataset) +# stats_collector['mse'] +# where mini_dataset has the input_fieldnames() as fields and the input_attribute_names() +# as attributes, and in the resulting dataset the output_attribute_names() are set to the +# proper numeric values. - default_statistics = [mean,standard_deviation,min,max] + + +import theano +from theano import tensor as t +from Learner import Learner +from lookup_list import LookupList + +class StatsCollectorModel(AttributesHolder): + def __init__(self,stats_collector): + self.stats_collector = stats_collector + self.outputs = LookupList(stats_collector.output_names,[None for name in stats_collector.output_names]) + # the statistics get initialized here + self.update_function = theano.function(input_attributes+input_fields,output_attributes+output_fields,linker="c|py") + for name,value in self.outputs.items(): + self.__setattribute__(name,value) + def update(self,dataset): + input_fields = dataset.fields()(self.stats_collector.input_field_names) + input_attributes = dataset.getAttributes(self.stats_collector.input_attribute_names) + self.outputs._values = self.update_function(input_attributes+input_fields) + for name,value in self.outputs.items(): + self.__setattribute__(name,value) + def __call__(self): + return self.outputs + def attributeNames(self): + return self.outputs.keys() - __init__(self,n_quantities_observed, statistics=default_statistics): - self.n_quantities_observed=n_quantities_observed +class StatsCollector(AttributesHolder): + + def __init__(self,input_attributes, input_fields, outputs): + self.input_attributes = input_attributes + self.input_fields = input_fields + self.outputs = outputs + self.input_attribute_names = [v.name for v in input_attributes] + self.input_field_names = [v.name for v in input_fields] + self.output_names = [v.name for v in output_attributes] + + def __call__(self,dataset=None): + model = StatsCollectorModel(self) + if dataset: + self.update(dataset) + return model - clear(self): - raise NotImplementedError +if __name__ == '__main__': + def my_statscollector(): + regularizer = t.scalar() + nll = t.matrix() + class_error = t.matrix() + total_loss = regularizer+t.examplewise_sum(nll) + avg_nll = t.examplewise_mean(nll) + avg_class_error = t.examplewise_mean(class_error) + for name,val in locals(): val.name = name + return StatsCollector([regularizer],[nll,class_error],[total_loss,avg_nll,avg_class_error]) + + + - update(self,observations): - """The observations is a numpy vector of length n_quantities_observed. Some - entries can be 'missing' (with a NaN entry) and will not be counted in the - statistics.""" - raise NotImplementedError - - __getattr__(self, statistic) - """Return a particular statistic, which may be inferred from the collected statistics. - The argument is a string naming that statistic.""" +# OLD DESIGN: +# +# class StatsCollector(object): +# """A StatsCollector object is used to record performance statistics during training +# or testing of a learner. It can be configured to measure different things and +# accumulate the appropriate statistics. From these statistics it can be interrogated +# to obtain performance measures of interest (such as maxima, minima, mean, standard +# deviation, standard error, etc.). Optionally, the observations can be weighted +# (yielded weighted mean, weighted variance, etc., where applicable). The statistics +# that are desired can be specified among a list supported by the StatsCollector +# class or subclass. When some statistics are requested, others become automatically +# available (e.g., sum or mean).""" +# +# default_statistics = [mean,standard_deviation,min,max] +# +# __init__(self,n_quantities_observed, statistics=default_statistics): +# self.n_quantities_observed=n_quantities_observed +# +# clear(self): +# raise NotImplementedError +# +# update(self,observations): +# """The observations is a numpy vector of length n_quantities_observed. Some +# entries can be 'missing' (with a NaN entry) and will not be counted in the +# statistics.""" +# raise NotImplementedError +# +# __getattr__(self, statistic) +# """Return a particular statistic, which may be inferred from the collected statistics. +# The argument is a string naming that statistic."""