changeset 262:716c99f4eb3a

merge
author Xavier Glorot <glorotxa@iro.umontreal.ca>
date Wed, 17 Mar 2010 16:41:51 -0400
parents 6d16a2bf142b (current diff) 0c0f0b3f6a93 (diff)
children a0264184684e 7800be7bce66
files
diffstat 4 files changed, 120 insertions(+), 122 deletions(-) [+]
line wrap: on
line diff
--- a/datasets/defs.py	Wed Mar 17 16:41:16 2010 -0400
+++ b/datasets/defs.py	Wed Mar 17 16:41:51 2010 -0400
@@ -11,44 +11,45 @@
 NIST_PATH = os.getenv('NIST_PATH','/data/lisa/data/nist/by_class/')
 DATA_PATH = os.getenv('DATA_PATH','/data/lisa/data/ift6266h10/')
 
-nist_digits = FTDataSet(train_data = [os.path.join(NIST_PATH,'digits/digits_train_data.ft')],
+nist_digits = lambda maxsize=None: FTDataSet(train_data = [os.path.join(NIST_PATH,'digits/digits_train_data.ft')],
                         train_lbl = [os.path.join(NIST_PATH,'digits/digits_train_labels.ft')],
                         test_data = [os.path.join(NIST_PATH,'digits/digits_test_data.ft')],
                         test_lbl = [os.path.join(NIST_PATH,'digits/digits_test_labels.ft')],
-                        indtype=theano.config.floatX, inscale=255.)
-nist_lower = FTDataSet(train_data = [os.path.join(NIST_PATH,'lower/lower_train_data.ft')],
+                        indtype=theano.config.floatX, inscale=255., maxsize=maxsize)
+nist_lower = lambda maxsize=None: FTDataSet(train_data = [os.path.join(NIST_PATH,'lower/lower_train_data.ft')],
                         train_lbl = [os.path.join(NIST_PATH,'lower/lower_train_labels.ft')],
                         test_data = [os.path.join(NIST_PATH,'lower/lower_test_data.ft')],
                         test_lbl = [os.path.join(NIST_PATH,'lower/lower_test_labels.ft')],
-                        indtype=theano.config.floatX, inscale=255.)
-nist_upper = FTDataSet(train_data = [os.path.join(NIST_PATH,'upper/upper_train_data.ft')],
+                        indtype=theano.config.floatX, inscale=255., maxsize=maxsize)
+nist_upper = lambda maxsize=None: FTDataSet(train_data = [os.path.join(NIST_PATH,'upper/upper_train_data.ft')],
                         train_lbl = [os.path.join(NIST_PATH,'upper/upper_train_labels.ft')],
                         test_data = [os.path.join(NIST_PATH,'upper/upper_test_data.ft')],
                         test_lbl = [os.path.join(NIST_PATH,'upper/upper_test_labels.ft')],
-                        indtype=theano.config.floatX, inscale=255.)
+                        indtype=theano.config.floatX, inscale=255., maxsize=maxsize)
 
-nist_all = FTDataSet(train_data = [os.path.join(DATA_PATH,'train_data.ft')],
+nist_all = lambda maxsize=None: FTDataSet(train_data = [os.path.join(DATA_PATH,'train_data.ft')],
                      train_lbl = [os.path.join(DATA_PATH,'train_labels.ft')],
                      test_data = [os.path.join(DATA_PATH,'test_data.ft')],
                      test_lbl = [os.path.join(DATA_PATH,'test_labels.ft')],
                      valid_data = [os.path.join(DATA_PATH,'valid_data.ft')],
                      valid_lbl = [os.path.join(DATA_PATH,'valid_labels.ft')],
-                     indtype=theano.config.floatX, inscale=255.)
+                     indtype=theano.config.floatX, inscale=255., maxsize=maxsize)
 
