comparison transformations/pipeline.py @ 15:f6b6c74bb82f

Fix the datatypes.
author Arnaud Bergeron <abergeron@gmail.com>
date Thu, 28 Jan 2010 12:31:21 -0500
parents faacc76d21c2
children fdb0e0870fb4
comparison
equal deleted inserted replaced
14:ebf61603489b 15:f6b6c74bb82f
12 mods = [] 12 mods = []
13 13
14 # DANGER: HIGH VOLTAGE -- DO NOT EDIT BELOW THIS LINE 14 # DANGER: HIGH VOLTAGE -- DO NOT EDIT BELOW THIS LINE
15 # ----------------------------------------------------------- 15 # -----------------------------------------------------------
16 16
17 train_data = open('/data/lisa/data/nist/by_class/all/all_train_data.ft', 'rb') 17 outf = sys.argv[1]
18 paramsf = sys.argv[2]
19 dataf = '/data/lisa/data/nist/by_class/all/all_train_data.ft'
20 if len(sys.argv) >= 4:
21 dataf = sys.argv[3]
22
23 train_data = open(dataf, 'rb')
18 24
19 dim = tuple(ft._read_header(train_data)[3]) 25 dim = tuple(ft._read_header(train_data)[3])
20 26
21 res_data = numpy.empty(dim) 27 res_data = numpy.empty(dim, dtype=numpy.int8)
22 28
23 all_settings = ['complexity'] 29 all_settings = ['complexity']
24 30
25 for mod in mods: 31 for mod in mods:
26 all_settings += mod.get_settings_names() 32 all_settings += mod.get_settings_names()
27 33
28 params = numpy.empty(((dim[0]/BATCH_SIZE)+1, len(all_settings))) 34 params = numpy.empty(((dim[0]/BATCH_SIZE)+1, len(all_settings)))
29 35
30 for i in xrange(0, dim[0], BATCH_SIZE): 36 for i in xrange(0, dim[0], BATCH_SIZE):
31 train_data.seek(0) 37 train_data.seek(0)
32 imgs = ft.read(train_data, slice(i, i+BATCH_SIZE)) 38 imgs = ft.read(train_data, slice(i, i+BATCH_SIZE)).astype(numpy.float32)/255
33 39
34 complexity = random.random() 40 complexity = random.random()
35 p = i/BATCH_SIZE 41 p = i/BATCH_SIZE
36 j = 1 42 j = 1
37 for mod in mods: 43 for mod in mods:
38 par = mod.regenerate_parameters(complexity) 44 par = mod.regenerate_parameters(complexity)
39 params[p, j:j+len(par)] = par 45 params[p, j:j+len(par)] = par
40 j += len(par) 46 j += len(par)
41 47
42 for k in range(imgs.shape[0]): 48 for k in range(imgs.shape[0]):
43 c = imgs[k] 49 c = imgs[k].reshape((32, 32))
44 for mod in mods: 50 for mod in mods:
45 c = mod.transform_image(c) 51 c = mod.transform_image(c)
46 res_data[i+k] = c 52 res_data[i+k] = c.reshape((1024,))*255
47 53
48 with open(sys.argv[1], 'wb') as f: 54 with open(outf, 'wb') as f:
49 ft.write(f, res_data) 55 ft.write(f, res_data)
50 56
51 numpy.save(sys.argv[2], params) 57 numpy.save(paramsf, params)