comparison statscollector.py @ 192:f62a03c9d485

Redesign of StatsCollector
author Yoshua Bengio <bengioy@iro.umontreal.ca>
date Thu, 15 May 2008 12:46:21 -0400
parents 2cd82666b9a7
children 50a8302addaf
comparison
equal deleted inserted replaced
191:e816821c1e50 192:f62a03c9d485
1 1
2 from numpy import * 2 # Here is how I see stats collectors:
3 3
4 class StatsCollector(object): 4 # def my_stats((residue,nll),(regularizer)):
5 """A StatsCollector object is used to record performance statistics during training 5 # mse=examplewise_mean(square_norm(residue))
6 or testing of a learner. It can be configured to measure different things and 6 # training_loss=regularizer+examplewise_sum(nll)
7 accumulate the appropriate statistics. From these statistics it can be interrogated 7 # set_names(locals())
8 to obtain performance measures of interest (such as maxima, minima, mean, standard 8 # return ((residue,nll),(regularizer),(),(mse,training_loss))
9 deviation, standard error, etc.). Optionally, the observations can be weighted 9 # my_stats_collector = make_stats_collector(my_stats)
10 (yielded weighted mean, weighted variance, etc., where applicable). The statistics 10 #
11 that are desired can be specified among a list supported by the StatsCollector 11 # where make_stats_collector calls my_stats(examplewise_fields, attributes) to
12 class or subclass. When some statistics are requested, others become automatically 12 # construct its update function, and figure out what are the input fields (here "residue"
13 available (e.g., sum or mean).""" 13 # and "nll") and input attributes (here "regularizer") it needs, and the output
14 # attributes that it computes (here "mse" and "training_loss"). Remember that
15 # fields are examplewise quantities, but attributes are not, in my jargon.
16 # In the above example, I am highlighting that some operations done in my_stats
17 # are examplewise and some are not. I am hoping that theano Ops can do these
18 # kinds of internal side-effect operations (and proper initialization of these hidden
19 # variables). I expect that a StatsCollector (returned by make_stats_collector)
20 # knows the following methods:
21 # stats_collector.input_fieldnames
22 # stats_collector.input_attribute_names
23 # stats_collector.output_attribute_names
24 # stats_collector.update(mini_dataset)
25 # stats_collector['mse']
26 # where mini_dataset has the input_fieldnames() as fields and the input_attribute_names()
27 # as attributes, and in the resulting dataset the output_attribute_names() are set to the
28 # proper numeric values.
14 29
15 default_statistics = [mean,standard_deviation,min,max] 30
31
32 import theano
33 from theano import tensor as t
34 from Learner import Learner
35 from lookup_list import LookupList
36
37 class StatsCollectorModel(AttributesHolder):
38 def __init__(self,stats_collector):
39 self.stats_collector = stats_collector
40 self.outputs = LookupList(stats_collector.output_names,[None for name in stats_collector.output_names])
41 # the statistics get initialized here
42 self.update_function = theano.function(input_attributes+input_fields,output_attributes+output_fields,linker="c|py")
43 for name,value in self.outputs.items():
44 self.__setattribute__(name,value)
45 def update(self,dataset):
46 input_fields = dataset.fields()(self.stats_collector.input_field_names)
47 input_attributes = dataset.getAttributes(self.stats_collector.input_attribute_names)
48 self.outputs._values = self.update_function(input_attributes+input_fields)
49 for name,value in self.outputs.items():
50 self.__setattribute__(name,value)
51 def __call__(self):
52 return self.outputs
53 def attributeNames(self):
54 return self.outputs.keys()
16 55
17 __init__(self,n_quantities_observed, statistics=default_statistics): 56 class StatsCollector(AttributesHolder):
18 self.n_quantities_observed=n_quantities_observed 57
58 def __init__(self,input_attributes, input_fields, outputs):
59 self.input_attributes = input_attributes
60 self.input_fields = input_fields
61 self.outputs = outputs
62 self.input_attribute_names = [v.name for v in input_attributes]
63 self.input_field_names = [v.name for v in input_fields]
64 self.output_names = [v.name for v in output_attributes]
65
66 def __call__(self,dataset=None):
67 model = StatsCollectorModel(self)
68 if dataset:
69 self.update(dataset)
70 return model
19 71
20 clear(self): 72 if __name__ == '__main__':
21 raise NotImplementedError 73 def my_statscollector():
74 regularizer = t.scalar()
75 nll = t.matrix()
76 class_error = t.matrix()
77 total_loss = regularizer+t.examplewise_sum(nll)
78 avg_nll = t.examplewise_mean(nll)
79 avg_class_error = t.examplewise_mean(class_error)
80 for name,val in locals(): val.name = name
81 return StatsCollector([regularizer],[nll,class_error],[total_loss,avg_nll,avg_class_error])
82
22 83
23 update(self,observations):
24 """The observations is a numpy vector of length n_quantities_observed. Some
25 entries can be 'missing' (with a NaN entry) and will not be counted in the
26 statistics."""
27 raise NotImplementedError
28 84
29 __getattr__(self, statistic) 85
30 """Return a particular statistic, which may be inferred from the collected statistics. 86 # OLD DESIGN:
31 The argument is a string naming that statistic.""" 87 #
88 # class StatsCollector(object):
89 # """A StatsCollector object is used to record performance statistics during training
90 # or testing of a learner. It can be configured to measure different things and
91 # accumulate the appropriate statistics. From these statistics it can be interrogated
92 # to obtain performance measures of interest (such as maxima, minima, mean, standard
93 # deviation, standard error, etc.). Optionally, the observations can be weighted
94 # (yielded weighted mean, weighted variance, etc., where applicable). The statistics
95 # that are desired can be specified among a list supported by the StatsCollector
96 # class or subclass. When some statistics are requested, others become automatically
97 # available (e.g., sum or mean)."""
98 #
99 # default_statistics = [mean,standard_deviation,min,max]
100 #
101 # __init__(self,n_quantities_observed, statistics=default_statistics):
102 # self.n_quantities_observed=n_quantities_observed
103 #
104 # clear(self):
105 # raise NotImplementedError
106 #
107 # update(self,observations):
108 # """The observations is a numpy vector of length n_quantities_observed. Some
109 # entries can be 'missing' (with a NaN entry) and will not be counted in the
110 # statistics."""
111 # raise NotImplementedError
112 #
113 # __getattr__(self, statistic)
114 # """Return a particular statistic, which may be inferred from the collected statistics.
115 # The argument is a string naming that statistic."""
32 116
33 117
34 118
35 119
36 120