-ocr = FTDataSet(train_data = [os.path.join(DATA_PATH,'ocr_train_data.ft')],
+ocr = lambda maxsize=None: FTDataSet(train_data = [os.path.join(DATA_PATH,'ocr_train_data.ft')],
                 train_lbl = [os.path.join(DATA_PATH,'ocr_train_labels.ft')],
                 test_data = [os.path.join(DATA_PATH,'ocr_test_data.ft')],
                 test_lbl = [os.path.join(DATA_PATH,'ocr_test_labels.ft')],
                 valid_data = [os.path.join(DATA_PATH,'ocr_valid_data.ft')],
                 valid_lbl = [os.path.join(DATA_PATH,'ocr_valid_labels.ft')],
-                indtype=theano.config.floatX, inscale=255.)
+                indtype=theano.config.floatX, inscale=255., maxsize=maxsize)
 
-nist_P07 = FTDataSet(train_data = [os.path.join(DATA_PATH,'data/P07_train'+str(i)+'_data.ft') for i in range(100)],
+nist_P07 = lambda maxsize=None: FTDataSet(train_data = [os.path.join(DATA_PATH,'data/P07_train'+str(i)+'_data.ft') for i in range(100)],
                      train_lbl = [os.path.join(DATA_PATH,'data/P07_train'+str(i)+'_labels.ft') for i in range(100)],
                      test_data = [os.path.join(DATA_PATH,'data/P07_test_data.ft')],
                      test_lbl = [os.path.join(DATA_PATH,'data/P07_test_labels.ft')],
                      valid_data = [os.path.join(DATA_PATH,'data/P07_valid_data.ft')],
                      valid_lbl = [os.path.join(DATA_PATH,'data/P07_valid_labels.ft')],
-                     indtype=theano.config.floatX, inscale=255.)
+                     indtype=theano.config.floatX, inscale=255., maxsize=maxsize)
 
-mnist = GzpklDataSet(os.path.join(DATA_PATH,'mnist.pkl.gz'))
+mnist = lambda maxsize=None: GzpklDataSet(os.path.join(DATA_PATH,'mnist.pkl.gz'),
+                                          maxsize=maxsize)
--- a/datasets/ftfile.py	Wed Mar 17 16:41:16 2010 -0400
+++ b/datasets/ftfile.py	Wed Mar 17 16:41:51 2010 -0400
@@ -89,57 +89,58 @@
         return res
 
 class FTSource(object):
-    def __init__(self, file, skip=0, size=None, dtype=None, scale=1):
+    def __init__(self, file, skip=0, size=None, maxsize=None, 
+                 dtype=None, scale=1):
         r"""
         Create a data source from a possible subset of a .ft file.
 
         Parameters:
-            `file` (string) -- the filename
-            `skip` (int, optional) -- amount of examples to skip from
-                                      the start of the file.  If
-                                      negative, skips filesize - skip.
-            `size` (int, optional) -- truncates number of examples
-                                      read (after skipping).  If
-                                      negative truncates to 
-                                      filesize - size 
-                                      (also after skipping).
-            `dtype` (dtype, optional) -- convert the data to this
-                                         dtype after reading.
-            `scale` (number, optional) -- scale (that is divide) the
-                                          data by this number (after
-                                          dtype conversion, if any).
+            `file` -- (string) the filename
+            `skip` -- (int, optional) amount of examples to skip from
+                      the start of the file.  If negative, skips
+                      filesize - skip.
+            `size` -- (int, optional) truncates number of examples
+                      read (after skipping).  If negative truncates to
+                      filesize - size (also after skipping).
+            `maxsize` -- (int, optional) the maximum size of the file
+            `dtype` -- (dtype, optional) convert the data to this
+                       dtype after reading.
+            `scale` -- (number, optional) scale (that is divide) the
+                       data by this number (after dtype conversion, if
+                       any).
 
         Tests:
-           >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft')
-           >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', size=1000)
-           >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=10)
-           >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=100, size=120)
+        >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft')
+        >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', size=1000)
+        >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=10)
+        >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=100, size=120)
         """
         self.file = file
         self.skip = skip
         self.size = size
         self.dtype = dtype
         self.scale = scale
