view transformations/pipeline.py @ 15:f6b6c74bb82f

Fix the datatypes.
author Arnaud Bergeron <abergeron@gmail.com>
date Thu, 28 Jan 2010 12:31:21 -0500
parents faacc76d21c2
children fdb0e0870fb4
line wrap: on
line source

from __future__ import with_statement

import sys, os
import numpy
import filetensor as ft
import random

BATCH_SIZE = 100

#import <modules> and stuff them in mods below

mods = []

# DANGER: HIGH VOLTAGE -- DO NOT EDIT BELOW THIS LINE
# -----------------------------------------------------------

outf = sys.argv[1]
paramsf = sys.argv[2]
dataf = '/data/lisa/data/nist/by_class/all/all_train_data.ft'
if len(sys.argv) >= 4:
    dataf = sys.argv[3]

train_data = open(dataf, 'rb')

dim = tuple(ft._read_header(train_data)[3])

res_data = numpy.empty(dim, dtype=numpy.int8)

all_settings = ['complexity']

for mod in mods:
    all_settings += mod.get_settings_names()

params = numpy.empty(((dim[0]/BATCH_SIZE)+1, len(all_settings)))

for i in xrange(0, dim[0], BATCH_SIZE):
    train_data.seek(0)
    imgs = ft.read(train_data, slice(i, i+BATCH_SIZE)).astype(numpy.float32)/255
    
    complexity = random.random()
    p = i/BATCH_SIZE
    j = 1
    for mod in mods:
        par = mod.regenerate_parameters(complexity)
        params[p, j:j+len(par)] = par
        j += len(par)
    
    for k in range(imgs.shape[0]):
        c = imgs[k].reshape((32, 32))
        for mod in mods:
            c = mod.transform_image(c)
        res_data[i+k] = c.reshape((1024,))*255

with open(outf, 'wb') as f:
    ft.write(f, res_data)

numpy.save(paramsf, params)