changeset 428:6cd30b785885

continuing connection transition
author catherine@dellzilla
date Mon, 25 Jan 2010 15:51:00 -0500
parents 8c1dec7fbd71
children 76bf7f767c10
files sqlpython/connections.py sqlpython/schemagroup.py sqlpython/sqlpyPlus.py sqlpython/sqlpython.py
diffstat 4 files changed, 81 insertions(+), 230 deletions(-) [+]
line wrap: on
line diff
--- a/sqlpython/connections.py	Mon Jan 25 14:41:28 2010 -0500
+++ b/sqlpython/connections.py	Mon Jan 25 15:51:00 2010 -0500
@@ -1,9 +1,10 @@
 import re
 import os
+import getpass
 import gerald
 import schemagroup
 
-class Connection(object):
+class DatabaseInstance(object):
     password = None
     uri = None
     connection_uri_parser = re.compile('(postgres|oracle|mysql|sqlite|mssql):/(.*$)', re.IGNORECASE)
@@ -12,7 +13,7 @@
         self.default_rdbms = default_rdbms
         if not self.parse_connect_uri(arg):
             self.parse_connect_arg(arg, opts)
-        self.reconnect()
+        self.connection = self.new_connection()
         self.discover_schemas()
     
     def parse_connect_uri(self, uri):
@@ -28,15 +29,15 @@
             return False
             
     def parse_connect_arg(self, arg, opts):
-        self.password = opts.password    
+        self.password = opts.password or getpass.getpass('Password: ')
         self.host = opts.hostname
         self.oracle_connect_mode = 0
         if opts.postgres:
-            self.__class__ = PostgresConnection
+            self.__class__ = PostgresDatabaseInstance
         elif opts.mysql:
-            self.__class__ = MySQLConnection
+            self.__class__ = MySQLDatabaseInstance
         elif opts.oracle:
-            self.__class__ = OracleConnection
+            self.__class__ = OracleDatabaseInstance
         else:
             self.__class__ = rdbms_types.get(self.default_rdbms)
         self.assign_args(arg, opts)
@@ -48,34 +49,30 @@
     def gerald_uri(self):
         return self.uri.split('?mode=')[0]
         
-    def reconnect(self):
-        self.password = self.password or getpass.getpass('Password: ')
-        self.connection = self.new_connection()
-
     def discover_schemas(self):
         self.schemas = schemagroup.SchemaDict(
             {}, rdbms = self.rdbms, user = self.username, 
             connection = self.connection, connection_string = self.gerald_uri())
         self.schemas.refresh_asynch()
     
-    def set_connection_number(self, connection_number):
-        self.connection_number = connection_number
-        self.prompt = "%d:%s@%s> " % (self.connection_number, self.username, self.db_name)        
+    def set_instance_number(self, instance_number):
+        self.instance_number = instance_number
+        self.prompt = "%d:%s@%s> " % (self.instance_number, self.username, self.db_name)        
 
-class OpenSourceConnection(Connection):
-    def assign_args(self, opts, arg):
-        self.assign_args(opts, arg)        
-        self.username = username or os.environ['USER']
+class OpenSourceDatabaseInstance(DatabaseInstance):
+    def assign_args(self, arg, opts):
+        self.assign_details(arg, opts)        
+        self.username = self.username or os.environ['USER']
         self.db_name = self.db_name or self.username
-        self.host = opts.host or self.host or 'localhost'
+        self.host = opts.hostname or self.host or 'localhost'
 
 try:
     import psycopg2
-    class PostgresConnection(OpenSourceConnection):
+    class PostgresDatabaseInstance(OpenSourceDatabaseInstance):
         rdbms = 'postgres'
         default_port = 5432
         def assign_details(self, arg, opts):
-            self.port = os.getenv('PGPORT') or self.port
+            self.port = opts.port or os.getenv('PGPORT') or self.default_port
             self.host = self.host or os.getenv('PGHOST')
             args = arg.split()
             if len(args) > 1:
@@ -87,12 +84,12 @@
                                      password = self.password, database = self.db_name,
                                      port = self.port)                
 except ImportError:
-    class PostgresConnection(OpenSourceConnection):
+    class PostgresDatabaseInstance(OpenSourceDatabaseInstance):
         pass
             
 try:
     import MySQLdb
