view sandbox/statscollector.py @ 465:8cde974b6486

merge
author Joseph Turian <turian@iro.umontreal.ca>
date Wed, 15 Oct 2008 17:00:35 -0400
parents d7611a3811f2
children
line wrap: on
line source


# Here is how I see stats collectors:

def my_stats(graph):
    graph.mse=examplewise_mean(square_norm(graph.residue))
    graph.training_loss=graph.regularizer+examplewise_sum(graph.nll)
    return [graph.mse,graph.training_loss]
    

#    def my_stats(residue,nll,regularizer):
#            mse=examplewise_mean(square_norm(residue))
#            training_loss=regularizer+examplewise_sum(nll)
#            set_names(locals())
#            return ((residue,nll),(regularizer),(),(mse,training_loss))
#    my_stats_collector = make_stats_collector(my_stats)
#
# where make_stats_collector calls my_stats(examplewise_fields, attributes) to
# construct its update function, and figure out what are the input fields (here "residue"
# and "nll") and input attributes (here "regularizer") it needs, and the output
# attributes that it computes (here "mse" and "training_loss"). Remember that
# fields are examplewise quantities, but attributes are not, in my jargon.
# In the above example, I am highlighting that some operations done in my_stats
# are examplewise and some are not.  I am hoping that theano Ops can do these
# kinds of internal side-effect operations (and proper initialization of these hidden
# variables). I expect that a StatsCollector (returned by make_stats_collector)
# knows the following methods:
#     stats_collector.input_fieldnames
#     stats_collector.input_attribute_names
#     stats_collector.output_attribute_names
#     stats_collector.update(mini_dataset)
#     stats_collector['mse']
# where mini_dataset has the input_fieldnames() as fields and the input_attribute_names()
# as attributes, and in the resulting dataset the output_attribute_names() are set to the
# proper numeric values.



import theano
from theano import tensor as t
from Learner import Learner
from lookup_list import LookupList

class StatsCollectorModel(AttributesHolder):
    def __init__(self,stats_collector):
        self.stats_collector = stats_collector
        self.outputs = LookupList(stats_collector.output_names,[None for name in stats_collector.output_names])
        # the statistics get initialized here
        self.update_function = theano.function(input_attributes+input_fields,output_attributes+output_fields,linker="c|py")
        for name,value in self.outputs.items():
            self.__setattribute__(name,value)
    def update(self,dataset):
        input_fields = dataset.fields()(self.stats_collector.input_field_names)
        input_attributes = dataset.getAttributes(self.stats_collector.input_attribute_names)
        self.outputs._values = self.update_function(input_attributes+input_fields)
        for name,value in self.outputs.items():
            self.__setattribute__(name,value)
    def __call__(self):
        return self.outputs
    def attributeNames(self):
        return self.outputs.keys()
    
class StatsCollector(AttributesHolder):
        
    def __init__(self,input_attributes, input_fields, outputs):
        self.input_attributes = input_attributes
        self.input_fields = input_fields
        self.outputs = outputs
        self.input_attribute_names = [v.name for v in input_attributes]
        self.input_field_names = [v.name for v in input_fields]
        self.output_names = [v.name for v in output_attributes]
            
    def __call__(self,dataset=None):
        model = StatsCollectorModel(self)
        if dataset:
            self.update(dataset)
        return model

if __name__ == '__main__':
    def my_statscollector():
        regularizer = t.scalar()
        nll = t.matrix()
        class_error = t.matrix()
        total_loss = regularizer+t.examplewise_sum(nll)
        avg_nll = t.examplewise_mean(nll)
        avg_class_error = t.examplewise_mean(class_error)
        for name,val in locals().items(): val.name = name
        return StatsCollector([regularizer],[nll,class_error],[total_loss,avg_nll,avg_class_error])
    



# OLD DESIGN:
#
# class StatsCollector(object):
#     """A StatsCollector object is used to record performance statistics during training
#     or testing of a learner. It can be configured to measure different things and
#     accumulate the appropriate statistics. From these statistics it can be interrogated
#     to obtain performance measures of interest (such as maxima, minima, mean, standard
#     deviation, standard error, etc.). Optionally, the observations can be weighted
#     (yielded weighted mean, weighted variance, etc., where applicable). The statistics
#     that are desired can be specified among a list supported by the StatsCollector
#     class or subclass. When some statistics are requested, others become automatically
#     available (e.g., sum or mean)."""
#
#     default_statistics = [mean,standard_deviation,min,max]
#    
#     __init__(self,n_quantities_observed, statistics=default_statistics):
#         self.n_quantities_observed=n_quantities_observed
#
#     clear(self):
#         raise NotImplementedError
#
#     update(self,observations):
#         """The observations is a numpy vector of length n_quantities_observed. Some
#         entries can be 'missing' (with a NaN entry) and will not be counted in the
#         statistics."""
#         raise NotImplementedError
#
#     __getattr__(self, statistic)
#         """Return a particular statistic, which may be inferred from the collected statistics.
#         The argument is a string naming that statistic."""