diff transformations/pipeline.py @ 15:f6b6c74bb82f

Fix the datatypes.
author Arnaud Bergeron <abergeron@gmail.com>
date Thu, 28 Jan 2010 12:31:21 -0500
parents faacc76d21c2
children fdb0e0870fb4
line wrap: on
line diff
--- a/transformations/pipeline.py	Thu Jan 28 11:50:01 2010 -0500
+++ b/transformations/pipeline.py	Thu Jan 28 12:31:21 2010 -0500
@@ -14,11 +14,17 @@
 # DANGER: HIGH VOLTAGE -- DO NOT EDIT BELOW THIS LINE
 # -----------------------------------------------------------
 
-train_data = open('/data/lisa/data/nist/by_class/all/all_train_data.ft', 'rb')
+outf = sys.argv[1]
+paramsf = sys.argv[2]
+dataf = '/data/lisa/data/nist/by_class/all/all_train_data.ft'
+if len(sys.argv) >= 4:
+    dataf = sys.argv[3]
+
+train_data = open(dataf, 'rb')
 
 dim = tuple(ft._read_header(train_data)[3])
 
-res_data = numpy.empty(dim)
+res_data = numpy.empty(dim, dtype=numpy.int8)
 
 all_settings = ['complexity']
 
@@ -29,7 +35,7 @@
 
 for i in xrange(0, dim[0], BATCH_SIZE):
     train_data.seek(0)
-    imgs = ft.read(train_data, slice(i, i+BATCH_SIZE))
+    imgs = ft.read(train_data, slice(i, i+BATCH_SIZE)).astype(numpy.float32)/255
     
     complexity = random.random()
     p = i/BATCH_SIZE
@@ -40,12 +46,12 @@
         j += len(par)
     
     for k in range(imgs.shape[0]):
-        c = imgs[k]
+        c = imgs[k].reshape((32, 32))
         for mod in mods:
             c = mod.transform_image(c)
-        res_data[i+k] = c
+        res_data[i+k] = c.reshape((1024,))*255
 
-with open(sys.argv[1], 'wb') as f:
+with open(outf, 'wb') as f:
     ft.write(f, res_data)
 
-numpy.save(sys.argv[2], params)
+numpy.save(paramsf, params)