diff sqlpyPlus.py @ 153:ebdd20cfba69

plots finally editable
author catherine@Elli.myhome.westell.com
date Mon, 29 Sep 2008 05:28:28 -0400
parents c26bc528cb05
children
line wrap: on
line diff
--- a/sqlpyPlus.py	Fri Sep 26 16:31:17 2008 -0400
+++ b/sqlpyPlus.py	Mon Sep 29 05:28:28 2008 -0400
@@ -25,7 +25,8 @@
 """
 import sys, os, re, sqlpython, cx_Oracle, pyparsing, re, completion, datetime, pickle
 from cmd2 import Cmd, make_option, options, Statekeeper
-from output_templates import *
+from output_templates import output_templates
+from plothandler import Plot
 try:
     import pylab
 except:
@@ -339,53 +340,7 @@
             if not givenBindVars.has_key(varname):
                 print 'Bind variable %s not defined.' % (varname)                
     return result
-
-try:
-    import pylab
-    class Plot(object):
-        plottable_types = (cx_Oracle.NUMBER, datetime.datetime)    
-        def __init__(self):
-            self.legends = []
-            self.yserieslists = []
-            self.xticks = []
-        def build(self, sqlSession):
-            self.title = sqlSession.tblname
-            self.xlabel = sqlSession.curs.description[0][0]
-            self.datatypes = [d[1] for d in sqlSession.curs.description]
-            for (colNum, datatype) in enumerate(self.datatypes):
-                if colNum > 0 and datatype in self.plottable_types:
-                    self.yserieslists.append([row[colNum] for row in sqlSession.rows])
-                    self.legends.append(sqlSession.curs.description[colNum][0])
-            if self.datatypes[0] in self.plottable_types:
-                self.xvalues = [r[0] for r in sqlSession.rows]
-            else:
-                self.xvalues = range(sqlSession.curs.rowcount)
-                self.xticks = [r[0] for r in sqlSession.rows]
-        def save(self):
-            pass
-        def draw(self):
-            if not self.yserieslists:
-                result = 'At least one quantitative column needed to plot.'
-                return result       
-            if self.xticks:
-                pylab.xticks(self.xvalues, self.xticks)
-            for yseries in self.yserieslists:
-                pylab.plot(self.xvalues, yseries, '-o')
-            pylab.xlabel(self.xlabel)
-            pylab.title(self.title)
-            pylab.legend(self.legends)
-            pylab.show()
-            return 'If your lines zigzag, you may want to ORDER BY the x axis.'
-            
-except ImportError:
-    class Plot(object):
-        def build(self, sqlSession):
-            pass
-        def save(self):
-            pass
-        def draw(self):
-            return 'Must install python-matplotlib to plot query results.'
-        
+       
 class sqlpyPlus(sqlpython.sqlpython):
     defaultExtension = 'sql'
     sqlpython.sqlpython.shortcuts.update({':': 'setbind', '\\': 'psql', '@': '_load'})
@@ -462,13 +417,9 @@
         self.colnames = [d[0] for d in self.curs.description]
         if outformat == '\\i':
             result = self.output_as_insert_statements()
-        elif outformat ==  '\\x':
-            result = xml_template.generate(**self.__dict__)
-        elif outformat == '\\g':
-            result = list_template.generate(**self.__dict__)
-        elif outformat == '\\G':
+        elif outformat in output_templates:
             self.colnamelen = max(len(colname) for colname in self.colnames)
-            result = aligned_list_template.generate(**self.__dict__)
+            result = output_templates[outformat].generate(**self.__dict__)
         elif outformat in ('\\s', '\\S', '\\c', '\\C'): #csv
             result = []
             if outformat in ('\\s', '\\c'):
@@ -476,8 +427,6 @@
             for row in self.rows:
                 result.append(','.join('"%s"' % self.str_or_empty(itm) for itm in row))
             result = '\n'.join(result)
-        elif outformat == '\\h':
-            result = html_template.generate(**self.__dict__)
         elif outformat == '\\t': # transposed
             rows = [self.colnames]
             rows.extend(list(self.rows))
@@ -493,8 +442,9 @@
         elif outformat == '\\p':
             plot = Plot()
             plot.build(self)
-            plot.save()
-            return plot.draw()
+            plot.shelve()
+            plot.draw()
+            return ''
         else:
             result = sqlpython.pmatrix(self.rows, self.curs.description, self.maxfetch)
         return result
@@ -549,7 +499,8 @@
         return completions
     
     rowlimitPattern = pyparsing.Word(pyparsing.nums)('rowlimit')
-    terminatorPattern = (pyparsing.oneOf('; \\s \\S \\c \\C \\t \\x \\h \\g \\G \\i \\p')    
+    rawTerminators = '; \\s \\S \\c \\C \\t \\i \\p ' + ' '.join(output_templates.keys())
+    terminatorPattern = (pyparsing.oneOf(rawTerminators)    
                         ^ pyparsing.Literal('\n/') ^ \
                         (pyparsing.Literal('\nEOF') + pyparsing.stringEnd)) \
                         ('terminator') + \