changeset 407:b9f545594207

Automated merge with ssh://projects@lgcm.iro.umontreal.ca/hg/pylearn
author Frederic Bastien <bastienf@iro.umontreal.ca>
date Thu, 10 Jul 2008 09:03:11 -0400
parents c2e6a8fcc35e (diff) c2f17f231960 (current diff)
children 5175c564e37a cf22ebfc90eb
files
diffstat 25 files changed, 674 insertions(+), 254 deletions(-) [+]
line wrap: on
line diff
--- a/linear_regression.py	Wed Jul 09 16:55:27 2008 -0400
+++ b/linear_regression.py	Thu Jul 10 09:03:11 2008 -0400
@@ -4,16 +4,35 @@
 the use of theano.
 """
 
-from pylearn import OfflineLearningAlgorithm
+from pylearn.learner import OfflineLearningAlgorithm
 from theano import tensor as T
+from theano.others_ops import prepend_1_to_each_row
 from theano.scalar import as_scalar
 from common.autoname import AutoName
+import theano
+import numpy
 
 class LinearRegression(OfflineLearningAlgorithm):
     """
     Implement linear regression, with or without L2 regularization
     (the former is called Ridge Regression and the latter Ordinary Least Squares).
 
+    Usage:
+
+       linear_regressor=LinearRegression(L2_regularizer=0.1)
+       linear_predictor=linear_regression(training_set)
+       all_results_dataset=linear_predictor(test_set) # creates a dataset with "output" and "squared_error" field
+       outputs = linear_predictor.compute_outputs(inputs) # inputs and outputs are numpy arrays
+       outputs, errors = linear_predictor.compute_outputs_and_errors(inputs,targets)
+       errors = linear_predictor.compute_errors(inputs,targets)
+       mse = linear_predictor.compute_mse(inputs,targets)
+       
+       
+
+    The training_set must have fields "input" and "target".
+    The test_set must have field "input", and needs "target" if
+    we want to compute the squared errors.
+
     The predictor parameters are obtained analytically from the training set.
     Training can proceed sequentially (with multiple calls to update with
     different disjoint subsets of the training sets). After each call to
@@ -52,12 +71,24 @@
        - 'squared_error' (optionally produced by learned model if 'target' is provided)
           = example-wise squared error
     """
-    def __init__(self, L2_regularizer=0):
-        self.predictor = LinearPredictor(None,None
+    def __init__(self, L2_regularizer=0,minibatch_size=10000):
         self.L2_regularizer=L2_regularizer
-        self._XtX = T.matrix('XtX')
-        self._XtY = T.matrix('XtY')
-        self._extended_input = T.prepend_one_to_each_row(self._input)
+        self.equations = LinearRegressionEquations()
+        self.minibatch_size=1000
+
+    def __call__(self,trainset):
+        first_example = trainset[0]
+        n_inputs = first_example['input'].size
+        n_outputs = first_example['target'].size
+        XtX = numpy.zeros((n_inputs+1,n_inputs+1))
+        XtY = numpy.zeros((n_inputs+1,n_outputs))
+        for i in xrange(n_inputs):
+            XtX[i+1,i+1]=self.L2_regularizer
+        mbs=min(self.minibatch_size,len(trainset))
+        for inputs,targets in trainset.minibatches(["input","target"],minibatch_size=mbs):
+            XtX,XtY=self.equations.update(XtX,XtY,numpy.array(inputs),numpy.array(targets))
+        theta=numpy.linalg.solve(XtX,XtY)
+        return LinearPredictor(theta)
 
 class LinearPredictorEquations(AutoName):
     inputs = T.matrix() # minibatchsize x n_inputs
@@ -76,22 +107,37 @@
         def fn(input_vars,output_vars):
             return staticmethod(theano.function(input_vars,output_vars, linker=linker))
 
-        cls.compute_outputs = fn([inputs,theta],[outputs])
-        cls.compute_errors = fn([outputs,targets],[squared_errors])
+        cls.compute_outputs = fn([cls.inputs,cls.theta],[cls.outputs])
+        cls.compute_errors = fn([cls.outputs,cls.targets],[cls.squared_errors])
 
         cls.__compiled = True
 
-    def __init__(self)
+    def __init__(self):
         self.compile()
         
 class LinearRegressionEquations(LinearPredictorEquations):
     P = LinearPredictorEquations
     XtX = T.matrix() # (n_inputs+1) x (n_inputs+1)
     XtY = T.matrix() # (n_inputs+1) x n_outputs
-    extended_input = T.prepend_scalar_to_each_row(1.,P.inputs)
-    new_XtX = add_inplace(XtX,T.dot(extended_input.T,extended_input))
-    new_XtY = add_inplace(XtY,T.dot(extended_input.T,P.targets))
-    new_theta = T.Cholesky_solve_inplace(P.theta,XtX,XtY)  # solve linear system XtX theta = XtY 
+    extended_input = prepend_1_to_each_row(P.inputs)
+    new_XtX = T.add_inplace(XtX,T.dot(extended_input.T,extended_input))
+    new_XtY = T.add_inplace(XtY,T.dot(extended_input.T,P.targets))
+
+    __compiled = False
+    
+    @classmethod
+    def compile(cls,linker='c|py'):
+        if cls.__compiled:
+            return
+        def fn(input_vars,output_vars):
+            return staticmethod(theano.function(input_vars,output_vars, linker=linker))
+
+        cls.update = fn([cls.XtX,cls.XtY,cls.P.inputs,cls.P.targets],[cls.new_XtX,cls.new_XtY])
+
+        cls.__compiled = True
+
+    def __init__(self):
+        self.compile()
 
 class LinearPredictor(object):
     """
@@ -103,15 +149,18 @@
         self.theta=theta
         self.n_inputs=theta.shape[0]-1
         self.n_outputs=theta.shape[1]