+        self.maxsize = maxsize
     
     def open(self):
         r"""
         Returns an FTFile that corresponds to this dataset.
         
         Tests:
-           >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft')
-           >>> f = s.open()
-           >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', size=1)
-           >>> len(s.open().read(2))
-           1
-           >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=57646)
-           >>> s.open().size
-           1000
-           >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=57646, size=1)
-           >>> s.open().size
-           1
-           >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', size=-10)
-           >>> s.open().size
-           58636
+        >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft')
+        >>> f = s.open()
+        >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', size=1)
+        >>> len(s.open().read(2))
+        1
+        >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=57646)
+        >>> s.open().size
+        1000
+        >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=57646, size=1)
+        >>> s.open().size
+        1
+        >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', size=-10)
+        >>> s.open().size
+        58636
         """
         f = FTFile(self.file, scale=self.scale, dtype=self.dtype)
         if self.skip != 0:
@@ -147,19 +148,25 @@
         if self.size is not None and self.size < f.size:
             if self.size < 0:
                 f.size += self.size
+                if f.size < 0:
+                    f.size = 0
             else:
                 f.size = self.size
+        if self.maxsize is not None and f.size > self.maxsize:
+            f.size = self.maxsize
         return f
 
 class FTData(object):
     r"""
     This is a list of FTSources.
     """
