changeset 339:545f63b6ef42

supports bind variables in postgresql
author Catherine Devlin <catherine.devlin@gmail.com>
date Thu, 09 Apr 2009 14:42:41 -0400
parents a8835fe129f6
children 001d01eeac90
files sqlpython/sqlpyPlus.py sqlpython/sqlpython.py
diffstat 2 files changed, 44 insertions(+), 40 deletions(-) [+]
line wrap: on
line diff
--- a/sqlpython/sqlpyPlus.py	Thu Apr 09 00:15:09 2009 -0400
+++ b/sqlpython/sqlpyPlus.py	Thu Apr 09 14:42:41 2009 -0400
@@ -28,6 +28,7 @@
 from output_templates import output_templates
 from metadata import metaqueries
 from plothandler import Plot
+from sqlpython import Parser
 try:
     import pylab
 except (RuntimeError, ImportError):
@@ -227,42 +228,6 @@
     def pop(self, key, def_val=None):
         return dict.pop(self, key.lower(), def_val)
 
-class Parser(object):
-    comment_def = "--" + ~ ('-' + pyparsing.CaselessKeyword('begin')) + pyparsing.ZeroOrMore(pyparsing.CharsNotIn("\n"))    
-    def __init__(self, scanner, retainSeparator=True):
-        self.scanner = scanner
-        self.scanner.ignore(pyparsing.sglQuotedString)
-        self.scanner.ignore(pyparsing.dblQuotedString)
-        self.scanner.ignore(self.comment_def)
-        self.scanner.ignore(pyparsing.cStyleComment)
-        self.retainSeparator = retainSeparator
-    def separate(self, txt):
-        itms = []
-        for (sqlcommand, start, end) in self.scanner.scanString(txt):
-            if sqlcommand:
-                if type(sqlcommand[0]) == pyparsing.ParseResults:
-                    if self.retainSeparator:
-                        itms.append("".join(sqlcommand[0]))
-                    else:
-                        itms.append(sqlcommand[0][0])
-                else:
-                    if sqlcommand[0]:
-                        itms.append(sqlcommand[0])
-        return itms
-
-bindScanner = Parser(pyparsing.Literal(':') + pyparsing.Word( pyparsing.alphanums + "_$#" ))   
-    
-def findBinds(target, existingBinds, givenBindVars = {}):
-    result = givenBindVars
-    for finding, startat, endat in bindScanner.scanner.scanString(target):
-        varname = finding[1]
-        try:
-            result[varname] = existingBinds[varname]
-        except KeyError:
-            if not givenBindVars.has_key(varname):
-                print 'Bind variable %s not defined.' % (varname)                
-    return result
-
 class ResultSet(list):
     pass
 
@@ -690,7 +655,7 @@
             self.perror("Specify desired number of rows after terminator (not '%s')" % arg.parsed.suffix)
         if arg.parsed.terminator == '\\t':
             rowlimit = rowlimit or self.maxtselctrows
-        self.varsUsed = findBinds(arg, self.binds, bindVarsIn)
+        self.varsUsed = self.findBinds(arg, self.binds, bindVarsIn)
         if self.wildsql:
             selecttext = self.expandWildSql(arg)
         else:
@@ -1268,7 +1233,7 @@
         if arg.startswith(':'):
             self.do_setbind(arg[1:])
         else:
-            varsUsed = findBinds(arg, self.binds, {})
+            varsUsed = self.findBinds(arg, self.binds, {})
             try:
                 self.curs.execute('begin\n%s;end;' % arg, varsUsed)
             except Exception, e:
--- a/sqlpython/sqlpython.py	Thu Apr 09 00:15:09 2009 -0400
+++ b/sqlpython/sqlpython.py	Thu Apr 09 14:42:41 2009 -0400
@@ -9,9 +9,33 @@
 # See also http://twiki.cern.ch/twiki/bin/view/PSSGroup/SqlPython
 
 import cmd2,getpass,binascii,cx_Oracle,re,os
-import sqlpyPlus, sqlalchemy
+import sqlpyPlus, sqlalchemy, pyparsing
 __version__ = '1.6.4'    
 
+class Parser(object):
+    comment_def = "--" + ~ ('-' + pyparsing.CaselessKeyword('begin')) + pyparsing.ZeroOrMore(pyparsing.CharsNotIn("\n"))    
+    def __init__(self, scanner, retainSeparator=True):
+        self.scanner = scanner
+        self.scanner.ignore(pyparsing.sglQuotedString)
+        self.scanner.ignore(pyparsing.dblQuotedString)
+        self.scanner.ignore(self.comment_def)
+        self.scanner.ignore(pyparsing.cStyleComment)
+        self.retainSeparator = retainSeparator
+    def separate(self, txt):
+        itms = []
+        for (sqlcommand, start, end) in self.scanner.scanString(txt):
+            if sqlcommand:
+                if type(sqlcommand[0]) == pyparsing.ParseResults:
+                    if self.retainSeparator:
+                        itms.append("".join(sqlcommand[0]))
+                    else:
+                        itms.append(sqlcommand[0][0])
+                else:
+                    if sqlcommand[0]:
+                        itms.append(sqlcommand[0])
+        return itms
+
+    
 class sqlpython(cmd2.Cmd):
     '''A python module to reproduce Oracle's command line with focus on customization and extention'''
 
@@ -247,8 +271,23 @@
     
     terminatorSearchString = '|'.join('\\' + d.split()[0] for d in do_terminators.__doc__.splitlines())
         
+    bindScanner = {'oracle': Parser(pyparsing.Literal(':') + pyparsing.Word( pyparsing.alphanums + "_$#" )),
+                   'postgres': Parser(pyparsing.Literal('%(') + 
+                                      pyparsing.Word(pyparsing.alphanums + "_$#") + ')s')}
+    def findBinds(self, target, existingBinds, givenBindVars = {}):
+        result = givenBindVars
+        if self.rdbms in self.bindScanner:
+            for finding, startat, endat in self.bindScanner[self.rdbms].scanner.scanString(target):
+                varname = finding[1]
+                try:
+                    result[varname] = existingBinds[varname]
+                except KeyError:
+                    if not givenBindVars.has_key(varname):
+                        print 'Bind variable %s not defined.' % (varname)
+        return result
+
     def default(self, arg):
-        self.varsUsed = sqlpyPlus.findBinds(arg, self.binds, givenBindVars={})
+        self.varsUsed = self.findBinds(arg, self.binds, givenBindVars={})
         ending_args = arg.lower().split()[-2:]
         if 'end' in ending_args:
             command = '%s %s;'