Mercurial > pylearn
changeset 540:85d3300c9a9c
m
author | James Bergstra <bergstrj@iro.umontreal.ca> |
---|---|
date | Thu, 13 Nov 2008 17:54:56 -0500 |
parents | e3f84d260023 |
children | 5b4ccbf022c8 ee5324c21e60 |
files | pylearn/algorithms/logistic_regression.py pylearn/algorithms/rbm.py pylearn/algorithms/sgd.py pylearn/dbdict/tools.py |
diffstat | 4 files changed, 44 insertions(+), 8 deletions(-) [+] |
line wrap: on
line diff
--- a/pylearn/algorithms/logistic_regression.py Thu Nov 13 17:01:44 2008 -0500 +++ b/pylearn/algorithms/logistic_regression.py Thu Nov 13 17:54:56 2008 -0500 @@ -72,11 +72,15 @@ raise NotImplementedError() else: rval = LogRegN(n_in=n_in, n_out=n_out, l1=l1, l2=l2) + print 'RVAL input target', rval.input, rval.target rval.minimizer = minimizer([rval.input, rval.target], rval.regularized_cost, rval.params) - return rval.make() + return rval.make(mode='FAST_RUN') #TODO: grouping parameters by prefix does not play well with providing defaults. Think... +#FIX : Guillaume suggested a convention: plugin handlers (dataset_factory, minimizer_factory, +# etc.) should never provide default arguments for parameters, and accept **kwargs to catch +# irrelevant parameters. class _fit_logreg_defaults(object): minimizer_algo = 'dummy' #minimizer_lr = 0.001
--- a/pylearn/algorithms/rbm.py Thu Nov 13 17:01:44 2008 -0500 +++ b/pylearn/algorithms/rbm.py Thu Nov 13 17:54:56 2008 -0500 @@ -1,13 +1,17 @@ import sys, copy import theano from theano import tensor as T -from theano.tensor import nnet +from theano.tensor.nnet import sigmoid from theano.compile import module from theano import printing, pprint from theano import compile import numpy as N +from ..datasets import make_dataset +from .minimizer import make_minimizer +from .stopper import make_stopper + class RBM(module.FancyModule): # is it really necessary to pass ALL of these ? - GD @@ -15,6 +19,7 @@ nvis=None, nhid=None, input=None, w=None, hidb=None, visb=None): + super(RBM, self).__init__() # symbolic theano stuff # what about multidimensional inputs/outputs ? do they have to be @@ -24,11 +29,11 @@ self.hidb = hidb if hidb is not None else module.Member(T.dvector()) # 1-step Markov chain - self.hid = T.sigmoid(T.dot(self.w,self.input) + self.hidb) + self.hid = sigmoid(T.dot(self.w,self.input) + self.hidb) self.hid_sample = self.hid #TODO: sample! - self.vis = T.sigmoid(T.dot(self.w.T, self.hid) + self.visb) + self.vis = sigmoid(T.dot(self.w.T, self.hid) + self.visb) self.vis_sample = self.vis #TODO: sample! - self.neg_hid = T.sigmoid(T.dot(self.w, self.vis) + self.hidb) + self.neg_hid = sigmoid(T.dot(self.w, self.vis) + self.hidb) # cd1 updates: self.params = [self.w, self.visb, self.hidb] @@ -44,3 +49,26 @@ def RBM_cd(): pass; + +def train_rbm(state, channel=lambda *args, **kwargs:None): + dataset = make_dataset(**state.subdict(prefix='dataset_')) + train = dataset.train + + rbm_module = RBM( + nvis=train.x.shape[1], + nhid=state.size_hidden) + + batchsize = state.batchsize + verbose = state.verbose + iter = [0] + + while iter[0] != state.max_iters: + for j in xrange(0,len(train.x)-batchsize+1,batchsize): + rbm.cd1(train.x[j:j+batchsize]) + if verbose > 1: + print 'estimated train cost...' + if iter[0] == state.max_iters: + break + else: + iter[0] += 1 +
--- a/pylearn/algorithms/sgd.py Thu Nov 13 17:01:44 2008 -0500 +++ b/pylearn/algorithms/sgd.py Thu Nov 13 17:54:56 2008 -0500 @@ -15,13 +15,12 @@ self.gparams = T.grad(cost, self.params) if gradients is None else gradients self.updates = dict((p, p - self.lr * g) for p, g in zip(self.params, self.gparams)) - self.args = args self.step = module.Method( - self.args, None, + args, [], updates=self.updates) self.step_cost = module.Method( - self.args, cost, + args, cost, updates=self.updates) #no initialization is done here.
--- a/pylearn/dbdict/tools.py Thu Nov 13 17:01:44 2008 -0500 +++ b/pylearn/dbdict/tools.py Thu Nov 13 17:54:56 2008 -0500 @@ -4,6 +4,7 @@ MODULE = 'dbdict_module' SYMBOL = 'dbdict_symbol' +PREIMPORT = 'dbdict_preimport' def dummy_channel(*args, **kwargs): return None @@ -52,6 +53,10 @@ dbdict_module_name = getattr(state,MODULE) dbdict_symbol = getattr(state, SYMBOL) + preimport_list = getattr(state, PREIMPORT, "").split() + for preimport in preimport_list: + __import__(preimport, fromlist=[None], level=0) + try: dbdict_module = __import__(dbdict_module_name, fromlist=[None], level=0) dbdict_fn = getattr(dbdict_module, dbdict_symbol)