diff data_generation/pipeline/filter_nist.py @ 633:13baba8a4522

merge
author Yoshua Bengio <bengioy@iro.umontreal.ca>
date Sat, 19 Mar 2011 22:51:40 -0400
parents 75dbbe409578
children
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/data_generation/pipeline/filter_nist.py	Sat Mar 19 22:51:40 2011 -0400
@@ -0,0 +1,62 @@
+import numpy
+from pylearn.io import filetensor as ft
+from ift6266 import datasets
+from ift6266.datasets.ftfile import FTDataSet
+
+dataset_str = 'P07_' # NISTP # 'P07safe_' 
+
+#base_path = '/data/lisatmp/ift6266h10/data/'+dataset_str
+#base_output_path = '/data/lisatmp/ift6266h10/data/transformed_digits/'+dataset_str+'train'
+
+base_path = '/data/lisa/data/ift6266h10/data/'+dataset_str
+base_output_path = '/data/lisatmp/ift6266h10/data/transformed_digits/'+dataset_str+'train'
+
+for fileno in range(100):
+    print "Processing file no ", fileno
+
+    output_data_file = base_output_path+str(fileno)+'_data.ft'
+    output_labels_file = base_output_path+str(fileno)+'_labels.ft'
+
+    print "Reading from ",base_path+'train'+str(fileno)+'_data.ft'
+
+    dataset = lambda maxsize=None, min_file=0, max_file=100: \
+                    FTDataSet(train_data = [base_path+'train'+str(fileno)+'_data.ft'],
+                       train_lbl = [base_path+'train'+str(fileno)+'_labels.ft'],
+                       test_data = [base_path+'_test_data.ft'],
+                       test_lbl = [base_path+'_test_labels.ft'],
+                       valid_data = [base_path+'_valid_data.ft'],
+                       valid_lbl = [base_path+'_valid_labels.ft'])
+                       # no conversion or scaling... keep data as is
+                       #indtype=theano.config.floatX, inscale=255., maxsize=maxsize)
+
+    ds = dataset()
+
+    all_x = []
+    all_y = []
+
+    all_count = 0
+
+    for mb_x,mb_y in ds.train(1):
+        if mb_y[0] <= 9:
+            all_x.append(mb_x[0])
+            all_y.append(mb_y[0])
+
+        if (all_count+1) % 100000 == 0:
+            print "Done next 100k"
+
+        all_count += 1
+   
+    # data is stored as uint8 on 0-255
+    merged_x = numpy.asarray(all_x, dtype=numpy.uint8)
+    merged_y = numpy.asarray(all_y, dtype=numpy.int32)
+
+    print "Kept", len(all_x), "(shape ", merged_x.shape, ") examples from", all_count
+
+    f = open(output_data_file, 'wb')
+    ft.write(f, merged_x)
+    f.close()
+
+    f = open(output_labels_file, 'wb')
+    ft.write(f, merged_y)
+    f.close()
+