-        self.predict_equations = LinearPredictorEquations()
+        self.equations = LinearPredictorEquations()
 
     def compute_outputs(self,inputs):
-        return self.predict_equations.compute_outputs(inputs,self.theta)
+        return self.equations.compute_outputs(inputs,self.theta)
     def compute_errors(self,inputs,targets):
-        return self.predict_equations.compute_errors(self.compute_outputs(inputs),targets)
+        return self.equations.compute_errors(self.compute_outputs(inputs),targets)
     def compute_outputs_and_errors(self,inputs,targets):
         outputs = self.compute_outputs(inputs)
-        return [outputs,self.predict_equations.compute_errors(outputs,targets)]
+        return [outputs,self.equations.compute_errors(outputs,targets)]
+    def compute_mse(self,inputs,targets):
+        errors = self.compute_errors(inputs,targets)
+        return numpy.sum(errors)/errors.size
     
     def __call__(self,dataset,output_fieldnames=None,cached_output_dataset=False):
         assert dataset.hasFields(["input"])
@@ -137,38 +186,3 @@
             return ds
         
 
-        self._XtX = T.matrix('XtX')
-        self._XtY = T.matrix('XtY')
-        self._extended_input = T.prepend_one_to_each_row(self._input)
-        self._output = T.dot(self._input,self._W.T) + self._b  # (n_examples , n_outputs) matrix
-        self._squared_error = T.sum_within_rows(T.sqr(self._output-self._target)) # (n_examples ) vector
-        self._regularizer = self._L2_regularizer * T.dot(self._W,self._W)
-        self._new_XtX = add_inplace(self._XtX,T.dot(self._extended_input.T,self._extended_input))
-        self._new_XtY = add_inplace(self._XtY,T.dot(self._extended_input.T,self._target))
-        self._new_theta = T.solve_inplace(self._theta,self._XtX,self._XtY)
-
-    def allocate(self,dataset):
-        dataset_n_inputs  = dataset["input"].shape[1]
-        dataset_n_outputs = dataset["target"].shape[1]
-        if not self._n_inputs:
-            self._n_inputs = dataset_n_inputs 
-            self._n_outputs = dataset_n_outputs
-            self.XtX = numpy.zeros((1+self._n_inputs,1+self._n_inputs))
-            self.XtY = numpy.zeros((1+self._n_inputs,self._n_outputs))
-            self.theta = numpy.zeros((self._n_outputs,1+self._n_inputs))
-            self.forget()
-        elif self._n_inputs!=dataset_n_inputs or self._n_outputs!=dataset_n_outputs:
-            # if the input or target changes dimension on the fly, we resize and forget everything
-            self.forget()
-            
-    def forget(self):
-        if self._n_inputs and self._n_outputs:
-            self.XtX.resize((1+self.n_inputs,1+self.n_inputs))
-            self.XtY.resize((1+self.n_inputs,self.n_outputs))
-            self.XtX.data[:,:]=0
-            self.XtY.data[:,:]=0
-            numpy.diag(self.XtX.data)[1:]=self.L2_regularizer
-
-    def __call__(self,dataset):
-
-
--- a/nnet_ops.py	Wed Jul 09 16:55:27 2008 -0400
+++ b/nnet_ops.py	Thu Jul 10 09:03:11 2008 -0400
@@ -380,3 +380,10 @@
     b = tensor.zeros_like(x[0,:])
     return crossentropy_softmax_1hot_with_bias(x, b, y_idx, **kwargs)
 
