Mercurial > sqlpython
changeset 339:545f63b6ef42
supports bind variables in postgresql
author | Catherine Devlin <catherine.devlin@gmail.com> |
---|---|
date | Thu, 09 Apr 2009 14:42:41 -0400 |
parents | a8835fe129f6 |
children | 001d01eeac90 |
files | sqlpython/sqlpyPlus.py sqlpython/sqlpython.py |
diffstat | 2 files changed, 44 insertions(+), 40 deletions(-) [+] |
line wrap: on
line diff
--- a/sqlpython/sqlpyPlus.py Thu Apr 09 00:15:09 2009 -0400 +++ b/sqlpython/sqlpyPlus.py Thu Apr 09 14:42:41 2009 -0400 @@ -28,6 +28,7 @@ from output_templates import output_templates from metadata import metaqueries from plothandler import Plot +from sqlpython import Parser try: import pylab except (RuntimeError, ImportError): @@ -227,42 +228,6 @@ def pop(self, key, def_val=None): return dict.pop(self, key.lower(), def_val) -class Parser(object): - comment_def = "--" + ~ ('-' + pyparsing.CaselessKeyword('begin')) + pyparsing.ZeroOrMore(pyparsing.CharsNotIn("\n")) - def __init__(self, scanner, retainSeparator=True): - self.scanner = scanner - self.scanner.ignore(pyparsing.sglQuotedString) - self.scanner.ignore(pyparsing.dblQuotedString) - self.scanner.ignore(self.comment_def) - self.scanner.ignore(pyparsing.cStyleComment) - self.retainSeparator = retainSeparator - def separate(self, txt): - itms = [] - for (sqlcommand, start, end) in self.scanner.scanString(txt): - if sqlcommand: - if type(sqlcommand[0]) == pyparsing.ParseResults: - if self.retainSeparator: - itms.append("".join(sqlcommand[0])) - else: - itms.append(sqlcommand[0][0]) - else: - if sqlcommand[0]: - itms.append(sqlcommand[0]) - return itms - -bindScanner = Parser(pyparsing.Literal(':') + pyparsing.Word( pyparsing.alphanums + "_$#" )) - -def findBinds(target, existingBinds, givenBindVars = {}): - result = givenBindVars - for finding, startat, endat in bindScanner.scanner.scanString(target): - varname = finding[1] - try: - result[varname] = existingBinds[varname] - except KeyError: - if not givenBindVars.has_key(varname): - print 'Bind variable %s not defined.' % (varname) - return result - class ResultSet(list): pass @@ -690,7 +655,7 @@ self.perror("Specify desired number of rows after terminator (not '%s')" % arg.parsed.suffix) if arg.parsed.terminator == '\\t': rowlimit = rowlimit or self.maxtselctrows - self.varsUsed = findBinds(arg, self.binds, bindVarsIn) + self.varsUsed = self.findBinds(arg, self.binds, bindVarsIn) if self.wildsql: selecttext = self.expandWildSql(arg) else: @@ -1268,7 +1233,7 @@ if arg.startswith(':'): self.do_setbind(arg[1:]) else: - varsUsed = findBinds(arg, self.binds, {}) + varsUsed = self.findBinds(arg, self.binds, {}) try: self.curs.execute('begin\n%s;end;' % arg, varsUsed) except Exception, e:
--- a/sqlpython/sqlpython.py Thu Apr 09 00:15:09 2009 -0400 +++ b/sqlpython/sqlpython.py Thu Apr 09 14:42:41 2009 -0400 @@ -9,9 +9,33 @@ # See also http://twiki.cern.ch/twiki/bin/view/PSSGroup/SqlPython import cmd2,getpass,binascii,cx_Oracle,re,os -import sqlpyPlus, sqlalchemy +import sqlpyPlus, sqlalchemy, pyparsing __version__ = '1.6.4' +class Parser(object): + comment_def = "--" + ~ ('-' + pyparsing.CaselessKeyword('begin')) + pyparsing.ZeroOrMore(pyparsing.CharsNotIn("\n")) + def __init__(self, scanner, retainSeparator=True): + self.scanner = scanner + self.scanner.ignore(pyparsing.sglQuotedString) + self.scanner.ignore(pyparsing.dblQuotedString) + self.scanner.ignore(self.comment_def) + self.scanner.ignore(pyparsing.cStyleComment) + self.retainSeparator = retainSeparator + def separate(self, txt): + itms = [] + for (sqlcommand, start, end) in self.scanner.scanString(txt): + if sqlcommand: + if type(sqlcommand[0]) == pyparsing.ParseResults: + if self.retainSeparator: + itms.append("".join(sqlcommand[0])) + else: + itms.append(sqlcommand[0][0]) + else: + if sqlcommand[0]: + itms.append(sqlcommand[0]) + return itms + + class sqlpython(cmd2.Cmd): '''A python module to reproduce Oracle's command line with focus on customization and extention''' @@ -247,8 +271,23 @@ terminatorSearchString = '|'.join('\\' + d.split()[0] for d in do_terminators.__doc__.splitlines()) + bindScanner = {'oracle': Parser(pyparsing.Literal(':') + pyparsing.Word( pyparsing.alphanums + "_$#" )), + 'postgres': Parser(pyparsing.Literal('%(') + + pyparsing.Word(pyparsing.alphanums + "_$#") + ')s')} + def findBinds(self, target, existingBinds, givenBindVars = {}): + result = givenBindVars + if self.rdbms in self.bindScanner: + for finding, startat, endat in self.bindScanner[self.rdbms].scanner.scanString(target): + varname = finding[1] + try: + result[varname] = existingBinds[varname] + except KeyError: + if not givenBindVars.has_key(varname): + print 'Bind variable %s not defined.' % (varname) + return result + def default(self, arg): - self.varsUsed = sqlpyPlus.findBinds(arg, self.binds, givenBindVars={}) + self.varsUsed = self.findBinds(arg, self.binds, givenBindVars={}) ending_args = arg.lower().split()[-2:] if 'end' in ending_args: command = '%s %s;'