changeset 192:f62a03c9d485

Redesign of StatsCollector
author Yoshua Bengio <bengioy@iro.umontreal.ca>
date Thu, 15 May 2008 12:46:21 -0400
parents e816821c1e50
children cb6b945acf5a
files statscollector.py
diffstat 1 files changed, 109 insertions(+), 25 deletions(-) [+]
line wrap: on
line diff
--- a/statscollector.py	Wed May 14 20:04:44 2008 -0400
+++ b/statscollector.py	Thu May 15 12:46:21 2008 -0400
@@ -1,34 +1,118 @@
 
-from numpy import *
+# Here is how I see stats collectors:
 
-class StatsCollector(object):
-    """A StatsCollector object is used to record performance statistics during training
-    or testing of a learner. It can be configured to measure different things and
-    accumulate the appropriate statistics. From these statistics it can be interrogated
-    to obtain performance measures of interest (such as maxima, minima, mean, standard
-    deviation, standard error, etc.). Optionally, the observations can be weighted
-    (yielded weighted mean, weighted variance, etc., where applicable). The statistics
-    that are desired can be specified among a list supported by the StatsCollector
-    class or subclass. When some statistics are requested, others become automatically
-    available (e.g., sum or mean)."""
+#    def my_stats((residue,nll),(regularizer)):
+#            mse=examplewise_mean(square_norm(residue))
+# 	         training_loss=regularizer+examplewise_sum(nll)
+#            set_names(locals())
+#            return ((residue,nll),(regularizer),(),(mse,training_loss))
+#    my_stats_collector = make_stats_collector(my_stats)
+#
+# where make_stats_collector calls my_stats(examplewise_fields, attributes) to
+# construct its update function, and figure out what are the input fields (here "residue"
+# and "nll") and input attributes (here "regularizer") it needs, and the output
+# attributes that it computes (here "mse" and "training_loss"). Remember that
+# fields are examplewise quantities, but attributes are not, in my jargon.
+# In the above example, I am highlighting that some operations done in my_stats
+# are examplewise and some are not.  I am hoping that theano Ops can do these
+# kinds of internal side-effect operations (and proper initialization of these hidden
+# variables). I expect that a StatsCollector (returned by make_stats_collector)
+# knows the following methods:
+#     stats_collector.input_fieldnames
+#     stats_collector.input_attribute_names
+#     stats_collector.output_attribute_names
+#     stats_collector.update(mini_dataset)
+#     stats_collector['mse']
+# where mini_dataset has the input_fieldnames() as fields and the input_attribute_names()
+# as attributes, and in the resulting dataset the output_attribute_names() are set to the
+# proper numeric values.
 
-    default_statistics = [mean,standard_deviation,min,max]
+
+
+import theano
+from theano import tensor as t
+from Learner import Learner
+from lookup_list import LookupList
+
+class StatsCollectorModel(AttributesHolder):
+    def __init__(self,stats_collector):
+        self.stats_collector = stats_collector
+        self.outputs = LookupList(stats_collector.output_names,[None for name in stats_collector.output_names])
+        # the statistics get initialized here
+        self.update_function = theano.function(input_attributes+input_fields,output_attributes+output_fields,linker="c|py")
+        for name,value in self.outputs.items():
+            self.__setattribute__(name,value)
+    def update(self,dataset):
+        input_fields = dataset.fields()(self.stats_collector.input_field_names)
+        input_attributes = dataset.getAttributes(self.stats_collector.input_attribute_names)
+        self.outputs._values = self.update_function(input_attributes+input_fields)
+        for name,value in self.outputs.items():
+            self.__setattribute__(name,value)
+    def __call__(self):
+        return self.outputs
+    def attributeNames(self):
+        return self.outputs.keys()
     
-    __init__(self,n_quantities_observed, statistics=default_statistics):
-        self.n_quantities_observed=n_quantities_observed
+class StatsCollector(AttributesHolder):
+        
+    def __init__(self,input_attributes, input_fields, outputs):
+        self.input_attributes = input_attributes
+        self.input_fields = input_fields
+        self.outputs = outputs
+        self.input_attribute_names = [v.name for v in input_attributes]
+        self.input_field_names = [v.name for v in input_fields]
+        self.output_names = [v.name for v in output_attributes]
+            
+    def __call__(self,dataset=None):
+        model = StatsCollectorModel(self)
+        if dataset:
+            self.update(dataset)
+        return model
 
-    clear(self):
-        raise NotImplementedError
+if __name__ == '__main__':
+    def my_statscollector():
+        regularizer = t.scalar()
+        nll = t.matrix()
+        class_error = t.matrix()
+        total_loss = regularizer+t.examplewise_sum(nll)
+        avg_nll = t.examplewise_mean(nll)
+        avg_class_error = t.examplewise_mean(class_error)
+        for name,val in locals(): val.name = name
+        return StatsCollector([regularizer],[nll,class_error],[total_loss,avg_nll,avg_class_error])
+    
+
+
 
-    update(self,observations):
-        """The observations is a numpy vector of length n_quantities_observed. Some
-        entries can be 'missing' (with a NaN entry) and will not be counted in the
-        statistics."""
-        raise NotImplementedError
-
-    __getattr__(self, statistic)
-        """Return a particular statistic, which may be inferred from the collected statistics.
-        The argument is a string naming that statistic."""
+# OLD DESIGN:
+#
+# class StatsCollector(object):
+#     """A StatsCollector object is used to record performance statistics during training
+#     or testing of a learner. It can be configured to measure different things and
+#     accumulate the appropriate statistics. From these statistics it can be interrogated
+#     to obtain performance measures of interest (such as maxima, minima, mean, standard
+#     deviation, standard error, etc.). Optionally, the observations can be weighted
+#     (yielded weighted mean, weighted variance, etc., where applicable). The statistics
+#     that are desired can be specified among a list supported by the StatsCollector
+#     class or subclass. When some statistics are requested, others become automatically
+#     available (e.g., sum or mean)."""
+#
+#     default_statistics = [mean,standard_deviation,min,max]
+#    
+#     __init__(self,n_quantities_observed, statistics=default_statistics):
+#         self.n_quantities_observed=n_quantities_observed
+#
+#     clear(self):
+#         raise NotImplementedError
+#
+#     update(self,observations):
+#         """The observations is a numpy vector of length n_quantities_observed. Some
+#         entries can be 'missing' (with a NaN entry) and will not be counted in the
+#         statistics."""
+#         raise NotImplementedError
+#
+#     __getattr__(self, statistic)
+#         """Return a particular statistic, which may be inferred from the collected statistics.
+#         The argument is a string naming that statistic."""