Mercurial > pylearn
changeset 218:df3fae88ab46
small debugging
author | Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca> |
---|---|
date | Fri, 23 May 2008 12:22:54 -0400 |
parents | 44dd9b6448c5 |
children | 5b3afda2f1ad |
files | dataset.py denoising_aa.py mlp_factory_approach.py nnet_ops.py test_dataset.py |
diffstat | 5 files changed, 37 insertions(+), 20 deletions(-) [+] |
line wrap: on
line diff
--- a/dataset.py Thu May 22 19:08:46 2008 -0400 +++ b/dataset.py Fri May 23 12:22:54 2008 -0400 @@ -245,8 +245,7 @@ if n_batches is not None: ds_nbatches = min(n_batches,ds_nbatches) if fieldnames: - if not dataset.hasFields(*fieldnames): - raise ValueError('field not present', fieldnames) + assert dataset.hasFields(*fieldnames) else: self.fieldnames=dataset.fieldNames() self.iterator = self.dataset.minibatches_nowrap(self.fieldnames,self.minibatch_size, @@ -970,16 +969,7 @@ for fieldname, fieldcolumns in self.fields_columns.items(): if type(fieldcolumns) is int: assert fieldcolumns>=0 and fieldcolumns<data_array.shape[1] - - if 0: - #I changed this because it didn't make sense to me, - # and it made it more difficult to write my learner. - # If it breaks stuff, let's talk about it. - # - James 22/05/2008 - self.fields_columns[fieldname]=[fieldcolumns] - else: - self.fields_columns[fieldname]=fieldcolumns - + self.fields_columns[fieldname]=[fieldcolumns] elif type(fieldcolumns) is slice: start,step=None,None if not fieldcolumns.start:
--- a/denoising_aa.py Thu May 22 19:08:46 2008 -0400 +++ b/denoising_aa.py Fri May 23 12:22:54 2008 -0400 @@ -31,6 +31,7 @@ def squash_affine_formula(squash_function=sigmoid): """ + Simply does: squash_function(b + xW) By convention prefix the parameters by _ """ class SquashAffineFormula(Formulas): @@ -53,7 +54,7 @@ class ProbabilisticClassifierLossFormula(Formulas): a = t.matrix() # of dimensions minibatch_size x n_classes, pre-softmax output target_class = t.ivector() # dimension (minibatch_size) - nll, probability_predictions = crossentropy_softmax_1hot(a, target_class) + nll, probability_predictions = crossentropy_softmax_1hot(a, target_class) # defined in nnet_ops.py return ProbabilisticClassifierLossFormula() def binomial_cross_entropy_formula(): @@ -64,6 +65,8 @@ # using the identity softplus(a) - softplus(-a) = a, # we obtain that q log(p) + (1-q) log(1-p) = q a - softplus(a) nll = -t.sum(q*a - softplus(-a)) + # next line was missing... hope it's all correct above + return BinomialCrossEntropyFormula() def squash_affine_autoencoder_formula(hidden_squash=t.tanh, reconstruction_squash=sigmoid, @@ -102,9 +105,33 @@ self.denoising_autoencoder_formula = corruption_formula + autoencoder.rename(x='corrupted_x') def __call__(self, training_set=None): + """ Allocate and optionnaly train a model""" model = DenoisingAutoEncoderModel(self) if training_set: - print 'what do I do if training set????' + print 'DenoisingAutoEncoder(): what do I do if training_set????' + # copied from mlp_factory_approach: + if len(trainset) == sys.maxint: + raise NotImplementedError('Learning from infinite streams is not supported') + nval = int(self.validation_portion * len(trainset)) + nmin = len(trainset) - nval + assert nmin >= 0 + minset = trainset[:nmin] #real training set for minimizing loss + valset = trainset[nmin:] #validation set for early stopping + best = model + for stp in self.early_stopper(): + model.update( + minset.minibatches([input, target], minibatch_size=min(32, + len(trainset)))) + #print 'mlp.__call__(), we did an update' + if stp.set_score: + stp.score = model(valset, ['loss_01']) + if (stp.score < stp.best_score): + best = copy.copy(model) + model = best + # end of the copy from mlp_factory_approach + + return model + def compile(self, inputs, outputs): return theano.function(inputs,outputs,unpack_single=False,linker=self.linker)
--- a/mlp_factory_approach.py Thu May 22 19:08:46 2008 -0400 +++ b/mlp_factory_approach.py Fri May 23 12:22:54 2008 -0400 @@ -4,7 +4,7 @@ import theano from theano import tensor as t -import dataset, nnet_ops, stopper +from pylearn import dataset, nnet_ops, stopper def _randshape(*shape): @@ -31,18 +31,19 @@ """Update this model from more training data.""" params = self.params #TODO: why should we have to unpack target like this? + # tbm : creates problem... for input, target in input_target: rval= self.update_fn(input, target[:,0], *params) #print rval[0] - def __call__(self, testset, fieldnames=['output_class']): + def __call__(self, testset, fieldnames=['output_class'],input='input',target='target'): """Apply this model (as a function) to new data""" #TODO: cache fn between calls - assert 'input' == testset.fieldNames()[0] + assert input == testset.fieldNames()[0] # why first one??? assert len(testset.fieldNames()) <= 2 v = self.algo.v outputs = [getattr(v, name) for name in fieldnames] - inputs = [v.input] + ([v.target] if 'target' in testset else []) + inputs = [v.input] + ([v.target] if target in testset else []) inputs.extend(v.params) theano_fn = _cache(self._fn_cache, (tuple(inputs), tuple(outputs)), lambda: self.algo._fn(inputs, outputs))
--- a/nnet_ops.py Thu May 22 19:08:46 2008 -0400 +++ b/nnet_ops.py Fri May 23 12:22:54 2008 -0400 @@ -44,7 +44,7 @@ return ScalarSoftplus.static_impl(x) def grad(self, (x,), (gz,)): return [gz * scalar_sigmoid(x)] - def c_code(self, node, name, (x,), (z,), sub): + def c_code(self, name, node, (x,), (z,), sub): if node.inputs[0].type in [scalar.float32, scalar.float64]: return """%(z)s = %(x)s < -30.0