"""This file provides a very crude image-viewing and movie-viewing mini-application.

It is provided to assist in the development of datasets whose elements are images or movies.
For an example of how to do this, see the `glviewer` function in .

Currently, the key controls that navigate the dataset are:

    j - next dataset element
    k - previous dataset element
    0 - show image 0

    J - next frame in current movie
    K - previous frame in current movie
    ) - show frame 0 of current movie

    q - quit.
# Modified to be an image viewer by James Bergstra Sept 2009
# Ported to PyOpenGL 2.0 by Tarn Weisner Burton 10May2001
# This code was created by Richard Campbell '99 (ported to Python/PyOpenGL by John Ferguson 2000)
# The port was based on the lesson5 tutorial module by Tony Colston (  
# If you've found this code useful, please let me know (email John Ferguson at
# See original source and C based tutorial at

import traceback
import time
import string
from OpenGL.GL import *
from OpenGL.GLUT import *
from OpenGL.GLU import *
import sys
from Image import *
import numpy

import logging
_logger = logging.getLogger('glviewer')
def debug(*msg): _logger.debug(' '.join(str(m) for m in msg))
def info(*msg):' '.join(str(m) for m in msg))
def warn(*msg): _logger.warn(' '.join(str(m) for m in msg))
def warning(*msg): _logger.warning(' '.join(str(m) for m in msg))
def error(*msg): _logger.error(' '.join(str(m) for m in msg))

def load_texture(x):
    debug('loading texture with shape', x.shape)
    if x.ndim == 2:
        if x.dtype == 'uint8':
            rows, cols = x.shape
            buf = numpy.zeros((rows, cols, 4), dtype=x.dtype)
            buf += x.reshape( (rows, cols, 1))
            return glTexImage2D(GL_TEXTURE_2D, 0, 3, cols, rows, 0, GL_RGBA, GL_UNSIGNED_BYTE,
        elif str(x.dtype).startswith('float'):
            rows, cols = x.shape
            buf = numpy.zeros((rows, cols, 4), dtype='uint8')
            buf += x.reshape( (rows, cols, 1)) * 255
            debug( 'HHUH? %f'% buf.sum())
            return glTexImage2D(GL_TEXTURE_2D, 0, 3, cols, rows, 0, GL_RGBA, GL_UNSIGNED_BYTE,
            raise NotImplementedError()
    elif x.ndim == 3:
        rows, cols, channels = x.shape
        if x.dtype == 'uint8':
            if channels == 4:
                return glTexImage2D(GL_TEXTURE_2D, 0, 3, cols, rows, 0, GL_RGBA, GL_UNSIGNED_BYTE, x[::-1].flatten())
                buf = numpy.zeros((rows, cols, 4), dtype=x.dtype)
                if channels == 1:
                    buf += x.reshape( (rows, cols, 1))
                if channels == 3:
                    buf[:,:,:3] = x
                return glTexImage2D(GL_TEXTURE_2D, 0, 3, cols, rows, 0, GL_RGBA, GL_UNSIGNED_BYTE, buf[::-1].flatten())
            raise NotImplementedError()
        raise NotImplementedError()

    # if you get here, it means a case was missed
    assert 0

class GlViewer(object):
    # Number of the glut window.
    window = 0

    view_angle = 28.0  # this makes the edge of the cube match up with the viewport

    def __init__(self, texture_fn):

        # Rotations for cube. 
        self.xrot = self.yrot = self.zrot = 0.0

        self.texture = 0

        self.texture_fn = texture_fn

        self.pos = -1
        self.pos_frame = -1
        self.pos_is_movie = False
        self.texture_array = None

        self.win_shape = (256, 256)
        self.rot = numpy.zeros(3)
        self.drot = numpy.ones(3) * .0

    def init_LoadTextures(self):
        # Create Texture    
        glBindTexture(GL_TEXTURE_2D, glGenTextures(1))   # 2d texture (x and y size)
        self.refresh_texture(0, 0)
        glTexParameterf(GL_TEXTURE_2D, GL_TEXTURE_WRAP_S, GL_CLAMP)
        glTexParameterf(GL_TEXTURE_2D, GL_TEXTURE_WRAP_T, GL_CLAMP)
        glTexParameterf(GL_TEXTURE_2D, GL_TEXTURE_WRAP_S, GL_REPEAT)
        glTexParameterf(GL_TEXTURE_2D, GL_TEXTURE_WRAP_T, GL_REPEAT)
    # A general OpenGL initialization function.  Sets all of the initial parameters. 
    def init_GL(self):
        glClearColor(0.0, 0.0, 0.0, 0.0)    # This Will Clear The Background Color To Black
        glClearDepth(1.0)                   # Enables Clearing Of The Depth Buffer
        glDepthFunc(GL_LESS)                # The Type Of Depth Test To Do
        glEnable(GL_DEPTH_TEST)             # Enables Depth Testing
        glShadeModel(GL_SMOOTH)             # Enables Smooth Color Shading
        glLoadIdentity()                    # Reset The Projection Matrix
                                            # Calculate The Aspect Ratio Of The Window
        Width, Height = self.win_shape
        gluPerspective(self.view_angle, float(Width)/float(Height), 0.1, 100.0)


    def main(self):
        # texture gen: an iterator over images
        # Call this function like this:
        # python -c 'import MNIST, glviewer; glviewer.main(x for (x,y) in MNIST.MNIST().train())'

        #TODO: this advances the iterator un-necessarily... we just want a frame to look at the
        #      dimensions

        global window

        # Select type of Display mode:   
        #  Double buffer 
        #  RGBA color
        # Alpha components supported 
        # Depth buffer
        info('initializing OpenGl subsystem')
        ##glutInitDisplayMode(GLUT_RGBA | GLUT_DOUBLE | GLUT_DEPTH)

        win_width, win_height = self.win_shape
        # get a 640 x 480 window 
        ##glutInitWindowSize(win_width, win_height)
        # the window starts at the upper left corner of the screen 
        glutInitWindowPosition(0, 0)
        # Okay, like the C version we retain the window id to use when closing, but for those of you new
        # to Python (like myself), remember this assignment would make the variable local and not global
        # if it weren't for the global declaration at the start of main.
        window = glutCreateWindow("GlViewer")

        # Register the drawing function with glut, BUT in Python land, at least using PyOpenGL, we need to
        # set the function pointer and invoke a function to actually register the callback, otherwise it
        # would be very much like the C version of the code.    
        # Uncomment this line to get full screen.
        # glutFullScreen()

        # When we are doing nothing, redraw the scene.
        # Register the function called when our window is resized.
        # Register the function called when the keyboard is pressed.  

        # create the texture we will use for showing images

        # Initialize our window. 

        # Start Event Processing Engine 

    # The function called when our window is resized (which shouldn't happen if you enable fullscreen, below)
    def ReSizeGLScene(self, Width, Height):
        if Height == 0:                       # Prevent A Divide By Zero If The Window Is Too Small 
            Height = 1

        glViewport(0, 0, Width, Height)       # Reset The Current Viewport And Perspective Transformation
        gluPerspective(self.view_angle, float(Width)/float(Height), 0.1, 100.0)

        self.win_shape = (Width, Height)

    def refresh_texture(self, new_pos, new_frame):
        debug('refresh_texture', new_pos, new_frame, 'current', self.pos, self.pos_frame)
        if new_pos != self.pos:
            texture_array = None
                texture_array = self.texture_fn(new_pos)
            except Exception, e:

            if texture_array is None:
            # calling the texture_fn can mess up the OpenGL state
            # here we set it up again

            if texture_array.ndim == 4:
                self.pos_is_movie = True
            if texture_array.ndim == 3 and texture_array.shape[2] > 4:
                self.pos_is_movie = True

            self.pos = new_pos
            self.texture_array = texture_array
            pos_changed = True
            if self.pos_is_movie:
                info('example', new_pos, 'is movie of', texture_array.shape[0], 'frames')
                info('example', new_pos, 'is still frame')
            pos_changed = False
            texture_array = self.texture_array

        if new_frame == self.pos_frame and not pos_changed:
            # nothing to do

        if self.pos_is_movie:
            n_frames = texture_array.shape[0]
            if n_frames > new_frame:
                self.pos_frame = new_frame
                # current frame goes beyond end of movie
            # this example is a static frame

    # The main drawing function. 
    def on_idle(self):
        # update state stuff pre-draw

        # update state stuff post draw
        self.rot += self.drot

    def draw_scene(self):

        xrot, yrot, zrot = self.rot

        glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT)      # Clear The Screen And The Depth Buffer
        glLoadIdentity()                                        # Reset The View
        glTranslatef(0.0,0.0,-5.0)                      # Move Into The Screen

        glRotatef(xrot,1.0,0.0,0.0)                     # Rotate The Cube On It's X Axis
        glRotatef(yrot,0.0,1.0,0.0)                     # Rotate The Cube On It's Y Axis
        glRotatef(zrot,0.0,0.0,1.0)                     # Rotate The Cube On It's Z Axis
        # Note there does not seem to be support for this call.
        #glBindTexture(GL_TEXTURE_2D,texture)   # Rotate The Pyramid On It's Y Axis

        glBegin(GL_QUADS)                           # Start Drawing The Cube
        # Front Face (note that the texture's corners have to match the quad's corners)
        glTexCoord2f(0.0, 0.0); glVertex3f(-1.0, -1.0,  1.0)    # Bottom Left Of The Texture and Quad
        glTexCoord2f(1.0, 0.0); glVertex3f( 1.0, -1.0,  1.0)    # Bottom Right Of The Texture and Quad
        glTexCoord2f(1.0, 1.0); glVertex3f( 1.0,  1.0,  1.0)    # Top Right Of The Texture and Quad
        glTexCoord2f(0.0, 1.0); glVertex3f(-1.0,  1.0,  1.0)    # Top Left Of The Texture and Quad
        # Back Face
        glTexCoord2f(1.0, 0.0); glVertex3f(-1.0, -1.0, -1.0)    # Bottom Right Of The Texture and Quad
        glTexCoord2f(1.0, 1.0); glVertex3f(-1.0,  1.0, -1.0)    # Top Right Of The Texture and Quad
        glTexCoord2f(0.0, 1.0); glVertex3f( 1.0,  1.0, -1.0)    # Top Left Of The Texture and Quad
        glTexCoord2f(0.0, 0.0); glVertex3f( 1.0, -1.0, -1.0)    # Bottom Left Of The Texture and Quad
        # Top Face
        glTexCoord2f(0.0, 1.0); glVertex3f(-1.0,  1.0, -1.0)    # Top Left Of The Texture and Quad
        glTexCoord2f(0.0, 0.0); glVertex3f(-1.0,  1.0,  1.0)    # Bottom Left Of The Texture and Quad
        glTexCoord2f(1.0, 0.0); glVertex3f( 1.0,  1.0,  1.0)    # Bottom Right Of The Texture and Quad
        glTexCoord2f(1.0, 1.0); glVertex3f( 1.0,  1.0, -1.0)    # Top Right Of The Texture and Quad
        # Bottom Face       
        glTexCoord2f(1.0, 1.0); glVertex3f(-1.0, -1.0, -1.0)    # Top Right Of The Texture and Quad
        glTexCoord2f(0.0, 1.0); glVertex3f( 1.0, -1.0, -1.0)    # Top Left Of The Texture and Quad
        glTexCoord2f(0.0, 0.0); glVertex3f( 1.0, -1.0,  1.0)    # Bottom Left Of The Texture and Quad
        glTexCoord2f(1.0, 0.0); glVertex3f(-1.0, -1.0,  1.0)    # Bottom Right Of The Texture and Quad
        # Right face
        glTexCoord2f(1.0, 0.0); glVertex3f( 1.0, -1.0, -1.0)    # Bottom Right Of The Texture and Quad
        glTexCoord2f(1.0, 1.0); glVertex3f( 1.0,  1.0, -1.0)    # Top Right Of The Texture and Quad
        glTexCoord2f(0.0, 1.0); glVertex3f( 1.0,  1.0,  1.0)    # Top Left Of The Texture and Quad
        glTexCoord2f(0.0, 0.0); glVertex3f( 1.0, -1.0,  1.0)    # Bottom Left Of The Texture and Quad
        # Left Face
        glTexCoord2f(0.0, 0.0); glVertex3f(-1.0, -1.0, -1.0)    # Bottom Left Of The Texture and Quad
        glTexCoord2f(1.0, 0.0); glVertex3f(-1.0, -1.0,  1.0)    # Bottom Right Of The Texture and Quad
        glTexCoord2f(1.0, 1.0); glVertex3f(-1.0,  1.0,  1.0)    # Top Right Of The Texture and Quad
        glTexCoord2f(0.0, 1.0); glVertex3f(-1.0,  1.0, -1.0)    # Top Left Of The Texture and Quad
        glEnd();                                # Done Drawing The Cube

        #  since this is double buffered, swap the buffers to display what just got drawn. 

    # The function called whenever a key is pressed. Note the use of Python tuples to pass in: (key, x, y)  
    def keyPressed(self, *args):
        ESCAPE = '\033'


        if args[0] == 'j': # down
            self.refresh_texture(self.pos + 1, 0)
            info( 'Current image: ', self.pos)
        elif args[0] == 'k': # up
            self.refresh_texture(self.pos - 1, 0)
            info( 'Current image: ', self.pos)
        elif args[0] == '0': # reset to position 0
            self.refresh_texture(0, 0)
            info( 'Current image: ', self.pos)


        elif args[0] == ')': # ')' is shift-0,  reset to frame 0
            self.refresh_texture(self.pos, 0)
            info( 'Current image: ', self.pos)
        elif args[0] == 'J': # advance frame
            self.refresh_texture(self.pos, self.pos_frame + 1)
            info( 'Next frame')
        elif args[0] == 'K': # advance frame
            if self.pos_frame:
                self.refresh_texture(self.pos, self.pos_frame - 1)
                info( 'Previous frame')
                warn('Not backing up past frame 0')

        elif args[0] == ESCAPE or args[0]=='q':

    def print_key_help(self):
        print "Program controls:"
        print "  q: quit"
        print ""
        print "Example controls:"
        print "  0: reset to example 0"
        print "  j: next"
        print "  k: prev"
        print ""
        print "Frame controls (for movies)"
        print "  ): reset to frame 0"
        print "  J: forward"
        print "  K: backward"
        print "Hint: Hold keys down for continuous play"
        print ""

if __name__ == '__main__':

    sample_data = numpy.asarray(numpy.random.randint(low=0, high=256, size=(5, 64,64)),