changeset 153:ebdd20cfba69

plots finally editable
author catherine@Elli.myhome.westell.com
date Mon, 29 Sep 2008 05:28:28 -0400
parents c26bc528cb05
children 4680d0629b82
files editplot.bash editplot.py output_templates.py plothandler.py sqlpyPlus.py
diffstat 5 files changed, 89 insertions(+), 68 deletions(-) [+]
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/editplot.bash	Mon Sep 29 05:28:28 2008 -0400
@@ -0,0 +1,1 @@
+ipython -pylab editplot.py
\ No newline at end of file
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/editplot.py	Mon Sep 29 05:28:28 2008 -0400
@@ -0,0 +1,4 @@
+#!/bin/bash
+from plothandler import Plot
+
+Plot().unshelve()
\ No newline at end of file
--- a/output_templates.py	Fri Sep 26 16:31:17 2008 -0400
+++ b/output_templates.py	Mon Sep 29 05:28:28 2008 -0400
@@ -1,15 +1,20 @@
 import genshi.template
 
-xml_template = genshi.template.NewTextTemplate("""
+# To make more output formats available to sqlpython, just edit this
+# file, or place a copy in your local directory and edit that.
+
+output_templates = {
+
+'\\x': genshi.template.NewTextTemplate("""
 <xml>
   <${tblname}_resultset>{% for row in rows %}
     <$tblname>{% for (colname, itm) in zip(colnames, row) %}
       <${colname.lower()}>$itm</${colname.lower()}>{% end %}
     </$tblname>{% end %}
   </${tblname}_resultset>
-</xml>""")
-    
-html_template = genshi.template.MarkupTemplate("""
+</xml>"""),
+
+'\\h': genshi.template.MarkupTemplate("""
 <!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Strict//EN" "http://www.w3.org/TR/xhtml1/DTD/xhtml1-strict.dtd">
 <html xmlns:py="http://genshi.edgewall.org/" xmlns="http://www.w3.org/1999/xhtml" xml:lang="en" lang="en">
   <head>
@@ -32,16 +37,18 @@
       </tr>
     </table>
   </body>
-</html>""")
+</html>"""),
 
-list_template = genshi.template.NewTextTemplate("""
+'\\g': genshi.template.NewTextTemplate("""
 {% for (rowNum, row) in enumerate(rows) %}
 **** Row: ${rowNum + 1}
 {% for (colname, itm) in zip(colnames, row) %}$colname: $itm
-{% end %}{% end %}""")
+{% end %}{% end %}"""),
 