-    def __init__(self, datafiles, labelfiles, skip=0, size=None,
+    def __init__(self, datafiles, labelfiles, skip=0, size=None, maxsize=None,
                  inscale=1, indtype=None, outscale=1, outdtype=None):
-        self.inputs = [FTSource(f, skip, size, scale=inscale, dtype=indtype)
+        if maxsize is not None:
+            maxsize /= len(datafiles)
+        self.inputs = [FTSource(f, skip, size, maxsize, scale=inscale, dtype=indtype)
                        for f in  datafiles]
-        self.outputs = [FTSource(f, skip, size, scale=outscale, dtype=outdtype)
+        self.outputs = [FTSource(f, skip, size, maxsize, scale=outscale, dtype=outdtype)
                         for f in labelfiles]
 
     def open_inputs(self):
@@ -170,7 +177,9 @@
     
 
 class FTDataSet(DataSet):
-    def __init__(self, train_data, train_lbl, test_data, test_lbl, valid_data=None, valid_lbl=None, indtype=None, outdtype=None, inscale=1, outscale=1):
+    def __init__(self, train_data, train_lbl, test_data, test_lbl, 
+                 valid_data=None, valid_lbl=None, indtype=None, outdtype=None,
+                 inscale=1, outscale=1, maxsize=None):
         r"""
         Defines a DataSet from a bunch of files.
         
@@ -184,6 +193,7 @@
                                            (optional)
            `indtype`, `outdtype`,  -- see FTSource.__init__()
            `inscale`, `outscale`      (optional)
+           `maxsize` -- maximum size of the set returned
                                                              
 
         If `valid_data` and `valid_labels` are not supplied then a sample
@@ -191,21 +201,26 @@
         set.
         """
         if valid_data is None:
-            total_valid_size = sum(FTFile(td).size for td in test_data)
+            total_valid_size = min(sum(FTFile(td).size for td in test_data), maxsize)
             valid_size = total_valid_size/len(train_data)
             self._train = FTData(train_data, train_lbl, size=-valid_size,
-                    inscale=inscale, outscale=outscale, indtype=indtype,
-                    outdtype=outdtype)
+                                 inscale=inscale, outscale=outscale,
+                                 indtype=indtype, outdtype=outdtype,
+                                 maxsize=maxsize)
             self._valid = FTData(train_data, train_lbl, skip=-valid_size,
-                    inscale=inscale, outscale=outscale, indtype=indtype, 
-                    outdtype=outdtype)
+                                 inscale=inscale, outscale=outscale,
+                                 indtype=indtype, outdtype=outdtype,
+                                 maxsize=maxsize)
         else:
-            self._train = FTData(train_data, train_lbl,inscale=inscale,
-                    outscale=outscale, indtype=indtype, outdtype=outdtype)
-            self._valid = FTData(valid_data, valid_lbl,inscale=inscale,
-                    outscale=outscale, indtype=indtype, outdtype=outdtype)
-        self._test = FTData(test_data, test_lbl,inscale=inscale,
-                outscale=outscale, indtype=indtype, outdtype=outdtype)
+            self._train = FTData(train_data, train_lbl, maxsize=maxsize,
+                                 inscale=inscale, outscale=outscale, 
+                                 indtype=indtype, outdtype=outdtype)
+            self._valid = FTData(valid_data, valid_lbl, maxsize=maxsize,
+                                 inscale=inscale, outscale=outscale,
+                                 indtype=indtype, outdtype=outdtype)
+        self._test = FTData(test_data, test_lbl, maxsize=maxsize,
+                            inscale=inscale, outscale=outscale,
+                            indtype=indtype, outdtype=outdtype)
 
     def _return_it(self, batchsize, bufsize, ftdata):
         return izip(DataIterator(ftdata.open_inputs(), batchsize, bufsize),
--- a/datasets/gzpklfile.py	Wed Mar 17 16:41:16 2010 -0400
+++ b/datasets/gzpklfile.py	Wed Mar 17 16:41:51 2010 -0400
@@ -19,8 +19,9 @@
         return res
 
 class GzpklDataSet(DataSet):
-    def __init__(self, fname):
+    def __init__(self, fname, maxsize):
         self._fname = fname
+        self.maxsize = maxsize
         self._train = 0
         self._valid = 1
         self._test = 2
@@ -35,5 +36,5 @@
     def _return_it(self, batchsz, bufsz, id):
         if not hasattr(self, 'datas'):
             self._load()
-        return izip(DataIterator([ArrayFile(self.datas[id][0])], batchsz, bufsz),
-                    DataIterator([ArrayFile(self.datas[id][1])], batchsz, bufsz))
+        return izip(DataIterator([ArrayFile(self.datas[id][0][:self.maxsize])], batchsz, bufsz),
+                    DataIterator([ArrayFile(self.datas[id][1][:self.maxsize])], batchsz, bufsz))
--- a/deep/convolutional_dae/stacked_convolutional_dae.py	Wed Mar 17 16:41:16 2010 -0400
+++ b/deep/convolutional_dae/stacked_convolutional_dae.py	Wed Mar 17 16:41:51 2010 -0400
@@ -4,24 +4,21 @@
 import sys
 import theano.tensor as T
 from theano.tensor.shared_randomstreams import RandomStreams
-import theano.sandbox.softsign
+#import theano.sandbox.softsign
 
 from theano.tensor.signal import downsample
 from theano.tensor.nnet import conv 
 
-sys.path.append('../../../')
-
 from ift6266 import datasets
 from ift6266.baseline.log_reg.log_reg import LogisticRegression
 
 batch_size = 100
 
-
 class SigmoidalLayer(object):
     def __init__(self, rng, input, n_in, n_out):
 
         self.input = input
-
+ 
         W_values = numpy.asarray( rng.uniform( \
               low = -numpy.sqrt(6./(n_in+n_out)), \
               high = numpy.sqrt(6./(n_in+n_out)), \
@@ -37,7 +34,8 @@
 class dA_conv(object):
  
   def __init__(self, input, filter_shape, corruption_level = 0.1, 
-               shared_W = None, shared_b = None, image_shape = None):
+               shared_W = None, shared_b = None, image_shape = None, 
+               poolsize = (2,2)):
 
     theano_rng = RandomStreams()
     
@@ -69,18 +67,12 @@
     self.tilde_x = theano_rng.binomial( self.x.shape, 1, 1 - corruption_level,dtype=theano.config.floatX) * self.x
 
     conv1_out = conv.conv2d(self.tilde_x, self.W, filter_shape=filter_shape,
-                            image_shape=image_shape,
-                            unroll_kern=4,unroll_batch=4, 
-                            border_mode='valid')
-
+                            image_shape=image_shape, border_mode='valid')
     
     self.y = T.tanh(conv1_out + self.b.dimshuffle('x', 0, 'x', 'x'))
-
     
-    da_filter_shape = [ filter_shape[1], filter_shape[0], filter_shape[2],\
-                       filter_shape[3] ]
-    da_image_shape = [ batch_size, filter_shape[0], image_shape[2]-filter_shape[2]+1, image_shape[3]-filter_shape[3]+1 ]
-    #import pdb; pdb.set_trace()
+    da_filter_shape = [ filter_shape[1], filter_shape[0], 
+                        filter_shape[2], filter_shape[3] ]
     initial_W_prime =  numpy.asarray( numpy.random.uniform( \
               low = -numpy.sqrt(6./(fan_in+fan_out)), \
               high = numpy.sqrt(6./(fan_in+fan_out)), \
@@ -88,9 +80,7 @@
     self.W_prime = theano.shared(value = initial_W_prime, name = "W_prime")
 
     conv2_out = conv.conv2d(self.y, self.W_prime,
-                            filter_shape = da_filter_shape,\
-                            image_shape = da_image_shape, \
-                            unroll_kern=4,unroll_batch=4, \
+                            filter_shape = da_filter_shape,
                             border_mode='full')
 
     self.z =  (T.tanh(conv2_out + self.b_prime.dimshuffle('x', 0, 'x', 'x'))+center) / scale
@@ -115,8 +105,7 @@
         self.b = theano.shared(value=b_values)
  
         conv_out = conv.conv2d(input, self.W,
-                filter_shape=filter_shape, image_shape=image_shape,
-                               unroll_kern=4,unroll_batch=4)
+                filter_shape=filter_shape, image_shape=image_shape)
  
 
         fan_in = numpy.prod(filter_shape[1:])
@@ -137,7 +126,7 @@
 class SdA():
     def __init__(self, input, n_ins_mlp, conv_hidden_layers_sizes,
                  mlp_hidden_layers_sizes, corruption_levels, rng, n_out, 
-                 pretrain_lr, finetune_lr):
+                 pretrain_lr, finetune_lr, img_shape):
         
         self.layers = []
         self.pretrain_functions = []
@@ -154,7 +143,7 @@
             max_poolsize=conv_hidden_layers_sizes[i][2]
                 
             if i == 0 :
-                layer_input=self.x.reshape((batch_size, 1, 32, 32))
+                layer_input=self.x.reshape((self.x.shape[0], 1) + img_shape)
             else:
                 layer_input=self.layers[-1].output
             
@@ -170,7 +159,7 @@
             da_layer = dA_conv(corruption_level = corruption_levels[0],
                                input = layer_input,
                                shared_W = layer.W, shared_b = layer.b,
-                               filter_shape=filter_shape,
+                               filter_shape = filter_shape,
                                image_shape = image_shape )
             
             gparams = T.grad(da_layer.cost, da_layer.params)
@@ -221,13 +210,13 @@
         
         self.errors = self.logLayer.errors(self.y)
 
-def sgd_optimization_mnist( learning_rate=0.1, pretraining_epochs = 0, \
-                            pretrain_lr = 0.1, training_epochs = 1000, \
-                            kernels = [ [2,5,5] , [2,3,3] ], mlp_layers=[500], \
-                            corruption_levels = [ 0.2, 0.2, 0.2], batch_size = batch_size, \
-                            max_pool_layers = [ [2,2] , [2,2] ], \
-                            dataset=datasets.nist_digits):
-    
+def sgd_optimization_mnist(learning_rate=0.1, pretraining_epochs = 1,
+                           pretrain_lr = 0.1, training_epochs = 1000,
+                           kernels = [[4,5,5], [4,3,3]], mlp_layers=[500],
+                           corruption_levels = [0.2, 0.2, 0.2], 
+                           batch_size = batch_size, img_shape=(28, 28),
+                           max_pool_layers = [[2,2], [2,2]],
+                           dataset=datasets.mnist(5000)):
  
     # allocate symbolic variables for the data
     index = T.lscalar() # index to a [mini]batch
@@ -235,31 +224,32 @@
     y = T.ivector('y') # the labels are presented as 1d vector of
     # [int] labels
 
-    layer0_input = x.reshape((batch_size,1,32,32))
+    layer0_input = x.reshape((x.shape[0],1)+img_shape)
     
     rng = numpy.random.RandomState(1234)
     conv_layers=[]
-    init_layer = [ [ kernels[0][0],1,kernels[0][1],kernels[0][2] ],\
-                   [ batch_size , 1, 32, 32 ],    
-                    max_pool_layers[0] ]
+    init_layer = [[kernels[0][0],1,kernels[0][1],kernels[0][2]],
+                  None, # do not specify the batch size since it can 
+                        # change for the last one and then theano will 
+                        # crash.
+                  max_pool_layers[0]]
     conv_layers.append(init_layer)
 
-    conv_n_out = int((32-kernels[0][2]+1)/max_pool_layers[0][0])
-    print init_layer[1]
-    
+    conv_n_out = (img_shape[0]-kernels[0][2]+1)/max_pool_layers[0][0]
+
     for i in range(1,len(kernels)):    
-        layer = [ [ kernels[i][0],kernels[i-1][0],kernels[i][1],kernels[i][2] ],\
-                  [ batch_size, kernels[i-1][0],conv_n_out,conv_n_out ],    
-                   max_pool_layers[i] ]
+        layer = [[kernels[i][0],kernels[i-1][0],kernels[i][1],kernels[i][2]],
+                 None, # same comment as for init_layer
+                 max_pool_layers[i] ]
         conv_layers.append(layer)
-        conv_n_out = int( (conv_n_out - kernels[i][2]+1)/max_pool_layers[i][0])
-        print layer [1]
+        conv_n_out =  (conv_n_out - kernels[i][2]+1)/max_pool_layers[i][0]
+
     network = SdA(input = layer0_input, n_ins_mlp = kernels[-1][0]*conv_n_out**2,
                   conv_hidden_layers_sizes = conv_layers,
                   mlp_hidden_layers_sizes = mlp_layers,
-                  corruption_levels = corruption_levels , n_out = 62,
-                  rng = rng , pretrain_lr = pretrain_lr ,
-                  finetune_lr = learning_rate )
+                  corruption_levels = corruption_levels, n_out = 62,
+                  rng = rng , pretrain_lr = pretrain_lr,
+                  finetune_lr = learning_rate, img_shape=img_shape)
 
     test_model = theano.function([network.x, network.y], network.errors)
  
@@ -267,9 +257,7 @@
     for i in xrange(len(network.layers)-len(mlp_layers)):
         for epoch in xrange(pretraining_epochs):
             for x, y in dataset.train(batch_size):
-                if x.shape[0] == batch_size:
-                    c = network.pretrain_functions[i](x)
-
+                c = network.pretrain_functions[i](x)
             print 'pre-training convolution layer %i, epoch %d, cost '%(i,epoch), c
 
     patience = 10000 # look as this many examples regardless
@@ -291,16 +279,12 @@
     while (epoch < training_epochs) and (not done_looping):
       epoch = epoch + 1
       for x, y in dataset.train(batch_size):
-        if x.shape[0] != batch_size:
-            continue
+ 
         cost_ij = network.finetune(x, y)
         iter += 1
         
         if iter % validation_frequency == 0:
-            validation_losses = []
-            for xv, yv in dataset.valid(batch_size):
-                if xv.shape[0] == batch_size:
-                    validation_losses.append(test_model(xv, yv))
+            validation_losses = [test_model(xv, yv) for xv, yv in dataset.valid(batch_size)]
             this_validation_loss = numpy.mean(validation_losses)
             print('epoch %i, iter %i, validation error %f %%' % \
                    (epoch, iter, this_validation_loss*100.))
@@ -318,10 +302,7 @@
                 best_iter = iter
                 
                 # test it on the test set
-                test_losses=[]
-                for xt, yt in dataset.test(batch_size):
-                    if xt.shape[0] == batch_size:
-                        test_losses.append(test_model(xt, yt))
+                test_losses = [test_model(xt, yt) for xt, yt in dataset.test(batch_size)]
                 test_score = numpy.mean(test_losses)
                 print((' epoch %i, iter %i, test error of best '
                       'model %f %%') %