Mercurial > pylearn
comparison statscollector.py @ 192:f62a03c9d485
Redesign of StatsCollector
author | Yoshua Bengio <bengioy@iro.umontreal.ca> |
---|---|
date | Thu, 15 May 2008 12:46:21 -0400 |
parents | 2cd82666b9a7 |
children | 50a8302addaf |
comparison
equal
deleted
inserted
replaced
191:e816821c1e50 | 192:f62a03c9d485 |
---|---|
1 | 1 |
2 from numpy import * | 2 # Here is how I see stats collectors: |
3 | 3 |
4 class StatsCollector(object): | 4 # def my_stats((residue,nll),(regularizer)): |
5 """A StatsCollector object is used to record performance statistics during training | 5 # mse=examplewise_mean(square_norm(residue)) |
6 or testing of a learner. It can be configured to measure different things and | 6 # training_loss=regularizer+examplewise_sum(nll) |
7 accumulate the appropriate statistics. From these statistics it can be interrogated | 7 # set_names(locals()) |
8 to obtain performance measures of interest (such as maxima, minima, mean, standard | 8 # return ((residue,nll),(regularizer),(),(mse,training_loss)) |
9 deviation, standard error, etc.). Optionally, the observations can be weighted | 9 # my_stats_collector = make_stats_collector(my_stats) |
10 (yielded weighted mean, weighted variance, etc., where applicable). The statistics | 10 # |
11 that are desired can be specified among a list supported by the StatsCollector | 11 # where make_stats_collector calls my_stats(examplewise_fields, attributes) to |
12 class or subclass. When some statistics are requested, others become automatically | 12 # construct its update function, and figure out what are the input fields (here "residue" |
13 available (e.g., sum or mean).""" | 13 # and "nll") and input attributes (here "regularizer") it needs, and the output |
14 # attributes that it computes (here "mse" and "training_loss"). Remember that | |
15 # fields are examplewise quantities, but attributes are not, in my jargon. | |
16 # In the above example, I am highlighting that some operations done in my_stats | |
17 # are examplewise and some are not. I am hoping that theano Ops can do these | |
18 # kinds of internal side-effect operations (and proper initialization of these hidden | |
19 # variables). I expect that a StatsCollector (returned by make_stats_collector) | |
20 # knows the following methods: | |
21 # stats_collector.input_fieldnames | |
22 # stats_collector.input_attribute_names | |
23 # stats_collector.output_attribute_names | |
24 # stats_collector.update(mini_dataset) | |
25 # stats_collector['mse'] | |
26 # where mini_dataset has the input_fieldnames() as fields and the input_attribute_names() | |
27 # as attributes, and in the resulting dataset the output_attribute_names() are set to the | |
28 # proper numeric values. | |
14 | 29 |
15 default_statistics = [mean,standard_deviation,min,max] | 30 |
31 | |
32 import theano | |
33 from theano import tensor as t | |
34 from Learner import Learner | |
35 from lookup_list import LookupList | |
36 | |
37 class StatsCollectorModel(AttributesHolder): | |
38 def __init__(self,stats_collector): | |
39 self.stats_collector = stats_collector | |
40 self.outputs = LookupList(stats_collector.output_names,[None for name in stats_collector.output_names]) | |
41 # the statistics get initialized here | |
42 self.update_function = theano.function(input_attributes+input_fields,output_attributes+output_fields,linker="c|py") | |
43 for name,value in self.outputs.items(): | |
44 self.__setattribute__(name,value) | |
45 def update(self,dataset): | |
46 input_fields = dataset.fields()(self.stats_collector.input_field_names) | |
47 input_attributes = dataset.getAttributes(self.stats_collector.input_attribute_names) | |
48 self.outputs._values = self.update_function(input_attributes+input_fields) | |
49 for name,value in self.outputs.items(): | |
50 self.__setattribute__(name,value) | |
51 def __call__(self): | |
52 return self.outputs | |
53 def attributeNames(self): | |
54 return self.outputs.keys() | |
16 | 55 |
17 __init__(self,n_quantities_observed, statistics=default_statistics): | 56 class StatsCollector(AttributesHolder): |
18 self.n_quantities_observed=n_quantities_observed | 57 |
58 def __init__(self,input_attributes, input_fields, outputs): | |
59 self.input_attributes = input_attributes | |
60 self.input_fields = input_fields | |
61 self.outputs = outputs | |
62 self.input_attribute_names = [v.name for v in input_attributes] | |
63 self.input_field_names = [v.name for v in input_fields] | |
64 self.output_names = [v.name for v in output_attributes] | |
65 | |
66 def __call__(self,dataset=None): | |
67 model = StatsCollectorModel(self) | |
68 if dataset: | |
69 self.update(dataset) | |
70 return model | |
19 | 71 |
20 clear(self): | 72 if __name__ == '__main__': |
21 raise NotImplementedError | 73 def my_statscollector(): |
74 regularizer = t.scalar() | |
75 nll = t.matrix() | |
76 class_error = t.matrix() | |
77 total_loss = regularizer+t.examplewise_sum(nll) | |
78 avg_nll = t.examplewise_mean(nll) | |
79 avg_class_error = t.examplewise_mean(class_error) | |
80 for name,val in locals(): val.name = name | |
81 return StatsCollector([regularizer],[nll,class_error],[total_loss,avg_nll,avg_class_error]) | |
82 | |
22 | 83 |
23 update(self,observations): | |
24 """The observations is a numpy vector of length n_quantities_observed. Some | |
25 entries can be 'missing' (with a NaN entry) and will not be counted in the | |
26 statistics.""" | |
27 raise NotImplementedError | |
28 | 84 |
29 __getattr__(self, statistic) | 85 |
30 """Return a particular statistic, which may be inferred from the collected statistics. | 86 # OLD DESIGN: |
31 The argument is a string naming that statistic.""" | 87 # |
88 # class StatsCollector(object): | |
89 # """A StatsCollector object is used to record performance statistics during training | |
90 # or testing of a learner. It can be configured to measure different things and | |
91 # accumulate the appropriate statistics. From these statistics it can be interrogated | |
92 # to obtain performance measures of interest (such as maxima, minima, mean, standard | |
93 # deviation, standard error, etc.). Optionally, the observations can be weighted | |
94 # (yielded weighted mean, weighted variance, etc., where applicable). The statistics | |
95 # that are desired can be specified among a list supported by the StatsCollector | |
96 # class or subclass. When some statistics are requested, others become automatically | |
97 # available (e.g., sum or mean).""" | |
98 # | |
99 # default_statistics = [mean,standard_deviation,min,max] | |
100 # | |
101 # __init__(self,n_quantities_observed, statistics=default_statistics): | |
102 # self.n_quantities_observed=n_quantities_observed | |
103 # | |
104 # clear(self): | |
105 # raise NotImplementedError | |
106 # | |
107 # update(self,observations): | |
108 # """The observations is a numpy vector of length n_quantities_observed. Some | |
109 # entries can be 'missing' (with a NaN entry) and will not be counted in the | |
110 # statistics.""" | |
111 # raise NotImplementedError | |
112 # | |
113 # __getattr__(self, statistic) | |
114 # """Return a particular statistic, which may be inferred from the collected statistics. | |
115 # The argument is a string naming that statistic.""" | |
32 | 116 |
33 | 117 |
34 | 118 |
35 | 119 |
36 | 120 |