-    class MySQLConnection(OpenSourceConnection):
+    class MySQLDatabaseInstance(OpenSourceDatabaseInstance):
         rdbms = 'mysql'
         default_port = 3306        
         def assign_details(self, arg, opts):
@@ -102,13 +99,13 @@
                                     passwd = self.password, db = self.db_name,
                                     port = self.port, sql_mode = 'ANSI')
 except ImportError:
-    class MySQLConnection(OpenSourceConnection):
+    class MySQLDatabaseInstance(OpenSourceDatabaseInstance):
         pass
 
 try:
     import cx_Oracle
     
-    class OracleConnection(Connection):
+    class OracleDatabaseInstance(DatabaseInstance):
         rdbms = 'oracle'
         connection_parser = re.compile('(?P<username>[^/\s]*)(/(?P<password>[^/\s]*))?@((?P<host>[^/\s:]*)(:(?P<port>\d{1,4}))?/)?(?P<db_name>[^/\s:]*)(\s+as\s+(?P<mode>sys(dba|oper)))?',
                                             re.IGNORECASE)
@@ -137,9 +134,9 @@
             
                                            
 except ImportError:
-    class OracleConnection(Connection):
+    class OracleDatabaseInstance(DatabaseInstance):
         pass
                                        
-rdbms_types = {'oracle': OracleConnection, 'mysql': MySQLConnection, 'postgres': PostgresConnection}
+rdbms_types = {'oracle': OracleDatabaseInstance, 'mysql': MySQLDatabaseInstance, 'postgres': PostgresDatabaseInstance}
                   
         
\ No newline at end of file
--- a/sqlpython/schemagroup.py	Mon Jan 25 14:41:28 2010 -0500
+++ b/sqlpython/schemagroup.py	Mon Jan 25 15:51:00 2010 -0500
@@ -51,7 +51,7 @@
 class OracleSchemaAccess(object):        
     child_type = gerald.oracle_schema.User
     current_database_time_query = 'SELECT sysdate FROM dual'
-    def latest_ddl_timestamp_query(self, username, connection):
+    def most_recent_ddl_by_owner(self, username, connection):
         curs = connection.cursor()
         curs.execute('''SELECT   owner, MAX(last_ddl_time)
                         FROM     all_objects
@@ -59,24 +59,30 @@
                         -- sort :username to top
                         ORDER BY REPLACE(owner, :username, 'A'), owner''',
                      {'username': username.upper()})
-        return curs 
+        result = curs.fetchall()
+        curs.close()
+        return result
 
 class PostgresSchemaAccess(object):        
     child_type = gerald.PostgresSchema # we need User here, too
     current_database_time_query = 'SELECT current_time'
-    def latest_ddl_timestamp_query(self, username, connection):
+    def most_recent_ddl_by_owner(self, username, connection):
         curs = connection.cursor()
         curs.execute("""SELECT  '%s', current_time""" % username)
-        return curs 
+        result = curs.fetchall()
+        curs.close()
+        return result
         # TODO: we just assume that we always need a refresh - that's sloppy
     
 class MySQLSchemaAccess(object):        
     child_type = gerald.MySQLSchema
     current_database_time_query = 'SELECT now()'
-    def latest_ddl_timestamp_query(self, username, connection):
+    def most_recent_ddl_by_owner(self, username, connection):
         curs = connection.cursor()
         curs.execute("""SELECT  '%s', now()""" % username)
-        return curs 
+        result = curs.fetchall()
+        curs.close()
+        return result
     
 class SchemaDict(dict):
     schema_types = {'oracle': OracleSchemaAccess, 'postgres': PostgresSchemaAccess, 'mysql': MySQLSchemaAccess}
@@ -95,7 +101,9 @@
     def get_current_database_time(self):
         curs = self.connection.cursor()
         curs.execute(self.schema_access.current_database_time_query)
-        return curs.fetchone()[0]              
+        current_time = curs.fetchone()[0]
+        curs.close()
+        return current_time       
     def refresh_times(self, target_schema):
         now = self.get_current_database_time()
         result = []
@@ -107,8 +115,7 @@
             
     def refresh(self):
         current_database_time = self.get_current_database_time()
-        curs = self.schema_access.latest_ddl_timestamp_query(self.user, self.connection)
-        for (owner, last_ddl_time) in curs.fetchall():
+        for (owner, last_ddl_time) in self.schema_access.most_recent_ddl_by_owner(self.user, self.connection):
             if (owner not in self) or (self[owner].refreshed < last_ddl_time):
                 self.refresh_one(owner, current_database_time)
                 # what if a user's last object is deleted?
