# HG changeset patch # User Frederic Bastien # Date 1209999723 14400 # Node ID 158653a9bc7c949bcecbfe90a043cb62e10c51f3 # Parent 3499918faa9db49c131674cd46473d472ff0eb99# Parent 4b0859606d05037463a863e24ab05e2a46ae9750 Automated merge with ssh://p-omega1@lgcm.iro.umontreal.ca/tlearn diff -r 4b0859606d05 -r 158653a9bc7c _nnet_ops.py --- a/_nnet_ops.py Mon May 05 10:57:33 2008 -0400 +++ b/_nnet_ops.py Mon May 05 11:02:03 2008 -0400 @@ -11,6 +11,11 @@ def test_elemwise(self): TT.verify_grad(self, Sigmoid, [numpy.random.rand(3,4)]) +class T_softplus(unittest.TestCase): + def setUp(self): + numpy.random.seed(9999) + def test_elemwise(self): + TT.verify_grad(self, Softplus, [numpy.random.rand(3,4)]) class T_CrossentropySoftmax1Hot(unittest.TestCase): def setUp(self): @@ -18,10 +23,16 @@ def test0(self): y_idx = [0,1,3] def output1(a,b): - return crossentropy_softmax_1hot(a, b, y_idx)[0:1] + return crossentropy_softmax_1hot_with_bias(a, b, y_idx)[0:1] TT.verify_grad(self, output1, [numpy.random.rand(3,4), numpy.random.rand(4)]) + def test1(self): + y_idx = [0,1,3] + def output1(a): + return crossentropy_softmax_1hot(a, y_idx)[0:1] + TT.verify_grad(self, output1, [numpy.random.rand(3,4)]) + if __name__ == '__main__': diff -r 4b0859606d05 -r 158653a9bc7c dataset.py --- a/dataset.py Mon May 05 10:57:33 2008 -0400 +++ b/dataset.py Mon May 05 11:02:03 2008 -0400 @@ -27,29 +27,29 @@ feasible or not recommanded on streams). To iterate over examples, there are several possibilities: - * for example in dataset([field1, field2,field3, ...]): - * for val1,val2,val3 in dataset([field1, field2,field3]): - * for minibatch in dataset.minibatches([field1, field2, ...],minibatch_size=N): - * for mini1,mini2,mini3 in dataset.minibatches([field1, field2, field3], minibatch_size=N): - * for example in dataset: + - for example in dataset([field1, field2,field3, ...]): + - for val1,val2,val3 in dataset([field1, field2,field3]): + - for minibatch in dataset.minibatches([field1, field2, ...],minibatch_size=N): + - for mini1,mini2,mini3 in dataset.minibatches([field1, field2, field3], minibatch_size=N): + - for example in dataset:: print example['x'] - * for x,y,z in dataset: - Each of these is documented below. All of these iterators are expected - to provide, in addition to the usual 'next()' method, a 'next_index()' method - which returns a non-negative integer pointing to the position of the next - example that will be returned by 'next()' (or of the first example in the - next minibatch returned). This is important because these iterators - can wrap around the dataset in order to do multiple passes through it, - in possibly unregular ways if the minibatch size is not a divisor of the - dataset length. + - for x,y,z in dataset: + Each of these is documented below. All of these iterators are expected + to provide, in addition to the usual 'next()' method, a 'next_index()' method + which returns a non-negative integer pointing to the position of the next + example that will be returned by 'next()' (or of the first example in the + next minibatch returned). This is important because these iterators + can wrap around the dataset in order to do multiple passes through it, + in possibly unregular ways if the minibatch size is not a divisor of the + dataset length. To iterate over fields, one can do - * for field in dataset.fields(): + - for field in dataset.fields(): for field_value in field: # iterate over the values associated to that field for all the dataset examples - * for field in dataset(field1,field2,...).fields() to select a subset of fields - * for field in dataset.fields(field1,field2,...) to select a subset of fields + - for field in dataset(field1,field2,...).fields() to select a subset of fields + - for field in dataset.fields(field1,field2,...) to select a subset of fields and each of these fields is iterable over the examples: - * for field_examples in dataset.fields(): + - for field_examples in dataset.fields(): for example_value in field_examples: ... but when the dataset is a stream (unbounded length), it is not recommanded to do @@ -57,7 +57,7 @@ an unsynchronized ways. Hence the fields() method is illegal for streams, by default. The result of fields() is a DataSetFields object, which iterates over fields, and whose elements are iterable over examples. A DataSetFields object can - be turned back into a DataSet with its examples() method: + be turned back into a DataSet with its examples() method:: dataset2 = dataset1.fields().examples() and dataset2 should behave exactly like dataset1 (in fact by default dataset2==dataset1). @@ -72,34 +72,37 @@ of examples) can be extracted. These operations are not supported by default in the case of streams. - * dataset[:n] returns a dataset with the n first examples. + - dataset[:n] returns a dataset with the n first examples. - * dataset[i1:i2:s] returns a dataset with the examples i1,i1+s,...i2-s. + - dataset[i1:i2:s] returns a dataset with the examples i1,i1+s,...i2-s. - * dataset[i] returns an Example. + - dataset[i] returns an Example. - * dataset[[i1,i2,...in]] returns a dataset with examples i1,i2,...in. + - dataset[[i1,i2,...in]] returns a dataset with examples i1,i2,...in. - * dataset[fieldname] an iterable over the values of the field fieldname across - the dataset (the iterable is obtained by default by calling valuesVStack - over the values for individual examples). + - dataset[fieldname] an iterable over the values of the field fieldname across + the dataset (the iterable is obtained by default by calling valuesVStack + over the values for individual examples). - * dataset. returns the value of a property associated with - the name . The following properties should be supported: + - dataset. returns the value of a property associated with + the name . The following properties should be supported: - 'description': a textual description or name for the dataset - 'fieldtypes': a list of types (one per field) + A DataSet may have other attributes that it makes visible to other objects. These are + used to store information that is not example-wise but global to the dataset. + The list of names of these attributes is given by the attribute_names() method. Datasets can be concatenated either vertically (increasing the length) or horizontally (augmenting the set of fields), if they are compatible, using the following operations (with the same basic semantics as numpy.hstack and numpy.vstack): - * dataset1 | dataset2 | dataset3 == dataset.hstack([dataset1,dataset2,dataset3]) + - dataset1 | dataset2 | dataset3 == dataset.hstack([dataset1,dataset2,dataset3]) creates a new dataset whose list of fields is the concatenation of the list of fields of the argument datasets. This only works if they all have the same length. - * dataset1 & dataset2 & dataset3 == dataset.vstack([dataset1,dataset2,dataset3]) + - dataset1 & dataset2 & dataset3 == dataset.vstack([dataset1,dataset2,dataset3]) creates a new dataset that concatenates the examples from the argument datasets (and whose length is the sum of the length of the argument datasets). This only @@ -114,25 +117,42 @@ or other properties of the dataset or associated with the dataset or the result of a computation stored in a dataset. These can be accessed through the [key] syntax when key is a string (or more specifically, neither an integer, a slice, nor a list). - + A DataSet sub-class should always redefine the following methods: - * __len__ if it is not a stream - * fieldNames - * minibatches_nowrap (called by DataSet.minibatches()) - * valuesHStack - * valuesVStack + - __len__ if it is not a stream + - fieldNames + - minibatches_nowrap (called by DataSet.minibatches()) + - valuesHStack + - valuesVStack For efficiency of implementation, a sub-class might also want to redefine - * hasFields - * __getitem__ may not be feasible with some streams - * __iter__ + - hasFields + - __getitem__ may not be feasible with some streams + - __iter__ + A sub-class should also append attributes to self._attribute_names + (the default value returned by attributeNames()). + By convention, attributes not in attributeNames() should have a name + starting with an underscore. + @todo enforce/test that convention! """ + numpy_vstack = lambda fieldname,values: return numpy.vstack(values) + numpy_hstack = lambda fieldnames,values: return numpy.hstack(values) + def __init__(self,description=None,fieldtypes=None): if description is None: # by default return "(,,...)" description = type(self).__name__ + " ( " + join([x.__name__ for x in type(self).__bases__]) + " )" self.description=description self.fieldtypes=fieldtypes + self._attribute_names = ["description"] + if fieldtypes: + self._attribute_names.append("fieldtypes") + + def attributeNames(self): return self._attribute_names + + def setAttributes(self,attribute_names,attribute_values): + for name,value in zip(attribute_names,attribute_values): + self.__setattr__(name,value) class MinibatchToSingleExampleIterator(object): """ @@ -227,7 +247,9 @@ self.n_batches_done+=1 if upper >= self.L and self.n_batches: self.next_row -= self.L - return minibatch + return DataSetFields(MinibatchDataSet(minibatch,self.dataset.valuesVStack, + self.dataset.valuesHStack), + minibatch.keys()) minibatches_fieldnames = None @@ -346,7 +368,7 @@ """ Return a DataSetFields object associated with this dataset. """ - return DataSetFields(self,*fieldnames) + return DataSetFields(self,fieldnames) def __getitem__(self,i): """ @@ -392,7 +414,8 @@ return MinibatchDataSet( Example(self.fieldNames(),[ self.valuesVStack(fieldname,field_values) for fieldname,field_values - in zip(self.fieldNames(),fields_values)])) + in zip(self.fieldNames(),fields_values)]), + self.valuesVStack,self.valuesHStack) # else check for a fieldname if self.hasFields(i): return self.minibatches(fieldnames=[i],minibatch_size=len(self),n_batches=1,offset=0).next()[0] @@ -545,7 +568,7 @@ the syntax used for DataSets, the | concatenates the fields and the & concatenates the examples. """ - def __init__(self,dataset,*fieldnames): + def __init__(self,dataset,fieldnames): original_dataset=dataset if not fieldnames: fieldnames=dataset.fieldNames() @@ -647,7 +670,7 @@ [field[self.next_example:upper] for field in self.ds._fields]) self.next_example+=minibatch_size - return DataSetFields(MinibatchDataSet(minibatch),*fieldnames) + return minibatch return Iterator(self) @@ -720,11 +743,7 @@ return self def next(self): # concatenate all the fields of the minibatches - minibatch = reduce(LookupList.__add__,[iterator.next() for iterator in self.iterators]) - # and return a DataSetFields whose dataset is the transpose (=examples()) of this minibatch - return DataSetFields(MinibatchDataSet(minibatch,self.hsds.valuesVStack, - self.hsds.valuesHStack), - fieldnames if fieldnames else hsds.fieldNames()) + return reduce(LookupList.__add__,[iterator.next() for iterator in self.iterators]) assert self.hasfields(fieldnames) # find out which underlying datasets are necessary to service the required fields @@ -853,11 +872,10 @@ while self.n_left_in_mb>0: self.move_to_next_dataset() extra_mb.append(self.next_iterator.next()) - examples = Example(names, + mb = Example(fieldnames, [dataset.valuesVStack(name, [mb[name]]+[b[name] for b in extra_mb]) for name in fieldnames]) - mb = DataSetFields(MinibatchDataSet(examples),fieldnames) self.next_row+=minibatch_size self.next_dataset_row+=minibatch_size @@ -936,7 +954,8 @@ if self.hasFields(key[i]): key[i]=self.fields_columns[key[i]] return MinibatchDataSet(Example(fieldnames, - [self.data[key,self.fields_columns[f]] for f in fieldnames])) + [self.data[key,self.fields_columns[f]] for f in fieldnames]), + self.valuesVStack,self.valuesHStack) # else check for a fieldname if self.hasFields(key): @@ -978,20 +997,147 @@ Optionally, for finite-length dataset, all the values can be computed (and cached) upon construction of the CachedDataSet, rather at the first access. + + @todo when cache_all_upon_construction create mini-batches that are as + large as possible but not so large as to fill up memory. + + @todo add disk-buffering capability, so that when the cache becomes too + big for memory, we cache things on disk, trying to keep in memory only + the record most likely to be accessed next. """ + def __init__(self,source_dataset,cache_all_upon_construction=False): + self.source_dataset=source_dataset + self.cache_all_upon_construction=cache_all_upon_construction + if cache_all_upon_construction: + # this potentially brings all the source examples + # into memory at once, which may be too much + # the work could possibly be done by minibatches + # that are as large as possible but no more than what memory allows. + self.cached_examples = zip(*source_dataset.minibatches(minibatch_size=len(source_dataset)).__iter__().next()) + else: + self.cached_examples = [] + self.fieldNames = source_dataset.fieldNames + self.hasFields = source_dataset.hasFields + self.valuesHStack = source_dataset.valuesHStack + self.valuesVStack = source_dataset.valuesVStack + + def __len__(self): + return len(self.source_dataset) + + def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): + class CacheIterator(object): + def __init__(self,dataset): + self.dataset=dataset + self.current=offset + def __iter__(self): return self + def next(self): + upper = self.current+minibatch_size + cache_len = len(self.dataset.cached_examples) + if upper>=cache_len: # whole minibatch is not already in cache + # cache everything from current length to upper + for example in self.dataset.source_dataset[cache_len:upper]: + self.dataset.cached_examples.append(example) + all_fields_minibatch = Example(self.dataset.fieldNames(), + self.dataset.cached_examples[self.current:self.current+minibatch_size]) + if self.dataset.fieldNames()==fieldnames: + return all_fields_minibatch + return Example(fieldnames,[all_fields_minibatch[name] for name in fieldnames]) + return CacheIterator(self) + + class ApplyFunctionDataSet(DataSet): """ A dataset that contains as fields the results of applying a given function example-wise or minibatch-wise to all the fields of an input dataset. The output of the function should be an iterable (e.g. a list or a LookupList) - over the resulting values. In minibatch mode, the function is expected - to work on minibatches (takes a minibatch in input and returns a minibatch - in output). + over the resulting values. + + In minibatch mode, the function is expected to work on minibatches (takes + a minibatch in input and returns a minibatch in output). More precisely, + it means that each element of the input or output list should be iterable + and indexable over the individual example values (typically these + elements will be numpy arrays). All of the elements in the input and + output lists should have the same length, which is the length of the + minibatch. The function is applied each time an example or a minibatch is accessed. To avoid re-doing computation, wrap this dataset inside a CachedDataSet. + + If the values_{h,v}stack functions are not provided, then + the input_dataset.values{H,V}Stack functions are used by default. """ + def __init__(self,input_dataset,function,output_names,minibatch_mode=True, + values_hstack=None,values_vstack=None, + description=None,fieldtypes=None): + """ + Constructor takes an input dataset that has as many fields as the function + expects as inputs. The resulting dataset has as many fields as the function + produces as outputs, and that should correspond to the number of output names + (provided in a list). + + Note that the expected semantics of the function differs in minibatch mode + (it takes minibatches of inputs and produces minibatches of outputs, as + documented in the class comment). + """ + self.input_dataset=input_dataset + self.function=function + self.output_names=output_names + self.minibatch_mode=minibatch_mode + DataSet.__init__(description,fieldtypes) + self.valuesHStack = values_hstack if values_hstack else input_dataset.valuesHStack + self.valuesVStack = values_vstack if values_vstack else input_dataset.valuesVStack + + def __len__(self): + return len(self.input_dataset) + + def fieldnames(self): + return self.output_names + + def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): + class ApplyFunctionIterator(object): + def __init__(self,output_dataset): + self.input_dataset=output_dataset.input_dataset + self.output_dataset=output_dataset + self.input_iterator=input_dataset.minibatches(minibatch_size=minibatch_size, + n_batches=n_batches,offset=offset).__iter__() + + def __iter__(self): return self + + def next(self): + function_inputs = self.input_iterator.next() + all_output_names = self.output_dataset.output_names + if self.output_dataset.minibatch_mode: + function_outputs = self.output_dataset.function(function_inputs) + else: + input_examples = zip(*function_inputs) + output_examples = [self.output_dataset.function(input_example) + for input_example in input_examples] + function_outputs = [self.output_dataset.valuesVStack(name,values) + for name,values in zip(all_output_names, + zip(*output_examples))] + all_outputs = Example(all_output_names,function_outputs) + if fieldnames==all_output_names: + return all_outputs + return Example(fieldnames,[all_outputs[name] for name in fieldnames]) + + return ApplyFunctionIterator(self.input_dataset,self) + + def __iter__(self): # only implemented for increased efficiency + class ApplyFunctionSingleExampleIterator(object): + def __init__(self,output_dataset): + self.current=0 + self.output_dataset=output_dataset + self.input_iterator=output_dataset.input_dataset.__iter__() + def __iter__(self): return self + def next(self): + function_inputs = self.input_iterator.next() + if self.output_dataset.minibatch_mode: + function_outputs = [output[0] for output in self.output_dataset.function(function_inputs)] + else: + function_outputs = self.output_dataset.function(function_inputs) + return Example(self.output_dataset.output_names,function_outputs) + return ApplyFunctionSingleExampleIterator(self) def supervised_learning_dataset(src_dataset,input_fields,target_fields,weight_field=None): diff -r 4b0859606d05 -r 158653a9bc7c filetensor.py --- a/filetensor.py Mon May 05 10:57:33 2008 -0400 +++ b/filetensor.py Mon May 05 11:02:03 2008 -0400 @@ -1,18 +1,18 @@ """ Read and write the matrix file format described at -http://www.cs.nyu.edu/~ylclab/data/norb-v1.0/index.html +U{http://www.cs.nyu.edu/~ylclab/data/norb-v1.0/index.html} The format is for dense tensors: - magic number indicating type and endianness - 4bytes - rank of tensor - int32 - dimensions - int32, int32, int32, ... - + - magic number indicating type and endianness - 4bytes + - rank of tensor - int32 + - dimensions - int32, int32, int32, ... + - The number of dimensions and rank is slightly tricky: - for scalar: rank=0, dimensions = [1, 1, 1] - for vector: rank=1, dimensions = [?, 1, 1] - for matrix: rank=2, dimensions = [?, ?, 1] + - for scalar: rank=0, dimensions = [1, 1, 1] + - for vector: rank=1, dimensions = [?, 1, 1] + - for matrix: rank=2, dimensions = [?, ?, 1] For rank >= 3, the number of dimensions matches the rank exactly. diff -r 4b0859606d05 -r 158653a9bc7c learner.py --- a/learner.py Mon May 05 10:57:33 2008 -0400 +++ b/learner.py Mon May 05 11:02:03 2008 -0400 @@ -30,8 +30,10 @@ on-line setting or the sequential (Bayesian or not) settings. The result is a function that can be applied on data, with the same semantics of the Learner.use method. + The user may optionally provide a training StatsCollector that is used to record - some statistics of the outputs computed during training. + some statistics of the outputs computed during training. It is update(d) during + training. """ return self.use # default behavior is 'non-adaptive', i.e. update does not do anything @@ -51,6 +53,39 @@ If output_fields is specified, it may be use to indicate which fields should be constructed in the output DataSet (for example ['output','classification_error']). Optionally, if copy_inputs, the input fields (of the input_dataset) can be made - visible in the output DataSet returned by this function. + visible in the output DataSet returned by this method. """ raise NotImplementedError + + def attribute_names(self): + """ + A Learner may have attributes that it wishes to export to other objects. To automate + such export, sub-classes should define here the names (list of strings) of these attributes. + """ + return [] + +class TLearner(Learner): + """ + TLearner is a virtual class of Learners that attempts to factor out of the definition + of a learner the steps that are common to many implementations of learning algorithms, + so as to leave only "the equations" to define in particular sub-classes, using Theano. + + In the default implementations of use and update, it is assumed that the 'use' and 'update' methods + visit examples in the input dataset sequentially. In the 'use' method only one pass through the dataset is done, + whereas the sub-learner may wish to iterate over the examples multiple times. Subclasses where this + basic model is not appropriate can simply redefine update or use. + + Sub-classes must provide the following functions and functionalities: + - attributeNames(): defines all the names of attributes which can be used as fields or + attributes in input/output datasets or in stats collectors. + All these attributes are expected to be theano.Result objects + (with a .data property and recognized by theano.Function for compilation). + The sub-class constructor defines the relations between + the Theano variables that may be used by 'use' and 'update' + or by a stats collector. + - defaultOutputFields(input_fields): return a list of default dataset output fields when + None are provided by the caller of use. + - + + """ + diff -r 4b0859606d05 -r 158653a9bc7c linear_regression.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/linear_regression.py Mon May 05 11:02:03 2008 -0400 @@ -0,0 +1,219 @@ + +from learner import * +from theano import tensor as t +from compile import Function +from theano.scalar import as_scalar + +# this is one of the simplest example of learner, and illustrates +# the use of theano +class LinearRegression(Learner): + """ + Implement linear regression, with or without L2 regularization + (the former is called Ridge Regression and the latter Ordinary Least Squares). + + The predictor is obtained analytically. + + The L2 regularization coefficient is obtained analytically. + For each (input[t],output[t]) pair in a minibatch,:: + + output_t = b + W * input_t + + where b and W are obtained by minimizing:: + + lambda sum_{ij} W_{ij}^2 + sum_t ||output_t - target_t||^2 + + Let X be the whole training set inputs matrix (one input example per row), + with the first column full of 1's, and Let Y the whole training set + targets matrix (one example's target vector per row). + Let theta = the matrix with b in its first column and W in the others, + then each theta[:,i] is the solution of the linear system:: + + XtX * theta[:,i] = XtY[:,i] + + where XtX is a (n_inputs+1)x(n_inputs+1) matrix containing X'*X + plus lambda on the diagonal except at (0,0), + and XtY is a (n_inputs+1)*n_outputs matrix containing X'*Y. + + The fields and attributes expected and produced by use and update are the following: + + - Input and output fields (example-wise quantities): + + - 'input' (always expected by use and update as an input_dataset field) + - 'target' (optionally expected by use and update as an input_dataset field) + - 'output' (optionally produced by use as an output dataset field) + - 'squared_error' (optionally produced by use as an output dataset field, needs 'target') = example-wise squared error + + - optional input attributes (optionally expected as input_dataset attributes) + + - 'lambda' (only used by update) + - 'b' (only used by use) + - 'W' (only used by use) + + - optional output attributes (available in self and optionally in output dataset) + + - 'b' (only set by update) + - 'W' (only set by update) + - 'regularization_term' (only set by update) + - 'XtX' (only set by update) + - 'XtY' (only set by update) + + """ + +# definitions specifiques a la regression lineaire: + + def global_inputs(self): + self.lambda = as_scalar(0.,'lambda') + self.theta = t.matrix('theta') + self.W = self.theta[:,1:] + self.b = self.theta[:,0] + self.XtX = t.matrix('XtX') + self.XtY = t.matrix('XtY') + + def global_outputs(self): + self.regularizer = self.lambda * t.dot(self.W,self.W) + self.loss = self.regularizer + t.sum(self.squared_error) # this only makes sense if the whole training set fits in memory in a minibatch + self.loss_function = Function([self.W,self.lambda,self.squared_error],[self.loss]) + + def initialize(self): + self.XtX.resize((1+self.n_inputs,1+self.n_inputs)) + self.XtY.resize((1+self.n_inputs,self.n_outputs)) + self.XtX.data[:,:]=0 + self.XtY.data[:,:]=0 + numpy.diag(self.XtX.data)[1:]=self.lambda.data + + def updated_variables(self): + self.new_XtX = self.XtX + t.dot(self.extended_input.T,self.extended_input) + self.new_XtY = self.XtY + t.dot(self.extended_input.T,self.target) + self.new_theta = t.solve(self.XtX,self.XtY) + + def minibatch_wise_inputs(self): + self.input = t.matrix('input') # n_examples x n_inputs + self.target = t.matrix('target') # n_examples x n_outputs + + def minibatch_wise_outputs(self): + # self.input is a (n_examples, n_inputs) minibatch matrix + self.extended_input = t.prepend_one_to_each_row(self.input) + self.output = t.dot(self.input,self.W.T) + self.b # (n_examples , n_outputs) matrix + self.squared_error = t.sum_within_rows(t.sqr(self.output-self.target)) # (n_examples ) vector + + def attributeNames(self): + return ["lambda","b","W","regularization_term","XtX","XtY"] + + def defaultOutputFields(self, input_fields): + output_fields = ["output"] + if "target" in input_fields: + output_fields.append("squared_error") + return output_fields + + # poutine generale basee sur ces fonctions + + def minibatchwise_use_functions(self, input_fields, output_fields, stats_collector): + if not output_fields: + output_fields = self.defaultOutputFields(input_fields) + if stats_collector: + stats_collector_inputs = stats_collector.inputUpdateAttributes() + for attribute in stats_collector_inputs: + if attribute not in input_fields: + output_fields.append(attribute) + key = (input_fields,output_fields) + if key not in self.use_functions_dictionary: + self.use_functions_dictionary[key]=Function(self.names2attributes(input_fields), + self.names2attributes(output_fields)) + return self.use_functions_dictionary[key] + + def attributes(self,return_copy=False): + return self.names2attributes(self.attributeNames()) + + def names2attributes(self,names,return_Result=False, return_copy=False): + if return_Result: + if return_copy: + return [copy.deepcopy(self.__getattr__(name)) for name in names] + else: + return [self.__getattr__(name) for name in names] + else: + if return_copy: + return [copy.deepcopy(self.__getattr__(name).data) for name in names] + else: + return [self.__getattr__(name).data for name in names] + + def use(self,input_dataset,output_fieldnames=None,test_stats_collector=None,copy_inputs=True): + minibatchwise_use_function = minibatchwise_use_functions(input_dataset.fieldNames(),output_fieldnames,test_stats_collector) + virtual_output_dataset = ApplyFunctionDataSet(input_dataset, + minibatchwise_use_function, + True,DataSet.numpy_vstack, + DataSet.numpy_hstack) + # actually force the computation + output_dataset = CachedDataSet(virtual_output_dataset,True) + if copy_inputs: + output_dataset = input_dataset | output_dataset + # compute the attributes that should be copied in the dataset + output_dataset.setAttributes(self.attributeNames(),self.attributes(return_copy=True)) + if test_stats_collector: + test_stats_collector.update(output_dataset) + for attribute in test_stats_collector.attributeNames(): + output_dataset[attribute] = copy.deepcopy(test_stats_collector[attribute]) + return output_dataset + + def update(self,training_set,train_stats_collector=None): + self.update_start() + for minibatch in training_set.minibatches(self.training_set_input_fields, minibatch_size=self.minibatch_size): + self.update_minibatch(minibatch) + if train_stats_collector: + minibatch_set = minibatch.examples() + minibatch_set.setAttributes(self.attributeNames(),self.attributes()) + train_stats_collector.update(minibatch_set) + self.update_end() + return self.use + + def __init__(self,lambda=0.,max_memory_use=500): + """ + @type lambda: float + @param lambda: regularization coefficient + """ + + W=t.matrix('W') + # b is a broadcastable row vector (can be replicated into + # as many rows as there are examples in the minibach) + b=t.row('b') + minibatch_input = t.matrix('input') # n_examples x n_inputs + minibatch_target = t.matrix('target') # n_examples x n_outputs + minibatch_output = t.dot(minibatch_input,W.T) + b # n_examples x n_outputs + lambda = as_scalar(lambda) + regularizer = self.lambda * t.dot(W,W) + example_squared_error = t.sum_within_rows(t.sqr(minibatch_output-minibatch_target)) + self.output_function = Function([W,b,minibatch_input],[minibatch_output]) + self.squared_error_function = Function([minibatch_output,minibatch_target],[self.example_squared_error]) + self.loss_function = Function([W,squared_error],[self.regularizer + t.sum(self.example_squared_error)]) + self.W=None + self.b=None + self.XtX=None + self.XtY=None + + def forget(self): + if self.W: + self.XtX *= 0 + self.XtY *= 0 + + def use(self,input_dataset,output_fieldnames=None,copy_inputs=True): + input_fieldnames = input_dataset.fieldNames() + assert "input" in input_fieldnames + if not output_fields: + output_fields = ["output"] + if "target" in input_fieldnames: + output_fields += ["squared_error"] + else: + if "squared_error" in output_fields or "total_loss" in output_fields: + assert "target" in input_fieldnames + + use_functions = [] + for output_fieldname in output_fieldnames: + if output_fieldname=="output": + use_functions.append(self.output_function) + elif output_fieldname=="squared_error": + use_functions.append(lambda self.output_function) + + n_examples = len(input_dataset) + + for minibatch in input_dataset.minibatches(minibatch_size=minibatch_size, allow_odd_last_minibatch=True): + use_function( + diff -r 4b0859606d05 -r 158653a9bc7c nnet_ops.py --- a/nnet_ops.py Mon May 05 10:57:33 2008 -0400 +++ b/nnet_ops.py Mon May 05 11:02:03 2008 -0400 @@ -2,32 +2,93 @@ from theano import tensor, gof, scalar import numpy -class ScalarSigmoid(scalar.UnaryScalarOp): +############ +# +# SCALAR OPS +# + +class ScalarSigmoid(scalar.FloatUnaryScalarOp): + @staticmethod + def st_impl(x): + if x < -30.0: + return 0.0 + if x > 30.0: + return 1.0 + return 1.0 / (1.0 + numpy.exp(-x)) def impl(self, x): - return 1.0 / (1 + numpy.exp(-x)) + return ScalarSigmoid.st_impl(x) def grad(self, (x,), (gz,)): - return gz * scalar_sigmoid(x) * (1.0 - scalar_sigmoid(x)), - def c_foreach(self, (x,), (z,)): - return "%(z)s = 1.0 / (1 + exp(-%(x)s));" % locals() + y = scalar_sigmoid(x) + return [gz * y * (1.0 - y)] + def c_foreach(self, (x,), (z,), sub): + if 'float' in self.inputs[0].dtype: + return """%(z)s = + %(x)s < -30.0 + ? 0.0 + : %(x)s > 30.0 + ? 1.0 + : 1.0 /(1.0+exp(-%(x)s));""" % locals() + raise NotImplementedError('only floatingpoint is implemented') scalar_sigmoid = gof.op.constructor(ScalarSigmoid) -Sigmoid, sigmoid, SigmoidInplace, sigmoid_inplace \ - = theano.tensor.broadcast(ScalarSigmoid, 'Sigmoid') +Sigmoid, sigmoid, SigmoidInplace, sigmoid_inplace =\ + tensor.broadcast(ScalarSigmoid, 'Sigmoid') +class ScalarSoftplus(scalar.FloatUnaryScalarOp): + @staticmethod + def static_impl(x): + if x < -30.0: + return 0.0 + if x > 30.0: + return x + return numpy.log1p(numpy.exp(x)) + def impl(self, x): + return ScalarSoftplus.static_impl(x) + def grad(self, (x,), (gz,)): + return [gz * scalar_sigmoid(x)] + def c_foreach(self, (x,), (z,), sub): + if 'float' in self.inputs[0].dtype: + return """%(z)s = + %(x)s < -30.0 + ? 0.0 + : %(x)s > 30.0 + ? %(x)s + : log1p(exp(%(x)s));""" % locals() + raise NotImplementedError('only floating point x is implemented') +scalar_softplus = gof.op.constructor(ScalarSoftplus) +Softplus, softplus, SoftplusInplace, softplus_inplace =\ + tensor.broadcast(ScalarSoftplus, 'Softplus') -class CrossentropySoftmax1Hot(gof.op.Op): - """A special compound Op for the output of neural-net classifiers. +############ +# +# TENSOR OPS +# + + +class CrossentropySoftmax1HotWithBias(gof.op.Op): + """A special compound L{Op} for the output of neural-net classifiers. + + @type x: is a matrix of floats (32 or 64) + @type b: is a [row] vector of floats (32 or 64), length is number of cols in x + @type y_idx: a [column] vector of int (32 or 64), length is number of rows in x - This Op has two outputs: - - KL(softmax(x), y) - - softmax(x) + @precondition: every entry in y_idx is a valid (non-negative) column index into x + + This L{Op} has two outputs: + - KL(softmax(x+b), y) + - softmax(x+b) - x[i] is assumed to be a dense vector + softmax(x[i]) is the i'th distribution over len(x[i]) options - y[i] is an integer index, encoding a 1-hot distribution + + y_idx[i] is an integer index, encoding a 1-hot distribution. + + In practice, when we're trying to do classification, we have one row in x + and y_idx per example, and y[i] is the index of the (correct) class of the + i'th example. """ - nin=2 + nin=3 nout=2 def __init__(self, x, b, y_idx, **kwargs): x = tensor._as_tensor(x) @@ -52,7 +113,9 @@ def perform(self): x, b, y_idx = [i.data for i in self.inputs] if b.shape[0] != x.shape[1]: - raise ValueError('b must have same shape as x[0]') + raise ValueError('b must have same number of columns as x') + if y_idx.shape[0] != x.shape[0]: + raise ValueError('y_idx must have same number of rows as x') sm = numpy.zeros_like(x) # softmax nll = numpy.zeros(x.shape[0]) #nll(y | softmax(x)) @@ -66,17 +129,12 @@ def grad(self, (x, b, y_idx), (g_nll, g_sm)): if g_sm is not None: raise NotImplementedError() - nll, sm = crossentropy_softmax_1hot(x, b, y_idx) - dx = CrossentropySoftmax1HotDx(g_nll, sm, y_idx).outputs[0] + nll, sm = crossentropy_softmax_1hot_with_bias(x, b, y_idx) + dx = CrossentropySoftmax1HotWithBiasDx(g_nll, sm, y_idx).outputs[0] db = tensor.Sum(dx, axis = [0]).outputs[0] return dx, db, None - def c_validate_cleanup(self, (x, b, y_idx), (nll, sm), sub): - """Not sure...""" - return "" - def c_support_code(self): - return """ - """ + def c_headers(self): return [''] def c_code(self, (x, b, y_idx), (nll, sm), sub): # this implementation was lifted from # /u/bergstrj/cvs/bergstrj/src/feb07/nn.cxx @@ -89,25 +147,67 @@ return """ npy_intp* Nx = %(x)s->dimensions; - if (%(x)s->nd != 2) { %(fail)s } - if (%(b)s->nd != 1) { %(fail)s } - if (%(y_idx)s->nd != 1) { %(fail)s } - if (%(x)s->descr->type_num != PyArray_DOUBLE) { %(fail)s} - if (%(b)s->descr->type_num != PyArray_DOUBLE) { %(fail)s} - if (%(y_idx)s->descr->type_num != PyArray_INT64) { %(fail)s} - - %(nll)s = (PyArrayObject*)PyArray_SimpleNew(1, PyArray_DIMS(%(y_idx)s), type_num_%(x)s); - if(!%(nll)s){%(fail)s} + if (%(x)s->nd != 2) + { + PyErr_SetString(PyExc_ValueError, "a not 2d tensor"); + %(fail)s; + } + if (%(b)s->nd != 1) + { + PyErr_SetString(PyExc_ValueError, "b not 1d tensor"); + %(fail)s; + } + if (%(y_idx)s->nd != 1) + { + PyErr_SetString(PyExc_ValueError, "y_idx not 1d tensor"); + %(fail)s; + } + if (%(x)s->descr->type_num != PyArray_DOUBLE) + { + PyErr_SetString(PyExc_TypeError, "a not float64"); + %(fail)s; + } + if (%(b)s->descr->type_num != PyArray_DOUBLE) + { + PyErr_SetString(PyExc_TypeError, "b not float64"); + %(fail)s; + } + if (%(y_idx)s->descr->type_num != PyArray_INT64) + { + PyErr_SetString(PyExc_TypeError, "y_idx not int64"); + %(fail)s; + } + if ((%(x)s->dimensions[1] != %(b)s->dimensions[0]) + || (%(x)s->dimensions[0] != %(y_idx)s->dimensions[0])) + { + PyErr_SetString(PyExc_ValueError, "dimension mismatch in arguments"); + %(fail)s; + } - %(sm)s = (PyArrayObject*)PyArray_SimpleNew(2, PyArray_DIMS(%(x)s), type_num_%(x)s); - if(!%(sm)s) { - // The normal cleanup code will take care of %(nll)s - // Py_XDECREF(%(nll)s); %(nll)s=NULL; - %(fail)s + if ((NULL == %(nll)s) //initial condition + || (%(nll)s->dimensions[0] != %(y_idx)s->dimensions[0])) + { + if (NULL != %(nll)s) Py_XDECREF(%(nll)s); + %(nll)s = (PyArrayObject*)PyArray_SimpleNew(1, PyArray_DIMS(%(y_idx)s), type_num_%(x)s); + if(!%(nll)s) + { + PyErr_SetString(PyExc_MemoryError, "failed to alloc nll output"); + %(fail)s; + } } - if (%(x)s->dimensions[1] != %(b)s->dimensions[0]) {%(fail)s} - if (%(sm)s->dimensions[0] != %(x)s->dimensions[0]) {%(fail)s} - if (%(sm)s->dimensions[1] != %(x)s->dimensions[1]) {%(fail)s} + if ((NULL == %(sm)s) + || (%(sm)s->dimensions[0] != %(x)s->dimensions[0]) + || (%(sm)s->dimensions[1] != %(x)s->dimensions[1])) + { + if (NULL != %(sm)s) Py_XDECREF(%(sm)s); + %(sm)s = (PyArrayObject*)PyArray_SimpleNew(2, PyArray_DIMS(%(x)s), type_num_%(x)s); + if(!%(sm)s) { + // The normal cleanup code will take care of %(nll)s + // Py_XDECREF(%(nll)s); %(nll)s=NULL; + PyErr_SetString(PyExc_MemoryError, "failed to alloc sm output"); + %(fail)s + } + } for (size_t i = 0; i < Nx[0]; ++i) { @@ -181,11 +281,10 @@ } """ % dict(locals(), **sub) - +crossentropy_softmax_1hot_with_bias = \ + gof.op.constructor(CrossentropySoftmax1HotWithBias) -crossentropy_softmax_1hot = gof.op.constructor(CrossentropySoftmax1Hot) - -class CrossentropySoftmax1HotDx (gof.op.Op): +class CrossentropySoftmax1HotWithBiasDx (gof.op.Op): nin=3 nout=1 """Gradient wrt x of the CrossentropySoftmax1Hot Op""" @@ -204,36 +303,42 @@ self.outputs[0].data = dx def grad(self, *args): raise NotImplementedError() - def c_validate_update(self, (dnll, sm, y_idx), (dx,), sub): - """Allocate output storage""" - return """ - if (%(dnll)s->nd != 1) { %(fail)s } - if (%(sm)s->nd != 2) { %(fail)s } - if (%(y_idx)s->nd != 1) { %(fail)s } - if (%(dnll)s->descr->type_num != PyArray_DOUBLE) { %(fail)s} - if (%(sm)s->descr->type_num != PyArray_DOUBLE) { %(fail)s} - if (%(y_idx)s->descr->type_num != PyArray_INT64) { %(fail)s} - - %(dx)s = (PyArrayObject*)PyArray_SimpleNew(2, PyArray_DIMS(%(sm)s), type_num_%(sm)s); - if(!%(dx)s){%(fail)s} - - """ % dict(locals(), **sub) - def c_validate_cleanup(self, inputs, outputs, sub): - """Not sure...""" - return "" - def c_support_code(self): - return """ - """ def c_code(self, (dnll, sm, y_idx), (dx,), sub): return """ - npy_intp* shape = %(dx)s->dimensions; - if (%(dnll)s->dimensions[0] != %(sm)s->dimensions[0]) {%(fail)s} - if (%(dnll)s->dimensions[0] != %(y_idx)s->dimensions[0]) {%(fail)s} - if (%(dnll)s->dimensions[0] != %(dx)s->dimensions[0]) {%(fail)s} - if (%(sm)s->dimensions[1] != %(dx)s->dimensions[1]) {%(fail)s} + if ((%(dnll)s->descr->type_num != PyArray_DOUBLE) + || (%(sm)s->descr->type_num != PyArray_DOUBLE) + || (%(y_idx)s->descr->type_num != PyArray_INT64)) + { + PyErr_SetString(PyExc_TypeError, "types should be float64, float64, int64"); + %(fail)s; + } + if ((%(dnll)s->nd != 1) + || (%(sm)s->nd != 2) + || (%(y_idx)s->nd != 1)) + { + PyErr_SetString(PyExc_ValueError, "rank error"); + %(fail)s; + } + if ((%(dnll)s->dimensions[0] != %(sm)s->dimensions[0]) + || (%(dnll)s->dimensions[0] != %(y_idx)s->dimensions[0])) + { + PyErr_SetString(PyExc_ValueError, "dimension mismatch"); + %(fail)s; + } + if ((NULL == %(dx)s) + || (%(dx)s->dimensions[0] != %(sm)s->dimensions[0]) + || (%(dx)s->dimensions[1] != %(sm)s->dimensions[1])) + { + if (NULL != %(dx)s) Py_XDECREF(%(dx)s); + %(dx)s = (PyArrayObject*)PyArray_SimpleNew(2, PyArray_DIMS(%(sm)s), type_num_%(sm)s); + if(!%(dx)s) { + PyErr_SetString(PyExc_MemoryError, "failed to alloc dx output"); + %(fail)s + } + } - for (size_t i = 0; i < shape[0]; ++i) + for (size_t i = 0; i < %(dx)s->dimensions[0]; ++i) { const double dnll_i = ((double*)(%(dnll)s->data + %(dnll)s->strides[0] * i))[0]; @@ -245,14 +350,19 @@ double* __restrict__ dx_i = (double*)(%(dx)s->data + %(dx)s->strides[0] * i); npy_intp Sdx = %(dx)s->strides[1]/sizeof(double); - for (size_t j = 0; j < shape[1]; ++j) + for (size_t j = 0; j < %(dx)s->dimensions[1]; ++j) { dx_i[j * Sdx] = dnll_i * sm_i[j * Ssm]; } - if (y_i >= shape[1]) + if (y_i >= %(dx)s->dimensions[1]) { %(fail)s; } dx_i[y_i * Sdx] -= dnll_i; } """ % dict(locals(), **sub) + +def crossentropy_softmax_1hot(x, y_idx, **kwargs): + b = tensor.zeros_like(x[0,:]) + return crossentropy_softmax_1hot_with_bias(x, b, y_idx, **kwargs) + diff -r 4b0859606d05 -r 158653a9bc7c test_dataset.py --- a/test_dataset.py Mon May 05 10:57:33 2008 -0400 +++ b/test_dataset.py Mon May 05 11:02:03 2008 -0400 @@ -124,7 +124,6 @@ #* ds1 | ds2 | ds3 == ds.hstack([ds1,ds2,ds3]) #* ds1 & ds2 & ds3 == ds.vstack([ds1,ds2,ds3]) - # for (x,y) in (ds('x','y'),a): #don't work # haven't found a variant that work. # assert numpy.append(x,y)==z @@ -170,6 +169,7 @@ def ArrayFieldsDataSet(): raise NotImplementedError() +test1() test_LookupList() test_ArrayDataSet()