diff sqlpyPlus.py @ 151:802d8df993da

midway through making plots saveable
author catherine@dellzilla
date Fri, 26 Sep 2008 16:11:29 -0400
parents b00a020b81c6
children c26bc528cb05
line wrap: on
line diff
--- a/sqlpyPlus.py	Fri Sep 26 13:03:10 2008 -0400
+++ b/sqlpyPlus.py	Fri Sep 26 16:11:29 2008 -0400
@@ -23,9 +23,13 @@
 
 - catherinedevlin.blogspot.com  May 31, 2006
 """
-import sys, os, re, sqlpython, cx_Oracle, pyparsing, re, completion
+import sys, os, re, sqlpython, cx_Oracle, pyparsing, re, completion, datetime, pickle
 from cmd2 import Cmd, make_option, options, Statekeeper
 from output_templates import *
+try:
+    import pylab
+except:
+    pass
 
 descQueries = {
 'TABLE': ("""
@@ -336,6 +340,52 @@
                 print 'Bind variable %s not defined.' % (varname)                
     return result
 
+try:
+    import pylab
+    class Plot(object):
+        plottable_types = (cx_Oracle.NUMBER, datetime.datetime)    
+        def __init__(self):
+            self.legends = []
+            self.yserieslists = []
+            self.xticks = []
+        def build(self, sqlSession):
+            self.title = sqlSession.tblname
+            self.xlabel = sqlSession.curs.description[0][0]
+            self.datatypes = [d[1] for d in sqlSession.curs.description]
+            for (colNum, datatype) in enumerate(self.datatypes):
+                if colNum > 0 and datatype in self.plottable_types:
+                    self.yserieslists.append([row[colNum] for row in sqlSession.rows])
+                    self.legends.append(sqlSession.curs.description[colNum][0])
+            if self.datatypes[0] in self.plottable_types:
+                self.xvalues = [r[0] for r in sqlSession.rows]
+            else:
+                self.xvalues = range(sqlSession.curs.rowcount)
+                self.xticks = [r[0] for r in sqlSession.rows]
+        def save(self):
+            pass
+        def draw(self):
+            if not self.yserieslists:
+                result = 'At least one quantitative column needed to plot.'
+                return result       
+            if self.xticks:
+                pylab.xticks(self.xvalues, self.xticks)
+            for (colName, yseries) in self.yserieslists.items():
+                pylab.plot(xvalues, yseries, '-o')
+            pylab.xlabel(self.xlabel)
+            pylab.title(self.title)
+            pylab.legend(self.legends)
+            pylab.show()
+            return 'If your lines zigzag, you may want to ORDER BY the x axis.'
+            
+except ImportError:
+    class Plot(object):
+        def build(self, sqlSession):
+            pass
+        def save(self):
+            pass
+        def draw(self):
+            return 'Must install python-matplotlib to plot query results.'
+        
 class sqlpyPlus(sqlpython.sqlpython):
     defaultExtension = 'sql'
     sqlpython.sqlpython.shortcuts.update({':': 'setbind', '\\': 'psql', '@': '_load'})
@@ -406,7 +456,6 @@
                   (self.tblname, ','.join(self.colnames), formatRow(row))
                   for row in self.rows]
         return '\n'.join(result)
-    
     tableNameFinder = re.compile(r'from\s+([\w$#_"]+)', re.IGNORECASE | re.MULTILINE | re.DOTALL)          
     def output(self, outformat, rowlimit):
         self.tblname = self.tableNameFinder.search(self.curs.statement).group(1)
@@ -441,10 +490,15 @@
                     transpr[x][0] = rname
             newdesc[0][0] = 'COLUMN NAME'
             result = '\n' + sqlpython.pmatrix(transpr,newdesc)            
+        elif outformat == '\\p':
+            plot = Plot()
+            plot.build(self)
+            plot.save()
+            return plot.draw()
         else:
             result = sqlpython.pmatrix(self.rows, self.curs.description, self.maxfetch)
         return result
-
+        
     legalOracle = re.compile('[a-zA-Z_$#]')
 
     def select_scalar_list(self, sql, binds={}):
@@ -495,7 +549,7 @@
         return completions
     
     rowlimitPattern = pyparsing.Word(pyparsing.nums)('rowlimit')
-    terminatorPattern = (pyparsing.oneOf('; \\s \\S \\c \\C \\t \\x \\h \\g \\G \\i')    
+    terminatorPattern = (pyparsing.oneOf('; \\s \\S \\c \\C \\t \\x \\h \\g \\G \\i \\p')    
                         ^ pyparsing.Literal('\n/') ^ \
                         (pyparsing.Literal('\nEOF') + pyparsing.stringEnd)) \
                         ('terminator') + \