-aligned_list_template = genshi.template.NewTextTemplate("""
+'\\G': genshi.template.NewTextTemplate("""
 {% for (rowNum, row) in enumerate(rows) %}
 **** Row: ${rowNum + 1}
 {% for (colname, itm) in zip(colnames, row) %}${colname.ljust(colnamelen)}: $itm
-{% end %}{% end %}""")
+{% end %}{% end %}"""),
+
+}
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/plothandler.py	Mon Sep 29 05:28:28 2008 -0400
@@ -0,0 +1,58 @@
+import shelve, pickle, cx_Oracle, datetime, sys
+shelvename = 'plot.shelve'
+
+try:
+    import pylab
+    class Plot(object):
+        plottable_types = (cx_Oracle.NUMBER, datetime.datetime)    
+        def __init__(self):
+            self.legends = []
+            self.yserieslists = []
+            self.xticks = []
+        def build(self, sqlSession):
+            self.title = sqlSession.tblname
+            self.xlabel = sqlSession.curs.description[0][0]
+            self.datatypes = [d[1] for d in sqlSession.curs.description]
+            for (colNum, datatype) in enumerate(self.datatypes):
+                if colNum > 0 and datatype in self.plottable_types:
+                    yseries = [row[colNum] for row in sqlSession.rows]
+                    if max(yseries) is not None:
+                        self.yserieslists.append(yseries)
+                        self.legends.append(sqlSession.curs.description[colNum][0])
+            if self.datatypes[0] in self.plottable_types:
+                self.xvalues = [r[0] for r in sqlSession.rows]
+            else:
+                self.xvalues = range(sqlSession.curs.rowcount)
+                self.xticks = [r[0] for r in sqlSession.rows]
+        def shelve(self):
+            s = shelve.open(shelvename,'c')
+            for k in ('xvalues xticks yserieslists title legends xlabel'.split()):
+                s[k] = getattr(self, k)
+            s.close()
+            # reading pickles fails with EOF error, don't understand
+        def unshelve(self):
+            s = shelve.open(shelvename)
+            self.__dict__.update(s)
+            s.close()
+            self.draw()            
+        def draw(self):
+            if not self.yserieslists:
+                print 'At least one quantitative column needed to plot.'
+                return None
+            for yseries in self.yserieslists:
+                pylab.plot(self.xvalues, yseries, '-o')
+            if self.xticks:
+                pylab.xticks(self.xvalues, self.xticks)
+            pylab.xlabel(self.xlabel)
+            pylab.title(self.title)
+            pylab.legend(self.legends)
+            pylab.show()
+            
+except ImportError:
+    class Plot(object):
+        def build(self, sqlSession):
+            pass
+        def save(self):
+            pass
+        def draw(self):
+            return 'Must install python-matplotlib to plot query results.'
\ No newline at end of file
--- a/sqlpyPlus.py	Fri Sep 26 16:31:17 2008 -0400
+++ b/sqlpyPlus.py	Mon Sep 29 05:28:28 2008 -0400
@@ -25,7 +25,8 @@
 """
 import sys, os, re, sqlpython, cx_Oracle, pyparsing, re, completion, datetime, pickle
 from cmd2 import Cmd, make_option, options, Statekeeper
-from output_templates import *
+from output_templates import output_templates
+from plothandler import Plot
 try:
     import pylab
 except:
@@ -339,53 +340,7 @@
             if not givenBindVars.has_key(varname):
                 print 'Bind variable %s not defined.' % (varname)                
     return result
-
-try:
-    import pylab
-    class Plot(object):
-        plottable_types = (cx_Oracle.NUMBER, datetime.datetime)    
-        def __init__(self):
-            self.legends = []
-            self.yserieslists = []
-            self.xticks = []
-        def build(self, sqlSession):
-            self.title = sqlSession.tblname
-            self.xlabel = sqlSession.curs.description[0][0]
-            self.datatypes = [d[1] for d in sqlSession.curs.description]
-            for (colNum, datatype) in enumerate(self.datatypes):
-                if colNum > 0 and datatype in self.plottable_types:
-                    self.yserieslists.append([row[colNum] for row in sqlSession.rows])
-                    self.legends.append(sqlSession.curs.description[colNum][0])
-            if self.datatypes[0] in self.plottable_types:
-                self.xvalues = [r[0] for r in sqlSession.rows]
-            else:
-                self.xvalues = range(sqlSession.curs.rowcount)
-                self.xticks = [r[0] for r in sqlSession.rows]
-        def save(self):
-            pass
-        def draw(self):
-            if not self.yserieslists:
-                result = 'At least one quantitative column needed to plot.'
-                return result       
-            if self.xticks:
-                pylab.xticks(self.xvalues, self.xticks)
-            for yseries in self.yserieslists:
-                pylab.plot(self.xvalues, yseries, '-o')
-            pylab.xlabel(self.xlabel)
-            pylab.title(self.title)
-            pylab.legend(self.legends)
-            pylab.show()
-            return 'If your lines zigzag, you may want to ORDER BY the x axis.'
-            
-except ImportError:
-    class Plot(object):
-        def build(self, sqlSession):
-            pass
-        def save(self):
-            pass
-        def draw(self):
-            return 'Must install python-matplotlib to plot query results.'
-        
+       
 class sqlpyPlus(sqlpython.sqlpython):
     defaultExtension = 'sql'
     sqlpython.sqlpython.shortcuts.update({':': 'setbind', '\\': 'psql', '@': '_load'})
@@ -462,13 +417,9 @@
         self.colnames = [d[0] for d in self.curs.description]
         if outformat == '\\i':
             result = self.output_as_insert_statements()
-        elif outformat ==  '\\x':
-            result = xml_template.generate(**self.__dict__)
-        elif outformat == '\\g':
-            result = list_template.generate(**self.__dict__)
-        elif outformat == '\\G':
+        elif outformat in output_templates:
             self.colnamelen = max(len(colname) for colname in self.colnames)
-            result = aligned_list_template.generate(**self.__dict__)
+            result = output_templates[outformat].generate(**self.__dict__)
         elif outformat in ('\\s', '\\S', '\\c', '\\C'): #csv
             result = []
             if outformat in ('\\s', '\\c'):
@@ -476,8 +427,6 @@
             for row in self.rows:
                 result.append(','.join('"%s"' % self.str_or_empty(itm) for itm in row))
             result = '\n'.join(result)
-        elif outformat == '\\h':
-            result = html_template.generate(**self.__dict__)
         elif outformat == '\\t': # transposed
             rows = [self.colnames]
             rows.extend(list(self.rows))
@@ -493,8 +442,9 @@
         elif outformat == '\\p':
             plot = Plot()
             plot.build(self)
-            plot.save()
-            return plot.draw()
+            plot.shelve()
+            plot.draw()
+            return ''
         else:
             result = sqlpython.pmatrix(self.rows, self.curs.description, self.maxfetch)
         return result
@@ -549,7 +499,8 @@
         return completions
     
     rowlimitPattern = pyparsing.Word(pyparsing.nums)('rowlimit')
-    terminatorPattern = (pyparsing.oneOf('; \\s \\S \\c \\C \\t \\x \\h \\g \\G \\i \\p')    
+    rawTerminators = '; \\s \\S \\c \\C \\t \\i \\p ' + ' '.join(output_templates.keys())
+    terminatorPattern = (pyparsing.oneOf(rawTerminators)    
                         ^ pyparsing.Literal('\n/') ^ \
                         (pyparsing.Literal('\nEOF') + pyparsing.stringEnd)) \
                         ('terminator') + \