changeset 1283:a73db8d65abb

cifar10 op - added an op for generating whitened patches
author James Bergstra <bergstrj@iro.umontreal.ca>
date Wed, 15 Sep 2010 17:46:21 -0400
parents f36f59e53c28
children 1817485d586d
files pylearn/dataset_ops/cifar10.py
diffstat 1 files changed, 70 insertions(+), 0 deletions(-) [+]
line wrap: on
line diff
--- a/pylearn/dataset_ops/cifar10.py	Wed Sep 15 17:46:03 2010 -0400
+++ b/pylearn/dataset_ops/cifar10.py	Wed Sep 15 17:46:21 2010 -0400
@@ -160,6 +160,76 @@
 
 nclasses = 10
 
+import pylearn.datasets.image_patches
+import pylearn.preprocessing.pca
+
+@memo
+def random_cifar_patches(dtype, N,R,C, centered=True):
+    #These used to be arguments, but optional arguments don't work well with the cache
+    # because the cache doesn't [yet] look up what they are
+    rng_seed=89234
+    channel_rank=2
+
+    rng=numpy.random.RandomState(rng_seed)
+    imgs = train_data_labels(dtype)[0][:40000].reshape((40000,3,32,32)).transpose((0,2,3,1))
+    forget() #un-cache the original images
+    #import pdb; pdb.set_trace()
+    patches = pylearn.datasets.image_patches.extract_random_patches(imgs, N,R,C, rng)
+    orig_shape = patches.shape
+
+    # center individual examples
+    patches = patches.reshape((orig_shape[0], R*C*3))
+    patches -= patches.mean(axis=1).reshape((N, 1))
+    patches = patches.reshape(orig_shape)
+
+    if channel_rank==4:
+        pass
+    elif channel_rank==2:
+        # put the channels the cifar10 way :/
+        patches = patches.transpose((0,3,1,2)).copy()
+    else:
+        raise NotImplementedError()
+    if centered:
+        patches -= patches.mean(axis=0)
+    return patches
+
+@memo
+def random_cifar_patches_pca(max_components, max_energy_fraction, dtype, N,R,C,*args):
+    pca, _ = pylearn.preprocessing.pca.pca_from_examples(
+            random_cifar_patches(dtype,N,R,C,*args).reshape((N,R*C*3)),
+            max_components, max_energy_fraction, x_centered=True)
+    return pca
+
+@memo
+def whitened_random_cifar_patches(max_components, max_energy_fraction, dtype,N,R,C,*args):
+    pca = random_cifar_patches_pca(max_components, max_energy_fraction, dtype,N,R,C,*args)
+    patches = random_cifar_patches(dtype,N,R,C,*args).reshape((N,R*C*3))
+    random_cifar_patches.forget() #un-cache the original patches
+    return pylearn.preprocessing.pca.pca_whiten(pca, patches).astype(dtype)
+
+def cifar10_patches(s_idx, split, dtype='float32', rasterized=True, color='rgb', 
+        n_patches=1000, patch_size=(8,8), pca_components=80):
+    """
+    Return
+    """
+    if split != 'train': raise NotImplementedError()
+    if dtype != 'float32':raise NotImplementedError()
+    if color != 'rgb': raise NotImplementedError()
+    if s_idx.ndim != 1: raise NotImplementedError()
+
+    x_op = TensorFnDataset(dtype, (False,), 
+            (whitened_random_cifar_patches, (
+                pca_components,None,dtype,n_patches, patch_size[0], patch_size[1])),
+            (patch_size[0],patch_size[1],3))
+    x = x_op(s_idx%n_patches)
+
+    if rasterized:
+        x = x.flatten(2)
+    else:
+        raise NotImplementedError()
+
+    return x
+
 def glviewer(split='train'):
     from glviewer import GlViewer
     i = theano.tensor.iscalar()