--- a/sqlpython/sqlpyPlus.py	Mon Jan 25 14:41:28 2010 -0500
+++ b/sqlpython/sqlpyPlus.py	Mon Jan 25 15:51:00 2010 -0500
@@ -349,7 +349,7 @@
                       create drop alter _multiline_comment'''.split()
     sqlpython.sqlpython.noSpecialParse.append('spool')
     commentGrammars = pyparsing.Or([pyparsing.cStyleComment, pyparsing.Literal('--') + pyparsing.restOfLine])
-    prefixParser = pyparsing.Optional(pyparsing.Word(pyparsing.nums)('connection_number') 
+    prefixParser = pyparsing.Optional(pyparsing.Word(pyparsing.nums)('instance_number') 
                                       + ':')
     reserved_words = [
             'alter', 'begin', 'comment', 'create', 'delete', 'drop', 'end', 'for', 'grant', 
@@ -1517,14 +1517,12 @@
         'XML SCHEMA')
         
     def metadata(self):
-        schemas = self.connections[self.connection_number]['schemas']
-        
-        username = self.connections[self.connection_number]['user']
+        username = self.conn.username
         if self.rdbms == 'oracle':
             username = username.upper()
         elif self.rdbms == 'postgres':
             username = username.lower()
-        return (username, schemas)
+        return (username, self.conn.schemas)
         
     def _to_sql_wildcards(self, original):
         return original.replace('*','%').replace('?','_')
--- a/sqlpython/sqlpython.py	Mon Jan 25 14:41:28 2010 -0500
+++ b/sqlpython/sqlpython.py	Mon Jan 25 15:51:00 2010 -0500
@@ -49,89 +49,51 @@
 
     def __init__(self):
         cmd2.Cmd.__init__(self)
-        self.no_connection()
+        self.no_instance()
         self.maxfetch = 1000
         self.terminator = ';'
         self.timeout = 30
         self.commit_on_exit = True
-        self.connections = {}
+        self.instances = {}
         
-    def no_connection(self):
+    def no_instance(self):
         self.prompt = 'SQL.No_Connection> '
         self.curs = None
         self.conn = None
-        self.connection_number = None
+        self.instance_number = None
         
-    def make_connection_current(self, connection_number):
-        conn = self.connections[connection_number]
+    def make_instance_current(self, instance_number):
+        db_instance = self.instances[instance_number]
         self.prompt = conn.prompt
         self.rdbms = conn.rdbms
-        self.connection_number = connection_number
+        self.instance_number = instance_number
         self.curs = conn.connection.cursor()
-        self.conn = conn
+        self.current_instance = db_instance
             
-    def successfully_connect_to_number(self, arg):
-        try:
-            connection_number = int(arg)
-        except ValueError:            
-            return False
-        try:
-            self.make_connection_current(connection_number)
-        except IndexError:
-            self.list_connections()
-            return False
-        if (self.rdbms == 'oracle') and self.serveroutput:
-            self.curs.callproc('dbms_output.enable', [])           
-        return True
-
-    def successful_connection_to_number(self, arg):
-        # deprecated 
-        try:
-            connection_number = int(arg)
-        except ValueError:            
-            return False
-        self.make_connection_current(connection_number)
-        if (self.rdbms == 'oracle') and self.serveroutput:
-            self.curs.callproc('dbms_output.enable', [])           
-        return True
-
-    def list_connections(self):
+    def list_instances(self):
         self.stdout.write('Existing connections:\n')
-        self.stdout.write('\n'.join('%s (%s)' % (v['prompt'], v['rdbms']) 
-                                    for (k,v) in sorted(self.connections.items())) + '\n')
+        self.stdout.write('\n'.join('%s (%s)' % (v.prompt, v.rdbms) 
+                                    for (k,v) in sorted(self.instances.items())) + '\n')
         
     def disconnect(self, arg):
         try:
-            connection_number = int(arg)
-            connection = self.connections[connection_number]
+            instance_number = int(arg)
+            instance = self.instances[instance_number]
         except (ValueError, KeyError):
-            self.list_connections()
+            self.list_instances()
             return
         if self.commit_on_exit:
-            connection['conn'].commit()
-        self.connections.pop(connection_number)
-        if connection_number == self.connection_number:
-            self.no_connection()
+            instance.connection.commit()
+        self.instances.pop(instance_number)
+        if instance_number == self.instance_number:
+            self.no_instance()
             
     def closeall(self):
-        for connection_number in self.connections.keys():
-            self.disconnect(connection_number)
+        for instance_number in self.instances.keys():
+            self.disconnect(instance_number)
         self.curs = None
-        self.no_connection()        
+        self.no_instance()        
             
-    def url_connect(self, arg):
-        eng = sqlalchemy.create_engine(arg, use_ansiquotes=True) 
-        self.conn = eng.connect().connection
-        user = eng.url.username or ''
-        rdbms = eng.url.drivername
-        conn  = {'conn': self.conn, 'prompt': self.prompt, 'dbname': eng.url.database,
-                 'rdbms': rdbms, 'user': user, 'eng': eng, 
-                 'schemas': schemagroup.SchemaDict({}, 
-                    rdbms=rdbms, user=user, connection=self.conn, connection_string=arg)}
-        s = conn['schemas']
-        s.refresh_asynch()
-        return conn
-    
     legal_sql_word = pyparsing.Word(pyparsing.alphanums + '_$#')
     legal_hostname = pyparsing.Word(pyparsing.alphanums + '_-.')('host') + pyparsing.Optional(
         ':' + pyparsing.Word(pyparsing.nums)('port'))
@@ -145,54 +107,6 @@
     postgresql_connect_parser = pyparsing.Optional(legal_sql_word('db_name') + 
                                                    pyparsing.Optional(legal_sql_word('username')))
           
-    def connect_url(self, arg, opts):               
-        if opts.oracle:
-            rdbms = 'oracle'
-        elif opts.postgres:
-            rdbms = 'postgres'
-        elif opts.mysql:
-            rdbms = 'mysql'
-        else:
-            rdbms = self.default_rdbms
-        mode = 0
-        host = None
-        port = None
-        
-        if rdbms == 'oracle':
-            result = self.oracle_connect_parser.parseString(arg)
-            if result.mode == 'sysdba':
-                mode = cx_Oracle.SYSDBA
-            elif result.mode == 'sysoper':
-                mode = cx_Oracle.SYSOPER   
-            else:
-                mode = 0
-        elif rdbms == 'postgres':
-            result = self.postgresql_connect_parser.parseString(arg)
-            port = opts.port or os.environ.get('PGPORT') or 5432            
-            host = opts.host or os.environ.get('PGHOST') or 'localhost'
-       
-        username = result.username or opts.username           
-        if not username and rdbms == 'postgres':
-            username = os.environ.get('PGUSER') or os.environ.get('USER')
-
-        db_name = result.db_name or opts.database
-        if not db_name:
-            if rdbms == 'oracle':
-                db_name = os.environ.get('ORACLE_SID')
-            elif rdbms == 'postgres':
-                db_name = os.environ.get('PGDATABASE') or username
-        
-        password = result.password or getpass.getpass('Password: ')
-               
-        if host:
-            if port:
-                host = '%s:%s' % (host, port)
-            db_name = '%s/%s' % (host, db_name)
-
-        url = '%s://%s:%s@%s' % (rdbms, username, password, db_name)
-        if mode:
-            url = '%s/?mode=%d' % mode
-        return url
 
     @cmd2.options([cmd2.make_option('-a', '--add', action='store_true', 
                                     help='add connection (keep current connection)'),
@@ -222,108 +136,43 @@
             return 
         if opts.close:
             if not arg:
-                arg = self.connection_number
+                arg = self.instance_number
             self.disconnect(arg)
             return 
         if (not arg) and (not opts.postgres):
-            self.list_connections()
+            self.list_instances()
             return 
         if self.successfully_connect_to_number(arg):
             return
         
-        conn = connections.Connection(arg, opts, default_rdbms = self.default_rdbms)
-        if opts.add or (self.connection_number is None):
+        db_instance = connections.DatabaseInstance(arg, opts, default_rdbms = self.default_rdbms)
+        if opts.add or (self.instance_number is None):
             try:
-                self.connection_number = max(self.connections.keys()) + 1
+                self.instance_number = max(self.instances.keys()) + 1
             except ValueError:
-                self.connection_number = 0            
-        conn.set_connection_number(self.connection_number)
-        self.connections[self.connection_number] = conn
-        self.make_connection_current(self.connection_number)        
+                self.instance_number = 0            
+        db_instance.set_instance_number(self.instance_number)
+        self.instances[self.instance_number] = conn
+        self.make_instance_current(self.instance_number)        
         if (self.rdbms == 'oracle') and self.serveroutput:
             self.curs.callproc('dbms_output.enable', [])        
     
-    @cmd2.options([cmd2.make_option('-a', '--add', action='store_true', 
-                                    help='add connection (keep current connection)'),
-                   cmd2.make_option('-c', '--close', action='store_true', 
-                                    help='close connection {N} (or current)'),
-                   cmd2.make_option('-C', '--closeall', action='store_true', 
-                                    help='close all connections'),
-                   cmd2.make_option('--postgres', action='store_true', help='Connect to postgreSQL: `sqlpython --postgres [DBNAME [USERNAME]]`'),
-                   cmd2.make_option('--oracle', action='store_true', help='Connect to an Oracle database'),
-                   cmd2.make_option('--mysql', action='store_true', help='Connect to a MySQL database'),                   
-                   cmd2.make_option('-r', '--rdbms', type='string', 
-                                    help='Type of database to connect to (oracle, postgres, mysql)'),
-                   cmd2.make_option('-H', '--host', type='string', 
-                                    help='Host to connect to (postgresql only)'),                                  
-                   cmd2.make_option('-p', '--port', type='int', 
-                                    help='Port to connect to (postgresql only)'),                                  
-                   cmd2.make_option('-d', '--database', type='string', 
-                                    help='Database name to connect to'),
-                   cmd2.make_option('-U', '--username', type='string', 
-                                    help='Database user name to connect as')
-                   ])
-    def _o_connect(self, arg, opts):
- 
-        '''Opens the DB connection'''
-        if opts.closeall:
-            self.closeall()
-            return 
-        if opts.close:
-            if not arg:
-                arg = self.connection_number
-            self.disconnect(arg)
-            return 
-        if (not arg) and (not opts.postgres):
-            self.list_connections()
-            return 
-        try:
-            if self.successful_connection_to_number(arg):
-                return
-        except IndexError:
-            self.list_connections()
-            return
-        try:
-            connect_info = self.url_connect(arg)
-        except sqlalchemy.exc.ArgumentError, e:
-            url = self.connect_url(arg, opts)
-            connect_info = self.url_connect(url)
-        except Exception, e:
-            self.perror(str(e))
-            self.perror(r'URL connection format: rdbms://username:password@host/database')
-            return
-        if opts.add or (self.connection_number is None):
-            try:
-                self.connection_number = max(self.connections.keys()) + 1
-            except ValueError:
-                self.connection_number = 0
-        connect_info['prompt'] = '%d:%s@%s> ' % (self.connection_number, connect_info['user'], connect_info['dbname'])
-        self.connections[self.connection_number] = connect_info
-        self.make_connection_current(self.connection_number)
-        self.curs = self.conn.cursor()
-        if (self.rdbms == 'oracle') and self.serveroutput:
-            self.curs.callproc('dbms_output.enable', [])
-        #if (self.rdbms == 'mysql'):
-        #    self.curs.execute('SET SQL_MODE=ANSI')
-        #    # this dies... if only we could set sql_mode when making the connection
-        return 
-    
     def postparsing_precmd(self, statement):
         stop = 0
-        self.saved_connection_number = None
-        if statement.parsed.connection_number:
-            saved_connection_number = self.connection_number
+        self.saved_instance_number = None
+        if statement.parsed.instance_number:
+            saved_instance_number = self.instance_number
             try:
-                if self.successful_connection_to_number(statement.parsed.connection_number):
+                if self.successfully_connect_to_number(statement.parsed.instance_number):
                     if statement.parsed.command:
-                        self.saved_connection_number = saved_connection_number
+                        self.saved_instance_number = saved_instance_number
             except KeyError:
-                self.list_connections()
-                raise KeyError, 'No connection #%s' % statement.parsed.connection_number
+                self.list_instances()
+                raise KeyError, 'No connection #%s' % statement.parsed.instance_number
         return stop, statement           
     def postparsing_postcmd(self, stop):
-        if self.saved_connection_number is not None:
-            self.successful_connection_to_number(self.saved_connection_number)
+        if self.saved_instance_number is not None:
+            self.successfully_connect_to_number(self.saved_instance_number)
         return stop
                 
     do_host = cmd2.Cmd.do_shell