changeset 557:c49a244bbfb5

pass the WEIRD_STUFF flag as init argument to allow the creation of a test.
author Frederic Bastien <bastienf@iro.umontreal.ca>
date Mon, 01 Dec 2008 16:12:47 -0500
parents 81b225224880
children b58e71878bb5
files pylearn/algorithms/sgd.py
diffstat 1 files changed, 7 insertions(+), 8 deletions(-) [+]
line wrap: on
line diff
--- a/pylearn/algorithms/sgd.py	Mon Dec 01 15:50:50 2008 -0500
+++ b/pylearn/algorithms/sgd.py	Mon Dec 01 16:12:47 2008 -0500
@@ -6,17 +6,16 @@
 
 from .minimizer import minimizer_factory
 
-WEIRD_STUFF = True
-
 class StochasticGradientDescent(module.FancyModule):
     """Fixed stepsize gradient descent"""
-    def __init__(self, args, cost, params, gradients=None, stepsize=None):
+    def __init__(self, args, cost, params, gradients=None, stepsize=None, WEIRD_STUFF=True):
         """
         :param stepsize: the step to take in (negative) gradient direction
         :type stepsize: None, scalar value, or scalar TensorResult
         """
         super(StochasticGradientDescent, self).__init__()
-
+        self.WEIRD_STUFF = WEIRD_STUFF
+        print WEIRD_STUFF
         self.stepsize_init = None
 
         if stepsize is None:
@@ -24,7 +23,7 @@
         elif isinstance(stepsize, T.TensorResult):
             self.stepsize = stepsize
         else:
-            if WEIRD_STUFF:
+            if self.WEIRD_STUFF:
                 #TODO: why is this necessary? why does the else clause not work?
                 self.stepsize = module.Member(T.dscalar())
                 self.stepsize_init = stepsize
@@ -46,15 +45,15 @@
                 args, cost,
                 updates=self.updates)
     def _instance_initialize(self, obj):
-        if WEIRD_STUFF:
+        if self.WEIRD_STUFF:
             obj.stepsize = self.stepsize_init
         else:
             pass
 
 
 @minimizer_factory('sgd')
-def sgd_minimizer(stepsize=None):
+def sgd_minimizer(stepsize=None, **args):
     def m(i,c,p,g=None):
-        return StochasticGradientDescent(i, c, p, stepsize=stepsize)
+        return StochasticGradientDescent(i, c, p, stepsize=stepsize, **args)
     return m