Mercurial > sqlpython
diff sqlpyPlus.py @ 151:802d8df993da
midway through making plots saveable
author | catherine@dellzilla |
---|---|
date | Fri, 26 Sep 2008 16:11:29 -0400 |
parents | b00a020b81c6 |
children | c26bc528cb05 |
line wrap: on
line diff
--- a/sqlpyPlus.py Fri Sep 26 13:03:10 2008 -0400 +++ b/sqlpyPlus.py Fri Sep 26 16:11:29 2008 -0400 @@ -23,9 +23,13 @@ - catherinedevlin.blogspot.com May 31, 2006 """ -import sys, os, re, sqlpython, cx_Oracle, pyparsing, re, completion +import sys, os, re, sqlpython, cx_Oracle, pyparsing, re, completion, datetime, pickle from cmd2 import Cmd, make_option, options, Statekeeper from output_templates import * +try: + import pylab +except: + pass descQueries = { 'TABLE': (""" @@ -336,6 +340,52 @@ print 'Bind variable %s not defined.' % (varname) return result +try: + import pylab + class Plot(object): + plottable_types = (cx_Oracle.NUMBER, datetime.datetime) + def __init__(self): + self.legends = [] + self.yserieslists = [] + self.xticks = [] + def build(self, sqlSession): + self.title = sqlSession.tblname + self.xlabel = sqlSession.curs.description[0][0] + self.datatypes = [d[1] for d in sqlSession.curs.description] + for (colNum, datatype) in enumerate(self.datatypes): + if colNum > 0 and datatype in self.plottable_types: + self.yserieslists.append([row[colNum] for row in sqlSession.rows]) + self.legends.append(sqlSession.curs.description[colNum][0]) + if self.datatypes[0] in self.plottable_types: + self.xvalues = [r[0] for r in sqlSession.rows] + else: + self.xvalues = range(sqlSession.curs.rowcount) + self.xticks = [r[0] for r in sqlSession.rows] + def save(self): + pass + def draw(self): + if not self.yserieslists: + result = 'At least one quantitative column needed to plot.' + return result + if self.xticks: + pylab.xticks(self.xvalues, self.xticks) + for (colName, yseries) in self.yserieslists.items(): + pylab.plot(xvalues, yseries, '-o') + pylab.xlabel(self.xlabel) + pylab.title(self.title) + pylab.legend(self.legends) + pylab.show() + return 'If your lines zigzag, you may want to ORDER BY the x axis.' + +except ImportError: + class Plot(object): + def build(self, sqlSession): + pass + def save(self): + pass + def draw(self): + return 'Must install python-matplotlib to plot query results.' + class sqlpyPlus(sqlpython.sqlpython): defaultExtension = 'sql' sqlpython.sqlpython.shortcuts.update({':': 'setbind', '\\': 'psql', '@': '_load'}) @@ -406,7 +456,6 @@ (self.tblname, ','.join(self.colnames), formatRow(row)) for row in self.rows] return '\n'.join(result) - tableNameFinder = re.compile(r'from\s+([\w$#_"]+)', re.IGNORECASE | re.MULTILINE | re.DOTALL) def output(self, outformat, rowlimit): self.tblname = self.tableNameFinder.search(self.curs.statement).group(1) @@ -441,10 +490,15 @@ transpr[x][0] = rname newdesc[0][0] = 'COLUMN NAME' result = '\n' + sqlpython.pmatrix(transpr,newdesc) + elif outformat == '\\p': + plot = Plot() + plot.build(self) + plot.save() + return plot.draw() else: result = sqlpython.pmatrix(self.rows, self.curs.description, self.maxfetch) return result - + legalOracle = re.compile('[a-zA-Z_$#]') def select_scalar_list(self, sql, binds={}): @@ -495,7 +549,7 @@ return completions rowlimitPattern = pyparsing.Word(pyparsing.nums)('rowlimit') - terminatorPattern = (pyparsing.oneOf('; \\s \\S \\c \\C \\t \\x \\h \\g \\G \\i') + terminatorPattern = (pyparsing.oneOf('; \\s \\S \\c \\C \\t \\x \\h \\g \\G \\i \\p') ^ pyparsing.Literal('\n/') ^ \ (pyparsing.Literal('\nEOF') + pyparsing.stringEnd)) \ ('terminator') + \