changeset 285:316abf2191a4

substvar define working now
author catherine@dellzilla
date Fri, 20 Mar 2009 09:47:22 -0400
parents ad20675a17f7
children abb4c6524113 3cade02da892
files sqlpython/sqlpyPlus.py
diffstat 1 files changed, 51 insertions(+), 26 deletions(-) [+]
line wrap: on
line diff
--- a/sqlpython/sqlpyPlus.py	Fri Mar 20 09:05:40 2009 -0400
+++ b/sqlpython/sqlpyPlus.py	Fri Mar 20 09:47:22 2009 -0400
@@ -674,11 +674,6 @@
             prompt = ''
         varname = args.lower().split()[0]
         self.substvars[varname] = self.pseudo_raw_input(prompt)
-        
-    def do_define(self, args):
-        if not args:
-            for (substvar, val) in sorted(self.substvars.items()):
-                print 'DEFINE %s = %s' % (substvar, val)
                 
     def ampersand_substitution(self, raw, regexpr, isglobal):
         subst = regexpr.search(raw)
@@ -1213,32 +1208,62 @@
             for (var, val) in sorted(self.binds.items()):
                 print ':%s = %s' % (var, val)
 
+    def split_on_parser(self, parser, arg):
+        try:
+            assigner, startat, endat = parser.scanner.scanString(arg).next()
+            return (arg[:startat].strip(), arg[endat:].strip())
+        except StopIteration:
+            return ''.join(arg.split()[:1]), ''
+        
     assignmentScanner = Parser(pyparsing.Literal(':=') ^ '=')
+    def interpret_variable_assignment(self, arg):
+        '''
+        Accepts strings like `foo = 'bar'` or `baz := 22`, returning Python
+        variables as appropriate
+        '''
+        var, val = self.split_on_parser(self.assignmentScanner, arg) 
+        if not var:
+            return None, None
+        if (len(val) > 1) and ((val[0] == val[-1] == "'") or (val[0] == val[-1] == '"')):
+            return var, val[1:-1]
+        try:
+            return var, int(val)
+        except ValueError:
+            try:
+                return var, float(val)
+            except ValueError:
+                # use the conversions implicit in cx_Oracle's select to 
+                # cast the value into an appropriate type (dates, for instance)
+                try:
+                    self.curs.execute('SELECT %s FROM dual' % val)
+                    return var, self.curs.fetchone()[0]
+                except cx_Oracle.DatabaseError:
+                    return var, val  # we give up and assume it's a string
+            
     def do_setbind(self, arg):
+        '''Sets or shows values of bind (`:`) variables.'''        
         if not arg:
             return self.do_print(arg)
-        try:
-            assigner, startat, endat = self.assignmentScanner.scanner.scanString(arg).next()
-        except StopIteration:
-            self.do_print(arg)
-            return
-        var, val = arg[:startat].strip(), arg[endat:].strip()
-        if val[0] == val[-1] == "'" and len(val) > 1:
-            self.binds[var] = val[1:-1]
-            return
-        try:
-            self.binds[var] = int(val)
-            return
-        except ValueError:
-            try:
-                self.binds[var] = float(val)
-                return
-            except ValueError: 
-                statekeeper = Statekeeper(self, ('autobind',))  
-                self.autobind = True
-                self.onecmd('SELECT %s AS %s FROM dual;' % (val, var))
-                statekeeper.restore()
+        var, val = self.interpret_variable_assignment(arg)
+        if val:
+            self.binds[var] = val
+        else:
+            return self.do_print(var)
 
+    def do_define(self, arg):
+        '''Sets or shows values of substitution (`&`) variables.'''
+        if not arg:
+            for (substvar, val) in sorted(self.substvars.items()):
+                print 'DEFINE %s = "%s" (%s)' % (substvar, val, type(val))
+        var, val = self.interpret_variable_assignment(arg)
+        if val:
+            self.substvars[var] = val
+        else:
+            if var in self.substvars:
+                print 'DEFINE %s = "%s" (%s)' % (var, self.substvars[var], type(self.substvars[var]))
+
+    do_def = do_define               
+    
     def do_exec(self, arg):
         if arg.startswith(':'):
             self.do_setbind(arg[1:])