# HG changeset patch # User Frederic Bastien # Date 1210616116 14400 # Node ID ad144fa72bf555dc3aa5942077dc77900e4ab00b # Parent 3f4e5c9bdc5e4b116b6c3e2399e81a8d79b12ad9# Parent f5f235bebee44b4b08a3428a076d95fd6c8dd866 Automated merge with ssh://p-omega1@lgcm.iro.umontreal.ca/tlearn diff -r f5f235bebee4 -r ad144fa72bf5 dataset.py --- a/dataset.py Mon May 12 14:13:39 2008 -0400 +++ b/dataset.py Mon May 12 14:15:16 2008 -0400 @@ -16,6 +16,11 @@ raise AbstractFunction() def setAttributes(self,attribute_names,attribute_values,make_copies=False): + """ + Allow the attribute_values to not be a list (but a single value) if the attribute_names is of length 1. + """ + if len(attribute_names)==1 and not (isinstance(attribute_values,list) or isinstance(attribute_values,tuple) ): + attribute_values = [attribute_values] if make_copies: for name,value in zip(attribute_names,attribute_values): self.__setattr__(name,copy.deepcopy(value)) @@ -69,7 +74,7 @@ but when the dataset is a stream (unbounded length), it is not recommanded to do such things because the underlying dataset may refuse to access the different fields in an unsynchronized ways. Hence the fields() method is illegal for streams, by default. - The result of fields() is a DataSetFields object, which iterates over fields, + The result of fields() is a L{DataSetFields} object, which iterates over fields, and whose elements are iterable over examples. A DataSetFields object can be turned back into a DataSet with its examples() method:: dataset2 = dataset1.fields().examples() @@ -981,12 +986,12 @@ key[i]=self.fields_columns[key[i]] return MinibatchDataSet(Example(fieldnames, #we must separate differently for list as numpy - # don't support self.data[[i1,...],[i2,...]] + # doesn't support self.data[[i1,...],[i2,...]] # when their is more then two i1 and i2 [self.data[key,:][:,self.fields_columns[f]] if isinstance(self.fields_columns[f],list) else - self.data[key,self.fields_columns[f]] - for f in fieldnames]), + self.data[key,self.fields_columns[f]] for f in fieldnames]), + self.valuesVStack,self.valuesHStack) # else check for a fieldname @@ -1116,14 +1121,14 @@ self.function=function self.output_names=output_names self.minibatch_mode=minibatch_mode - DataSet.__init__(description,fieldtypes) + DataSet.__init__(self,description,fieldtypes) self.valuesHStack = values_hstack if values_hstack else input_dataset.valuesHStack self.valuesVStack = values_vstack if values_vstack else input_dataset.valuesVStack def __len__(self): return len(self.input_dataset) - def fieldnames(self): + def fieldNames(self): return self.output_names def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): @@ -1131,8 +1136,8 @@ def __init__(self,output_dataset): self.input_dataset=output_dataset.input_dataset self.output_dataset=output_dataset - self.input_iterator=input_dataset.minibatches(minibatch_size=minibatch_size, - n_batches=n_batches,offset=offset).__iter__() + self.input_iterator=self.input_dataset.minibatches(minibatch_size=minibatch_size, + n_batches=n_batches,offset=offset).__iter__() def __iter__(self): return self @@ -1140,7 +1145,7 @@ function_inputs = self.input_iterator.next() all_output_names = self.output_dataset.output_names if self.output_dataset.minibatch_mode: - function_outputs = self.output_dataset.function(function_inputs) + function_outputs = self.output_dataset.function(*function_inputs) else: input_examples = zip(*function_inputs) output_examples = [self.output_dataset.function(input_example) @@ -1153,7 +1158,7 @@ return all_outputs return Example(fieldnames,[all_outputs[name] for name in fieldnames]) - return ApplyFunctionIterator(self.input_dataset,self) + return ApplyFunctionIterator(self) def __iter__(self): # only implemented for increased efficiency class ApplyFunctionSingleExampleIterator(object): diff -r f5f235bebee4 -r ad144fa72bf5 learner.py --- a/learner.py Mon May 12 14:13:39 2008 -0400 +++ b/learner.py Mon May 12 14:15:16 2008 -0400 @@ -1,13 +1,16 @@ -from dataset import AttributesHolder -import compile +from dataset import AttributesHolder,AbstractFunction,ApplyFunctionDataSet,DataSet,CachedDataSet +import theano +from theano import compile +from theano import tensor as t class Learner(AttributesHolder): - """Base class for learning algorithms, provides an interface + """ + Base class for learning algorithms, provides an interface that allows various algorithms to be applicable to generic learning algorithms. - A Learner can be seen as a learning algorithm, a function that when + A L{Learner} can be seen as a learning algorithm, a function that when applied to training data returns a learned function, an object that can be applied to other data and return some output data. """ @@ -32,7 +35,7 @@ The result is a function that can be applied on data, with the same semantics of the Learner.use method. - The user may optionally provide a training StatsCollector that is used to record + The user may optionally provide a training L{StatsCollector} that is used to record some statistics of the outputs computed during training. It is update(d) during training. """ @@ -45,19 +48,76 @@ and return the learned function. """ self.forget() - return self.update(learning_task,train_stats_collector) + return self.update(training_set,train_stats_collector) + + def use(self,input_dataset,output_fieldnames=None, + test_stats_collector=None,copy_inputs=True, + put_stats_in_output_dataset=True, + output_attributes=[]): + """ + Once a L{Learner} has been trained by one or more call to 'update', it can + be used with one or more calls to 'use'. The argument is an input L{DataSet} (possibly + containing a single example) and the result is an output L{DataSet} of the same length. + If output_fieldnames is specified, it may be use to indicate which fields should + be constructed in the output L{DataSet} (for example ['output','classification_error']). + Otherwise, self.defaultOutputFields is called to choose the output fields. + Optionally, if copy_inputs, the input fields (of the input_dataset) can be made + visible in the output L{DataSet} returned by this method. + Optionally, attributes of the learner can be copied in the output dataset, + and statistics computed by the stats collector also put in the output dataset. + Note the distinction between fields (which are example-wise quantities, e.g. 'input') + and attributes (which are not, e.g. 'regularization_term'). + + We provide here a default implementation that does all this using + a sub-class defined method: minibatchwiseUseFunction. + + @todo check if some of the learner attributes are actually SPECIFIED + as attributes of the input_dataset, and if so use their values instead + of the ones in the learner. + + The learner tries to compute in the output dataset the output fields specified. + If None is specified then self.defaultOutputFields(input_dataset.fieldNames()) + is called to determine the output fields. - def use(self,input_dataset,output_fields=None,copy_inputs=True): - """Once a Learner has been trained by one or more call to 'update', it can - be used with one or more calls to 'use'. The argument is a DataSet (possibly - containing a single example) and the result is a DataSet of the same length. - If output_fields is specified, it may be use to indicate which fields should - be constructed in the output DataSet (for example ['output','classification_error']). - Optionally, if copy_inputs, the input fields (of the input_dataset) can be made - visible in the output DataSet returned by this method. + Attributes of the learner can also optionally be copied into the output dataset. + If output_attributes is None then all of the attributes in self.AttributeNames() + are copied in the output dataset, but if it is [] (the default), then none are copied. + If a test_stats_collector is provided, then its attributes (test_stats_collector.AttributeNames()) + are also copied into the output dataset attributes. """ - raise NotImplementedError + minibatchwise_use_function = self.minibatchwiseUseFunction(input_dataset.fieldNames(), + output_fieldnames, + test_stats_collector) + virtual_output_dataset = ApplyFunctionDataSet(input_dataset, + minibatchwise_use_function, + True,DataSet.numpy_vstack, + DataSet.numpy_hstack) + # actually force the computation + output_dataset = CachedDataSet(virtual_output_dataset,True) + if copy_inputs: + output_dataset = input_dataset | output_dataset + # copy the wanted attributes in the dataset + if output_attributes is None: + output_attributes = self.attributeNames() + if output_attributes: + assert set(attribute_names) <= set(self.attributeNames()) + output_dataset.setAttributes(output_attributes, + self.names2attributes(output_attributes,return_copy=True)) + if test_stats_collector: + test_stats_collector.update(output_dataset) + if put_stats_in_output_dataset: + output_dataset.setAttributes(test_stats_collector.attributeNames(), + test_stats_collector.attributes()) + return output_dataset + def minibatchwiseUseFunction(self, input_fields, output_fields, stats_collector): + """ + Returns a function that can map the given input fields to the given output fields + and to the attributes that the stats collector needs for its computation. + That function is expected to operate on minibatches. + The function returned makes use of the self.useInputAttributes() and + sets the attributes specified by self.useOutputAttributes(). + """ def attributeNames(self): """ A Learner may have attributes that it wishes to export to other objects. To automate @@ -67,12 +127,22 @@ """ return [] - def updateInputAttributes(self): + def attributes(self,return_copy=False): + """ + Return a list with the values of the learner's attributes (or optionally, a deep copy). + """ + return self.names2attributes(self.attributeNames(),return_copy) + + def names2attributes(self,names): """ - A subset of self.attributeNames() which are the names of attributes needed by update() in order - to do its work. + Private helper function that maps a list of attribute names to a list + of (optionally copies) values of attributes. """ - raise AbstractFunction() + res=[] + for name in names: + assert name in names + res.append(self.__getattribute__(name)) + return res def useInputAttributes(self): """ @@ -81,15 +151,6 @@ """ raise AbstractFunction() - def updateOutputAttributes(self): - """ - A subset of self.attributeNames() which are the names of attributes modified/created by update() in order - to do its work. - - By default these are inferred from the various update output attributes: - """ - return ["parameters"] + self.updateMinibatchOutputAttributes() + self.updateEndOutputAttributes() - def useOutputAttributes(self): """ A subset of self.attributeNames() which are the names of attributes modified/created by use() in order @@ -135,6 +196,7 @@ def __init__(self): Learner.__init__(self) + self.use_functions_dictionary={} def defaultOutputFields(self, input_fields): """ @@ -145,22 +207,10 @@ """ raise AbstractFunction() - def allocate(self, minibatch): - """ - This function is called at the beginning of each updateMinibatch - and should be used to check that all required attributes have been - allocated and initialized (usually this function calls forget() - when it has to do an initialization). + def minibatchwiseUseFunction(self, input_fields, output_fields, stats_collector): """ - raise AbstractFunction() - - def minibatchwise_use_functions(self, input_fields, output_fields, stats_collector): - """ - Private helper function called by the generic TLearner.use. It returns a function - that can map the given input fields to the given output fields (along with the - attributes that the stats collector needs for its computation. The function - called also automatically makes use of the self.useInputAttributes() and - sets the self.useOutputAttributes(). + Implement minibatchwiseUseFunction by exploiting Theano compilation + and the expression graph defined by a sub-class constructor. """ if not output_fields: output_fields = self.defaultOutputFields(input_fields) @@ -169,7 +219,7 @@ for attribute in stats_collector_inputs: if attribute not in input_fields: output_fields.append(attribute) - key = (input_fields,output_fields) + key = (tuple(input_fields),tuple(output_fields)) if key not in self.use_functions_dictionary: use_input_attributes = self.useInputAttributes() use_output_attributes = self.useOutputAttributes() @@ -177,7 +227,7 @@ self.names2OpResults(output_fields+use_output_attributes)) def f(*input_field_values): input_attribute_values = self.names2attributes(use_input_attributes) - results = complete_f(*(input_field_values + input_attribute_values)) + results = complete_f(*(list(input_field_values) + input_attribute_values)) output_field_values = results[0:len(output_fields)] output_attribute_values = results[len(output_fields):len(results)] if use_output_attributes: @@ -186,77 +236,17 @@ self.use_functions_dictionary[key]=f return self.use_functions_dictionary[key] - def attributes(self,return_copy=False): - """ - Return a list with the values of the learner's attributes (or optionally, a deep copy). - """ - return self.names2attributes(self.attributeNames(),return_copy) - - def names2attributes(self,names,return_copy=False): - """ - Private helper function that maps a list of attribute names to a list - of (optionally copies) values of attributes. - """ - if return_copy: - return [copy.deepcopy(self.__getattr__(name).data) for name in names] - else: - return [self.__getattr__(name).data for name in names] - def names2OpResults(self,names): """ Private helper function that maps a list of attribute names to a list of corresponding Op Results (with the same name but with a '_' prefix). """ - return [self.__getattr__('_'+name).data for name in names] - - def use(self,input_dataset,output_fieldnames=None,output_attributes=[], - test_stats_collector=None,copy_inputs=True, put_stats_in_output_dataset=True): - """ - The learner tries to compute in the output dataset the output fields specified - - @todo check if some of the learner attributes are actually SPECIFIED - as attributes of the input_dataset, and if so use their values instead - of the ones in the learner. - - The learner tries to compute in the output dataset the output fields specified. - If None is specified then self.defaultOutputFields(input_dataset.fieldNames()) - is called to determine the output fields. - - Attributes of the learner can also optionally be copied into the output dataset. - If output_attributes is None then all of the attributes in self.AttributeNames() - are copied in the output dataset, but if it is [] (the default), then none are copied. - If a test_stats_collector is provided, then its attributes (test_stats_collector.AttributeNames()) - are also copied into the output dataset attributes. - """ - minibatchwise_use_function = minibatchwise_use_functions(input_dataset.fieldNames(), - output_fieldnames, - test_stats_collector) - virtual_output_dataset = ApplyFunctionDataSet(input_dataset, - minibatchwise_use_function, - True,DataSet.numpy_vstack, - DataSet.numpy_hstack) - # actually force the computation - output_dataset = CachedDataSet(virtual_output_dataset,True) - if copy_inputs: - output_dataset = input_dataset | output_dataset - # copy the wanted attributes in the dataset - if output_attributes is None: - output_attributes = self.attributeNames() - if output_attributes: - assert set(attribute_names) <= set(self.attributeNames()) - output_dataset.setAttributes(output_attributes, - self.names2attributes(output_attributes,return_copy=True)) - if test_stats_collector: - test_stats_collector.update(output_dataset) - if put_stats_in_output_dataset: - output_dataset.setAttributes(test_stats_collector.attributeNames(), - test_stats_collector.attributes()) - return output_dataset + return [self.__getattribute__('_'+name) for name in names] class MinibatchUpdatesTLearner(TLearner): """ - This adds to TLearner a + This adds to L{TLearner} a - updateStart(), updateEnd(), updateMinibatch(minibatch), isLastEpoch(): functions executed at the beginning, the end, in the middle (for each minibatch) of the update method, and at the end @@ -273,14 +263,21 @@ def __init__(self): TLearner.__init__(self) - self.update_minibatch_function = compile.function - (self.names2OpResults(self.updateMinibatchOutputAttributes()+ - self.updateMinibatchInputFields()), - self.names2OpResults(self.updateMinibatchOutputAttributes())) - self.update_end_function = compile.function - (self.names2OpResults(self.updateEndInputAttributes()), - self.names2OpResults(self.updateEndOutputAttributes())) + self.update_minibatch_function = compile.function(self.names2OpResults(self.updateMinibatchOutputAttributes()+ + self.updateMinibatchInputFields()), + self.names2OpResults(self.updateMinibatchOutputAttributes())) + self.update_end_function = compile.function(self.names2OpResults(self.updateEndInputAttributes()), + self.names2OpResults(self.updateEndOutputAttributes())) + def allocate(self, minibatch): + """ + This function is called at the beginning of each L{updateMinibatch} + and should be used to check that all required attributes have been + allocated and initialized (usually this function calls forget() + when it has to do an initialization). + """ + raise AbstractFunction() + def updateMinibatchInputFields(self): raise AbstractFunction() @@ -299,22 +296,22 @@ def parameterAttributes(self): raise AbstractFunction() - def updateStart(self): pass + def updateStart(self,training_set): + pass def updateEnd(self): self.setAttributes(self.updateEndOutputAttributes(), - self.update_end_function - (self.names2attributes(self.updateEndInputAttributes()))) + self.update_end_function(*self.names2attributes(self.updateEndInputAttributes()))) self.parameters = self.names2attributes(self.parameterAttributes()) def updateMinibatch(self,minibatch): # make sure all required fields are allocated and initialized self.allocate(minibatch) + input_attributes = self.names2attributes(self.updateMinibatchInputAttributes()) + input_fields = minibatch(*self.updateMinibatchInputFields()) self.setAttributes(self.updateMinibatchOutputAttributes(), # concatenate the attribute values and field values and then apply update fn - self.update_minibatch_function(*(self.names2attributes - (self.updateMinibatchInputAttributes())) - + minibatch(self.updateMinibatchInputFields()))) + self.update_minibatch_function(*(input_attributes+input_fields))) def isLastEpoch(self): """ @@ -331,12 +328,15 @@ """ self.updateStart(training_set) stop=False + if hasattr(self,'_minibatch_size') and self._minibatch_size: + minibatch_size=self._minibatch_size + else: + minibatch_size=min(100,len(training_set)) while not stop: if train_stats_collector: train_stats_collector.forget() # restart stats collectin at the beginning of each epoch - for minibatch in training_set.minibatches(self.training_set_input_fields, - minibatch_size=self.minibatch_size): - self.update_minibatch(minibatch) + for minibatch in training_set.minibatches(minibatch_size=minibatch_size): + self.updateMinibatch(minibatch) if train_stats_collector: minibatch_set = minibatch.examples() minibatch_set.setAttributes(self.attributeNames(),self.attributes()) @@ -345,17 +345,14 @@ self.updateEnd() return self.use -class OnlineGradientBasedTLearner(MinibatchUpdatesTLearner): +class OnlineGradientTLearner(MinibatchUpdatesTLearner): """ - Specialization of MinibatchUpdatesTLearner in which the minibatch updates + Specialization of L{MinibatchUpdatesTLearner} in which the minibatch updates are obtained by performing an online (minibatch-based) gradient step. Sub-classes must define the following: - - self._learning_rate (may be changed by the sub-class between epochs or minibatches) - - self.lossAttribute() = name of the loss field - + - self._learning_rate (may be changed by the sub-class between epochs or minibatches) + - self.lossAttribute() = name of the loss field """ def __init__(self,truly_online=False): """ @@ -366,14 +363,32 @@ self.truly_online=truly_online # create the formulas for the gradient update - old_params = [self.__getattr__("_"+name) for name in self.parameterAttributes()] + old_params = [self.__getattribute__("_"+name) for name in self.parameterAttributes()] new_params_names = ["_new_"+name for name in self.parameterAttributes()] - loss = self.__getattr__(self.lossAttribute()) + loss = self.__getattribute__("_"+self.lossAttribute()) self.setAttributes(new_params_names, - [t.add_inplace(self.param, - self._learning_rate*t.grad(loss,param)) + [t.add_inplace(param,self._learning_rate*t.grad(loss,param)) for param in old_params]) + MinibatchUpdatesTLearner.__init__(self) + + def namesOfAttributesToComputeOutputs(self,output_names): + """ + The output_names are attribute names (not the corresponding Result names, which have leading _). + Return the corresponding input names + """ + all_inputs = t.gof.graph.inputs(self.names2OpResults(output_names)) + # remove constants and leading '_' in name + + return [r.name for r in all_inputs if isinstance(r,theano.Result) and \ + not isinstance(r,theano.Constant) and not isinstance(r,theano.Value)] + #inputs = [] + #for r in all_inputs: + # if isinstance(r,theano.Result) and \ + # not isinstance(r,theano.Constant) and not isinstance(r,theano.Value): + # inputs.append(r.name) + #return inputs + def isLastEpoch(self): return self.truly_online @@ -381,9 +396,14 @@ return self.parameterAttributes() def updateMinibatchOutputAttributes(self): - return ["_new"+name for name in self.parameterAttributes()] + return ["new_"+name for name in self.parameterAttributes()] def updateEndInputAttributes(self): + return self.namesOfAttributesToComputeOutputs(self.updateEndOutputAttributes()) + + def useInputAttributes(self): return self.parameterAttributes() + def useOutputAttributes(self): + return [] diff -r f5f235bebee4 -r ad144fa72bf5 linear_regression.py --- a/linear_regression.py Mon May 12 14:13:39 2008 -0400 +++ b/linear_regression.py Mon May 12 14:15:16 2008 -0400 @@ -1,10 +1,13 @@ +""" +Implementation of linear regression, with or without L2 regularization. +This is one of the simplest example of L{learner}, and illustrates +the use of theano. +""" from learner import * from theano import tensor as t from theano.scalar import as_scalar -# this is one of the simplest example of learner, and illustrates -# the use of theano class LinearRegression(MinibatchUpdatesTLearner): """ Implement linear regression, with or without L2 regularization diff -r f5f235bebee4 -r ad144fa72bf5 lookup_list.py --- a/lookup_list.py Mon May 12 14:13:39 2008 -0400 +++ b/lookup_list.py Mon May 12 14:15:16 2008 -0400 @@ -49,7 +49,7 @@ The key in example[key] can either be an integer to index the fields or the name of the field. """ - if isinstance(key,int) or isinstance(key,slice) or isinstance(key,list): + if isinstance(key,int) or isinstance(key,slice) or (isinstance(key,list) and all([isinstance(i,int) for i in key])): return self._values[key] else: # if not an int, key must be a name # expecting key to be a valid field name @@ -101,10 +101,10 @@ def __ne__(self, other): return not self.__eq__(other) - def __hash__(): + def __hash__(self): raise NotImplementedError() - def __call__(*names): + def __call__(self,*names): """ Return a list of values associated with the given names (which must all be keys of the lookup list). """ diff -r f5f235bebee4 -r ad144fa72bf5 mlp.py --- a/mlp.py Mon May 12 14:13:39 2008 -0400 +++ b/mlp.py Mon May 12 14:15:16 2008 -0400 @@ -1,13 +1,17 @@ +""" +A straightforward classicial feedforward +one-hidden-layer neural net, with L2 regularization. +This is one of the simplest example of L{Learner}, and illustrates +the use of theano. +""" from learner import * from theano import tensor as t from nnet_ops import * - -# this is one of the simplest example of learner, and illustrates -# the use of theano +import math -class OneHiddenLayerNNetClassifier(MinibatchUpdatesTLearner): +class OneHiddenLayerNNetClassifier(OnlineGradientTLearner): """ Implement a straightforward classicial feedforward one-hidden-layer neural net, with L2 regularization. @@ -64,26 +68,31 @@ """ - def __init__(self,n_hidden,n_classes,learning_rate,init_range=1.): + def __init__(self,n_hidden,n_classes,learning_rate,max_n_epochs,L2_regularizer=0,init_range=1.,n_inputs=None,minibatch_size=None): + self._n_inputs = n_inputs self._n_outputs = n_classes self._n_hidden = n_hidden self._init_range = init_range + self._max_n_epochs = max_n_epochs + self._minibatch_size = minibatch_size self.learning_rate = learning_rate # this is the float + self.L2_regularizer = L2_regularizer self._learning_rate = t.scalar('learning_rate') # this is the symbol self._input = t.matrix('input') # n_examples x n_inputs - self._target = t.matrix('target','int32') # n_examples x n_outputs + self._target = t.imatrix('target') # n_examples x 1 + self._target_vector = self._target[:,0] self._L2_regularizer = t.scalar('L2_regularizer') self._W1 = t.matrix('W1') self._W2 = t.matrix('W2') self._b1 = t.row('b1') self._b2 = t.row('b2') - self._regularization_term = self._L2_regularizer * (t.dot(self._W1,self._W1) + t.dot(self._W2,self._W2)) + self._regularization_term = self._L2_regularizer * (t.sum(self._W1*self._W1) + t.sum(self._W2*self._W2)) self._output_activations =self._b2+t.dot(t.tanh(self._b1+t.dot(self._input,self._W1.T)),self._W2.T) - self._nll,self._output = crossentropy_softmax_1hot(self._output_activations,self._target) - self._output_class = t.argmax(self._output,1) - self._class_error = self._output_class != self._target + self._nll,self._output = crossentropy_softmax_1hot(self._output_activations,self._target_vector) + self._output_class, self._max_output = t.argmax(self._output,1) + self._class_error = t.neq(self._output_class,self._target_vector) self._minibatch_criterion = self._nll + self._regularization_term / t.shape(self._input)[0] - MinibatchUpdatesTLearner.__init__(self) + OnlineGradientTLearner.__init__(self) def attributeNames(self): return ["parameters","b1","W2","b2","W2", "L2_regularizer","regularization_term"] @@ -91,15 +100,6 @@ def parameterAttributes(self): return ["b1","W1", "b2", "W2"] - def useInputAttributes(self): - return self.parameterAttributes() - - def useOutputAttributes(self): - return [] - - def updateInputAttributes(self): - return self.parameterAttributes() + ["L2_regularizer"] - def updateMinibatchInputFields(self): return ["input","target"] @@ -119,8 +119,8 @@ minibatch_n_inputs = minibatch["input"].shape[1] if not self._n_inputs: self._n_inputs = minibatch_n_inputs - self.b1 = numpy.zeros(self._n_hidden) - self.b2 = numpy.zeros(self._n_outputs) + self.b1 = numpy.zeros((1,self._n_hidden)) + self.b2 = numpy.zeros((1,self._n_outputs)) self.forget() elif self._n_inputs!=minibatch_n_inputs: # if the input changes dimension on the fly, we resize and forget everything @@ -136,7 +136,11 @@ size=(self._n_outputs,self._n_hidden)) self.b1[:]=0 self.b2[:]=0 + self._n_epochs=0 + def isLastEpoch(self): + self._n_epochs +=1 + return self._n_epochs>=self._max_n_epochs class MLP(MinibatchUpdatesTLearner): """ diff -r f5f235bebee4 -r ad144fa72bf5 test_mlp.py --- a/test_mlp.py Mon May 12 14:13:39 2008 -0400 +++ b/test_mlp.py Mon May 12 14:15:16 2008 -0400 @@ -1,9 +1,17 @@ from mlp import * +import dataset def test0(): - nnet = OneHiddenLayerNNetClassifier(10,3,.1) + nnet = OneHiddenLayerNNetClassifier(10,3,.1,1000) + training_set = dataset.ArrayDataSet(numpy.array([[0, 0, 0], + [0, 1, 1], + [1, 0, 1], + [1, 1, 1]]), + {'input':slice(2),'target':2}) + fprop=nnet(training_set) + print fprop(training_set) test0()