Mercurial > ift6266
diff transformations/pipeline.py @ 15:f6b6c74bb82f
Fix the datatypes.
author | Arnaud Bergeron <abergeron@gmail.com> |
---|---|
date | Thu, 28 Jan 2010 12:31:21 -0500 |
parents | faacc76d21c2 |
children | fdb0e0870fb4 |
line wrap: on
line diff
--- a/transformations/pipeline.py Thu Jan 28 11:50:01 2010 -0500 +++ b/transformations/pipeline.py Thu Jan 28 12:31:21 2010 -0500 @@ -14,11 +14,17 @@ # DANGER: HIGH VOLTAGE -- DO NOT EDIT BELOW THIS LINE # ----------------------------------------------------------- -train_data = open('/data/lisa/data/nist/by_class/all/all_train_data.ft', 'rb') +outf = sys.argv[1] +paramsf = sys.argv[2] +dataf = '/data/lisa/data/nist/by_class/all/all_train_data.ft' +if len(sys.argv) >= 4: + dataf = sys.argv[3] + +train_data = open(dataf, 'rb') dim = tuple(ft._read_header(train_data)[3]) -res_data = numpy.empty(dim) +res_data = numpy.empty(dim, dtype=numpy.int8) all_settings = ['complexity'] @@ -29,7 +35,7 @@ for i in xrange(0, dim[0], BATCH_SIZE): train_data.seek(0) - imgs = ft.read(train_data, slice(i, i+BATCH_SIZE)) + imgs = ft.read(train_data, slice(i, i+BATCH_SIZE)).astype(numpy.float32)/255 complexity = random.random() p = i/BATCH_SIZE @@ -40,12 +46,12 @@ j += len(par) for k in range(imgs.shape[0]): - c = imgs[k] + c = imgs[k].reshape((32, 32)) for mod in mods: c = mod.transform_image(c) - res_data[i+k] = c + res_data[i+k] = c.reshape((1024,))*255 -with open(sys.argv[1], 'wb') as f: +with open(outf, 'wb') as f: ft.write(f, res_data) -numpy.save(sys.argv[2], params) +numpy.save(paramsf, params)