view scripts/stacked_dae/nist_sda.py @ 135:36010ff90255

Added script to test truetype font files validity (corruption)
author boulanni <nicolas_boulanger@hotmail.com>
date Sat, 20 Feb 2010 02:10:30 -0500
parents 5c79a2557f2f
children 7d8366fb90bf
line wrap: on
line source

#!/usr/bin/python
# coding: utf-8

import numpy 
import theano
import time
import theano.tensor as T
from theano.tensor.shared_randomstreams import RandomStreams

import os.path

from sgd_optimization import sgd_optimization

from jobman import DD
from pylearn.io import filetensor

from utils import produit_croise_jobs

NIST_ALL_LOCATION = '/data/lisa/data/nist/by_class/all'

# Just useful for tests... minimal number of epochs
DEFAULT_HP_NIST = DD({'finetuning_lr':0.1,
                       'pretraining_lr':0.1,
                       'pretraining_epochs_per_layer':1,
                       'max_finetuning_epochs':1,
                       'hidden_layers_sizes':[1000,1000],
                       'corruption_levels':[0.2,0.2],
                       'minibatch_size':20})

def jobman_entrypoint_nist(state, channel):
    sgd_optimization_nist(state)

def jobman_insert_nist():
    vals = {'finetuning_lr': [0.00001, 0.0001, 0.001, 0.01, 0.1],
            'pretraining_lr': [0.00001, 0.0001, 0.001, 0.01, 0.1],
            'pretraining_epochs_per_layer': [2,5,20],
            'hidden_layer_sizes': [100,300,1000],
            'num_hidden_layers':[1,2,3],
            'corruption_levels': [0.1,0.2,0.4],
            'minibatch_size': [5,20,100]}

    jobs = produit_croise_jobs(vals)

    for job in jobs:
        insert_job(job)


class NIST:
    def __init__(self, minibatch_size, basepath=None):
        global NIST_ALL_LOCATION

        self.minibatch_size = minibatch_size
        self.basepath = basepath and basepath or NIST_ALL_LOCATION

        self.set_filenames()

        # arrays of 2 elements: .x, .y
        self.train = [None, None]
        self.test = [None, None]

        self.load_train_test()

        self.valid = [[], []]
        #self.split_train_valid()


    def get_tvt(self):
        return self.train, self.valid, self.test

    def set_filenames(self):
        self.train_files = ['all_train_data.ft',
                                'all_train_labels.ft']

        self.test_files = ['all_test_data.ft',
                            'all_test_labels.ft']

    def load_train_test(self):
        self.load_data_labels(self.train_files, self.train)
        self.load_data_labels(self.test_files, self.test)

    def load_data_labels(self, filenames, pair):
        for i, fn in enumerate(filenames):
            f = open(os.path.join(self.basepath, fn))
            pair[i] = filetensor.read(f)
            f.close()

    def split_train_valid(self):
        test_len = len(self.test[0])
        
        new_train_x = self.train[0][:-test_len]
        new_train_y = self.train[1][:-test_len]

        self.valid[0] = self.train[0][-test_len:]
        self.valid[1] = self.train[1][-test_len:]

        self.train[0] = new_train_x
        self.train[1] = new_train_y

def test_load_nist():
    print "Will load NIST"

    import time
    t1 = time.time()
    nist = NIST(20)
    t2 = time.time()

    print "NIST loaded. time delta = ", t2-t1

    tr,v,te = nist.get_tvt()

    print "Lenghts: ", len(tr[0]), len(v[0]), len(te[0])

    raw_input("Press any key")

# hp for hyperparameters
def sgd_optimization_nist(hp=None, dataset_dir='/data/lisa/data/nist'):
    global DEFAULT_HP_NIST
    hp = hp and hp or DEFAULT_HP_NIST

    print "Will load NIST"

    import time
    t1 = time.time()
    nist = NIST(20)
    t2 = time.time()

    print "NIST loaded. time delta = ", t2-t1

    train,valid,test = nist.get_tvt()
    dataset = (train,valid,test)

    print "Lenghts train, valid, test: ", len(train[0]), len(valid[0]), len(test[0])

    n_ins = 32*32
    n_outs = 62 # 10 digits, 26*2 (lower, capitals)

    sgd_optimization(dataset, hp, n_ins, n_outs)

if __name__ == '__main__':

    import sys

    args = sys.argv[1:]

    if len(args) > 0 and args[0] == 'load_nist':
        test_load_nist()

    else:
        sgd_optimization_nist()