changeset 315:a0a36232983a

url_connect
author catherine@Elli.myhome.westell.com
date Tue, 31 Mar 2009 09:12:30 -0400
parents 0473ad96ddb7
children ea0b0e1ab6da
files sqlpython/sqlpyPlus.py sqlpython/sqlpython.py
diffstat 2 files changed, 83 insertions(+), 64 deletions(-) [+]
line wrap: on
line diff
--- a/sqlpython/sqlpyPlus.py	Mon Mar 30 14:32:54 2009 -0400
+++ b/sqlpython/sqlpyPlus.py	Tue Mar 31 09:12:30 2009 -0400
@@ -480,7 +480,7 @@
         
     def postcmd(self, stop, line):
         """Hook method executed just after a command dispatch is finished."""        
-        if self.serveroutput:
+        if self.orcl and self.serveroutput:
             self.dbms_output()
         return stop
     
@@ -538,10 +538,11 @@
         stop = self.postcmd(stop, line)
 
     def _onchange_serveroutput(self, old, new):
-        if new:
-            self.curs.callproc('dbms_output.enable', [])        
-        else:
-            self.curs.callproc('dbms_output.disable', [])        
+        if self.orcl:
+            if new:
+                self.curs.callproc('dbms_output.enable', [])        
+            else:
+                self.curs.callproc('dbms_output.disable', [])        
         
     def do_shortcuts(self,arg):
         """Lists available first-character shortcuts
@@ -556,7 +557,7 @@
     inputStatementFormatters[cx_Oracle.CLOB] = inputStatementFormatters[cx_Oracle.STRING]
     inputStatementFormatters[cx_Oracle.TIMESTAMP] = inputStatementFormatters[cx_Oracle.DATETIME]                
     def output(self, outformat, rowlimit):
-        self.tblname = self.tableNameFinder.search(self.curs.statement).group(1)
+        self.tblname = self.tableNameFinder.search(self.querytext).group(1)
         self.colnames = [d[0] for d in self.curs.description]
         if outformat in output_templates:
             self.colnamelen = max(len(colname) for colname in self.colnames)
@@ -797,10 +798,11 @@
             selecttext = self.expandWildSql(arg)
         else:
             selecttext = arg
-        self.curs.execute('select ' + selecttext, self.varsUsed)
+        self.querytext = 'select ' + selecttext
+        self.curs.execute(self.querytext, self.varsUsed)
         self.rows = self.curs.fetchmany(min(self.maxfetch, (rowlimit or self.maxfetch)))
         self.rc = self.curs.rowcount
-        if self.rc > 0:
+        if self.rc != 0:
             resultset = ResultSet()
             resultset.colnames = [d[0].lower() for d in self.curs.description]
             resultset.pystate = self.pystate
@@ -817,7 +819,7 @@
             print '\n1 row selected.\n'
             if self.autobind:
                 self.do_bind('')
-        elif self.rc < self.maxfetch:
+        elif (self.rc < self.maxfetch and self.rc > 0):
             print '\n%d rows selected.\n' % self.rc
         else:
             print '\nSelected Max Num rows (%d)' % self.rc
--- a/sqlpython/sqlpython.py	Mon Mar 30 14:32:54 2009 -0400
+++ b/sqlpython/sqlpython.py	Tue Mar 31 09:12:30 2009 -0400
@@ -9,7 +9,7 @@
 # See also http://twiki.cern.ch/twiki/bin/view/PSSGroup/SqlPython
 
 import cmd2,getpass,binascii,cx_Oracle,re,os
-import sqlpyPlus
+import sqlpyPlus, sqlalchemy
 __version__ = '1.6.3'    
 
 class sqlpython(cmd2.Cmd):
@@ -27,17 +27,17 @@
     def no_connection(self):
         self.prompt = 'SQL.No_Connection> '
         self.curs = None
-        self.orcl = None
+        self.conn = None
         self.connection_number = None
         
     def successful_connection_to_number(self, arg):
         try:
             connection_number = int(arg)
-            self.orcl = self.connections[connection_number]['conn']
+            self.conn = self.connections[connection_number]['conn']
             self.prompt = self.connections[connection_number]['prompt']
             self.connection_number = connection_number
-            self.curs = self.orcl.cursor()
-            if self.serveroutput:
+            self.curs = self.conn.cursor()
+            if self.orcl and self.serveroutput:
                 self.curs.callproc('dbms_output.enable', [])            
         except ValueError:            
             return False
@@ -45,7 +45,8 @@
 
     def list_connections(self):
         self.stdout.write('Existing connections:\n')
-        self.stdout.write('\n'.join(v['prompt'] for (k,v) in sorted(self.connections.items())) + '\n')
+        self.stdout.write('\n'.join('%s (%s)' % (v['prompt'], v['rdbms']) 
+                                    for (k,v) in sorted(self.connections.items())) + '\n')
         
     def disconnect(self, arg):
         try:
@@ -66,6 +67,54 @@
         self.curs = None
         self.no_connection()        
             
+    def url_connect(self, arg):
+        eng = sqlalchemy.create_engine(arg)
+        self.conn = eng.connect().connection
+        conn  = {'conn': self.conn, 'prompt': self.prompt, 'dbname': eng.url.database,
+                 'rdbms': eng.url.drivername, 'user': eng.url.username or '', 
+                 'eng': eng}
+        return conn
+    def ora_connect(self, arg):
+        modeval = 0
+        oraserv = None
+        for modere, modevalue in self.connection_modes.items():
+            if modere.search(arg):
+                arg = modere.sub('', arg)
+                modeval = modevalue
+        try:
+            orauser, oraserv = arg.split('@')
+        except ValueError:
+            try:
+                oraserv = os.environ['ORACLE_SID']
+            except KeyError:
+                print 'instance not specified and environment variable ORACLE_SID not set'
+                return
+            orauser = arg
+        sid = oraserv
+        try:
+            host, sid = oraserv.split('/')
+            try:
+                host, port = host.split(':')
+                port = int(port)
+            except ValueError:
+                port = 1521
+            oraserv = cx_Oracle.makedsn(host, port, sid)
+        except ValueError:
+            pass
+        try:
+            orauser, orapass = orauser.split('/')
+        except ValueError:
+            orapass = getpass.getpass('Password: ')
+        if orauser.upper() == 'SYS' and not modeval:
+            print 'Privilege not specified for SYS, assuming SYSOPER'
+            modeval = cx_Oracle.SYSOPER
+        if modeval == 0:   # can sqlalchemy connect as sysoper, sysdba?
+            return self.url_connect('oracle://%s:%s@%s' % (orauser, orapass, oraserv))
+        else:
+            self.conn = cx_Oracle.connect(orauser,orapass,oraserv,modeval)
+            result = {'user': orauser, 'rdbms': 'oracle', 'dbname': sid, 'conn': self.conn} 
+            return result
+    
     connection_modes = {re.compile(' AS SYSDBA', re.IGNORECASE): cx_Oracle.SYSDBA, 
                         re.compile(' AS SYSOPER', re.IGNORECASE): cx_Oracle.SYSOPER}
     @cmd2.options([cmd2.make_option('-a', '--add', action='store_true', 
@@ -93,56 +142,22 @@
         except IndexError:
             self.list_connections()
             return
-        modeval = 0
-        oraserv = None
-        for modere, modevalue in self.connection_modes.items():
-            if modere.search(arg):
-                arg = modere.sub('', arg)
-                modeval = modevalue
         try:
-            orauser, oraserv = arg.split('@')
-        except ValueError:
+            connect_info = self.url_connect(arg)
+        except sqlalchemy.exc.ArgumentError, e:
+            connect_info = self.ora_connect(arg)
+        if opts.add or (self.connection_number is None):
             try:
-                oraserv = os.environ['ORACLE_SID']
-            except KeyError:
-                print 'instance not specified and environment variable ORACLE_SID not set'
-                return
-            orauser = arg
-        self.sid = oraserv
-        try:
-            host, self.sid = oraserv.split('/')
-            try:
-                host, port = host.split(':')
-                port = int(port)
+                self.connection_number = max(self.connections.keys()) + 1
             except ValueError:
-                port = 1521
-            oraserv = cx_Oracle.makedsn(host, port, self.sid)
-        except ValueError:
-            pass
-        try:
-            orauser, orapass = orauser.split('/')
-        except ValueError:
-            orapass = getpass.getpass('Password: ')
-        if orauser.upper() == 'SYS' and not modeval:
-            print 'Privilege not specified for SYS, assuming SYSOPER'
-            modeval = cx_Oracle.SYSOPER
-        try:
-            self.orcl = cx_Oracle.connect(orauser,orapass,oraserv,modeval)
-            if opts.add or (self.connection_number is None):
-                try:
-                    self.connection_number = max(self.connections.keys()) + 1
-                except ValueError:
-                    self.connection_number = 0
-                self.connections[self.connection_number] = {'conn':self.orcl}
-            else:
-                self.connections[self.connection_number] = {'conn':self.orcl}
-            self.curs = self.orcl.cursor()
-            self.prompt = '%d:%s@%s> ' % (self.connection_number, orauser, self.sid)
-            self.connections[self.connection_number]['prompt'] = self.prompt
-        except Exception, e:
-            print e
-            return
-        if self.serveroutput:
+                self.connection_number = 0
+        self.connections[self.connection_number] = connect_info
+        self.curs = self.conn.cursor()
+        self.orcl = connect_info['rdbms'] == 'oracle'
+        self.prompt = '%d:%s@%s> ' % (self.connection_number, 
+                                      connect_info['user'], connect_info['dbname'])
+        self.connections[self.connection_number]['prompt'] = self.prompt
+        if self.orcl and self.serveroutput:
             self.curs.callproc('dbms_output.enable', [])
     def postparsing_precmd(self, statement):
         stop = 0
@@ -237,10 +252,12 @@
             command = '%s %s;'
         else:
             command = '%s %s'    
-        current_time = self.current_database_time()
+        if self.orcl:
+            current_time = self.current_database_time()
         self.curs.execute(command % (arg.parsed.command, arg.parsed.args), self.varsUsed)
         executionmessage = '\nExecuted%s\n' % ((self.curs.rowcount > 0) and ' (%d rows)' % self.curs.rowcount or '')
-        self._show_errors(all_users=True, limit=1, mintime=current_time)
+        if self.orcl:
+            self._show_errors(all_users=True, limit=1, mintime=current_time)
         print executionmessage
             
     def do_commit(self, arg=''):