Mercurial > pylearn
comparison sandbox/statscollector.py @ 426:d7611a3811f2
Moved incomplete stuff to sandbox
author | Yoshua Bengio <bengioy@iro.umontreal.ca> |
---|---|
date | Tue, 22 Jul 2008 15:20:25 -0400 |
parents | statscollector.py@4efb503fd0da |
children |
comparison
equal
deleted
inserted
replaced
425:e2b46a8f2b7b | 426:d7611a3811f2 |
---|---|
1 | |
2 # Here is how I see stats collectors: | |
3 | |
4 def my_stats(graph): | |
5 graph.mse=examplewise_mean(square_norm(graph.residue)) | |
6 graph.training_loss=graph.regularizer+examplewise_sum(graph.nll) | |
7 return [graph.mse,graph.training_loss] | |
8 | |
9 | |
10 # def my_stats(residue,nll,regularizer): | |
11 # mse=examplewise_mean(square_norm(residue)) | |
12 # training_loss=regularizer+examplewise_sum(nll) | |
13 # set_names(locals()) | |
14 # return ((residue,nll),(regularizer),(),(mse,training_loss)) | |
15 # my_stats_collector = make_stats_collector(my_stats) | |
16 # | |
17 # where make_stats_collector calls my_stats(examplewise_fields, attributes) to | |
18 # construct its update function, and figure out what are the input fields (here "residue" | |
19 # and "nll") and input attributes (here "regularizer") it needs, and the output | |
20 # attributes that it computes (here "mse" and "training_loss"). Remember that | |
21 # fields are examplewise quantities, but attributes are not, in my jargon. | |
22 # In the above example, I am highlighting that some operations done in my_stats | |
23 # are examplewise and some are not. I am hoping that theano Ops can do these | |
24 # kinds of internal side-effect operations (and proper initialization of these hidden | |
25 # variables). I expect that a StatsCollector (returned by make_stats_collector) | |
26 # knows the following methods: | |
27 # stats_collector.input_fieldnames | |
28 # stats_collector.input_attribute_names | |
29 # stats_collector.output_attribute_names | |
30 # stats_collector.update(mini_dataset) | |
31 # stats_collector['mse'] | |
32 # where mini_dataset has the input_fieldnames() as fields and the input_attribute_names() | |
33 # as attributes, and in the resulting dataset the output_attribute_names() are set to the | |
34 # proper numeric values. | |
35 | |
36 | |
37 | |
38 import theano | |
39 from theano import tensor as t | |
40 from Learner import Learner | |
41 from lookup_list import LookupList | |
42 | |
43 class StatsCollectorModel(AttributesHolder): | |
44 def __init__(self,stats_collector): | |
45 self.stats_collector = stats_collector | |
46 self.outputs = LookupList(stats_collector.output_names,[None for name in stats_collector.output_names]) | |
47 # the statistics get initialized here | |
48 self.update_function = theano.function(input_attributes+input_fields,output_attributes+output_fields,linker="c|py") | |
49 for name,value in self.outputs.items(): | |
50 self.__setattribute__(name,value) | |
51 def update(self,dataset): | |
52 input_fields = dataset.fields()(self.stats_collector.input_field_names) | |
53 input_attributes = dataset.getAttributes(self.stats_collector.input_attribute_names) | |
54 self.outputs._values = self.update_function(input_attributes+input_fields) | |
55 for name,value in self.outputs.items(): | |
56 self.__setattribute__(name,value) | |
57 def __call__(self): | |
58 return self.outputs | |
59 def attributeNames(self): | |
60 return self.outputs.keys() | |
61 | |
62 class StatsCollector(AttributesHolder): | |
63 | |
64 def __init__(self,input_attributes, input_fields, outputs): | |
65 self.input_attributes = input_attributes | |
66 self.input_fields = input_fields | |
67 self.outputs = outputs | |
68 self.input_attribute_names = [v.name for v in input_attributes] | |
69 self.input_field_names = [v.name for v in input_fields] | |
70 self.output_names = [v.name for v in output_attributes] | |
71 | |
72 def __call__(self,dataset=None): | |
73 model = StatsCollectorModel(self) | |
74 if dataset: | |
75 self.update(dataset) | |
76 return model | |
77 | |
78 if __name__ == '__main__': | |
79 def my_statscollector(): | |
80 regularizer = t.scalar() | |
81 nll = t.matrix() | |
82 class_error = t.matrix() | |
83 total_loss = regularizer+t.examplewise_sum(nll) | |
84 avg_nll = t.examplewise_mean(nll) | |
85 avg_class_error = t.examplewise_mean(class_error) | |
86 for name,val in locals().items(): val.name = name | |
87 return StatsCollector([regularizer],[nll,class_error],[total_loss,avg_nll,avg_class_error]) | |
88 | |
89 | |
90 | |
91 | |
92 # OLD DESIGN: | |
93 # | |
94 # class StatsCollector(object): | |
95 # """A StatsCollector object is used to record performance statistics during training | |
96 # or testing of a learner. It can be configured to measure different things and | |
97 # accumulate the appropriate statistics. From these statistics it can be interrogated | |
98 # to obtain performance measures of interest (such as maxima, minima, mean, standard | |
99 # deviation, standard error, etc.). Optionally, the observations can be weighted | |
100 # (yielded weighted mean, weighted variance, etc., where applicable). The statistics | |
101 # that are desired can be specified among a list supported by the StatsCollector | |
102 # class or subclass. When some statistics are requested, others become automatically | |
103 # available (e.g., sum or mean).""" | |
104 # | |
105 # default_statistics = [mean,standard_deviation,min,max] | |
106 # | |
107 # __init__(self,n_quantities_observed, statistics=default_statistics): | |
108 # self.n_quantities_observed=n_quantities_observed | |
109 # | |
110 # clear(self): | |
111 # raise NotImplementedError | |
112 # | |
113 # update(self,observations): | |
114 # """The observations is a numpy vector of length n_quantities_observed. Some | |
115 # entries can be 'missing' (with a NaN entry) and will not be counted in the | |
116 # statistics.""" | |
117 # raise NotImplementedError | |
118 # | |
119 # __getattr__(self, statistic) | |
120 # """Return a particular statistic, which may be inferred from the collected statistics. | |
121 # The argument is a string naming that statistic.""" | |
122 | |
123 | |
124 | |
125 | |
126 | |
127 |