changeset 1277:bef6c5f565cd

image_patches - extended extract_random_patches to handle rgb image stacks
author James Bergstra <bergstrj@iro.umontreal.ca>
date Tue, 14 Sep 2010 22:37:00 -0400
parents 822c7691a759
children c9ec065ff736
files pylearn/datasets/image_patches.py
diffstat 1 files changed, 15 insertions(+), 11 deletions(-) [+]
line wrap: on
line diff
--- a/pylearn/datasets/image_patches.py	Tue Sep 14 22:35:51 2010 -0400
+++ b/pylearn/datasets/image_patches.py	Tue Sep 14 22:37:00 2010 -0400
@@ -83,7 +83,7 @@
 def extract_random_patches(img_stack, N, R,C, rng):
     """Return subimages from the img_stack
 
-    :param img_stack: a 3D ndarray (n_images, rows, cols) or a list of 2D images.
+    :param img_stack: a 3D[4D] ndarray (n_images, rows, cols,[channels]) or a list of images.
     :param N: number of patches to extract
     :param R: number of rows in patch
     :param C: number of cols in patch
@@ -94,21 +94,25 @@
     image in the stack therefore would be sampled less frequently than patches from a smaller
     image in the stack.
 
-    To use ZCA whitening, extract patches from the raw data, and pass it to
-    preprocessing.pca.zca_whitening.
+    :hint:
+        To use ZCA whitening, extract patches from the raw data, and pass it to
+        preprocessing.pca.zca_whitening.
     """
-    rval = numpy.empty((N,R,C), dtype=img_stack[0].dtype)
+
+    rval = numpy.empty((N,R,C)+img_stack.shape[3:], dtype=img_stack[0].dtype)
 
     L = len(img_stack)
     img_idxlist = rng.randint(L,size=N)
+    offsets_R = rng.randint(img_stack.shape[1]-R+1, size=N)
+    offsets_C = rng.randint(img_stack.shape[2]-C+1, size=N)
 
-    for n, img_idxlist in enumerate(img_idxlist):
-        img_idx = rng.randint(L)
-        img_n = img_stack[img_idx]
-        offset_R = rng.randint(img_n.shape[0]-R+1)
-        offset_C = rng.randint(img_n.shape[1]-C+1)
-        rval[n] = img_n[offset_R:offset_R+R,offset_C:offset_C+C]
-
+    for n, (l,r,c) in enumerate(zip(img_idxlist, offsets_R, offsets_C)):
+        #img_idx = rng.randint(L)
+        #offset_R = rng.randint(img_n.shape[0]-R+1)
+        #offset_C = rng.randint(img_n.shape[1]-C+1)
+        #img_n = img_stack[l]
+        #rval[n] = img_n[offset_R:offset_R+R,offset_C:offset_C+C]
+        rval[n] = img_stack[l,r:r+R,c:c+C]
     return rval