Mercurial > pylearn
changeset 832:67b92a42f86b
added dataset_ops
author | James Bergstra <bergstrj@iro.umontreal.ca> |
---|---|
date | Fri, 16 Oct 2009 12:04:05 -0400 |
parents | 43e726898cf9 |
children | 039e93a95c20 |
files | pylearn/dataset_ops/COIL100.py pylearn/dataset_ops/MNIST.py pylearn/dataset_ops/README.txt pylearn/dataset_ops/__init__.py pylearn/dataset_ops/gldataset.py pylearn/dataset_ops/glviewer.py pylearn/dataset_ops/memo.py pylearn/dataset_ops/protocol.py |
diffstat | 8 files changed, 1167 insertions(+), 0 deletions(-) [+] |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/pylearn/dataset_ops/COIL100.py Fri Oct 16 12:04:05 2009 -0400 @@ -0,0 +1,62 @@ + +""" +http://www1.cs.columbia.edu/CAVE/software/softlib/coil-100.php + +"Columbia Object Image Library (COIL-100)," + S. A. Nene, S. K. Nayar and H. Murase, + Technical Report CUCS-006-96, February 1996. + +""" + +import os, cPickle +import Image, numpy +from pylearn.datasets.config import data_root # config + +from .memo import memo + +def filenames(): + root = os.path.join(data_root(), 'COIL-100', 'coil-100', ) + for filename in os.listdir(root): + yield filename, os.path.join(root,filename ) + +def filenameidx_imgidx(filename): + if filename.startswith("obj"): + obj_idx = int(filename[3:filename.index("_")]) + img_idx = int(filename[filename.index("_")+2:filename.index(".")]) + return obj_idx, img_idx + else: + raise ValueError(filename) + +_32x32grey_path = os.path.join(data_root(), "COIL-100", "dct_32x32_grey.pkl") +_32x32grey_header = "Dictionary of COIL-100 dataset at 32x32 resolution, greyscale" +def build_32x32_grey(): + f = file(_32x32grey_path, "w") + cPickle.dump(_32x32grey_header, f, protocol=cPickle.HIGHEST_PROTOCOL) + + dct = {} + for filename, fullname in filenames(): + if filename.startswith('obj'): + obj_idx, img_idx = filenameidx_imgidx(filename) + img = numpy.asarray(Image.open(fullname)) + dct.setdefault(obj_idx, {})[img_idx] = img.mean(axis=2)[::4,::4] + rval = numpy.empty((100, 72, 32, 32), dtype='float32') + rval[...] = -1 + for obj_id, dd in dct.iteritems(): + for img_id, v in dd.iteritems(): + rval[obj_id, img_id, :, :] = v + assert numpy.all(rval >= 0.0) + + cPickle.dump(rval, f, protocol=cPickle.HIGHEST_PROTOCOL) + f.close() + +@memo +def get_32x32_grey(): + f = file(_path_32x32_grey) + if _32x32grey_header != cPickle.load(f): + raise ValueError('wrong pickle file') + rval = cPickle.load(f) + f.close() + return rval + + +
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/pylearn/dataset_ops/MNIST.py Fri Oct 16 12:04:05 2009 -0400 @@ -0,0 +1,126 @@ +"""Regular MNIST using the dataset protocol +""" +import os, numpy +import theano +from pylearn.datasets.config import data_root # config +from pylearn.io.ubyte import read_ubyte_matrix +from protocol import TensorFnDataset # protocol.py __init__.py +from .memo import memo + +@memo +def get_train_img_u8_rasterized(): + """Returns 60000 x 784 MNIST train set""" + return read_ubyte_matrix( + os.path.join(data_root(), 'mnist', 'train-images-idx3-ubyte'), + 60000, 784, 16, + write=False, align=True, as_dtype='uint8') +@memo +def get_test_img_u8_rasterized(): + """Returns 10000 x 784 MNIST test set""" + return read_ubyte_matrix( + os.path.join(data_root(), 'mnist', 't10k-images-idx3-ubyte'), + 10000, 784, 16, + write=False, align=True, as_dtype='uint8') +@memo +def get_train_labels(): + # these are actually uint8, but the nnet classif code is for ints. + return read_ubyte_matrix( + os.path.join(data_root(), 'mnist', 'train-labels-idx1-ubyte'), + 60000, 1, 8, + write=False, align=True, as_dtype='int32').reshape(60000) +@memo +def get_test_labels(): + # these are actually uint8, but the nnet classif code is for ints. + return read_ubyte_matrix( + os.path.join(data_root(), 'mnist', 't10k-labels-idx1-ubyte'), + 10000, 1, 8, + write=False, align=True, as_dtype='int32').reshape(10000) + +#This will cause both the uint8 version and the float version of the dataset to be cached. +# For larger datasets, it would be better to use Theano's cast(x, dtype) to do this conversion +# on the fly. +@memo +def get_train_img_f32_rasterized(): + return get_train_img_u8_rasterized() / numpy.asarray(255, dtype='float32') +@memo +def get_train_img_f64_rasterized(): + return get_train_img_u8_rasterized() / numpy.asarray(255, dtype='float64') +@memo +def get_test_img_f32_rasterized(): + return get_test_img_u8_rasterized() / numpy.asarray(255, dtype='float32') +@memo +def get_test_img_f64_rasterized(): + return get_test_img_u8_rasterized() / numpy.asarray(255, dtype='float64') + +#@constructor +def mnist(s_idx, split, dtype='float64', rasterized=False): + """ + :param s_idx: + + :param split: + + :param dtype: + + :param rasterized: return examples as vectors (True) or 28x28 matrices (False) + + """ + if split not in ('train', 'valid', 'test'): + raise ValueError('split should be train, valid, or test', split) + + if split == 'test': + l_fn = get_test_labels + if dtype == 'uint8': + i_fn = get_test_img_u8_rasterized + elif dtype == 'float32': + i_fn = get_test_img_f32_rasterized + elif dtype == 'float64': + i_fn = get_test_img_f64_rasterized + else: + raise ValueError('invalid dtype', dtype) + else: + l_fn = get_train_labels + if dtype == 'uint8': + i_fn = get_train_img_u8_rasterized + elif dtype == 'float32': + i_fn = get_train_img_f32_rasterized + elif dtype == 'float64': + i_fn = get_train_img_f64_rasterized + else: + raise ValueError('invalid dtype', dtype) + + if split == 'test': + idx = s_idx + elif split == 'train': + idx = s_idx % 50000 + else: #valid + idx = s_idx + 50000 + + x = TensorFnDataset(dtype, (False,), i_fn, (784,))(idx) + y = TensorFnDataset('int32', (), l_fn)(idx) + if x.ndim == 1: + if not rasterized: + x = x.reshape((28,28)) + elif x.ndim == 2: + if not rasterized: + x = x.reshape(x.shape[0], (28,28)) + else: + assert False, 'what happened?' + + return x, y +nclasses = 10 + +def glviewer(part='train'): + from glviewer import GlViewer + if part == 'train': + if 0: + #hack that doesn't use the op + x = get_train_img_u8_rasterized().reshape((60000, 28, 28)) + GlViewer(x.__getitem__).main() + else: + # test that uses the op + i = theano.tensor.iscalar() + f = theano.function([i], mnist(i, 'train', dtype='uint8', rasterized=False)) + GlViewer(f).main() + + +
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/pylearn/dataset_ops/README.txt Fri Oct 16 12:04:05 2009 -0400 @@ -0,0 +1,13 @@ +The dataset_ops folder contains Theano Ops that provide dataset access to theano +programs. + +The protocol.py file sets out the basic convention that is followed by the Ops +in the other files. + +For an example of how to set up a dataset whose elements are slices from some +big underlying tensor, see MNIST.py. + +For an example of how to set up a dynamically-generated dataset, see +gldataset.py. + +
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/pylearn/dataset_ops/__init__.py Fri Oct 16 12:04:05 2009 -0400 @@ -0,0 +1,4 @@ +import logging +logging.getLogger('dataset_ops').setLevel(logging.INFO) + +from protocol import Dataset, TensorDataset, TensorFnDataset # protocol.py
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/pylearn/dataset_ops/gldataset.py Fri Oct 16 12:04:05 2009 -0400 @@ -0,0 +1,490 @@ +"""Demonstrate a complicated dynamically-generated dataset. +""" + +# __init__.py + +import sys, copy, logging, sys + +import Image #PIL + +from OpenGL.GL import * +from OpenGL.GLU import * +from OpenGL.GLUT import * +from pyglew import * + +from glviewer import load_texture + +import numpy + +import theano +from theano.compile.sandbox import shared +from theano.compile.sandbox import pfunc as function + +_logger = logging.getLogger('gldataset') +def debug(*msg): _logger.debug(' '.join(str(m) for m in msg)) +def info(*msg): _logger.info(' '.join(str(m) for m in msg)) +def warn(*msg): _logger.warn(' '.join(str(m) for m in msg)) +def warning(*msg): _logger.warning(' '.join(str(m) for m in msg)) +def error(*msg): _logger.error(' '.join(str(m) for m in msg)) + +def init_GL(shape=(64,64), title='Offscreen rendering using FB0'): + if not init_GL.done: + w, h = shape + init_GL.done = True + info('initializing OpenGl subsystem') + glutInit (sys.argv) + glutInitDisplayMode (GLUT_DOUBLE | GLUT_RGBA | GLUT_DEPTH) + glutInitWindowSize (w,h) + init_GL.window = glutCreateWindow (title) + glewInit() + + glEnable(GL_TEXTURE_2D) + glClearColor(0.0, 0.0, 0.0, 0.0) # This Will Clear The Background Color To Black + glClearDepth(1.0) # Enables Clearing Of The Depth Buffer + glDepthFunc(GL_LESS) # The Type Of Depth Test To Do + glEnable(GL_DEPTH_TEST) # Enables Depth Testing + glShadeModel(GL_SMOOTH) # Enables Smooth Color Shading + + #glMatrixMode(GL_PROJECTION) + #glLoadIdentity() # Reset The Projection Matrix + # Calculate The Aspect Ratio Of The Window + #gluPerspective(45.0, float(64)/float(64), 0.1, 100.0) + glMatrixMode(GL_MODELVIEW) +init_GL.done = False + +class PBufRenderer(object): + """Render an OpenGL program to a framebuffer instead of the screen. + + The way to use this class is to enclose all the OpenGL commands you want to render between + a call to setup() and a call to render(). So you would render a frame like this: + + .. code-block:: python + + p = PBufRenderer(shape) + p.setup() + my_display_code() + a = p.render() + my_display_code() + b = p.render() + + After running this code, 'a' and 'b' will be numpy arrays of shape `shape` + (3,) containing an + RBG rendering of your display_code. + """ + def __init__(self, size=(128,128), upside_down=False): + """ Offscreen rendering + + Save an offscreen rendering of size (w,h) to filename. + """ + + def round2 (n): + """ Get nearest power of two superior to n """ + f = 1 + while f<n: + f*= 2 + return f + + if size == None: + size = (512,512) + w = round2 (size[0]) + h = round2 (size[1]) + + image = Image.new ("RGB", (w, h), (0, 0, 0)) + bits = image.tostring("raw", "RGBX", 0, -1) + + debug('allocating framebuffer') + framebuffer = glGenFramebuffersEXT (1) + glBindFramebufferEXT(GL_FRAMEBUFFER_EXT, framebuffer) + + debug('allocating depthbuffer') + depthbuffer = glGenRenderbuffersEXT (1) + glBindRenderbufferEXT (GL_RENDERBUFFER_EXT,depthbuffer) + glRenderbufferStorageEXT (GL_RENDERBUFFER_EXT, GL_DEPTH_COMPONENT, w, h) + + # Create texture to render to + debug('allocating dynamic texture') + texture = glGenTextures (1) + glBindTexture (GL_TEXTURE_2D, texture) + glTexParameteri (GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_LINEAR) + glTexParameteri (GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_LINEAR) + # Question: why do we need to upload a texture that we are rendering *to* ? + glTexImage2D (GL_TEXTURE_2D, 0, GL_RGB, w, h, 0, + GL_RGB, GL_UNSIGNED_BYTE, bits) + + # store variables for later use. + self.texture = texture + self.framebuffer = framebuffer + self.depthbuffer = depthbuffer + self.image = image + self.bits = bits + self.size = size + self.texture_size = (w,h) + self.upside_down = upside_down + + # set the screen as output + glBindRenderbufferEXT (GL_RENDERBUFFER_EXT, 0) + glBindFramebufferEXT (GL_FRAMEBUFFER_EXT, 0) + + def __del__(self): + glBindRenderbufferEXT (GL_RENDERBUFFER_EXT, 0) + glBindFramebufferEXT (GL_FRAMEBUFFER_EXT, 0) + glDeleteTextures (1,[self.texture]) + glDeleteFramebuffersEXT (1, [self.framebuffer]) + glDeleteRenderbuffersExt (1, [self.depthbuffer]) + + def setup(self): + glBindRenderbufferEXT (GL_RENDERBUFFER_EXT, self.depthbuffer) + glBindFramebufferEXT(GL_FRAMEBUFFER_EXT, self.framebuffer) + glFramebufferTexture2DEXT (GL_FRAMEBUFFER_EXT, GL_COLOR_ATTACHMENT0_EXT, + GL_TEXTURE_2D, self.texture, 0); + glFramebufferRenderbufferEXT(GL_FRAMEBUFFER_EXT, GL_DEPTH_ATTACHMENT_EXT, + GL_RENDERBUFFER_EXT, self.depthbuffer); + + status = glCheckFramebufferStatusEXT (GL_FRAMEBUFFER_EXT); + if status != GL_FRAMEBUFFER_COMPLETE_EXT: + raise RuntimeError( "Error in framebuffer activation") + + # Re-orient viewport + glViewport (0, 0, self.size[0], self.size[1]) + glMatrixMode (GL_PROJECTION) + glLoadIdentity() + gluPerspective (40.,self.size[0]/float(self.size[1]),1.,40.) + glMatrixMode (GL_MODELVIEW) + glLoadIdentity() + gluLookAt (0,0,10, 0,0,0, 0,1,0) + + def render(self): + # TODO: Can we get away with glFlush? + glFinish() #renders to our framebuffer + + # read back the framebuffer to self.image + glBindTexture (GL_TEXTURE_2D, self.texture) + w,h = self.texture_size + data = glReadPixels (0, 0, w, h, GL_RGB, GL_UNSIGNED_BYTE) + rval = numpy.fromstring(data, dtype='uint8', count=w*h*3).reshape((w,h,3)) + if self.size != self.texture_size: + rval = rval[:self.size[0], :self.size[1],:] + + # return to default state of screen rendering + glBindRenderbufferEXT (GL_RENDERBUFFER_EXT, 0) + glBindFramebufferEXT(GL_FRAMEBUFFER_EXT, 0) + if self.upside_down: + return rval + else: + return rval[::-1,:,:] + +class OpenGlMovieFromImage(theano.Op): + """Helper base class to factor code used by Ops that want to make a movie from an input + image, using OpenGL. The subclass specifies how to actually make the movie. + """ + + def __init__(self, width, height, upside_down=False): + """To set up the renderer, we need to know the frame size of the images. + Setting up the renderer for each image is much slower. + """ + init_GL() #global initialization is no-op after first call + self.width=width + self.height=height + self.upside_down=upside_down + + self.renderer = None + # Delay construction of renderer until after merge-optimization + #PBufRenderer((width, height), upside_down=upside_down) + + #TODO: put texture into output state as reusable resource + self.texture = glGenTextures(1) + + def __del__(self): + glDeleteTextures (1,[self.texture]) + + def __eq__(self, other): + return type(self) == type(other) \ + and self.width == other.width \ + and self.height == other.height \ + and self.upside_down == other.upside_down + + def __hash__(self): + return hash(type(self)) ^ hash(self.width) ^ hash(self.height) ^ hash(self.upside_down) + + def make_node(self, x, istate): + _x = theano.tensor.as_tensor_variable(x) + if _x.type.dtype != 'uint8': + raise TypeError('must be 2- or 3-tensor of uint8', x) + if _x.type.broadcastable != (False, False) \ + and _x.type.broadcastable != (False, False, False): + raise TypeError('must be a 2- or 3-tensor of uint8', x) + if not isinstance(istate, theano.Variable): + raise TypeError("variable expected", istate) + o_type = theano.tensor.TensorType(dtype='uint8', broadcastable=[False, False, False, False]) + state_type = theano.gof.type.generic + return theano.Apply(self, [x, istate], [o_type(), state_type()]) + + def perform(self, node, (x, istate), (z_storage, ostate_storage)): + if self.renderer is None: + self.renderer = PBufRenderer((self.width, self.height), upside_down=self.upside_down) + + ostate = copy.deepcopy(istate) + self.renderer.setup() + + glBindTexture(GL_TEXTURE_2D, self.texture) # 2d texture (x and y size) + load_texture(x) + + z = numpy.zeros(self.z_shape, dtype='uint8') + for i in xrange(self.n_frames): + self.perform_set_state(istate, ostate, i) + self.perform_display(x, ostate, i) + di = self.renderer.render() + z[i] = di + + # store output images + z_storage[0] = z + + # store next state ostate_storage + ostate_storage[0] = ostate + +class ImageOnSpinningCube(OpenGlMovieFromImage): + def __init__(self, (n_frames, width, height), upside_down=False): + super(ImageOnSpinningCube, self).__init__(width, height, upside_down=upside_down) + self.n_frames = n_frames + self.z_shape = (n_frames, width, height, 3) + + def __eq__(self, other): + return super(ImageOnSpinningCube, self).__eq__(other) \ + and self.n_frames == other.n_frames \ + + def __hash__(self): + return super(ImageOnSpinningCube, self).__hash__() ^ hash(self.n_frames) + + def new_state(self, speed=10): + return dict( + rot=numpy.asarray((0.,0.,0.)), + drot=numpy.asarray((speed,speed,speed)), + ) + + def perform_set_state(self, istate, ostate, iter): + ostate['rot'] = istate['rot'] + istate['drot'] * iter + + def perform_display(self, x, ostate, i): + # retrieve some state variables related to rendering + xrot,yrot,zrot = ostate['rot'] + dxrot,dyrot,dzrot = ostate['drot'] + + # load x as a texture + glBindTexture(GL_TEXTURE_2D, self.texture) # 2d texture (x and y size) + glTexParameterf(GL_TEXTURE_2D, GL_TEXTURE_WRAP_S, GL_CLAMP) + glTexParameterf(GL_TEXTURE_2D, GL_TEXTURE_WRAP_T, GL_CLAMP) + glTexParameterf(GL_TEXTURE_2D, GL_TEXTURE_WRAP_S, GL_REPEAT) + glTexParameterf(GL_TEXTURE_2D, GL_TEXTURE_WRAP_T, GL_REPEAT) + glTexParameterf(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_NEAREST) + glTexParameterf(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_NEAREST) + glTexEnvf(GL_TEXTURE_ENV, GL_TEXTURE_ENV_MODE, GL_DECAL) + + glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT) # Clear The Screen And The Depth Buffer + glLoadIdentity() # Reset The View + glTranslatef(0.0,0.0,-5.0) # Move Into The Screen + + glRotatef(xrot,1.0,0.0,0.0) # Rotate The Cube On It's X Axis + glRotatef(yrot,0.0,1.0,0.0) # Rotate The Cube On It's Y Axis + glRotatef(zrot,0.0,0.0,1.0) # Rotate The Cube On It's Z Axis + + glBegin(GL_QUADS) # Start Drawing The Cube + + # Front Face (note that the texture's corners have to match the quad's corners) + glTexCoord2f(0.0, 0.0); glVertex3f(-1.0, -1.0, 1.0) # Bottom Left Of The Texture and Quad + glTexCoord2f(1.0, 0.0); glVertex3f( 1.0, -1.0, 1.0) # Bottom Right Of The Texture and Quad + glTexCoord2f(1.0, 1.0); glVertex3f( 1.0, 1.0, 1.0) # Top Right Of The Texture and Quad + glTexCoord2f(0.0, 1.0); glVertex3f(-1.0, 1.0, 1.0) # Top Left Of The Texture and Quad + + # Back Face + glTexCoord2f(1.0, 0.0); glVertex3f(-1.0, -1.0, -1.0) # Bottom Right Of The Texture and Quad + glTexCoord2f(1.0, 1.0); glVertex3f(-1.0, 1.0, -1.0) # Top Right Of The Texture and Quad + glTexCoord2f(0.0, 1.0); glVertex3f( 1.0, 1.0, -1.0) # Top Left Of The Texture and Quad + glTexCoord2f(0.0, 0.0); glVertex3f( 1.0, -1.0, -1.0) # Bottom Left Of The Texture and Quad + + # Top Face + glTexCoord2f(0.0, 1.0); glVertex3f(-1.0, 1.0, -1.0) # Top Left Of The Texture and Quad + glTexCoord2f(0.0, 0.0); glVertex3f(-1.0, 1.0, 1.0) # Bottom Left Of The Texture and Quad + glTexCoord2f(1.0, 0.0); glVertex3f( 1.0, 1.0, 1.0) # Bottom Right Of The Texture and Quad + glTexCoord2f(1.0, 1.0); glVertex3f( 1.0, 1.0, -1.0) # Top Right Of The Texture and Quad + + # Bottom Face + glTexCoord2f(1.0, 1.0); glVertex3f(-1.0, -1.0, -1.0) # Top Right Of The Texture and Quad + glTexCoord2f(0.0, 1.0); glVertex3f( 1.0, -1.0, -1.0) # Top Left Of The Texture and Quad + glTexCoord2f(0.0, 0.0); glVertex3f( 1.0, -1.0, 1.0) # Bottom Left Of The Texture and Quad + glTexCoord2f(1.0, 0.0); glVertex3f(-1.0, -1.0, 1.0) # Bottom Right Of The Texture and Quad + + # Right face + glTexCoord2f(1.0, 0.0); glVertex3f( 1.0, -1.0, -1.0) # Bottom Right Of The Texture and Quad + glTexCoord2f(1.0, 1.0); glVertex3f( 1.0, 1.0, -1.0) # Top Right Of The Texture and Quad + glTexCoord2f(0.0, 1.0); glVertex3f( 1.0, 1.0, 1.0) # Top Left Of The Texture and Quad + glTexCoord2f(0.0, 0.0); glVertex3f( 1.0, -1.0, 1.0) # Bottom Left Of The Texture and Quad + + # Left Face + glTexCoord2f(0.0, 0.0); glVertex3f(-1.0, -1.0, -1.0) # Bottom Left Of The Texture and Quad + glTexCoord2f(1.0, 0.0); glVertex3f(-1.0, -1.0, 1.0) # Bottom Right Of The Texture and Quad + glTexCoord2f(1.0, 1.0); glVertex3f(-1.0, 1.0, 1.0) # Top Right Of The Texture and Quad + glTexCoord2f(0.0, 1.0); glVertex3f(-1.0, 1.0, -1.0) # Top Left Of The Texture and Quad + + glEnd(); # Done Drawing The Cube + +def image_on_spinning_cube(x, shape, upside_down=False): + op = ImageOnSpinningCube(shape, upside_down=upside_down) + istate = shared(op.new_state()) + z, ostate = op(x, istate) + return z, {istate: ostate} + +class BrownianCamera(OpenGlMovieFromImage): + def __init__(self, (n_frames, width, height), upside_down=False): + super(BrownianCamera, self).__init__(width, height, upside_down=upside_down) + self.n_frames = n_frames + self.z_shape = (n_frames, width, height, 3) + + def __eq__(self, other): + return super(self.__class__, self).__eq__(other) \ + and self.n_frames == other.n_frames \ + + def __hash__(self): + return super(self.__class__, self).__hash__() ^ hash(self.n_frames) + + def new_state(self, pos_jitter=(.01,.01,.03), rot_jitter=(4.,4.,4.), seed=23424): + return dict( + pos_jitter=numpy.asarray(pos_jitter), + rot_jitter=numpy.asarray(rot_jitter), + pos0=numpy.asarray((0.,0.,-4.0)), + rot0=numpy.asarray((0.,0.,0.)), + alpha=0.1, + # dynamic things + pos=numpy.asarray((0.,0.,-4.0)), + dpos=numpy.asarray((0.,0.,0.)), + ddpos=numpy.asarray((0.,0.,0.)), + rot=numpy.asarray((0.,0.,0.)), + drot=numpy.asarray((0.,0.,0.)), + ddrot=numpy.asarray((0.,0.,0.)), + rng = numpy.random.RandomState(seed), + ) + + def perform_set_state(self, istate, ostate, iter): + alpha = ostate['alpha'] + if iter == 0: + ostate['pos'] = ostate['pos0'].copy() + ostate['dpos'] *= 0 + ostate['rot'] = ostate['rot0'].copy() + ostate['drot'] *= 0 + ostate['ddpos'] = ostate['rng'].uniform(low=-1,high=1,size=3) * ostate['pos_jitter'] + ostate['ddrot'] = ostate['rng'].uniform(low=-1,high=1,size=3) * ostate['rot_jitter'] + ostate['dpos'] += ostate['ddpos'] + ostate['drot'] += ostate['ddrot'] + ostate['pos'] = (1-alpha)*(ostate['pos'] + ostate['dpos']) + alpha * ostate['pos0'] + ostate['rot'] = (1-alpha)*(ostate['rot'] + ostate['drot']) + alpha * ostate['rot0'] + + def perform_display(self, x, ostate, i): + # retrieve some state variables related to rendering + xrot,yrot,zrot = ostate['rot'] + xpos,ypos,zpos = ostate['pos'] + + # load x as a texture + glBindTexture(GL_TEXTURE_2D, self.texture) # 2d texture (x and y size) + glTexParameterf(GL_TEXTURE_2D, GL_TEXTURE_WRAP_S, GL_CLAMP) + glTexParameterf(GL_TEXTURE_2D, GL_TEXTURE_WRAP_T, GL_CLAMP) + glTexParameterf(GL_TEXTURE_2D, GL_TEXTURE_WRAP_S, GL_REPEAT) + glTexParameterf(GL_TEXTURE_2D, GL_TEXTURE_WRAP_T, GL_REPEAT) + glTexParameterf(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_NEAREST) + glTexParameterf(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_NEAREST) + glTexEnvf(GL_TEXTURE_ENV, GL_TEXTURE_ENV_MODE, GL_DECAL) + + glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT) # Clear The Screen And The Depth Buffer + glLoadIdentity() # Reset The View + glTranslatef(xpos,ypos,zpos) # Move Into The Screen + + glRotatef(xrot,1.0,0.0,0.0) # Rotate The Cube On It's X Axis + glRotatef(yrot,0.0,1.0,0.0) # Rotate The Cube On It's Y Axis + glRotatef(zrot,0.0,0.0,1.0) # Rotate The Cube On It's Z Axis + + glBegin(GL_QUADS) # Start Drawing The Cube + + # Front Face (note that the texture's corners have to match the quad's corners) + glTexCoord2f(0.0, 0.0); glVertex3f(-1.0, -1.0, 1.0) # Bottom Left Of The Texture and Quad + glTexCoord2f(1.0, 0.0); glVertex3f( 1.0, -1.0, 1.0) # Bottom Right Of The Texture and Quad + glTexCoord2f(1.0, 1.0); glVertex3f( 1.0, 1.0, 1.0) # Top Right Of The Texture and Quad + glTexCoord2f(0.0, 1.0); glVertex3f(-1.0, 1.0, 1.0) # Top Left Of The Texture and Quad + + glEnd(); # Done Drawing The Cube + +_brownian_camera_ops = {} +def brownian_camera(x, shape, upside_down=False, seed=8234, speed=1.0): + if (shape, upside_down) not in _brownian_camera_ops: + _brownian_camera_ops[(shape, upside_down)] = BrownianCamera(shape, upside_down=upside_down) + op = _brownian_camera_ops[(shape, upside_down)] + istate = shared(op.new_state(seed=seed)) + istate.value['pos_jitter'] *= speed + istate.value['rot_jitter'] *= speed + z, ostate = op(x, istate) + return z, [(istate, ostate)] + + +def _dump_to_file(fn, filename='out.pkl', nexamples=1000, n_frames=10, **kwargs): + logging.basicConfig(level=logging.INFO, stream=sys.stderr) + import cPickle, time + + from MNIST import mnist + i = theano.tensor.iscalar() + z, z_updates = fn(mnist(i%50000, 'train', rasterized=False, dtype='uint8')[0], (n_frames, 28,28), **kwargs) + f = function([i], z[:,:,:,0], updates=z_updates) + + t0 = time.time() + rval = [] + for j in xrange(nexamples): + if 0 == j % 100: print >> sys.stderr, j + rval.append(f(j)) + dt = time.time() - t0 + info('Generating ', nexamples, 'examples took', dt, 'seconds.') + info('Generation rate:', nexamples/dt, 'examples per second.') + info('Generated ', nexamples*n_frames, 'frames') + info('Generation rate:', nexamples*n_frames/dt, 'frames per second.') + + cPickle.dump(rval, file(filename, 'w'), protocol=cPickle.HIGHEST_PROTOCOL) +def spinning_cube_dump(filename='spinning_cube.pkl', *args, **kwargs): + return _dump_to_file(fn=image_on_spinning_cube, filename=filename, *args, **kwargs) +def brownian_camera_dump(filename='brownian_camera.pkl', *args, **kwargs): + return _dump_to_file(fn=brownian_camera, filename=filename, *args, **kwargs) +def brownian_camera_dumpN(filename='brownian_cameraN.pkl', nexamples=10, n_frames=5, + n_movies=10, img_shape=(28,28), **kwargs): + logging.basicConfig(level=logging.INFO, stream=sys.stderr) + import cPickle, time + from MNIST import mnist + + s_idx = theano.tensor.iscalar() + inputs_updates = [brownian_camera( + x=mnist(s_idx*n_movies+i, 'train', rasterized=False, dtype='uint8')[0], + shape=(n_frames,)+img_shape, + seed=234234+i, **kwargs) + for i in xrange(n_movies)] + s_input = theano.tensor.stack(*(input for (input,update) in inputs_updates))\ + .reshape((n_movies*n_frames,)+img_shape+(3,)) + s_updates = [] + for i,u in inputs_updates: + s_updates.extend(u) + print s_updates + f = function([s_idx], s_input, updates=s_updates) + + t0 = time.time() + rval = [] + for j in xrange(nexamples): + if 0 == j % 1000: print >> sys.stderr, j + rval.append(f(j)) + dt = time.time() - t0 + info('Generating ', nexamples, 'examples took', dt, 'seconds.') + info('Generation rate:', nexamples/dt, 'examples per second.') + info('Generated ', nexamples*n_movies*n_frames, 'frames') + info('Generation rate:', nexamples*n_movies*n_frames/dt, 'frames per second.') + + cPickle.dump(rval, file(filename, 'w')) + + +def glviewer_from_file(filename='out.pkl'): + logging.basicConfig(level=logging.DEBUG, stream=sys.stderr) + import cPickle + rval = cPickle.load(file(filename)) + from glviewer import GlViewer + GlViewer(rval.__getitem__).main() + +
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/pylearn/dataset_ops/glviewer.py Fri Oct 16 12:04:05 2009 -0400 @@ -0,0 +1,367 @@ +"""This file provides a very crude image-viewing and movie-viewing mini-application. + +It is provided to assist in the development of datasets whose elements are images or movies. +For an example of how to do this, see the `glviewer` function in MNIST.py . + +Currently, the key controls that navigate the dataset are: + + j - next dataset element + k - previous dataset element + 0 - show image 0 + + J - next frame in current movie + K - previous frame in current movie + ) - show frame 0 of current movie + + q - quit. +""" +# Modified to be an image viewer by James Bergstra Sept 2009 +# +# Ported to PyOpenGL 2.0 by Tarn Weisner Burton 10May2001 +# +# This code was created by Richard Campbell '99 (ported to Python/PyOpenGL by John Ferguson 2000) +# +# The port was based on the lesson5 tutorial module by Tony Colston (tonetheman@hotmail.com). +# +# If you've found this code useful, please let me know (email John Ferguson at hakuin@voicenet.com). +# +# See original source and C based tutorial at http:#nehe.gamedev.net +# + + +import traceback +import time +import string +from OpenGL.GL import * +from OpenGL.GLUT import * +from OpenGL.GLU import * +import sys +from Image import * +import numpy + +import logging +_logger = logging.getLogger('glviewer') +_logger.setLevel(logging.INFO) +def debug(*msg): _logger.debug(' '.join(str(m) for m in msg)) +def info(*msg): _logger.info(' '.join(str(m) for m in msg)) +def warn(*msg): _logger.warn(' '.join(str(m) for m in msg)) +def warning(*msg): _logger.warning(' '.join(str(m) for m in msg)) +def error(*msg): _logger.error(' '.join(str(m) for m in msg)) + +def load_texture(x): + debug('loading texture with shape', x.shape) + if x.ndim == 2: + if x.dtype == 'uint8': + rows, cols = x.shape + buf = numpy.zeros((rows, cols, 4), dtype=x.dtype) + buf += x.reshape( (rows, cols, 1)) + glPixelStorei(GL_UNPACK_ALIGNMENT,1) + return glTexImage2D(GL_TEXTURE_2D, 0, 3, cols, rows, 0, GL_RGBA, GL_UNSIGNED_BYTE, + buf[::-1].flatten()) + else: + raise NotImplementedError() + elif x.ndim == 3: + rows, cols, channels = x.shape + if x.dtype == 'uint8': + if channels == 4: + return glTexImage2D(GL_TEXTURE_2D, 0, 3, cols, rows, 0, GL_RGBA, GL_UNSIGNED_BYTE, x[::-1].flatten()) + else: + buf = numpy.zeros((rows, cols, 4), dtype=x.dtype) + if channels == 1: + buf += x.reshape( (rows, cols, 1)) + if channels == 3: + buf[:,:,:3] = x + return glTexImage2D(GL_TEXTURE_2D, 0, 3, cols, rows, 0, GL_RGBA, GL_UNSIGNED_BYTE, buf[::-1].flatten()) + else: + raise NotImplementedError() + else: + raise NotImplementedError() + + # if you get here, it means a case was missed + assert 0 + + +class GlViewer(object): + # Number of the glut window. + window = 0 + + view_angle = 28.0 # this makes the edge of the cube match up with the viewport + + def __init__(self, texture_fn): + + # Rotations for cube. + self.xrot = self.yrot = self.zrot = 0.0 + + self.texture = 0 + + self.texture_fn = texture_fn + + self.pos = -1 + self.pos_frame = -1 + self.pos_is_movie = False + self.texture_array = None + + self.win_shape = (256, 256) + self.rot = numpy.zeros(3) + self.drot = numpy.ones(3) * .0 + + def init_LoadTextures(self): + # Create Texture + glBindTexture(GL_TEXTURE_2D, glGenTextures(1)) # 2d texture (x and y size) + self.refresh_texture(0, 0) + glTexParameterf(GL_TEXTURE_2D, GL_TEXTURE_WRAP_S, GL_CLAMP) + glTexParameterf(GL_TEXTURE_2D, GL_TEXTURE_WRAP_T, GL_CLAMP) + glTexParameterf(GL_TEXTURE_2D, GL_TEXTURE_WRAP_S, GL_REPEAT) + glTexParameterf(GL_TEXTURE_2D, GL_TEXTURE_WRAP_T, GL_REPEAT) + glTexParameterf(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_NEAREST) + glTexParameterf(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_NEAREST) + glTexEnvf(GL_TEXTURE_ENV, GL_TEXTURE_ENV_MODE, GL_DECAL) + + + # A general OpenGL initialization function. Sets all of the initial parameters. + def init_GL(self): + glEnable(GL_TEXTURE_2D) + glClearColor(0.0, 0.0, 0.0, 0.0) # This Will Clear The Background Color To Black + glClearDepth(1.0) # Enables Clearing Of The Depth Buffer + glDepthFunc(GL_LESS) # The Type Of Depth Test To Do + glEnable(GL_DEPTH_TEST) # Enables Depth Testing + glShadeModel(GL_SMOOTH) # Enables Smooth Color Shading + + glMatrixMode(GL_PROJECTION) + glLoadIdentity() # Reset The Projection Matrix + # Calculate The Aspect Ratio Of The Window + Width, Height = self.win_shape + gluPerspective(self.view_angle, float(Width)/float(Height), 0.1, 100.0) + + glMatrixMode(GL_MODELVIEW) + + def main(self): + # + # texture gen: an iterator over images + # + # Call this function like this: + # python -c 'import MNIST, glviewer; glviewer.main(x for (x,y) in MNIST.MNIST().train())' + # + + #TODO: this advances the iterator un-necessarily... we just want a frame to look at the + # dimensions + + global window + glutInit(sys.argv) + + # Select type of Display mode: + # Double buffer + # RGBA color + # Alpha components supported + # Depth buffer + info('initializing OpenGl subsystem') + ##glutInitDisplayMode(GLUT_RGBA | GLUT_DOUBLE | GLUT_DEPTH) + + win_width, win_height = self.win_shape + + # get a 640 x 480 window + ##glutInitWindowSize(win_width, win_height) + + # the window starts at the upper left corner of the screen + glutInitWindowPosition(0, 0) + + # Okay, like the C version we retain the window id to use when closing, but for those of you new + # to Python (like myself), remember this assignment would make the variable local and not global + # if it weren't for the global declaration at the start of main. + window = glutCreateWindow("GlViewer") + + # Register the drawing function with glut, BUT in Python land, at least using PyOpenGL, we need to + # set the function pointer and invoke a function to actually register the callback, otherwise it + # would be very much like the C version of the code. + glutDisplayFunc(self.draw_scene) + + # Uncomment this line to get full screen. + # glutFullScreen() + + # When we are doing nothing, redraw the scene. + glutIdleFunc(self.on_idle) + + # Register the function called when our window is resized. + glutReshapeFunc(self.ReSizeGLScene) + + # Register the function called when the keyboard is pressed. + glutKeyboardFunc(self.keyPressed) + + # create the texture we will use for showing images + self.init_LoadTextures() + + # Initialize our window. + self.init_GL() + + # Start Event Processing Engine + glutMainLoop() + + # The function called when our window is resized (which shouldn't happen if you enable fullscreen, below) + def ReSizeGLScene(self, Width, Height): + if Height == 0: # Prevent A Divide By Zero If The Window Is Too Small + Height = 1 + + glViewport(0, 0, Width, Height) # Reset The Current Viewport And Perspective Transformation + glMatrixMode(GL_PROJECTION) + glLoadIdentity() + gluPerspective(self.view_angle, float(Width)/float(Height), 0.1, 100.0) + glMatrixMode(GL_MODELVIEW) + + self.win_shape = (Width, Height) + + + def refresh_texture(self, new_pos, new_frame): + debug('refresh_texture', new_pos, new_frame, 'current', self.pos, self.pos_frame) + if new_pos != self.pos: + texture_array = None + try: + texture_array = self.texture_fn(new_pos) + except Exception, e: + traceback.print_exc() + + if texture_array is None: + return + # calling the texture_fn can mess up the OpenGL state + # here we set it up again + self.init_GL() + + self.pos_is_movie=False + if texture_array.ndim == 4: + self.pos_is_movie = True + if texture_array.ndim == 3 and texture_array.shape[2] > 4: + self.pos_is_movie = True + + self.pos = new_pos + self.texture_array = texture_array + pos_changed = True + if self.pos_is_movie: + info('example', new_pos, 'is movie of', texture_array.shape[0], 'frames') + else: + info('example', new_pos, 'is still frame') + else: + pos_changed = False + texture_array = self.texture_array + + if new_frame == self.pos_frame and not pos_changed: + # nothing to do + return + + if self.pos_is_movie: + n_frames = texture_array.shape[0] + if n_frames > new_frame: + self.pos_frame = new_frame + load_texture(texture_array[new_frame]) + else: + # current frame goes beyond end of movie + pass + else: + # this example is a static frame + load_texture(texture_array) + + # The main drawing function. + def on_idle(self): + # update state stuff pre-draw + self.draw_scene() + + # update state stuff post draw + self.rot += self.drot + + def draw_scene(self): + + xrot, yrot, zrot = self.rot + + glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT) # Clear The Screen And The Depth Buffer + glLoadIdentity() # Reset The View + glTranslatef(0.0,0.0,-5.0) # Move Into The Screen + + glRotatef(xrot,1.0,0.0,0.0) # Rotate The Cube On It's X Axis + glRotatef(yrot,0.0,1.0,0.0) # Rotate The Cube On It's Y Axis + glRotatef(zrot,0.0,0.0,1.0) # Rotate The Cube On It's Z Axis + + # Note there does not seem to be support for this call. + #glBindTexture(GL_TEXTURE_2D,texture) # Rotate The Pyramid On It's Y Axis + + glBegin(GL_QUADS) # Start Drawing The Cube + + # Front Face (note that the texture's corners have to match the quad's corners) + glTexCoord2f(0.0, 0.0); glVertex3f(-1.0, -1.0, 1.0) # Bottom Left Of The Texture and Quad + glTexCoord2f(1.0, 0.0); glVertex3f( 1.0, -1.0, 1.0) # Bottom Right Of The Texture and Quad + glTexCoord2f(1.0, 1.0); glVertex3f( 1.0, 1.0, 1.0) # Top Right Of The Texture and Quad + glTexCoord2f(0.0, 1.0); glVertex3f(-1.0, 1.0, 1.0) # Top Left Of The Texture and Quad + + # Back Face + glTexCoord2f(1.0, 0.0); glVertex3f(-1.0, -1.0, -1.0) # Bottom Right Of The Texture and Quad + glTexCoord2f(1.0, 1.0); glVertex3f(-1.0, 1.0, -1.0) # Top Right Of The Texture and Quad + glTexCoord2f(0.0, 1.0); glVertex3f( 1.0, 1.0, -1.0) # Top Left Of The Texture and Quad + glTexCoord2f(0.0, 0.0); glVertex3f( 1.0, -1.0, -1.0) # Bottom Left Of The Texture and Quad + + # Top Face + glTexCoord2f(0.0, 1.0); glVertex3f(-1.0, 1.0, -1.0) # Top Left Of The Texture and Quad + glTexCoord2f(0.0, 0.0); glVertex3f(-1.0, 1.0, 1.0) # Bottom Left Of The Texture and Quad + glTexCoord2f(1.0, 0.0); glVertex3f( 1.0, 1.0, 1.0) # Bottom Right Of The Texture and Quad + glTexCoord2f(1.0, 1.0); glVertex3f( 1.0, 1.0, -1.0) # Top Right Of The Texture and Quad + + # Bottom Face + glTexCoord2f(1.0, 1.0); glVertex3f(-1.0, -1.0, -1.0) # Top Right Of The Texture and Quad + glTexCoord2f(0.0, 1.0); glVertex3f( 1.0, -1.0, -1.0) # Top Left Of The Texture and Quad + glTexCoord2f(0.0, 0.0); glVertex3f( 1.0, -1.0, 1.0) # Bottom Left Of The Texture and Quad + glTexCoord2f(1.0, 0.0); glVertex3f(-1.0, -1.0, 1.0) # Bottom Right Of The Texture and Quad + + # Right face + glTexCoord2f(1.0, 0.0); glVertex3f( 1.0, -1.0, -1.0) # Bottom Right Of The Texture and Quad + glTexCoord2f(1.0, 1.0); glVertex3f( 1.0, 1.0, -1.0) # Top Right Of The Texture and Quad + glTexCoord2f(0.0, 1.0); glVertex3f( 1.0, 1.0, 1.0) # Top Left Of The Texture and Quad + glTexCoord2f(0.0, 0.0); glVertex3f( 1.0, -1.0, 1.0) # Bottom Left Of The Texture and Quad + + # Left Face + glTexCoord2f(0.0, 0.0); glVertex3f(-1.0, -1.0, -1.0) # Bottom Left Of The Texture and Quad + glTexCoord2f(1.0, 0.0); glVertex3f(-1.0, -1.0, 1.0) # Bottom Right Of The Texture and Quad + glTexCoord2f(1.0, 1.0); glVertex3f(-1.0, 1.0, 1.0) # Top Right Of The Texture and Quad + glTexCoord2f(0.0, 1.0); glVertex3f(-1.0, 1.0, -1.0) # Top Left Of The Texture and Quad + + glEnd(); # Done Drawing The Cube + + # since this is double buffered, swap the buffers to display what just got drawn. + glutSwapBuffers() + + # The function called whenever a key is pressed. Note the use of Python tuples to pass in: (key, x, y) + def keyPressed(self, *args): + ESCAPE = '\033' + + # EXAMPLE CONTROLS + + if args[0] == 'j': # down + self.refresh_texture(self.pos + 1, 0) + info( 'Current image: ', self.pos) + elif args[0] == 'k': # up + self.refresh_texture(self.pos - 1, 0) + info( 'Current image: ', self.pos) + elif args[0] == '0': # reset to position 0 + self.refresh_texture(0, 0) + info( 'Current image: ', self.pos) + + # FRAME CONTROLS + + elif args[0] == ')': # ')' is shift-0, reset to frame 0 + self.refresh_texture(self.pos, 0) + info( 'Current image: ', self.pos) + elif args[0] == 'J': # advance frame + self.refresh_texture(self.pos, self.pos_frame + 1) + info( 'Next frame') + elif args[0] == 'K': # advance frame + if self.pos_frame: + self.refresh_texture(self.pos, self.pos_frame - 1) + info( 'Previous frame') + else: + warn('Not backing up past frame 0') + + elif args[0] == ESCAPE or args[0]=='q': + sys.exit() + + +if __name__ == '__main__': + + sample_data = numpy.asarray(numpy.random.randint(low=0, high=256, size=(5, 64,64)), + dtype='uint8') + GlViewer(sample_data.__getitem__).main() +
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/pylearn/dataset_ops/memo.py Fri Oct 16 12:04:05 2009 -0400 @@ -0,0 +1,22 @@ +"""Provide a decorator that caches expensive functions +""" +import logging +_logger = logging.getLogger(__file__) +info = _logger.info +def infop(*args): + info(' '.join(str(a) for a in args)) + +def memo(f): + #TODO: support kwargs to rval. This requires looking up the names of f's parameters to + # construct a proper key. + + #TODO: use weak references instead of a normal dict so that the cache doesn't prevent + # garbage collection + cache = {} + def rval(*args): + if args not in cache: + cache[args] = f(*args) + return cache[args] + rval.__name__ = 'memo@%s'%f.__name__ + return rval +
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/pylearn/dataset_ops/protocol.py Fri Oct 16 12:04:05 2009 -0400 @@ -0,0 +1,83 @@ +"""Convenience base classes to help with writing Dataset ops + +""" + +__docformat__ = "restructuredtext_en" +import theano + +class Dataset(theano.Op): + """ + The basic dataset interface is an expression that maps an integer to a dataset element. + + There is also a minibatch option, in which the expression maps an array of integers to a + list or array of dataset elements. + """ + def __init__(self, single_type, batch_type): + self.single_type = single_type + self.batch_type = batch_type + + def make_node(self, idx): + _idx = theano.tensor.as_tensor_variable(idx) + if not _idx.dtype.startswith('int'): + raise TypeError() + if _idx.ndim == 0: # one example at a time + otype = self.single_type + elif _idx.ndim == 1: #many examples at a time + otype = self.batch_type + else: + raise TypeError(idx) + return theano.Apply(self, [_idx], [otype()]) + + def __eq__(self, other): + return type(self) == type(other) \ + and self.single_type == other.single_type \ + and self.batch_type == other.batch_type + + def __hash__(self): + return hash(type(self)) ^ hash(self.single_type) ^ hash(self.batch_type) + + def __str__(self): + return "%s{%s,%s}" % (self.__class__.__name__, self.single_type, self.batch_type) + + def grad(self, inputs, g_outputs): + return [None for i in inputs] + + +class TensorDataset(Dataset): + """A convenient base class for Datasets whose elements all have the same TensorType. + """ + def __init__(self, dtype, single_broadcastable, single_shape=None, batch_size=None): + single_broadcastable = tuple(single_broadcastable) + single_type = theano.tensor.Tensor( + broadcastable=single_broadcastable, + dtype=dtype, + shape=single_shape) + batch_type = theano.tensor.Tensor( + broadcastable=(False,)+single_type.broadcastable, + dtype=dtype, + shape=(batch_size,)+single_type.shape) + super(TensorDataset, self).__init__(single_type, batch_type) + +class TensorFnDataset(TensorDataset): + def __init__(self, dtype, bcast, fn, single_shape=None, batch_size=None): + super(TensorFnDataset, self).__init__(dtype, bcast, single_shape, batch_size) + self.fn = fn + + def __eq__(self, other): + return super(TensorFnDataset, self).__eq__(other) and self.fn == other.fn + + def __hash__(self): + return super(TensorFnDataset, self).__hash__() ^ hash(self.fn) + + def __str__(self): + try: + return "%s{%s}" % (self.__class__.__name__, self.fn.__name__) + except: + return "%s{%s}" % (self.__class__.__name__, self.fn) + + def perform(self, node, (idx,), (z,)): + x = self.fn() + if idx.ndim == 0: + z[0] = x[int(idx)] + else: + z[0] = x[idx]