comparison data_generation/pipeline/filter_nist.py @ 633:13baba8a4522

merge
author Yoshua Bengio <bengioy@iro.umontreal.ca>
date Sat, 19 Mar 2011 22:51:40 -0400
parents 75dbbe409578
children
comparison
equal deleted inserted replaced
632:5541056d3fb0 633:13baba8a4522
1 import numpy
2 from pylearn.io import filetensor as ft
3 from ift6266 import datasets
4 from ift6266.datasets.ftfile import FTDataSet
5
6 dataset_str = 'P07_' # NISTP # 'P07safe_'
7
8 #base_path = '/data/lisatmp/ift6266h10/data/'+dataset_str
9 #base_output_path = '/data/lisatmp/ift6266h10/data/transformed_digits/'+dataset_str+'train'
10
11 base_path = '/data/lisa/data/ift6266h10/data/'+dataset_str
12 base_output_path = '/data/lisatmp/ift6266h10/data/transformed_digits/'+dataset_str+'train'
13
14 for fileno in range(100):
15 print "Processing file no ", fileno
16
17 output_data_file = base_output_path+str(fileno)+'_data.ft'
18 output_labels_file = base_output_path+str(fileno)+'_labels.ft'
19
20 print "Reading from ",base_path+'train'+str(fileno)+'_data.ft'
21
22 dataset = lambda maxsize=None, min_file=0, max_file=100: \
23 FTDataSet(train_data = [base_path+'train'+str(fileno)+'_data.ft'],
24 train_lbl = [base_path+'train'+str(fileno)+'_labels.ft'],
25 test_data = [base_path+'_test_data.ft'],
26 test_lbl = [base_path+'_test_labels.ft'],
27 valid_data = [base_path+'_valid_data.ft'],
28 valid_lbl = [base_path+'_valid_labels.ft'])
29 # no conversion or scaling... keep data as is
30 #indtype=theano.config.floatX, inscale=255., maxsize=maxsize)
31
32 ds = dataset()
33
34 all_x = []
35 all_y = []
36
37 all_count = 0
38
39 for mb_x,mb_y in ds.train(1):
40 if mb_y[0] <= 9:
41 all_x.append(mb_x[0])
42 all_y.append(mb_y[0])
43
44 if (all_count+1) % 100000 == 0:
45 print "Done next 100k"
46
47 all_count += 1
48
49 # data is stored as uint8 on 0-255
50 merged_x = numpy.asarray(all_x, dtype=numpy.uint8)
51 merged_y = numpy.asarray(all_y, dtype=numpy.int32)
52
53 print "Kept", len(all_x), "(shape ", merged_x.shape, ") examples from", all_count
54
55 f = open(output_data_file, 'wb')
56 ft.write(f, merged_x)
57 f.close()
58
59 f = open(output_labels_file, 'wb')
60 ft.write(f, merged_y)
61 f.close()
62