Mercurial > pylearn
changeset 435:eac0a7d44ff0
merge
author | Olivier Breuleux <breuleuo@iro.umontreal.ca> |
---|---|
date | Mon, 04 Aug 2008 16:29:30 -0400 |
parents | 0f366ecb11ee (current diff) 200a5b0e24ea (diff) |
children | d7ed780364b3 |
files | gradient_learner.py sandbox/simple_autoassociator/globals.py statscollector.py |
diffstat | 23 files changed, 1070 insertions(+), 273 deletions(-) [+] |
line wrap: on
line diff
--- a/_test_filetensor.py Mon Aug 04 16:21:59 2008 -0400 +++ b/_test_filetensor.py Mon Aug 04 16:29:30 2008 -0400 @@ -30,8 +30,12 @@ def test_filename(self): gen = numpy.random.rand(1) - write(self.fname, gen) - mat = read(self.fname, None, debug=False) #load from filename + f = file(self.fname, 'w') + write(f, gen) + f.close() + f = file(self.fname, 'r') + mat = read(f, None, debug=False) #load from filename + f.close() self.failUnless(gen.shape == mat.shape) self.failUnless(numpy.all(gen == mat))
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/_test_linear_regression.py Mon Aug 04 16:29:30 2008 -0400 @@ -0,0 +1,25 @@ + +import unittest +from linear_regression import * +from make_test_datasets import * +import numpy + +class test_linear_regression(unittest.TestCase): + + def test1(self): + trainset,testset,theta=make_artificial_datasets_from_function(n_inputs=3, + n_targets=2, + n_examples=100, + f=linear_predictor) + + assert trainset.fields()['input'].shape==(50,3) + assert testset.fields()['target'].shape==(50,2) + regressor = LinearRegression(L2_regularizer=0.1) + predictor = regressor(trainset) + test_data = testset.fields() + mse = predictor.compute_mse(test_data['input'],test_data['target']) + print 'mse = ',mse + +if __name__ == '__main__': + unittest.main() +
--- a/_test_nnet_ops.py Mon Aug 04 16:21:59 2008 -0400 +++ b/_test_nnet_ops.py Mon Aug 04 16:29:30 2008 -0400 @@ -1,5 +1,6 @@ import unittest +import theano import theano._test_tensor as TT import numpy @@ -35,6 +36,43 @@ return crossentropy_softmax_1hot(a, y_idx)[0:1] TT.verify_grad(self, Dummy(), [numpy.random.rand(3,4)]) +class T_prepend(unittest.TestCase): + def test0(self): + """basic functionality""" + x=tensor.matrix('x') + y=Prepend_scalar_constant_to_each_row(4.)(x) + f=theano.function([x],[y]) + m=numpy.random.rand(3,5) + my = f(m) + self.failUnless(my.shape == (3, 6), my.shape) + self.failUnless(numpy.all( my[:,0] == 4.0)) + + +class T_prepend(unittest.TestCase): + def test0(self): + """basic functionality""" + x=tensor.matrix('x') + y=Prepend_scalar_to_each_row()(5.,x) + f=theano.function([x],[y]) + m=numpy.ones((3,5),dtype="float32") + my = f(m) + self.failUnless(str(my.dtype) == 'float64') + self.failUnless(my.shape == (3, 6)) + self.failUnless(numpy.all(my[:,0] == 5.0)) + +class T_solve(unittest.TestCase): + def setUp(self): + self.rng = numpy.random.RandomState(666) + + def test0(self): + A=self.rng.randn(5,5) + b=numpy.array(range(5),dtype=float) + x=numpy.linalg.solve(A,b) + Ax = numpy.dot(A,x) + are = theano.gradient.numeric_grad.abs_rel_err(Ax, b) + self.failUnless(numpy.all(are < 1.0e-5), (are, Ax, b)) + #print A,b + #print numpy.dot(A,x) if __name__ == '__main__':
--- a/dataset.py Mon Aug 04 16:21:59 2008 -0400 +++ b/dataset.py Mon Aug 04 16:29:30 2008 -0400 @@ -220,7 +220,8 @@ Sub-classes which implement finite-length datasets should redefine this method. Some methods only make sense for finite-length datasets. """ - return None + from sys import maxint + return maxint class MinibatchToSingleExampleIterator(object): @@ -943,6 +944,9 @@ del self.fieldname2dataset[fieldname] self.fieldname2dataset[rename_field(fieldname,self.datasets[i],i)]=i + def __len__(self): + return len(self.datasets[0]) + def hasFields(self,*fieldnames): for fieldname in fieldnames: if not fieldname in self.fieldname2dataset: @@ -1223,13 +1227,12 @@ else: self.fields_columns[fieldname]=fieldcolumns elif type(fieldcolumns) is slice: - start,step=None,None - if not fieldcolumns.start: + start,step=fieldcolumns.start,fieldcolumns.step + if not start: start=0 - if not fieldcolumns.step: + if not step: step=1 - if start or step: - self.fields_columns[fieldname]=slice(start,fieldcolumns.stop,step) + self.fields_columns[fieldname]=slice(start,fieldcolumns.stop,step) elif hasattr(fieldcolumns,"__iter__"): # something like a list for i in fieldcolumns: assert i>=0 and i<data_array.shape[1] @@ -1451,12 +1454,14 @@ (it takes minibatches of inputs and produces minibatches of outputs, as documented in the class comment). - TBM: are filedtypes the old field types (from input_dataset) or the new ones + TBM: are fieldtypes the old field types (from input_dataset) or the new ones (for the new dataset created)? """ self.input_dataset=input_dataset self.function=function self.output_names=output_names + #print 'self.output_names in afds:', self.output_names + #print 'length in afds:', len(self.output_names) self.minibatch_mode=minibatch_mode DataSet.__init__(self,description,fieldtypes) self.valuesHStack = values_hstack if values_hstack else input_dataset.valuesHStack @@ -1481,9 +1486,10 @@ for input_example in input_examples] all_output_fields = zip(*output_examples) + #print 'output_names=', self.output_names + #print 'all_output_fields', all_output_fields + #print 'len(all_output_fields)=', len(all_output_fields) all_outputs = Example(self.output_names, all_output_fields) - #print 'input_fields', input_fields - #print 'all_outputs', all_outputs if fieldnames==self.output_names: rval = all_outputs else:
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/examples/linear_classifier.py Mon Aug 04 16:29:30 2008 -0400 @@ -0,0 +1,224 @@ +#! /usr/bin/env python +""" +T. Bertin-Mahieux (2008) University of Montreal +bertinmt@iro.umontreal.ca + +linear_classifier.py +Simple script that creates a linear_classifier, and +learns the paramters using backpropagation. + +This is to illustrate how to use theano/pylearn. +Anyone that knows how to make this script simpler/clearer is welcomed to +make the modifications. +""" + + +import os +import sys +import time +import copy +import pickle +import numpy +import numpy as N +import numpy.random as NR +from pylearn import cost +import theano +from theano import tensor as T + + +def cost_function(*args,**kwargs) : + """ default cost function, quadratic """ + return cost.quadratic(*args,**kwargs) + + +class modelgraph() : + """ class that contains the graph of the model """ + lr = T.scalar() # learning rate + inputs = T.matrix() # inputs (one example per line) + true_outputs = T.matrix() # outputs (one example per line) + W = T.matrix() # weights input * W + b= output + b = T.vector() # bias + outputs = T.dot(inputs,W) + b # output, one per line + costs = cost_function(true_outputs,outputs) # costs + g_W = T.grad(costs,W) # gradient of W + g_b = T.grad(costs,b) # gradient of b + new_W = T.sub_inplace(W, lr * g_W) # update inplace of W + new_b = T.sub_inplace(b, lr * g_b) # update inplace of b + + +class model() : + """ + The model! + Contains needed matrices, needed functions, and a link to the model graph. + """ + + def __init__(self,input_size,output_size) : + """ init matrix and bias, creates the graph, create a dict of compiled functions """ + # graph + self.graph = modelgraph() + # weights and bias, saved in self.params + seed = 666 + r = NR.RandomState(seed) + W = r.uniform(size = [input_size, output_size], low = -1/N.sqrt(input_size), high = 1/N.sqrt(input_size)) + b = numpy.zeros((output_size, )) + self.params = [W,b] + # dictionary of compiled functions + self.func_dict = dict() + # keep some init_infos (may not be necessary) + self.init_params = [input_size,output_size] + + + def update(self,lr,true_inputs,true_outputs) : + """ does an update of the model, one gradient descent """ + # do we already have the proper theano function? + if self.func_dict.has_key('update_func') : + self.func_dict['update_func'](lr,true_inputs,true_outputs,self.params[0],self.params[1]) + return + else : + # create the theano function, tell him what are the inputs and outputs) + func = theano.function([self.graph.lr,self.graph.inputs,self.graph.true_outputs, + self.graph.W, self.graph.b], + [self.graph.new_W,self.graph.new_b]) + # add function to dictionary, so we don't compile it again + self.func_dict['update_func'] = func + # use this function + func(lr,true_inputs,true_outputs,self.params[0],self.params[1]) + return + + def costs(self,true_inputs,true_outputs) : + """ get the costs for given examples, don't update """ + # do we already have the proper theano function? + if self.func_dict.has_key('costs_func') : + return self.func_dict['costs_func'](true_inputs,true_outputs,self.params[0],self.params[1]) + else : + # create the theano function, tell him what are the inputs and outputs) + func = theano.function([self.graph.inputs,self.graph.true_outputs,self.graph.W,self.graph.b], + [self.graph.costs]) + # add function to dictionary, se we don't compile it again + self.func_dict['costs_func'] = func + # use this function + return func(true_inputs,true_outputs,self.params[0],self.params[1]) + + def outputs(self,true_inputs) : + """ get the output for a set of examples (could be called 'predict') """ + # do we already have the proper theano function? + if self.func_dict.has_key('outputs_func') : + return self.func_dict['outputs_func'](true_inputs,self.params[0],self.params[1]) + else : + # create the theano function, tell him what are the inputs and outputs) + func = theano.function([self.graph.inputs, self.graph.W, self.graph.b], + [self.graph.outputs]) + # add function to dictionary, se we don't compile it again + self.func_dict['outputs_func'] = func + # use this function + return func(true_inputs,self.params[0],self.params[1]) + + def __getitem__(self,inputs) : + """ for simplicity, we can use the model this way: predictions = model[inputs] """ + return self.outputs(inputs) + + def __getstate__(self) : + """ + To save/copy the model, used by pickle.dump() and by copy.deepcopy(). + @return a dictionnary with the params (matrix + bias) + """ + d = dict() + d['params'] = self.params + d['init_params'] = self.init_params + return d + + def __setstate__(self,d) : + """ + Get the dictionary created by __getstate__(), use it to recreate the model. + """ + self.params = d['params'] + self.init_params = d['init_params'] + self.graph = modelgraph() # we did not save the model graph + + def __str__(self) : + """ returns a string representing the model """ + res = "Linear regressor, input size =",str(self.init_params[0]) + res += ", output size =", str(self.init_params[1]) + return res + + def __equal__(self,other) : + """ + Compares the model based on the params. + @return True if the params are the same, False otherwise + """ + # class + if not isinstance(other,model) : + return False + # input size + if self.params[0].shape[0] != other.params[0].shape[0] : + return False + # output size + if self.params[0].shape[1] != other.params[0].shape[1] : + return False + # actual values + if not (self.params[0] == other.params[0]).all(): + return False + if not (self.params[1] == other.params[1]).all(): + return False + # all good + return True + + +def die_with_usage() : + """ help menu """ + print 'simple script to illustrate how to use theano/pylearn' + print 'to launch:' + print ' python linear_classifier.py -launch' + sys.exit(0) + + + +#************************************************************ +# main + +if __name__ == '__main__' : + + if len(sys.argv) < 2 : + die_with_usage() + + # print create data + inputs = numpy.array([[.1,.2], + [.2,.8], + [.9,.3], + [.6,.5]]) + outputs = numpy.array([[0], + [0], + [1], + [1]]) + assert inputs.shape[0] == outputs.shape[0] + + # create model + m = model(2,1) + + # predict + print 'prediction before training:' + print m[inputs] + + # update it for 100 iterations + for k in range(50) : + m.update(.1,inputs,outputs) + + # predict + print 'prediction after training:' + print m[inputs] + + # show points + import pylab as P + colors = outputs.flatten().tolist() + x = inputs[:,0] + y = inputs[:,1] + P.plot(x[numpy.where(outputs==0)[0]],y[numpy.where(outputs==0)[0]],'r+') + P.plot(x[numpy.where(outputs==1)[0]],y[numpy.where(outputs==1)[0]],'b+') + # decision line + p1 = (.5 - m.params[1] * 1.) / m.params[0][1,0] # abs = 0 + p2 = (.5 - m.params[1] * 1.) / m.params[0][0,0] # ord = 0 + P.plot((0,p2[0],2*p2[0]),(p1[0],0,-p1[0]),'g-') + # show + P.axis([-1,2,-1,2]) + P.show() +
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/examples/theano_update.py Mon Aug 04 16:29:30 2008 -0400 @@ -0,0 +1,56 @@ +import theano +from theano import tensor + +import numpy + +# Two scalar symbolic variables +a = tensor.scalar() +b = tensor.scalar() + +# Definition of output symbolic variable +c = a * b +# Definition of the function computing it +fprop = theano.function([a,b], [c]) + +# Initialize numerical variables +a_val = numpy.array(12.) +b_val = numpy.array(2.) +print 'a_val =', a_val +print 'b_val =', b_val + +# Numerical value of output is returned by the call to "fprop" +c_val = fprop(a_val, b_val) +print 'c_val =', c_val + + +# Definition of simple update (increment by one) +new_b = b + 1 +update = theano.function([b], [new_b]) + +# New numerical value of b is returned by the call to "update" +b_val = update(b_val) +print 'new b_val =', b_val +# We can use the new value in "fprop" +c_val = fprop(a_val, b_val) +print 'c_val =', c_val + + +# Definition of in-place update (increment by one) +re_new_b = tensor.add_inplace(b, 1.) +re_update = theano.function([b], [re_new_b]) + +# "re_update" can be used the same way as "update" +b_val = re_update(b_val) +print 'new b_val =', b_val +# We can use the new value in "fprop" +c_val = fprop(a_val, b_val) +print 'c_val =', c_val + +# It is not necessary to keep the return value when the update is done in place +re_update(b_val) +print 'new b_val =', b_val +c_val = fprop(a_val, b_val) +print 'c_val =', c_val + + +
--- a/filetensor.py Mon Aug 04 16:21:59 2008 -0400 +++ b/filetensor.py Mon Aug 04 16:29:30 2008 -0400 @@ -56,7 +56,8 @@ def read(f, subtensor=None, debug=False): """Load all or part of file 'f' into a numpy ndarray - If f is a string, it will be treated as a filename, and opened in read mode. + @param f: file from which to read + @type f: file-like object If subtensor is not None, it should be like the argument to numpy.ndarray.__getitem__. The following two expressions should return @@ -74,10 +75,6 @@ s_array = numpy.fromstring(s, dtype='int32') return s_array.item() - if isinstance(f, str): - if debug: print 'f', f - f = file(f, 'r') - #what is the data type of this matrix? #magic_s = f.read(4) #magic = numpy.fromstring(magic_s, dtype='int32') @@ -116,15 +113,17 @@ def write(f, mat): """Write a numpy.ndarray to file. - If 'f' is a string, then it will be interpreted as a filename. This filename - will be opened in 'w+' mode, and (automatically) closed at the end of the function. + @param f: file into which to write + @type f: file-like object + + @param mat: array to write to file + @type mat: numpy ndarray or compatible + """ def _write_int32(f, i): i_array = numpy.asarray(i, dtype='int32') if 0: print 'writing int32', i, i_array i_array.tofile(f) - if isinstance(f, str): - f = file(f, 'w+') try: _write_int32(f, _dtype_magic[str(mat.dtype)])
--- a/gradient_learner.py Mon Aug 04 16:21:59 2008 -0400 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,71 +0,0 @@ - -from learner import * -from tensor import * -import gradient -from compile import Function - -class GradientLearner(Learner): - """ - Base class for gradient-based optimization of a training criterion - that can consist in two parts, an additive part over examples, and - an example-independent part (usually called the regularizer). - The user provides a Theano formula that maps the fields of a minibatch (each being a tensor with the - same number of rows = minibatch size) and parameters to output fields (for the use function), one of which - must be a cost that is the training criterion to be minimized. Subclasses implement - a training strategy that uses the Theano formula to compute gradients and - to compute outputs in the update method. - The inputs, parameters, and outputs are lists of Theano tensors, - while the example_wise_cost and regularization_term are Theano tensors. - The user can specify a regularization coefficient that multiplies the regularization term. - The training algorithm looks for parameters that minimize - regularization_coefficient * regularization_term(parameters) + - sum_{inputs in training_set} example_wise_cost(inputs,parameters) - i.e. the regularization_term should not depend on the inputs, only on the parameters. - The learned function can map a subset of inputs to a subset of outputs (as long as the inputs subset - includes all the inputs required in the Theano expression for the selected outputs). - It is assumed that all the inputs are provided in the training set (as dataset fields - with the corresponding name), but not necessarily when using the learned function. - """ - def __init__(self, inputs, parameters, outputs, example_wise_cost, regularization_term=astensor(0.0), - regularization_coefficient = astensor(1.0)): - self.inputs = inputs - self.outputs = outputs - self.parameters = parameters - self.example_wise_cost = example_wise_cost - self.regularization_term = regularization_term - self.regularization_coefficient = regularization_coefficient - self.parameters_example_wise_gradient = gradient.grad(example_wise_cost, parameters) - self.parameters_regularization_gradient = gradient.grad(self.regularization_coefficient * regularization_term, parameters) - if example_wise_cost not in outputs: - outputs.append(example_wise_cost) - if regularization_term not in outputs: - outputs.append(regularization_term) - self.example_wise_gradient_fn = Function(inputs + parameters, - [self.parameters_example_wise_gradient + self.parameters_regularization_gradient]) - self.use_functions = {frozenset([input.name for input in inputs]+[output.name for output in outputs]) - : Function(inputs, outputs)} - - def use(self,input_dataset,output_fields=None,copy_inputs=True): - # obtain the function that maps the desired inputs to desired outputs - input_fields = input_dataset.fieldNames() - # map names of input fields to Theano tensors in self.inputs - input_variables = ??? - if output_fields is None: output_fields = [output.name for output in outputs] - # handle special case of inputs that are directly copied into outputs - # map names of output fields to Theano tensors in self.outputs - output_variables = ??? - use_function_key = input_fields+output_fields - if not self.use_functions.has_key(use_function_key): - self.use_function[use_function_key]=Function(input_variables,output_variables) - use_function = self.use_functions[use_function_key] - # return a dataset that computes the outputs - return input_dataset.apply_function(use_function,input_fields,output_fields,copy_inputs,compute_now=True) - - -class StochasticGradientDescent(object): - def update_parameters(self): - -class StochasticGradientLearner(GradientLearner,StochasticGradientDescent): - def __init__(self,inputs, parameters, outputs, example_wise_cost, regularization_term=astensor(0.0), - regularization_coefficient = astensor(1.0),) - def update()
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/kernel_regression.py Mon Aug 04 16:29:30 2008 -0400 @@ -0,0 +1,231 @@ +""" +Implementation of kernel regression: +""" + +from pylearn.learner import OfflineLearningAlgorithm +from theano import tensor as T +from nnet_ops import prepend_1_to_each_row +from theano.scalar import as_scalar +from common.autoname import AutoName +import theano +import numpy + +# map a N-vector to a 1xN matrix +row_vector = theano.elemwise.DimShuffle((False,),['x',0]) +# map a N-vector to a Nx1 matrix +col_vector = theano.elemwise.DimShuffle((False,),[0,'x']) + +class KernelRegression(OfflineLearningAlgorithm): + """ +Implementation of kernel regression: +* the data are n (x_t,y_t) pairs and we want to estimate E[y|x] +* the predictor computes + f(x) = b + \sum_{t=1}^n \alpha_t K(x,x_t) + with free parameters b and alpha, training inputs x_t, + and kernel function K (gaussian by default). + Clearly, each prediction involves O(n) computations. +* the learner chooses b and alpha to minimize + lambda alpha' G' G alpha + \sum_{t=1}^n (f(x_t)-y_t)^2 + where G is the matrix with entries G_ij = K(x_i,x_j). + The first (L2 regularization) term is the squared L2 + norm of the primal weights w = \sum_t \alpha_t phi(x_t) + where phi is the function s.t. K(u,v)=phi(u).phi(v). +* this involves solving a linear system with (n+1,n+1) + matrix, which is an O(n^3) computation. In addition, + that linear system matrix requires O(n^2) memory. + So this learning algorithm should be used only for + small datasets. +* the linear system is + (M + lambda I_n) theta = (1, y)' + where theta = (b, alpha), I_n is the (n+1)x(n+1) matrix that is the identity + except with a 0 at (0,0), M is the matrix with G in the sub-matrix starting + at (1,1), 1's in column 0, except for a value of n at (0,0), and sum_i G_{i,j} + in the rest of row 0. + +Note that this is gives an estimate of E[y|x,training_set] that is the +same as obtained with a Gaussian process regression. The GP +regression would also provide a Bayesian Var[y|x,training_set]. +It corresponds to an assumption that f is a random variable +with Gaussian (process) prior distribution with covariance +function K. Because we assume Gaussian noise we obtain a Gaussian +posterior for f (whose mean is computed here). + + + Usage: + + kernel_regressor=KernelRegression(L2_regularizer=0.1,gamma=0.5) (kernel=GaussianKernel(gamma=0.5)) + kernel_predictor=kernel_regressor(training_set) + all_results_dataset=kernel_predictor(test_set) # creates a dataset with "output" and "squared_error" field + outputs = kernel_predictor.compute_outputs(inputs) # inputs and outputs are numpy arrays + outputs, errors = kernel_predictor.compute_outputs_and_errors(inputs,targets) + errors = kernel_predictor.compute_errors(inputs,targets) + mse = kernel_predictor.compute_mse(inputs,targets) + + + + The training_set must have fields "input" and "target". + The test_set must have field "input", and needs "target" if + we want to compute the squared errors. + + The predictor parameters are obtained analytically from the training set. + Training is only done on a whole training set rather than on minibatches + (no online implementation). + + The dataset fields expected and produced by the learning algorithm and the trained model + are the following: + + - Input and output dataset fields (example-wise quantities): + + - 'input' (always expected as an input_dataset field) + - 'target' (always expected by the learning algorithm, optional for learned model) + - 'output' (always produced by learned model) + - 'squared_error' (optionally produced by learned model if 'target' is provided) + = example-wise squared error + """ + def __init__(self, kernel=None, L2_regularizer=0, gamma=1, use_bias=False): + # THE VERSION WITH BIAS DOES NOT SEEM RIGHT + self.kernel = kernel + self.L2_regularizer=L2_regularizer + self.use_bias=use_bias + self.gamma = gamma # until we fix things, the kernel type is fixed, Gaussian + self.equations = KernelRegressionEquations() + + def __call__(self,trainset): + n_examples = len(trainset) + first_example = trainset[0] + n_inputs = first_example['input'].size + n_outputs = first_example['target'].size + b1=1 if self.use_bias else 0 + M = numpy.zeros((n_examples+b1,n_examples+b1)) + Y = numpy.zeros((n_examples+b1,n_outputs)) + for i in xrange(n_examples): + M[i+b1,i+b1]=self.L2_regularizer + data = trainset.fields() + train_inputs = numpy.array(data['input']) + if self.use_bias: + Y[0]=1 + Y[b1:,:] = numpy.array(data['target']) + train_inputs_square,sumG,G=self.equations.compute_system_matrix(train_inputs,self.gamma) + M[b1:,b1:] += G + if self.use_bias: + M[0,1:] = sumG + M[1:,0] = 1 + M[0,0] = M.shape[0] + self.M=M + self.Y=Y + theta=numpy.linalg.solve(M,Y) + return KernelPredictor(theta,self.gamma, train_inputs, train_inputs_square) + +class KernelPredictorEquations(AutoName): + train_inputs = T.matrix() # n_examples x n_inputs + train_inputs_square = T.vector() # n_examples + inputs = T.matrix() # minibatchsize x n_inputs + targets = T.matrix() # minibatchsize x n_outputs + theta = T.matrix() # (n_examples+1) x n_outputs + b1 = T.shape(train_inputs_square)[0]<T.shape(theta)[0] + gamma = T.scalar() + inv_gamma2 = 1./(gamma*gamma) + b = b1*theta[0] + alpha = theta[b1:,:] + inputs_square = T.sum(inputs*inputs,axis=1) + Kx = T.exp(-(row_vector(train_inputs_square)-2*T.dot(inputs,train_inputs.T)+col_vector(inputs_square))*inv_gamma2) + outputs = T.dot(Kx,alpha) + b # minibatchsize x n_outputs + squared_errors = T.sum(T.sqr(targets-outputs),axis=1) + + __compiled = False + @classmethod + def compile(cls,linker='c|py'): + if cls.__compiled: + return + def fn(input_vars,output_vars): + return staticmethod(theano.function(input_vars,output_vars, linker=linker)) + + cls.compute_outputs = fn([cls.inputs,cls.theta,cls.gamma,cls.train_inputs,cls.train_inputs_square],[cls.outputs]) + cls.compute_errors = fn([cls.outputs,cls.targets],[cls.squared_errors]) + + cls.__compiled = True + + def __init__(self): + self.compile() + +class KernelRegressionEquations(KernelPredictorEquations): + #M = T.matrix() # (n_examples+1) x (n_examples+1) + inputs = T.matrix() # n_examples x n_inputs + gamma = T.scalar() + inv_gamma2 = 1./(gamma*gamma) + inputs_square = T.sum(inputs*inputs,axis=1) + #new_G = G+T.dot(inputs,inputs.T) + #new_G = T.gemm(G,1.,inputs,inputs.T,1.) + G = T.exp(-(row_vector(inputs_square)-2*T.dot(inputs,inputs.T)+col_vector(inputs_square))*inv_gamma2) + sumG = T.sum(G,axis=0) + + __compiled = False + + @classmethod + def compile(cls,linker='c|py'): + if cls.__compiled: + return + def fn(input_vars,output_vars): + return staticmethod(theano.function(input_vars,output_vars, linker=linker)) + + cls.compute_system_matrix = fn([cls.inputs,cls.gamma],[cls.inputs_square,cls.sumG,cls.G]) + + cls.__compiled = True + + def __init__(self): + self.compile() + +class KernelPredictor(object): + """ + A kernel predictor has parameters theta (a bias vector and a weight matrix alpha) + it can use to make a non-linear prediction (according to the KernelPredictorEquations). + It can compute its output (bias + alpha * kernel(train_inputs,input) and a squared error (||output - target||^2). + """ + def __init__(self, theta, gamma, train_inputs, train_inputs_square=None): + self.theta=theta + self.gamma=gamma + self.train_inputs=train_inputs + if train_inputs_square==None: + train_inputs_square = numpy.sum(train_inputs*train_inputs,axis=1) + self.train_inputs_square=train_inputs_square + self.equations = KernelPredictorEquations() + + def compute_outputs(self,inputs): + return self.equations.compute_outputs(inputs,self.theta,self.gamma,self.train_inputs,self.train_inputs_square) + def compute_errors(self,inputs,targets): + return self.equations.compute_errors(self.compute_outputs(inputs),targets) + def compute_outputs_and_errors(self,inputs,targets): + outputs = self.compute_outputs(inputs) + return [outputs,self.equations.compute_errors(outputs,targets)] + def compute_mse(self,inputs,targets): + errors = self.compute_errors(inputs,targets) + return numpy.sum(errors)/errors.size + + def __call__(self,dataset,output_fieldnames=None,cached_output_dataset=False): + assert dataset.hasFields(["input"]) + if output_fieldnames is None: + if dataset.hasFields(["target"]): + output_fieldnames = ["output","squared_error"] + else: + output_fieldnames = ["output"] + output_fieldnames.sort() + if output_fieldnames == ["squared_error"]: + f = self.compute_errors + elif output_fieldnames == ["output"]: + f = self.compute_outputs + elif output_fieldnames == ["output","squared_error"]: + f = self.compute_outputs_and_errors + else: + raise ValueError("unknown field(s) in output_fieldnames: "+str(output_fieldnames)) + + ds=ApplyFunctionDataSet(dataset,f,output_fieldnames) + if cached_output_dataset: + return CachedDataSet(ds) + else: + return ds + + +def kernel_predictor(inputs,params,*otherargs): + p = KernelPredictor(params,*otherargs[0]) + return p.compute_outputs(inputs) +
--- a/linear_regression.py Mon Aug 04 16:21:59 2008 -0400 +++ b/linear_regression.py Mon Aug 04 16:29:30 2008 -0400 @@ -4,9 +4,9 @@ the use of theano. """ -from pylearn.learner import OfflineLearningAlgorithm +from pylearn.learner import OfflineLearningAlgorithm,OnlineLearningAlgorithm from theano import tensor as T -from theano.others_ops import prepend_1_to_each_row +from nnet_ops import prepend_1_to_each_row from theano.scalar import as_scalar from common.autoname import AutoName import theano @@ -34,11 +34,6 @@ we want to compute the squared errors. The predictor parameters are obtained analytically from the training set. - Training can proceed sequentially (with multiple calls to update with - different disjoint subsets of the training sets). After each call to - update the predictor is ready to be used (and optimized for the union - of all the training sets passed to update since construction or since - the last call to forget). For each (input[t],output[t]) pair in a minibatch,:: @@ -74,7 +69,7 @@ def __init__(self, L2_regularizer=0,minibatch_size=10000): self.L2_regularizer=L2_regularizer self.equations = LinearRegressionEquations() - self.minibatch_size=1000 + self.minibatch_size=minibatch_size def __call__(self,trainset): first_example = trainset[0] @@ -186,3 +181,21 @@ return ds +def linear_predictor(inputs,params,*otherargs): + p = LinearPredictor(params) + return p.compute_outputs(inputs) + +#TODO : an online version +class OnlineLinearRegression(OnlineLearningAlgorithm): + """ + Training can proceed sequentially (with multiple calls to update with + different disjoint subsets of the training sets). After each call to + update the predictor is ready to be used (and optimized for the union + of all the training sets passed to update since construction or since + the last call to forget). + """ + pass + + + +
--- a/lookup_list.py Mon Aug 04 16:21:59 2008 -0400 +++ b/lookup_list.py Mon Aug 04 16:29:30 2008 -0400 @@ -29,6 +29,10 @@ U{http://epydoc.sourceforge.net/manual-epytext.html#doctest-blocks} """ def __init__(self,names=[],values=[]): + #print 'values=', values + #print 'length=', len(values) + #print 'names=', names + #print 'length=',len(names) assert len(values)==len(names) self.__dict__['_values']=values self.__dict__['_name2index']={}
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/make_test_datasets.py Mon Aug 04 16:29:30 2008 -0400 @@ -0,0 +1,87 @@ +from pylearn.dataset import ArrayDataSet +from shapeset.dset import Polygons +from linear_regression import linear_predictor +from kernel_regression import kernel_predictor +from numpy import * + +""" +General-purpose code to generate artificial datasets that can be used +to test different learning algorithms. +""" + +def make_triangles_rectangles_datasets(n_examples=600,train_frac=0.5,image_size=(10,10)): + """ + Make a binary classification dataset to discriminate triangle images from rectangle images. + """ + def convert_dataset(dset): + # convert the n_vert==3 into target==0 and n_vert==4 into target==1 + def mapf(images,n_vertices): + n=len(n_vertices) + targets = ndarray((n,1),dtype='float64') + for i in xrange(n): + targets[i,0] = array([0. if vertices[i]==3 else 1.],dtype='float64') + return images.reshape(len(images),images[0].size).astype('float64'),targets + return dataset.CachedDataSet(dataset.ApplyFunctionDataSet(dset("image","nvert"),mapf,["input","target"]),True) + + p=Polygons(image_size,[3,4],fg_min=1./255,fg_max=1./255,rot_max=1.,scale_min=0.35,scale_max=0.9,pos_min=0.1, pos_max=0.9) + data = p.subset[0:n_examples] + save_polygon_data(data,"shapes") + n_train=int(n_examples*train_frac) + trainset=convert_dataset(data.subset[0:n_train]) + testset=convert_dataset(data.subset[n_train:n_examples]) + return trainset,testset + +def make_artificial_datasets_from_function(n_inputs=1, + n_targets=1, + n_examples=20, + train_frac=0.5, + noise_level=0.1, # add Gaussian noise, noise_level=sigma + params_shape=None, + f=None, # function computing E[Y|X] + otherargs=None, # extra args to f + b=None): # force theta[0] with this value + """ + Make regression data of the form + Y | X ~ Normal(f(X,theta,otherargs),noise_level^2) + If n_inputs==1 then X is chosen at regular locations on the [-1,1] interval. + Otherwise X is sampled according to a Normal(0,1) on all dimensions (independently). + The parameters theta is a matrix of shape params_shape that is sampled from Normal(0,1). + Optionally theta[0] is set to the argument 'b', if b is provided. + + Return a training set and a test set, by splitting the generated n_examples + according to the 'train_frac'tion. + """ + n_train=int(train_frac*n_examples) + n_test=n_examples-n_train + if n_inputs==1: + delta1=2./n_train + delta2=2./n_test + inputs = vstack((array(zip(range(n_train)))*delta1-1, + 0.5*delta2+array(zip(range(n_test)))*delta2-1)) + else: + inputs = random.normal(size=(n_examples,n_inputs)) + if not f: + f = linear_predictor + if f==kernel_predictor and not otherargs[1]: + otherargs=(otherargs[0],inputs[0:n_train]) + if not params_shape: + if f==linear_predictor: + params_shape = (n_inputs+1,n_targets) + elif f==kernel_predictor: + params_shape = (otherargs[1].shape[0]+1,n_targets) + theta = random.normal(size=params_shape) if params_shape else None + if b: + theta[0]=b + outputs = f(inputs,theta,otherargs) + targets = outputs + random.normal(scale=noise_level,size=(n_examples,n_targets)) + # the | stacking creates a strange bug in LookupList constructor: + # trainset = ArrayDataSet(inputs[0:n_examples/2],{'input':slice(0,n_inputs)}) | \ + # ArrayDataSet(targets[0:n_examples/2],{'target':slice(0,n_targets)}) + # testset = ArrayDataSet(inputs[n_examples/2:],{'input':slice(0,n_inputs)}) | \ + # ArrayDataSet(targets[n_examples/2:],{'target':slice(0,n_targets)}) + data = hstack((inputs,targets)) + trainset = ArrayDataSet(data[0:n_train], + {'input':slice(0,n_inputs),'target':slice(n_inputs,n_inputs+n_targets)}) + testset = ArrayDataSet(data[n_train:], + {'input':slice(0,n_inputs),'target':slice(n_inputs,n_inputs+n_targets)}) + return trainset,testset,theta
--- a/nnet_ops.py Mon Aug 04 16:21:59 2008 -0400 +++ b/nnet_ops.py Mon Aug 04 16:29:30 2008 -0400 @@ -1,3 +1,6 @@ +## This file contain ops that are not currently integrated in the core of threano. +## Not all of those ops have been thoroughly tested. + import theano from theano import tensor, scalar import numpy @@ -387,3 +390,104 @@ @todo: Rewrite as a scalar, and then broadcast to tensor. """ return -(target * tensor.log(output) + (1 - target) * tensor.log(1 - output)) + + + +class Prepend_scalar_constant_to_each_row(theano.Op): + def __init__(self, val = 0): + if isinstance(val, float): + val = scalar.constant(val) + self.val = val + + def make_node(self, mat): + #check type of input + if not isinstance(mat,theano.Result) or not mat.type==tensor.matrix().type: + raise TypeError("Expected a matrix as input") + x = tensor.as_tensor(mat) + y = tensor.as_tensor(self.val) + if x.type.dtype != y.type.dtype: + TypeError("the value to prepend don't have the same type as the matrix") + + node = theano.Apply(op=self, inputs=[mat], outputs=[tensor.matrix()]) + return node + + def perform(self, node, (mat, ), (output, )): + new_shape=(mat.shape[0],mat.shape[1]+1) + if output[0] == None: + output[0]=numpy.empty(new_shape,dtype=mat.dtype) + out=output[0] + else: + if output[0].shape!=new_shape: + try: + output[0].resize(new_shape) + except: + output[0]=numpy.empty(new_shape, dtype=mat.dtype) + out=output[0] + + out[:,0].fill(self.val.data) + out[:,1:]=mat + + def grad(self, (mat,), (goutput,)): + return goutput[:,1:] + +class Prepend_scalar_to_each_row(theano.Op): + def make_node(self, val, mat): + #check type of input + if isinstance(val, float): + val = scalar.constant(val) + if not isinstance(mat,theano.Result) or not mat.type==tensor.matrix().type: + raise TypeError("Expected a matrix as input") + x = tensor.as_tensor(mat) + y = tensor.as_tensor(val) + if x.type.dtype != y.type.dtype: + TypeError("the value to prepend don't have the same type as the matrix") + + node = theano.Apply(op=self, inputs=[val,mat], outputs=[tensor.matrix()]) + return node + + def perform(self, node, (val,mat), (output, )): + new_shape=(mat.shape[0],mat.shape[1]+1) + if output[0] == None: + output[0]=numpy.empty(new_shape,dtype=mat.dtype) + out=output[0] + else: + if output[0].shape!=new_shape: + try: + output[0].resize(new_shape) + except: + output[0]=numpy.empty(new_shape, dtype=mat.dtype) + out=output[0] + out[:,0].fill(val) + out[:,1:]=mat + + def grad(self, (val, mat), (goutput,)): + return goutput[:,0], goutput[:,1:] + +prepend_scalar_to_each_row = Prepend_scalar_to_each_row() +prepend_0_to_each_row = Prepend_scalar_constant_to_each_row(0.) +prepend_1_to_each_row = Prepend_scalar_constant_to_each_row(1.) + +class solve(theano.Op): + """ + Find the solution to the linear equation Ax=b, + where A is a 2d matrix and b is a 1d or 2d matrix. + It use numpy.solve to find the solution. + """ + + def make_node(self, A, b): + if not isinstance(A, theano.Result) or not A.type==tensor.matrix().type: + raise TypeError("We expected that A had a matrix type") + if not isinstance(B, theano.Result) or not B.type==tensor.matrix().type: + raise TypeError("We expected that B had a matrix type") + + node = theano.Apply(op=self, inputs=[A, B], outputs=[tensor.matrix()]) + return node + + def perform(self, node, (A, B), (output, )): + ret=numpy.solve(A,B) + output[0]=ret + + def grad(self, (theta, A, B), (gtheta,)): + raise NotImplementedError() + +
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/sandbox/gradient_learner.py Mon Aug 04 16:29:30 2008 -0400 @@ -0,0 +1,71 @@ + +from learner import * +from tensor import * +import gradient +from compile import Function + +class GradientLearner(Learner): + """ + Base class for gradient-based optimization of a training criterion + that can consist in two parts, an additive part over examples, and + an example-independent part (usually called the regularizer). + The user provides a Theano formula that maps the fields of a minibatch (each being a tensor with the + same number of rows = minibatch size) and parameters to output fields (for the use function), one of which + must be a cost that is the training criterion to be minimized. Subclasses implement + a training strategy that uses the Theano formula to compute gradients and + to compute outputs in the update method. + The inputs, parameters, and outputs are lists of Theano tensors, + while the example_wise_cost and regularization_term are Theano tensors. + The user can specify a regularization coefficient that multiplies the regularization term. + The training algorithm looks for parameters that minimize + regularization_coefficient * regularization_term(parameters) + + sum_{inputs in training_set} example_wise_cost(inputs,parameters) + i.e. the regularization_term should not depend on the inputs, only on the parameters. + The learned function can map a subset of inputs to a subset of outputs (as long as the inputs subset + includes all the inputs required in the Theano expression for the selected outputs). + It is assumed that all the inputs are provided in the training set (as dataset fields + with the corresponding name), but not necessarily when using the learned function. + """ + def __init__(self, inputs, parameters, outputs, example_wise_cost, regularization_term=astensor(0.0), + regularization_coefficient = astensor(1.0)): + self.inputs = inputs + self.outputs = outputs + self.parameters = parameters + self.example_wise_cost = example_wise_cost + self.regularization_term = regularization_term + self.regularization_coefficient = regularization_coefficient + self.parameters_example_wise_gradient = gradient.grad(example_wise_cost, parameters) + self.parameters_regularization_gradient = gradient.grad(self.regularization_coefficient * regularization_term, parameters) + if example_wise_cost not in outputs: + outputs.append(example_wise_cost) + if regularization_term not in outputs: + outputs.append(regularization_term) + self.example_wise_gradient_fn = Function(inputs + parameters, + [self.parameters_example_wise_gradient + self.parameters_regularization_gradient]) + self.use_functions = {frozenset([input.name for input in inputs]+[output.name for output in outputs]) + : Function(inputs, outputs)} + + def use(self,input_dataset,output_fields=None,copy_inputs=True): + # obtain the function that maps the desired inputs to desired outputs + input_fields = input_dataset.fieldNames() + # map names of input fields to Theano tensors in self.inputs + input_variables = ??? + if output_fields is None: output_fields = [output.name for output in outputs] + # handle special case of inputs that are directly copied into outputs + # map names of output fields to Theano tensors in self.outputs + output_variables = ??? + use_function_key = input_fields+output_fields + if not self.use_functions.has_key(use_function_key): + self.use_function[use_function_key]=Function(input_variables,output_variables) + use_function = self.use_functions[use_function_key] + # return a dataset that computes the outputs + return input_dataset.apply_function(use_function,input_fields,output_fields,copy_inputs,compute_now=True) + + +class StochasticGradientDescent(object): + def update_parameters(self): + +class StochasticGradientLearner(GradientLearner,StochasticGradientDescent): + def __init__(self,inputs, parameters, outputs, example_wise_cost, regularization_term=astensor(0.0), + regularization_coefficient = astensor(1.0),) + def update()
--- a/sandbox/rbm/model.py Mon Aug 04 16:21:59 2008 -0400 +++ b/sandbox/rbm/model.py Mon Aug 04 16:29:30 2008 -0400 @@ -59,7 +59,7 @@ random.seed(random_seed) - self.parameters = parameters.Parameters(input_dimension=self.input_dimension, hidden_dimension=self.hidden_dimension, randomly_initialize=False, random_seed=self.random_seed) + self.parameters = parameters.Parameters(input_dimension=self.input_dimension, hidden_dimension=self.hidden_dimension, randomly_initialize=True, random_seed=self.random_seed) self.prev_dw = 0 self.prev_db = 0 self.prev_dc = 0
--- a/sandbox/simple_autoassociator/README.txt Mon Aug 04 16:21:59 2008 -0400 +++ b/sandbox/simple_autoassociator/README.txt Mon Aug 04 16:29:30 2008 -0400 @@ -1,1 +1,5 @@ This seems to work. + +@todo: + * Add momentum. + * Add learning rate decay schedule.
--- a/sandbox/simple_autoassociator/globals.py Mon Aug 04 16:21:59 2008 -0400 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,12 +0,0 @@ -""" -Global variables. -""" - -#INPUT_DIMENSION = 1000 -#INPUT_DIMENSION = 100 -INPUT_DIMENSION = 4 -#HIDDEN_DIMENSION = 10 -HIDDEN_DIMENSION = 1 -LEARNING_RATE = 0.1 -LR = LEARNING_RATE -SEED = 666
--- a/sandbox/simple_autoassociator/graph.py Mon Aug 04 16:21:59 2008 -0400 +++ b/sandbox/simple_autoassociator/graph.py Mon Aug 04 16:29:30 2008 -0400 @@ -6,7 +6,7 @@ from pylearn.nnet_ops import sigmoid, binary_crossentropy from theano import tensor as t from theano.tensor import dot -x = t.dvector() +x = t.dmatrix() w1 = t.dmatrix() b1 = t.dvector() w2 = t.dmatrix()
--- a/sandbox/simple_autoassociator/main.py Mon Aug 04 16:21:59 2008 -0400 +++ b/sandbox/simple_autoassociator/main.py Mon Aug 04 16:29:30 2008 -0400 @@ -7,9 +7,6 @@ y = sigmoid(dot(h, w2) + b2) Binary xent loss. - - LIMITATIONS: - - Only does pure stochastic gradient (batchsize = 1). """ @@ -24,11 +21,11 @@ ##nonzero_instances.append({1: 0.2, 2: 0.3, 5: 0.5}) import model -model = model.Model() +model = model.Model(input_dimension=10, hidden_dimension=4) for i in xrange(100000): - # Select an instance - instance = nonzero_instances[i % len(nonzero_instances)] +# # Select an instance +# instance = nonzero_instances[i % len(nonzero_instances)] - # SGD update over instance - model.update(instance) + # Update over instance + model.update(nonzero_instances)
--- a/sandbox/simple_autoassociator/model.py Mon Aug 04 16:21:59 2008 -0400 +++ b/sandbox/simple_autoassociator/model.py Mon Aug 04 16:29:30 2008 -0400 @@ -6,27 +6,42 @@ from graph import trainfn import parameters -import globals -from globals import LR - import numpy import random -random.seed(globals.SEED) + +import pylearn.sparse_instance class Model: - def __init__(self): - self.parameters = parameters.Parameters(randomly_initialize=True) + """ + @todo: Add momentum. + @todo: Add learning rate decay schedule. + """ + def __init__(self, input_dimension, hidden_dimension, learning_rate = 0.1, weight_decay = 0.0002, random_seed = 666): + self.input_dimension = input_dimension + self.hidden_dimension = hidden_dimension + self.learning_rate = learning_rate + self.weight_decay = weight_decay + self.random_seed = random_seed - def update(self, instance): + random.seed(random_seed) + + self.parameters = parameters.Parameters(input_dimension=self.input_dimension, hidden_dimension=self.hidden_dimension, randomly_initialize=True, random_seed=self.random_seed) + + def deterministic_reconstruction(self, x): + (y, h, loss, gw1, gb1, gw2, gb2) = trainfn(x, self.parameters.w1, self.parameters.b1, self.parameters.w2, self.parameters.b2) + return y + + def update(self, instances): """ Update the L{Model} using one training instance. - @param instance: A dict from feature index to (non-zero) value. + @param instances: A list of dict from feature index to (non-zero) value. @todo: Should assert that nonzero_indices and zero_indices are correct (i.e. are truly nonzero/zero). + @todo: Multiply L{self.weight_decay} by L{self.learning_rate}, as done in Semantic Hashing? + @todo: Decay the biases too? """ - x = numpy.zeros(globals.INPUT_DIMENSION) - for idx in instance.keys(): - x[idx] = instance[idx] + minibatch = len(instances) + x = pylearn.sparse_instance.to_vector(instances, self.input_dimension) (y, h, loss, gw1, gb1, gw2, gb2) = trainfn(x, self.parameters.w1, self.parameters.b1, self.parameters.w2, self.parameters.b2) # print @@ -39,15 +54,18 @@ # print "gw2:", gw2 # print "gb2:", gb2 - # SGD update - self.parameters.w1 -= LR * gw1 - self.parameters.b1 -= LR * gb1 - self.parameters.w2 -= LR * gw2 - self.parameters.b2 -= LR * gb2 + self.parameters.w1 *= (1 - self.weight_decay) + self.parameters.w2 *= (1 - self.weight_decay) - # Recompute the loss, to make sure it's descreasing - (y, h, loss, gw1, gb1, gw2, gb2) = trainfn(x, self.parameters.w1, self.parameters.b1, self.parameters.w2, self.parameters.b2) -# print "NEW y:", y - print "NEW total loss:", loss -# print "h:", h -# print self.parameters + # SGD update + self.parameters.w1 -= self.learning_rate * gw1 / minibatch + self.parameters.b1 -= self.learning_rate * gb1 / minibatch + self.parameters.w2 -= self.learning_rate * gw2 / minibatch + self.parameters.b2 -= self.learning_rate * gb2 / minibatch + +# # Recompute the loss, to make sure it's descreasing +# (y, h, loss, gw1, gb1, gw2, gb2) = trainfn(x, self.parameters.w1, self.parameters.b1, self.parameters.w2, self.parameters.b2) +## print "NEW y:", y +# print "NEW total loss:", loss +## print "h:", h +## print self.parameters
--- a/sandbox/simple_autoassociator/parameters.py Mon Aug 04 16:21:59 2008 -0400 +++ b/sandbox/simple_autoassociator/parameters.py Mon Aug 04 16:29:30 2008 -0400 @@ -3,20 +3,19 @@ """ import numpy -import globals class Parameters: """ Parameters used by the L{Model}. """ - def __init__(self, input_dimension=globals.INPUT_DIMENSION, hidden_dimension=globals.HIDDEN_DIMENSION, randomly_initialize=False, seed=globals.SEED): + def __init__(self, input_dimension, hidden_dimension, randomly_initialize, random_seed): """ Initialize L{Model} parameters. @param randomly_initialize: If True, then randomly initialize according to the given seed. If False, then just use zeroes. """ if randomly_initialize: - numpy.random.seed(seed) + numpy.random.seed(random_seed) self.w1 = (numpy.random.rand(input_dimension, hidden_dimension)-0.5)/input_dimension self.w2 = (numpy.random.rand(hidden_dimension, input_dimension)-0.5)/hidden_dimension self.b1 = numpy.zeros(hidden_dimension)
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/sandbox/statscollector.py Mon Aug 04 16:29:30 2008 -0400 @@ -0,0 +1,127 @@ + +# Here is how I see stats collectors: + +def my_stats(graph): + graph.mse=examplewise_mean(square_norm(graph.residue)) + graph.training_loss=graph.regularizer+examplewise_sum(graph.nll) + return [graph.mse,graph.training_loss] + + +# def my_stats(residue,nll,regularizer): +# mse=examplewise_mean(square_norm(residue)) +# training_loss=regularizer+examplewise_sum(nll) +# set_names(locals()) +# return ((residue,nll),(regularizer),(),(mse,training_loss)) +# my_stats_collector = make_stats_collector(my_stats) +# +# where make_stats_collector calls my_stats(examplewise_fields, attributes) to +# construct its update function, and figure out what are the input fields (here "residue" +# and "nll") and input attributes (here "regularizer") it needs, and the output +# attributes that it computes (here "mse" and "training_loss"). Remember that +# fields are examplewise quantities, but attributes are not, in my jargon. +# In the above example, I am highlighting that some operations done in my_stats +# are examplewise and some are not. I am hoping that theano Ops can do these +# kinds of internal side-effect operations (and proper initialization of these hidden +# variables). I expect that a StatsCollector (returned by make_stats_collector) +# knows the following methods: +# stats_collector.input_fieldnames +# stats_collector.input_attribute_names +# stats_collector.output_attribute_names +# stats_collector.update(mini_dataset) +# stats_collector['mse'] +# where mini_dataset has the input_fieldnames() as fields and the input_attribute_names() +# as attributes, and in the resulting dataset the output_attribute_names() are set to the +# proper numeric values. + + + +import theano +from theano import tensor as t +from Learner import Learner +from lookup_list import LookupList + +class StatsCollectorModel(AttributesHolder): + def __init__(self,stats_collector): + self.stats_collector = stats_collector + self.outputs = LookupList(stats_collector.output_names,[None for name in stats_collector.output_names]) + # the statistics get initialized here + self.update_function = theano.function(input_attributes+input_fields,output_attributes+output_fields,linker="c|py") + for name,value in self.outputs.items(): + self.__setattribute__(name,value) + def update(self,dataset): + input_fields = dataset.fields()(self.stats_collector.input_field_names) + input_attributes = dataset.getAttributes(self.stats_collector.input_attribute_names) + self.outputs._values = self.update_function(input_attributes+input_fields) + for name,value in self.outputs.items(): + self.__setattribute__(name,value) + def __call__(self): + return self.outputs + def attributeNames(self): + return self.outputs.keys() + +class StatsCollector(AttributesHolder): + + def __init__(self,input_attributes, input_fields, outputs): + self.input_attributes = input_attributes + self.input_fields = input_fields + self.outputs = outputs + self.input_attribute_names = [v.name for v in input_attributes] + self.input_field_names = [v.name for v in input_fields] + self.output_names = [v.name for v in output_attributes] + + def __call__(self,dataset=None): + model = StatsCollectorModel(self) + if dataset: + self.update(dataset) + return model + +if __name__ == '__main__': + def my_statscollector(): + regularizer = t.scalar() + nll = t.matrix() + class_error = t.matrix() + total_loss = regularizer+t.examplewise_sum(nll) + avg_nll = t.examplewise_mean(nll) + avg_class_error = t.examplewise_mean(class_error) + for name,val in locals().items(): val.name = name + return StatsCollector([regularizer],[nll,class_error],[total_loss,avg_nll,avg_class_error]) + + + + +# OLD DESIGN: +# +# class StatsCollector(object): +# """A StatsCollector object is used to record performance statistics during training +# or testing of a learner. It can be configured to measure different things and +# accumulate the appropriate statistics. From these statistics it can be interrogated +# to obtain performance measures of interest (such as maxima, minima, mean, standard +# deviation, standard error, etc.). Optionally, the observations can be weighted +# (yielded weighted mean, weighted variance, etc., where applicable). The statistics +# that are desired can be specified among a list supported by the StatsCollector +# class or subclass. When some statistics are requested, others become automatically +# available (e.g., sum or mean).""" +# +# default_statistics = [mean,standard_deviation,min,max] +# +# __init__(self,n_quantities_observed, statistics=default_statistics): +# self.n_quantities_observed=n_quantities_observed +# +# clear(self): +# raise NotImplementedError +# +# update(self,observations): +# """The observations is a numpy vector of length n_quantities_observed. Some +# entries can be 'missing' (with a NaN entry) and will not be counted in the +# statistics.""" +# raise NotImplementedError +# +# __getattr__(self, statistic) +# """Return a particular statistic, which may be inferred from the collected statistics. +# The argument is a string naming that statistic.""" + + + + + +
--- a/statscollector.py Mon Aug 04 16:21:59 2008 -0400 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,127 +0,0 @@ - -# Here is how I see stats collectors: - -def my_stats(graph): - graph.mse=examplewise_mean(square_norm(graph.residue)) - graph.training_loss=graph.regularizer+examplewise_sum(graph.nll) - return [graph.mse,graph.training_loss] - - -# def my_stats(residue,nll,regularizer): -# mse=examplewise_mean(square_norm(residue)) -# training_loss=regularizer+examplewise_sum(nll) -# set_names(locals()) -# return ((residue,nll),(regularizer),(),(mse,training_loss)) -# my_stats_collector = make_stats_collector(my_stats) -# -# where make_stats_collector calls my_stats(examplewise_fields, attributes) to -# construct its update function, and figure out what are the input fields (here "residue" -# and "nll") and input attributes (here "regularizer") it needs, and the output -# attributes that it computes (here "mse" and "training_loss"). Remember that -# fields are examplewise quantities, but attributes are not, in my jargon. -# In the above example, I am highlighting that some operations done in my_stats -# are examplewise and some are not. I am hoping that theano Ops can do these -# kinds of internal side-effect operations (and proper initialization of these hidden -# variables). I expect that a StatsCollector (returned by make_stats_collector) -# knows the following methods: -# stats_collector.input_fieldnames -# stats_collector.input_attribute_names -# stats_collector.output_attribute_names -# stats_collector.update(mini_dataset) -# stats_collector['mse'] -# where mini_dataset has the input_fieldnames() as fields and the input_attribute_names() -# as attributes, and in the resulting dataset the output_attribute_names() are set to the -# proper numeric values. - - - -import theano -from theano import tensor as t -from Learner import Learner -from lookup_list import LookupList - -class StatsCollectorModel(AttributesHolder): - def __init__(self,stats_collector): - self.stats_collector = stats_collector - self.outputs = LookupList(stats_collector.output_names,[None for name in stats_collector.output_names]) - # the statistics get initialized here - self.update_function = theano.function(input_attributes+input_fields,output_attributes+output_fields,linker="c|py") - for name,value in self.outputs.items(): - self.__setattribute__(name,value) - def update(self,dataset): - input_fields = dataset.fields()(self.stats_collector.input_field_names) - input_attributes = dataset.getAttributes(self.stats_collector.input_attribute_names) - self.outputs._values = self.update_function(input_attributes+input_fields) - for name,value in self.outputs.items(): - self.__setattribute__(name,value) - def __call__(self): - return self.outputs - def attributeNames(self): - return self.outputs.keys() - -class StatsCollector(AttributesHolder): - - def __init__(self,input_attributes, input_fields, outputs): - self.input_attributes = input_attributes - self.input_fields = input_fields - self.outputs = outputs - self.input_attribute_names = [v.name for v in input_attributes] - self.input_field_names = [v.name for v in input_fields] - self.output_names = [v.name for v in output_attributes] - - def __call__(self,dataset=None): - model = StatsCollectorModel(self) - if dataset: - self.update(dataset) - return model - -if __name__ == '__main__': - def my_statscollector(): - regularizer = t.scalar() - nll = t.matrix() - class_error = t.matrix() - total_loss = regularizer+t.examplewise_sum(nll) - avg_nll = t.examplewise_mean(nll) - avg_class_error = t.examplewise_mean(class_error) - for name,val in locals().items(): val.name = name - return StatsCollector([regularizer],[nll,class_error],[total_loss,avg_nll,avg_class_error]) - - - - -# OLD DESIGN: -# -# class StatsCollector(object): -# """A StatsCollector object is used to record performance statistics during training -# or testing of a learner. It can be configured to measure different things and -# accumulate the appropriate statistics. From these statistics it can be interrogated -# to obtain performance measures of interest (such as maxima, minima, mean, standard -# deviation, standard error, etc.). Optionally, the observations can be weighted -# (yielded weighted mean, weighted variance, etc., where applicable). The statistics -# that are desired can be specified among a list supported by the StatsCollector -# class or subclass. When some statistics are requested, others become automatically -# available (e.g., sum or mean).""" -# -# default_statistics = [mean,standard_deviation,min,max] -# -# __init__(self,n_quantities_observed, statistics=default_statistics): -# self.n_quantities_observed=n_quantities_observed -# -# clear(self): -# raise NotImplementedError -# -# update(self,observations): -# """The observations is a numpy vector of length n_quantities_observed. Some -# entries can be 'missing' (with a NaN entry) and will not be counted in the -# statistics.""" -# raise NotImplementedError -# -# __getattr__(self, statistic) -# """Return a particular statistic, which may be inferred from the collected statistics. -# The argument is a string naming that statistic.""" - - - - - -