+def binary_crossentropy(output, target):
+    """
+    Compute the crossentropy of binary output wrt binary target.
+    @note: We do not sum, crossentropy is computed by component.
+    @todo: Rewrite as a scalar, and then broadcast to tensor.
+    """
+    return -(target * tensor.log(output) + (1 - target) * tensor.log(1 - output))
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/sandbox/README.txt	Thu Jul 10 09:03:11 2008 -0400
@@ -0,0 +1,1 @@
+Stuff in the sandbox may be very broken and/or in flux.
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/sandbox/rbm/README.txt	Thu Jul 10 09:03:11 2008 -0400
@@ -0,0 +1,4 @@
+An RBM with binomial units trained with CD-1.
+by Joseph Turian
+    
+This seems to work fine.
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/sandbox/rbm/main.py	Thu Jul 10 09:03:11 2008 -0400
@@ -0,0 +1,26 @@
+#!/usr/bin/python
+"""
+Simple SGD RBM training.
+(An example of how to use the model.)
+"""
+
+
+import numpy
+
+nonzero_instances = []
+#nonzero_instances.append({0: 1, 1: 1})
+#nonzero_instances.append({0: 1, 2: 1})
+
+nonzero_instances.append({1: 0.1, 5: 0.5, 9: 1})
+nonzero_instances.append({2: 0.3, 5: 0.5, 8: 0.8})
+nonzero_instances.append({1: 0.2, 2: 0.3, 5: 0.5})
+
+import model
+model = model.Model(input_dimension=10, hidden_dimension=6)
+
+for i in xrange(100000):
+    # Select an instance
+    instance = nonzero_instances[i % len(nonzero_instances)]
+
+    # SGD update over instance
+    model.update([instance])
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/sandbox/rbm/model.py	Thu Jul 10 09:03:11 2008 -0400
@@ -0,0 +1,139 @@
+"""
+The model for an autoassociator for sparse inputs, using Ronan Collobert + Jason
+Weston's sampling trick (2008).
+"""
+
+import parameters
+
+import numpy
+from numpy import dot
+import random
+
+import pylearn.nnet_ops
+import pylearn.sparse_instance
+
+def sigmoid(v):
+    """
+    @todo: Move to pylearn.more_numpy
+    @todo: Fix to avoid floating point overflow.
+    """
+#    if x < -30.0: return 0.0
+#    if x > 30.0: return 1.0 
+    return 1.0 / (1.0 + numpy.exp(-v))
+
+def sample(v):
+    """
+    @todo: Move to pylearn.more_numpy
+    """
+    assert len(v.shape) == 2
+    x = numpy.zeros(v.shape)
+    for j in range(v.shape[0]):
+        for i in range(v.shape[1]):
+            assert v[j][i] >= 0 and v[j][i] <= 1
+            if random.random() < v[j][i]: x[j][i] = 1
+            else: x[j][i] = 0
+    return x
+
+def crossentropy(output, target):
+    """
+    Compute the crossentropy of binary output wrt binary target.
+    @note: We do not sum, crossentropy is computed by component.
+    @todo: Rewrite as a scalar, and then broadcast to tensor.
+    @todo: Move to pylearn.more_numpy
+    @todo: Fix to avoid floating point overflow.
+    """
+    return -(target * numpy.log(output) + (1 - target) * numpy.log(1 - output))
+
+
+class Model:
+    """
+    @todo: input dimensions should be stored here! not as a global.
+    """
+    def __init__(self, input_dimension, hidden_dimension, learning_rate = 0.1, momentum = 0.9, weight_decay = 0.0002, random_seed = 666):
+        self.input_dimension    = input_dimension
+        self.hidden_dimension   = hidden_dimension
+        self.learning_rate      = learning_rate
+        self.momentum           = momentum
+        self.weight_decay       = weight_decay
+        self.random_seed        = random_seed
+
+        random.seed(random_seed)
+
+        self.parameters = parameters.Parameters(input_dimension=self.input_dimension, hidden_dimension=self.hidden_dimension, randomly_initialize=False, random_seed=self.random_seed)
+        self.prev_dw = 0
+        self.prev_db = 0
+        self.prev_dc = 0
+
+    def deterministic_reconstruction(self, v0):
+        """
+        One up-down cycle, but a mean-field approximation (no sampling).
+        """
+        q = sigmoid(self.parameters.b + dot(v0, self.parameters.w))
+        p = sigmoid(self.parameters.c + dot(q, self.parameters.w.T))
+        return p
+
+    def deterministic_reconstruction_error(self, v0):
+        """
+        @note: According to Yoshua, -log P(V1 = v0 | tilde(h)(v0)).
+        """
+        return crossentropy(self.deterministic_reconstruction(v0), v0)
+
+    def update(self, instances):
+        """
+        Update the L{Model} using one training instance.
+        @param instance: A dict from feature index to (non-zero) value.
+        @todo: Should assert that nonzero_indices and zero_indices
+        are correct (i.e. are truly nonzero/zero).
+        @todo: Multiply L{self.weight_decay} by L{self.learning_rate}, as done in Semantic Hashing?
+        @todo: Decay the biases too?
+        """
+        minibatch = len(instances)
+        v0 = pylearn.sparse_instance.to_vector(instances, self.input_dimension)
+        print "old XENT:", numpy.sum(self.deterministic_reconstruction_error(v0))
+        q0 = sigmoid(self.parameters.b + dot(v0, self.parameters.w))
+        h0 = sample(q0)
+        p0 = sigmoid(self.parameters.c + dot(h0, self.parameters.w.T))
+        v1 = sample(p0)
+        q1 = sigmoid(self.parameters.b + dot(v1, self.parameters.w))
+
+        dw = self.learning_rate * (dot(v0.T, h0) - dot(v1.T, q1)) / minibatch + self.momentum * self.prev_dw
+        db = self.learning_rate * numpy.sum(h0 - q1, axis=0) / minibatch + self.momentum * self.prev_db
+        dc = self.learning_rate * numpy.sum(v0 - v1, axis=0) / minibatch + self.momentum * self.prev_dc
+
+        self.parameters.w *= (1 - self.weight_decay)
+
+        self.parameters.w += dw
+        self.parameters.b += db
+        self.parameters.c += dc
+
+        self.last_dw = dw
+        self.last_db = db
+        self.last_dc = dc
+
+        print "new XENT:", numpy.sum(self.deterministic_reconstruction_error(v0))
+
+#        print
+#        print "v[0]:", v0
+#        print "Q(h[0][i] = 1 | v[0]):", q0
+#        print "h[0]:", h0
+#        print "P(v[1][j] = 1 | h[0]):", p0
+#        print "XENT(P(v[1][j] = 1 | h[0]) | v0):", numpy.sum(crossentropy(p0, v0))
+#        print "v[1]:", v1
+#        print "Q(h[1][i] = 1 | v[1]):", q1
+#
+#        print
+#        print v0.T.shape
+#        print h0.shape
+#        print dot(v0.T, h0).shape
+#        print self.parameters.w.shape
+#        self.parameters.w += self.learning_rate * (dot(v0.T, h0) - dot(v1.T, q1)) / minibatch
+#        print
+#        print h0.shape
+#        print q1.shape
+#        print self.parameters.b.shape
+#        self.parameters.b += self.learning_rate * numpy.sum(h0 - q1, axis=0) / minibatch
+#        print v0.shape, v1.shape
+#        print
+#        print self.parameters.c.shape
+#        self.parameters.c += self.learning_rate * numpy.sum(v0 - v1, axis=0) / minibatch
+#        print self.parameters
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/sandbox/rbm/parameters.py	Thu Jul 10 09:03:11 2008 -0400
@@ -0,0 +1,32 @@
+"""
+Parameters (weights) used by the L{Model}.
+"""
+
+import numpy
+
+class Parameters:
+    """
+    Parameters used by the L{Model}.
+    """
+    def __init__(self, input_dimension, hidden_dimension, randomly_initialize, random_seed):
+        """
+        Initialize L{Model} parameters.
+        @param randomly_initialize: If True, then randomly initialize
+        according to the given random_seed. If False, then just use zeroes.
+        """
+        if randomly_initialize:
+            numpy.random.random_seed(random_seed)
+            self.w = (numpy.random.rand(input_dimension, hidden_dimension)-0.5)/input_dimension
+            self.b = numpy.zeros((1, hidden_dimension))
+            self.c = numpy.zeros((1, input_dimension))
+        else:
+            self.w = numpy.zeros((input_dimension, hidden_dimension))
+            self.b = numpy.zeros((1, hidden_dimension))
+            self.c = numpy.zeros((1, input_dimension))
+
+    def __str__(self):
+        s = ""
+        s += "w: %s\n" % self.w
+        s += "b: %s\n" % self.b
+        s += "c: %s\n" % self.c
+        return s
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/sandbox/simple_autoassociator/README.txt	Thu Jul 10 09:03:11 2008 -0400
@@ -0,0 +1,2 @@
+This is broken. It can't even learn the simple two training instances in
+main.py
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/sandbox/simple_autoassociator/globals.py	Thu Jul 10 09:03:11 2008 -0400
@@ -0,0 +1,12 @@
+"""
+Global variables.
+"""
+
+#INPUT_DIMENSION = 1000
+#INPUT_DIMENSION = 100
+INPUT_DIMENSION = 4
+HIDDEN_DIMENSION = 10
+#HIDDEN_DIMENSION = 4
+LEARNING_RATE = 0.1
+LR = LEARNING_RATE
+SEED = 666
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/sandbox/simple_autoassociator/graph.py	Thu Jul 10 09:03:11 2008 -0400
@@ -0,0 +1,26 @@
+"""
+Theano graph for a simple autoassociator.
+@todo: Make nearly everything private.
+"""
+
+from pylearn.nnet_ops import sigmoid, binary_crossentropy
+from theano import tensor as t
+from theano.tensor import dot
+x           = t.dvector()
+w1          = t.dmatrix()
+b1          = t.dvector()
+w2          = t.dmatrix()
+b2          = t.dvector()
+h           = sigmoid(dot(x, w1) + b1)
+y           = sigmoid(dot(h, w2) + b2)
+
+loss_unsummed = binary_crossentropy(y, x)
+loss = t.sum(loss_unsummed)
+
+(gw1, gb1, gw2, gb2, gy, gh) = t.grad(loss, [w1, b1, w2, b2, y, h])
+
+import theano.compile
+
+inputs  = [x, w1, b1, w2, b2]
+outputs = [y, h, loss, loss_unsummed, gw1, gb1, gw2, gb2, gy, gh]
+trainfn = theano.compile.function(inputs, outputs)
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/sandbox/simple_autoassociator/main.py	Thu Jul 10 09:03:11 2008 -0400
@@ -0,0 +1,34 @@
+#!/usr/bin/python
+"""
+    A simple autoassociator.
+
+    The learned model is::
+       h   = sigmoid(dot(x, w1) + b1)
+       y   = sigmoid(dot(h, w2) + b2)
+
+    Binary xent loss.
+
+    LIMITATIONS:
+       - Only does pure stochastic gradient (batchsize = 1).
+"""
+
+
+import numpy
+
+nonzero_instances = []
+nonzero_instances.append({0: 1, 1: 1})
+nonzero_instances.append({0: 1, 2: 1})
+
+#nonzero_instances.append({1: 0.1, 5: 0.5, 9: 1})
+#nonzero_instances.append({2: 0.3, 5: 0.5, 8: 0.8})
+##nonzero_instances.append({1: 0.2, 2: 0.3, 5: 0.5})
+
+import model
+model = model.Model()
+
+for i in xrange(100000):
+    # Select an instance
+    instance = nonzero_instances[i % len(nonzero_instances)]
+
+    # SGD update over instance
+    model.update(instance)
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/sandbox/simple_autoassociator/model.py	Thu Jul 10 09:03:11 2008 -0400
@@ -0,0 +1,58 @@
+"""
+The model for an autoassociator for sparse inputs, using Ronan Collobert + Jason
+Weston's sampling trick (2008).
+"""
+
+from graph import trainfn
+import parameters
+
+import globals
+from globals import LR
+
+import numpy
+import random
+random.seed(globals.SEED)
+
+class Model:
+    def __init__(self):
+        self.parameters = parameters.Parameters(randomly_initialize=True)
+
+    def update(self, instance):
+        """
+        Update the L{Model} using one training instance.
+        @param instance: A dict from feature index to (non-zero) value.
+        @todo: Should assert that nonzero_indices and zero_indices
+        are correct (i.e. are truly nonzero/zero).
+        """
+        x = numpy.zeros(globals.INPUT_DIMENSION)
+        for idx in instance.keys():
+            x[idx] = instance[idx]
+
+        (y, h, loss, loss_unsummed, gw1, gb1, gw2, gb2, gy, gh) = trainfn(x, self.parameters.w1, self.parameters.b1, self.parameters.w2, self.parameters.b2)
+        print
+        print "instance:", instance
+        print "x:", x
+        print "OLD y:", y
+        print "OLD loss (unsummed):", loss_unsummed
+        print "gy:", gy
+        print "gh:", gh
+        print "OLD total loss:", loss
+        print "gw1:", gw1
+        print "gb1:", gb1
+        print "gw2:", gw2
+        print "gb2:", gb2
+
+        # SGD update
+        self.parameters.w1  -= LR * gw1
+        self.parameters.b1  -= LR * gb1
+        self.parameters.w2  -= LR * gw2
+        self.parameters.b2  -= LR * gb2
+
+        # Recompute the loss, to make sure it's descreasing
+        (y, h, loss, loss_unsummed, gw1, gb1, gw2, gb2, gy, gh) = trainfn(x, self.parameters.w1, self.parameters.b1, self.parameters.w2, self.parameters.b2)
+        print "NEW y:", y
+        print "NEW loss (unsummed):", loss_unsummed
+        print "gy:", gy
+        print "NEW total loss:", loss
+        print "h:", h
+        print self.parameters
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/sandbox/simple_autoassociator/parameters.py	Thu Jul 10 09:03:11 2008 -0400
@@ -0,0 +1,37 @@
+"""
+Parameters (weights) used by the L{Model}.
+"""
+
+import numpy
+import globals
+
+class Parameters:
+    """
+    Parameters used by the L{Model}.
+    """
+    def __init__(self, input_dimension=globals.INPUT_DIMENSION, hidden_dimension=globals.HIDDEN_DIMENSION, randomly_initialize=False, seed=globals.SEED):
+        """
+        Initialize L{Model} parameters.
+        @param randomly_initialize: If True, then randomly initialize
+        according to the given seed. If False, then just use zeroes.
+        """
+        if randomly_initialize:
+            numpy.random.seed(seed)
+            self.w1 = (numpy.random.rand(input_dimension, hidden_dimension)-0.5)/input_dimension
+            self.w2 = (numpy.random.rand(hidden_dimension, input_dimension)-0.5)/hidden_dimension
+            self.b1 = numpy.zeros(hidden_dimension)
+            #self.b2 = numpy.zeros(input_dimension)
+            self.b2 = numpy.array([10, 0, 0, -10])
+        else:
+            self.w1 = numpy.zeros((input_dimension, hidden_dimension))
+            self.w2 = numpy.zeros((hidden_dimension, input_dimension))
+            self.b1 = numpy.zeros(hidden_dimension)
+            self.b2 = numpy.zeros(input_dimension)
+
+    def __str__(self):
+        s = ""
+        s += "w1: %s\n" % self.w1
+        s += "b1: %s\n" % self.b1
+        s += "w2: %s\n" % self.w2
+        s += "b2: %s\n" % self.b2
+        return s
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/sandbox/sparse_random_autoassociator/README.txt	Thu Jul 10 09:03:11 2008 -0400
@@ -0,0 +1,1 @@
+Since simple_aa doesn't work, this probably doesn't either.
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/sandbox/sparse_random_autoassociator/globals.py	Thu Jul 10 09:03:11 2008 -0400
@@ -0,0 +1,13 @@
+"""
+Global variables.
+"""
+
+INPUT_DIMENSION = 1000
+HIDDEN_DIMENSION = 20
+LEARNING_RATE = 0.1
+LR = LEARNING_RATE
+SEED = 666
+ZERO_SAMPLE_SIZE = 50
+#ZERO_SAMPLE_SIZE = 250
+MARGIN = 0.25
+#MARGIN = 0.0
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/sandbox/sparse_random_autoassociator/graph.py	Thu Jul 10 09:03:11 2008 -0400
@@ -0,0 +1,42 @@
+"""
+Theano graph for an autoassociator for sparse inputs, which will be trained
+using Ronan Collobert + Jason Weston's sampling trick (2008).
+@todo: Make nearly everything private.
+"""
+
+from globals import MARGIN
+
+from pylearn.nnet_ops import sigmoid, binary_crossentropy
+from theano import tensor as t
+from theano.tensor import dot
+xnonzero    = t.dvector()
+w1nonzero   = t.dmatrix()
+b1          = t.dvector()
+w2nonzero   = t.dmatrix()
+w2zero      = t.dmatrix()
+b2nonzero   = t.dvector()
+b2zero      = t.dvector()
+h           = sigmoid(dot(xnonzero, w1nonzero) + b1)
+ynonzero    = sigmoid(dot(h, w2nonzero) + b2nonzero)
+yzero       = sigmoid(dot(h, w2zero) + b2zero)
+
+# May want to weight loss wrt nonzero value? e.g. MARGIN violation for
+# 0.1 nonzero is not as bad as MARGIN violation for 0.2 nonzero.
+def hingeloss(MARGIN):
+    return -MARGIN * (MARGIN < 0)
+nonzeroloss = hingeloss(ynonzero - t.max(yzero) - MARGIN)
+zeroloss = hingeloss(-t.max(-(ynonzero)) - yzero - MARGIN)
+# xnonzero sensitive loss:
+#nonzeroloss = hingeloss(ynonzero - t.max(yzero) - MARGIN - xnonzero)
+#zeroloss = hingeloss(-t.max(-(ynonzero - xnonzero)) - yzero - MARGIN)
+loss = t.sum(nonzeroloss) + t.sum(zeroloss)
+
+#loss = t.sum(binary_crossentropy(ynonzero, xnonzero)) + t.sum(binary_crossentropy(yzero, t.constant(0)))
+
+(gw1nonzero, gb1, gw2nonzero, gw2zero, gb2nonzero, gb2zero) = t.grad(loss, [w1nonzero, b1, w2nonzero, w2zero, b2nonzero, b2zero])
+
+import theano.compile
+
+inputs  = [xnonzero, w1nonzero, b1, w2nonzero, w2zero, b2nonzero, b2zero]
+outputs = [ynonzero, yzero, loss, gw1nonzero, gb1, gw2nonzero, gw2zero, gb2nonzero, gb2zero]
+trainfn = theano.compile.function(inputs, outputs)
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/sandbox/sparse_random_autoassociator/main.py	Thu Jul 10 09:03:11 2008 -0400
@@ -0,0 +1,48 @@
+#!/usr/bin/python
+"""
+    An autoassociator for sparse inputs, using Ronan Collobert + Jason
+    Weston's sampling trick (2008).
+
+    The learned model is::
+       h   = sigmoid(dot(x, w1) + b1)
+       y   = sigmoid(dot(h, w2) + b2)
+
+    We assume that most of the inputs are zero, and hence that
+    we can separate x into xnonzero, x's nonzero components, and
+    xzero, a sample of the zeros. We sample---randomly without
+    replacement---ZERO_SAMPLE_SIZE zero columns from x.
+
+    The desideratum is that every nonzero entry is separated from every
+    zero entry by margin at least MARGIN.
+    For each ynonzero, we want it to exceed max(yzero) by at least MARGIN.
+    For each yzero, we want it to be exceed by min(ynonzero) by at least MARGIN.
+    The loss is a hinge loss (linear). The loss is irrespective of the
+    xnonzero magnitude (this may be a limitation). Hence, all nonzeroes
+    are equally important to exceed the maximum yzero.
+
+    (Alternately, there is a commented out binary xent loss.)
+
+    LIMITATIONS:
+       - Only does pure stochastic gradient (batchsize = 1).
+       - Loss is irrespective of the xnonzero magnitude.
+       - We will always use all nonzero entries, even if the training
+       instance is very non-sparse.
+"""
+
+
+import numpy
+
+nonzero_instances = []
+nonzero_instances.append({1: 0.1, 5: 0.5, 9: 1})
+nonzero_instances.append({2: 0.3, 5: 0.5, 8: 0.8})
+nonzero_instances.append({1: 0.2, 2: 0.3, 5: 0.5})
+
+import model
+model = model.Model()
+
+for i in xrange(100000):
+    # Select an instance
+    instance = nonzero_instances[i % len(nonzero_instances)]
+
+    # SGD update over instance
+    model.update(instance)
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/sandbox/sparse_random_autoassociator/model.py	Thu Jul 10 09:03:11 2008 -0400
@@ -0,0 +1,76 @@
+"""
+The model for an autoassociator for sparse inputs, using Ronan Collobert + Jason
+Weston's sampling trick (2008).
+"""
+
+from graph import trainfn
+import parameters
+
+import globals
+from globals import LR
+
+import numpy
+import random
+random.seed(globals.SEED)
+
+def _select_indices(instance):
+    """
+    Choose nonzero and zero indices (feature columns) of the instance.
+    We select B{all} nonzero indices.
+    We select L{globals.ZERO_SAMPLE_SIZE} zero indices randomly,
+    without replacement.
+    @bug: If there are not ZERO_SAMPLE_SIZE zeroes, we will enter
+    an endless loop.
+    @return: (nonzero_indices, zero_indices)
+    """
+    # Get the nonzero indices
+    nonzero_indices = instance.keys()
+    nonzero_indices.sort()
+
+    # Get the zero indices
+    # @bug: If there are not ZERO_SAMPLE_SIZE zeroes, we will enter an endless loop.
+    zero_indices = []
+    while len(zero_indices) < globals.ZERO_SAMPLE_SIZE:
+        idx = random.randint(0, globals.INPUT_DIMENSION - 1)
+        if idx in nonzero_indices or idx in zero_indices: continue
+        zero_indices.append(idx)
+    zero_indices.sort()
+
+    return (nonzero_indices, zero_indices)
+
+class Model:
+    def __init__(self):
+        self.parameters = parameters.Parameters(randomly_initialize=True)
+
+    def update(self, instance):
+        """
+        Update the L{Model} using one training instance.
+        @param instance: A dict from feature index to (non-zero) value.
+        @todo: Should assert that nonzero_indices and zero_indices
+        are correct (i.e. are truly nonzero/zero).
+        """
+        (nonzero_indices, zero_indices) = _select_indices(instance)
+        # No update if there aren't any non-zeros.
+        if len(nonzero_indices) == 0: return
+        xnonzero = numpy.asarray([instance[idx] for idx in nonzero_indices])
+        print
+        print "xnonzero:", xnonzero
+
+        (ynonzero, yzero, loss, gw1nonzero, gb1, gw2nonzero, gw2zero, gb2nonzero, gb2zero) = trainfn(xnonzero, self.parameters.w1[nonzero_indices, :], self.parameters.b1, self.parameters.w2[:, nonzero_indices], self.parameters.w2[:, zero_indices], self.parameters.b2[nonzero_indices], self.parameters.b2[zero_indices])
+        print "OLD ynonzero:", ynonzero
+        print "OLD yzero:", yzero
+        print "OLD total loss:", loss
+
+        # SGD update
+        self.parameters.w1[nonzero_indices, :]  -= LR * gw1nonzero
+        self.parameters.b1						-= LR * gb1
+        self.parameters.w2[:, nonzero_indices]  -= LR * gw2nonzero
+        self.parameters.w2[:, zero_indices]		-= LR * gw2zero
+        self.parameters.b2[nonzero_indices]		-= LR * gb2nonzero
+        self.parameters.b2[zero_indices]		-= LR * gb2zero
+
+        # Recompute the loss, to make sure it's descreasing
+        (ynonzero, yzero, loss, gw1nonzero, gb1, gw2nonzero, gw2zero, gb2nonzero, gb2zero) = trainfn(xnonzero, self.parameters.w1[nonzero_indices, :], self.parameters.b1, self.parameters.w2[:, nonzero_indices], self.parameters.w2[:, zero_indices], self.parameters.b2[nonzero_indices], self.parameters.b2[zero_indices])
+        print "NEW ynonzero:", ynonzero
+        print "NEW yzero:", yzero
+        print "NEW total loss:", loss
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/sandbox/sparse_random_autoassociator/parameters.py	Thu Jul 10 09:03:11 2008 -0400
@@ -0,0 +1,28 @@
+"""
+Parameters (weights) used by the L{Model}.
+"""
+
+import numpy
+import globals
+
+class Parameters:
+    """
+    Parameters used by the L{Model}.
+    """
+    def __init__(self, input_dimension=globals.INPUT_DIMENSION, hidden_dimension=globals.HIDDEN_DIMENSION, randomly_initialize=False, seed=globals.SEED):
+        """
+        Initialize L{Model} parameters.
+        @param randomly_initialize: If True, then randomly initialize
+        according to the given seed. If False, then just use zeroes.
+        """
+        if randomly_initialize:
+            numpy.random.seed(seed)
+            self.w1 = (numpy.random.rand(input_dimension, hidden_dimension)-0.5)/input_dimension
+            self.w2 = (numpy.random.rand(hidden_dimension, input_dimension)-0.5)/hidden_dimension
+            self.b1 = numpy.zeros(hidden_dimension)
+            self.b2 = numpy.zeros(input_dimension)
+        else:
+            self.w1 = numpy.zeros((input_dimension, hidden_dimension))
+            self.w2 = numpy.zeros((hidden_dimension, input_dimension))
+            self.b1 = numpy.zeros(hidden_dimension)
+            self.b2 = numpy.zeros(input_dimension)
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/sparse_instance.py	Thu Jul 10 09:03:11 2008 -0400
@@ -0,0 +1,22 @@
+"""
+Sparse instances.
+Each instance is represented as dict with key dimension.
+Dimensions not present in the dict have value 0.
+"""
+
+from numpy import zeros
+
+def to_vector(instances, dimensions):
+    """
+    Convert sparse instances to vectors.
+    @type instances: list of sparse instances
+    @param dimensions: The number of dimensions in each instance.
+    @rtype: numpy matrix (instances x dimensions)
+    @todo: Allow this function to convert SINGLE instances (not lists).
+    """
+    v = zeros((len(instances), dimensions))
+    l = len(instances)
+    for i in range(l):
+        for idx in instances[i].keys():
+            v[i][idx] = instances[i][idx]
+    return v
--- a/sparse_random_autoassociator/globals.py	Wed Jul 09 16:55:27 2008 -0400
+++ /dev/null	Thu Jan 01 00:00:00 1970 +0000
@@ -1,12 +0,0 @@
-"""
-Global variables.
-"""
-
-INPUT_DIMENSION = 20
-HIDDEN_DIMENSION = 5
-LEARNING_RATE = 0.1
-LR = LEARNING_RATE
-SEED = 666
-ZERO_SAMPLE_SIZE = 5
-MARGIN = 0.1
-#MARGIN = 0.0
--- a/sparse_random_autoassociator/graph.py	Wed Jul 09 16:55:27 2008 -0400
+++ /dev/null	Thu Jan 01 00:00:00 1970 +0000
@@ -1,40 +0,0 @@
-"""
-Theano graph for an autoassociator for sparse inputs, which will be trained
-using Ronan Collobert + Jason Weston's sampling trick (2008).
-@todo: Make nearly everything private.
-"""
-
-from globals import MARGIN
-
-from pylearn.nnet_ops import sigmoid, crossentropy_softmax_1hot
-from theano import tensor as t
-from theano.tensor import dot
-xnonzero    = t.dvector()
-w1nonzero   = t.dmatrix()
-b1          = t.dvector()
-w2nonzero   = t.dmatrix()
-w2zero      = t.dmatrix()
-b2nonzero   = t.dvector()
-b2zero      = t.dvector()
-h           = sigmoid(dot(xnonzero, w1nonzero) + b1)
-ynonzero    = sigmoid(dot(h, w2nonzero) + b2nonzero)
-yzero       = sigmoid(dot(h, w2zero) + b2zero)
-
-# May want to weight loss wrt nonzero value? e.g. MARGIN violation for
-# 0.1 nonzero is not as bad as MARGIN violation for 0.2 nonzero.
-def hingeloss(MARGIN):
-    return -MARGIN * (MARGIN < 0)
-nonzeroloss = hingeloss(ynonzero - t.max(yzero) - MARGIN)
-zeroloss = hingeloss(-t.max(-(ynonzero)) - yzero - MARGIN)
-# xnonzero sensitive loss:
-#nonzeroloss = hingeloss(ynonzero - t.max(yzero) - MARGIN - xnonzero)
-#zeroloss = hingeloss(-t.max(-(ynonzero - xnonzero)) - yzero - MARGIN)
-loss = t.sum(nonzeroloss) + t.sum(zeroloss)
-
-(gw1nonzero, gb1, gw2nonzero, gw2zero, gb2nonzero, gb2zero) = t.grad(loss, [w1nonzero, b1, w2nonzero, w2zero, b2nonzero, b2zero])
-
-import theano.compile
-
-inputs  = [xnonzero, w1nonzero, b1, w2nonzero, w2zero, b2nonzero, b2zero]
-outputs = [ynonzero, yzero, loss, gw1nonzero, gb1, gw2nonzero, gw2zero, gb2nonzero, gb2zero]
-trainfn = theano.compile.function(inputs, outputs)
--- a/sparse_random_autoassociator/main.py	Wed Jul 09 16:55:27 2008 -0400
+++ /dev/null	Thu Jan 01 00:00:00 1970 +0000
@@ -1,46 +0,0 @@
-#!/usr/bin/python
-"""
-    An autoassociator for sparse inputs, using Ronan Collobert + Jason
-    Weston's sampling trick (2008).
-
-    The learned model is::
-       h   = sigmoid(dot(x, w1) + b1)
-       y   = sigmoid(dot(h, w2) + b2)
-
-    We assume that most of the inputs are zero, and hence that
-    we can separate x into xnonzero, x's nonzero components, and
-    xzero, a sample of the zeros. We sample---randomly without
-    replacement---ZERO_SAMPLE_SIZE zero columns from x.
-
-    The desideratum is that every nonzero entry is separated from every
-    zero entry by margin at least MARGIN.
-    For each ynonzero, we want it to exceed max(yzero) by at least MARGIN.
-    For each yzero, we want it to be exceed by min(ynonzero) by at least MARGIN.
-    The loss is a hinge loss (linear). The loss is irrespective of the
-    xnonzero magnitude (this may be a limitation). Hence, all nonzeroes
-    are equally important to exceed the maximum yzero.
-
-    LIMITATIONS:
-       - Only does pure stochastic gradient (batchsize = 1).
-       - Loss is irrespective of the xnonzero magnitude.
-       - We will always use all nonzero entries, even if the training
-       instance is very non-sparse.
-"""
-
-
-import numpy
-
-nonzero_instances = []
-nonzero_instances.append({1: 0.1, 5: 0.5, 9: 1})
-nonzero_instances.append({2: 0.3, 5: 0.5, 8: 0.8})
-nonzero_instances.append({1: 0.2, 2: 0.3, 5: 0.5})
-
-import model
-model = model.Model()
-
-for i in xrange(100000):
-    # Select an instance
-    instance = nonzero_instances[i % len(nonzero_instances)]
-
-    # SGD update over instance
-    model.update(instance)
--- a/sparse_random_autoassociator/model.py	Wed Jul 09 16:55:27 2008 -0400
+++ /dev/null	Thu Jan 01 00:00:00 1970 +0000
@@ -1,76 +0,0 @@
-"""
-The model for an autoassociator for sparse inputs, using Ronan Collobert + Jason
-Weston's sampling trick (2008).
-"""
-
-from graph import trainfn
-import parameters
-
-import globals
-from globals import LR
-
-import numpy
-import random
-random.seed(globals.SEED)
-
-def _select_indices(instance):
-    """
-    Choose nonzero and zero indices (feature columns) of the instance.
-    We select B{all} nonzero indices.
-    We select L{globals.ZERO_SAMPLE_SIZE} zero indices randomly,
-    without replacement.
-    @bug: If there are not ZERO_SAMPLE_SIZE zeroes, we will enter
-    an endless loop.
-    @return: (nonzero_indices, zero_indices)
-    """
-    # Get the nonzero indices
-    nonzero_indices = instance.keys()
-    nonzero_indices.sort()
-
-    # Get the zero indices
-    # @bug: If there are not ZERO_SAMPLE_SIZE zeroes, we will enter an endless loop.
-    zero_indices = []
-    while len(zero_indices) < globals.ZERO_SAMPLE_SIZE:
-        idx = random.randint(0, globals.INPUT_DIMENSION - 1)
-        if idx in nonzero_indices or idx in zero_indices: continue
-        zero_indices.append(idx)
-    zero_indices.sort()
-
-    return (nonzero_indices, zero_indices)
-
-class Model:
-    def __init__(self):
-        self.parameters = parameters.Parameters(randomly_initialize=True)
-
-    def update(self, instance):
-        """
-        Update the L{Model} using one training instance.
-        @param instance: A dict from feature index to (non-zero) value.
-        @todo: Should assert that nonzero_indices and zero_indices
-        are correct (i.e. are truly nonzero/zero).
-        """
-        (nonzero_indices, zero_indices) = _select_indices(instance)
-        # No update if there aren't any non-zeros.
-        if len(nonzero_indices) == 0: return
-        xnonzero = numpy.asarray([instance[idx] for idx in nonzero_indices])
-        print
-        print "xnonzero:", xnonzero
-
-        (ynonzero, yzero, loss, gw1nonzero, gb1, gw2nonzero, gw2zero, gb2nonzero, gb2zero) = trainfn(xnonzero, self.parameters.w1[nonzero_indices, :], self.parameters.b1, self.parameters.w2[:, nonzero_indices], self.parameters.w2[:, zero_indices], self.parameters.b2[nonzero_indices], self.parameters.b2[zero_indices])
-        print "OLD ynonzero:", ynonzero
-        print "OLD yzero:", yzero
-        print "OLD total loss:", loss
-
-        # SGD update
-        self.parameters.w1[nonzero_indices, :]  -= LR * gw1nonzero
-        self.parameters.b1						-= LR * gb1
-        self.parameters.w2[:, nonzero_indices]  -= LR * gw2nonzero
-        self.parameters.w2[:, zero_indices]		-= LR * gw2zero
-        self.parameters.b2[nonzero_indices]		-= LR * gb2nonzero
-        self.parameters.b2[zero_indices]		-= LR * gb2zero
-
-        # Recompute the loss, to make sure it's descreasing
-        (ynonzero, yzero, loss, gw1nonzero, gb1, gw2nonzero, gw2zero, gb2nonzero, gb2zero) = trainfn(xnonzero, self.parameters.w1[nonzero_indices, :], self.parameters.b1, self.parameters.w2[:, nonzero_indices], self.parameters.w2[:, zero_indices], self.parameters.b2[nonzero_indices], self.parameters.b2[zero_indices])
-        print "NEW ynonzero:", ynonzero
-        print "NEW yzero:", yzero
-        print "NEW total loss:", loss
--- a/sparse_random_autoassociator/parameters.py	Wed Jul 09 16:55:27 2008 -0400
+++ /dev/null	Thu Jan 01 00:00:00 1970 +0000
@@ -1,28 +0,0 @@
-"""
-Parameters (weights) used by the L{Model}.
-"""
-
-import numpy
-import globals
-
-class Parameters:
-    """
-    Parameters used by the L{Model}.
-    """
-    def __init__(self, input_dimension=globals.INPUT_DIMENSION, hidden_dimension=globals.HIDDEN_DIMENSION, randomly_initialize=False, seed=globals.SEED):
-        """
-        Initialize L{Model} parameters.
-        @param randomly_initialize: If True, then randomly initialize
-        according to the given seed. If False, then just use zeroes.
-        """
-        if randomly_initialize:
-            numpy.random.seed(seed)
-            self.w1 = (numpy.random.rand(input_dimension, hidden_dimension)-0.5)/input_dimension
-            self.w2 = (numpy.random.rand(hidden_dimension, input_dimension)-0.5)/hidden_dimension
-            self.b1 = numpy.zeros(hidden_dimension)
-            self.b2 = numpy.zeros(input_dimension)
-        else:
-            self.w1 = numpy.zeros((input_dimension, hidden_dimension))
-            self.w2 = numpy.zeros((hidden_dimension, input_dimension))
-            self.b1 = numpy.zeros(hidden_dimension)
-            self.b2 = numpy.zeros(input_dimension)