Mercurial > ift6266
comparison data_generation/pipeline/filter_nist.py @ 633:13baba8a4522
merge
author | Yoshua Bengio <bengioy@iro.umontreal.ca> |
---|---|
date | Sat, 19 Mar 2011 22:51:40 -0400 |
parents | 75dbbe409578 |
children |
comparison
equal
deleted
inserted
replaced
632:5541056d3fb0 | 633:13baba8a4522 |
---|---|
1 import numpy | |
2 from pylearn.io import filetensor as ft | |
3 from ift6266 import datasets | |
4 from ift6266.datasets.ftfile import FTDataSet | |
5 | |
6 dataset_str = 'P07_' # NISTP # 'P07safe_' | |
7 | |
8 #base_path = '/data/lisatmp/ift6266h10/data/'+dataset_str | |
9 #base_output_path = '/data/lisatmp/ift6266h10/data/transformed_digits/'+dataset_str+'train' | |
10 | |
11 base_path = '/data/lisa/data/ift6266h10/data/'+dataset_str | |
12 base_output_path = '/data/lisatmp/ift6266h10/data/transformed_digits/'+dataset_str+'train' | |
13 | |
14 for fileno in range(100): | |
15 print "Processing file no ", fileno | |
16 | |
17 output_data_file = base_output_path+str(fileno)+'_data.ft' | |
18 output_labels_file = base_output_path+str(fileno)+'_labels.ft' | |
19 | |
20 print "Reading from ",base_path+'train'+str(fileno)+'_data.ft' | |
21 | |
22 dataset = lambda maxsize=None, min_file=0, max_file=100: \ | |
23 FTDataSet(train_data = [base_path+'train'+str(fileno)+'_data.ft'], | |
24 train_lbl = [base_path+'train'+str(fileno)+'_labels.ft'], | |
25 test_data = [base_path+'_test_data.ft'], | |
26 test_lbl = [base_path+'_test_labels.ft'], | |
27 valid_data = [base_path+'_valid_data.ft'], | |
28 valid_lbl = [base_path+'_valid_labels.ft']) | |
29 # no conversion or scaling... keep data as is | |
30 #indtype=theano.config.floatX, inscale=255., maxsize=maxsize) | |
31 | |
32 ds = dataset() | |
33 | |
34 all_x = [] | |
35 all_y = [] | |
36 | |
37 all_count = 0 | |
38 | |
39 for mb_x,mb_y in ds.train(1): | |
40 if mb_y[0] <= 9: | |
41 all_x.append(mb_x[0]) | |
42 all_y.append(mb_y[0]) | |
43 | |
44 if (all_count+1) % 100000 == 0: | |
45 print "Done next 100k" | |
46 | |
47 all_count += 1 | |
48 | |
49 # data is stored as uint8 on 0-255 | |
50 merged_x = numpy.asarray(all_x, dtype=numpy.uint8) | |
51 merged_y = numpy.asarray(all_y, dtype=numpy.int32) | |
52 | |
53 print "Kept", len(all_x), "(shape ", merged_x.shape, ") examples from", all_count | |
54 | |
55 f = open(output_data_file, 'wb') | |
56 ft.write(f, merged_x) | |
57 f.close() | |
58 | |
59 f = open(output_labels_file, 'wb') | |
60 ft.write(f, merged_y) | |
61 f.close() | |
62 |