Mercurial > pylearn
comparison pylearn/gd/sgd.py @ 1462:7c72948e1d53
Catch only exception to don't catch ctrl-C.
author | Frederic Bastien <nouiz@nouiz.org> |
---|---|
date | Fri, 08 Apr 2011 14:03:09 -0400 |
parents | 86bf03990aad |
children | cac29ca79a74 |
comparison
equal
deleted
inserted
replaced
1461:2aa80f5b5bbc | 1462:7c72948e1d53 |
---|---|
14 :param stepsizes: step by this amount times the negative gradient on each iteration | 14 :param stepsizes: step by this amount times the negative gradient on each iteration |
15 :type stepsizes: [symbolic] scalar or list of one [symbolic] scalar per param | 15 :type stepsizes: [symbolic] scalar or list of one [symbolic] scalar per param |
16 """ | 16 """ |
17 try: | 17 try: |
18 iter(stepsizes) | 18 iter(stepsizes) |
19 except: | 19 except Exception: |
20 stepsizes = [stepsizes for p in params] | 20 stepsizes = [stepsizes for p in params] |
21 if len(params) != len(grads): | 21 if len(params) != len(grads): |
22 raise ValueError('params and grads have different lens') | 22 raise ValueError('params and grads have different lens') |
23 updates = [(p, p - step * gp) for (step, p, gp) in zip(stepsizes, params, grads)] | 23 updates = [(p, p - step * gp) for (step, p, gp) in zip(stepsizes, params, grads)] |
24 return updates | 24 return updates |
25 | 25 |
26 def sgd_momentum_updates(params, grads, stepsizes, momentum=0.9): | 26 def sgd_momentum_updates(params, grads, stepsizes, momentum=0.9): |
27 # if stepsizes is just a scalar, expand it to match params | 27 # if stepsizes is just a scalar, expand it to match params |
28 try: | 28 try: |
29 iter(stepsizes) | 29 iter(stepsizes) |
30 except: | 30 except Exception: |
31 stepsizes = [stepsizes for p in params] | 31 stepsizes = [stepsizes for p in params] |
32 try: | 32 try: |
33 iter(momentum) | 33 iter(momentum) |
34 except: | 34 except Exception: |
35 momentum = [momentum for p in params] | 35 momentum = [momentum for p in params] |
36 if len(params) != len(grads): | 36 if len(params) != len(grads): |
37 raise ValueError('params and grads have different lens') | 37 raise ValueError('params and grads have different lens') |
38 headings = [theano.shared(p.get_value(borrow=False)*0) for p in params] | 38 headings = [theano.shared(p.get_value(borrow=False)*0) for p in params] |
39 updates = [] | 39 updates = [] |