changeset 440:e1a962dd7139

refactored connections
author catherine@dellzilla
date Mon, 01 Feb 2010 15:56:15 -0500
parents 0a2474b76db6
children da332a670378
files sqlpython/connections.py sqlpython/sqlpython.py
diffstat 2 files changed, 192 insertions(+), 156 deletions(-) [+]
line wrap: on
line diff
--- a/sqlpython/connections.py	Mon Feb 01 10:15:28 2010 -0500
+++ b/sqlpython/connections.py	Mon Feb 01 15:56:15 2010 -0500
@@ -5,6 +5,23 @@
 import time
 import threading
 import pickle
+import optparse
+import doctest
+
+try:
+    import cx_Oracle
+except ImportError:
+    pass
+
+try:
+    import psycopg2
+except ImportError:
+    pass
+
+try:
+    import MySQLdb
+except ImportError:
+    pass
 
 class ObjectDescriptor(object):
     def __init__(self, name, dbobj):
@@ -31,8 +48,164 @@
         
 class GeraldPlaceholder(object):
     current = False
-    complete = False
+    complete = False   
+    
+class OptionTestDummy(object):
+    mysql = None
+    postgres = None
+    username = None
+    password = None
+    hostname = None
+    port = None
+    database = None
+    mode = 0
+    def __init__(self, *args, **kwargs):
+        self.__dict__.update(kwargs)
     
+class ConnectionData(object):
+    username = None
+    password = None
+    hostname = None
+    port = None
+    database = None
+    mode = 0
+    connection_uri_parser = re.compile('(postgres|oracle|mysql|sqlite|mssql):/(.*$)', re.IGNORECASE)    
+    connection_parser = re.compile('((?P<database>\S+)(\s+(?P<username>\S+))?)?')    
+    def __init__(self, arg, opts, default_rdbms = 'oracle'):
+        '''
+        >>> opts = OptionTestDummy(postgres=True, password='password')        
+        >>> ConnectionData('thedatabase theuser', opts).uri()        
+        'postgres://theuser:password@localhost:5432/thedatabase'
+        >>> opts = OptionTestDummy(password='password')
+        >>> ConnectionData('oracle://user:password@db', opts).uri()        
+        'oracle://user:password@db'
+        >>> ConnectionData('user/password@db', opts).uri()
+        'oracle://user:password@db'
+        >>> ConnectionData('user/password@db as sysdba', opts).uri()
+        'oracle://user:password@db?mode=2'
+        >>> ConnectionData('user/password@thehost/db', opts).uri()        
+        'oracle://user:password@thehost:1521/db'
+        >>> opts = OptionTestDummy(postgres=True, hostname='thehost', password='password')
+        >>> ConnectionData('thedatabase theuser', opts).uri()        
+        'postgres://theuser:password@thehost:5432/thedatabase'
+        >>> opts = OptionTestDummy(mysql=True, password='password')
+        >>> ConnectionData('thedatabase theuser', opts).uri()        
+        'mysql://theuser:password@localhost:3306/thedatabase'
+        '''
+        self.arg = arg
+        self.opts = opts
+        self.default_rdbms = default_rdbms
+        self.determine_rdbms()
+        if not self.parse_connect_uri(arg):
+            self.set_defaults()        
+            connectargs = self.connection_parser.search(self.arg)
+            if connectargs:
+                for param in ('username', 'password', 'database', 'port', 'hostname', 'mode'):
+                    if hasattr(opts, param) and getattr(opts, param):
+                        setattr(self, param, getattr(opts, param))
+                    else:
+                        try:
+                            if connectargs.group(param):
+                                setattr(self, param, connectargs.group(param))
+                        except IndexError:
+                            pass
+        self.set_corrections()     
+        if not self.password:
+            self.password = getpass.getpass()        
+    def parse_connect_uri(self, uri):
+        results = self.connection_uri_parser.search(uri)
+        if results:
+            (self.username, self.password, self.hostname, self.port, self.database
+             ) = gerald.utilities.dburi.Connection().parse_uri(results.group(2))
+            self.set_class_from_rdbms_name(results.group(1))
+            self.port = self.port or self.default_port        
+            return True
+        else:
+            return False
+    def set_class_from_rdbms_name(self, rdbms_name):
+        for cls in (OracleConnectionData, PostgresConnectionData, MySQLConnectionData):
+            if cls.rdbms == rdbms_name:
+                self.__class__ = cls        
+    def uri(self):
+        return '%s://%s:%s@%s:%s/%s' % (self.rdbms, self.username, self.password,
+                                         self.hostname, self.port, self.database)  
+    def gerald_uri(self):
+        return self.uri().split('?mode=')[0]    
+    def determine_rdbms(self):
+        if self.opts.mysql:
+            self.__class__ = MySQLConnectionData
+        elif self.opts.postgres:
+            self.__class__ = PostgresConnectionData
+        else:
+            self.set_class_from_rdbms_name(self.default_rdbms)     
+    def set_defaults(self):
+        self.port = self.default_port
+    def set_corrections(self):
+        pass
+
+class MySQLConnectionData(ConnectionData):
+    rdbms = 'mysql'
+    default_port = 3306
+    def set_defaults(self):
+        self.port = self.default_port       
+        self.hostname = 'localhost'
+    def connection(self):
+        return MySQLdb.connect(host = self.hostname, user = self.username, 
+                                passwd = self.password, db = self.database,
+                                port = self.port, sql_mode = 'ANSI')        
+
+class PostgresConnectionData(ConnectionData):
+    rdbms = 'postgres'
+    default_port = 5432
+    def set_defaults(self):
+        self.port = os.getenv('PGPORT') or self.default_port
+        self.database = os.getenv('ORACLE_SID')
+        self.hostname = os.getenv('PGHOST') or 'localhost'
+    def connection(self):
+        return psycopg2.connect(host = self.hostname, user = self.username, 
+                                 password = self.password, database = self.database,
+                                 port = self.port)          
+      
+class OracleConnectionData(ConnectionData):
+    rdbms = 'oracle'
+    default_port = 1521
+    connection_parser = re.compile('(?P<username>[^/\s@]*)(/(?P<password>[^/\s@]*))?(@((?P<hostname>[^/\s:]*)(:(?P<port>\d{1,4}))?/)?(?P<database>[^/\s:]*))?(\s+as\s+(?P<mode>sys(dba|oper)))?',
+                                     re.IGNORECASE)
+    def uri(self):
+        if self.hostname:
+            uri = '%s://%s:%s@%s:%s/%s' % (self.rdbms, self.username, self.password,
+                                           self.hostname, self.port, self.database)           
+        else:
+            uri = '%s://%s:%s@%s' % (self.rdbms, self.username, self.password, self.database)
+        if self.mode:
+            uri = '%s?mode=%d' % (uri, self.mode)
+        return uri
+    def set_defaults(self):
+        self.port = 1521
+        self.database = os.getenv('ORACLE_SID')
+    def set_corrections(self):
+        if self.mode:
+            self.mode = getattr(cx_Oracle, self.mode.upper())
+        if self.hostname:
+            self.dsn = cx_Oracle.makedsn(self.hostname, self.port, self.database)
+        else:
+            self.dsn = self.database
+    def parse_connect_uri(self, uri):
+        if ConnectionData.parse_connect_uri(self, uri):
+            if not self.database:
+                self.database = self.hostname
+                self.hostname = None
+                self.port = self.default_port
+            return True            
+        return False
+    def connection(self):
+        return cx_Oracle.connect(user = self.username, password = self.password,
+                                  dsn = self.dsn, mode = self.mode)    
+
+gerald_classes = {'oracle': gerald.oracle_schema.User,
+                  'postgres': gerald.PostgresSchema,
+                  'mysql': gerald.MySQLSchema }
+
 class DatabaseInstance(object):
     import_failure = None
     username = None
@@ -43,55 +216,21 @@
     connection_uri_parser = re.compile('(postgres|oracle|mysql|sqlite|mssql):/(.*$)', re.IGNORECASE)
     
     def __init__(self, arg, opts, default_rdbms = 'oracle'):
-        opts.username = opts.username or opts.user
-        self.default_rdbms = default_rdbms
-        if not self.parse_connect_uri(arg):
-            self.parse_connect_arg(arg, opts)
-        self.connection = self.new_connection()
+        #opts.username = opts.username or opts.user
+        self.conn_data = ConnectionData(arg, opts, default_rdbms)
+        for v in ('username', 'database', 'rdbms'):
+            setattr(self, v, getattr(self.conn_data, v))
+        self.connection = self.conn_data.connection()
         self.gerald = GeraldPlaceholder()
         self.discover_metadata()        
         
     def discover_metadata(self):
         self.metadata_discovery_thread = MetadataDiscoveryThread(self)
         self.metadata_discovery_thread.start()
-    
-    def parse_connect_uri(self, uri):
-        results = self.connection_uri_parser.search(uri)
-        if results:
-            (self.username, self.password, self.host, self.port, self.db_name
-             ) = gerald.utilities.dburi.Connection().parse_uri(results.group(2))
-            self.__class__ = rdbms_types.get(results.group(1))
-            self.uri = uri
-            self.port = self.port or self.default_port        
-            return True
-        else:
-            return False
-            
-    def parse_connect_arg(self, arg, opts):
-        self.host = opts.hostname
-        self.oracle_connect_mode = 0
-        if opts.postgres:
-            self.__class__ = PostgresDatabaseInstance
-        elif opts.mysql:
-            self.__class__ = MySQLDatabaseInstance
-        elif opts.oracle:
-            self.__class__ = OracleDatabaseInstance
-        else:
-            self.__class__ = rdbms_types.get(self.default_rdbms)
-        self.assign_args(arg, opts)
-        self.db_name = opts.database or self.db_name
-        self.port = self.port or self.default_port        
-        self.password = self.password or opts.password or getpass.getpass('Password: ')        
-        self.uri = self.uri or self.calculated_uri()
-    def calculated_uri(self):
-        return '%s://%s:%s@%s:%s/%s' % (self.rdbms, self.username, self.password,
-                                         self.host, self.port, self.db_name)    
-    def gerald_uri(self):
-        return self.uri.split('?mode=')[0]
            
     def set_instance_number(self, instance_number):
         self.instance_number = instance_number
-        self.prompt = "%d:%s@%s> " % (self.instance_number, self.username, self.db_name)  
+        self.prompt = "%d:%s@%s> " % (self.instance_number, self.username, self.database)  
     def pickle(self):
         try:
             os.mkdir(self.pickledir)
@@ -102,13 +241,13 @@
         picklefile.close()
     def picklefile(self):
         return os.path.join(self.pickledir, ('%s.%s.%s.%s.pickle' % 
-                             (self.rdbms, self.username, self.host, self.db_name)).lower())
+                             (self.rdbms, self.username, self.conn_data.hostname, self.database)).lower())
     def retreive_pickled_gerald(self):
         picklefile = open(self.picklefile())
         schema = pickle.load(picklefile)
         picklefile.close()
-        newgerald = rdbms_types[self.rdbms].gerald_class(self.username, None)
-        newgerald.connect(self.gerald_uri())
+        newgerald = gerald_classes[self.rdbms](self.username, None)
+        newgerald.connect(self.conn_data.gerald_uri())
         newgerald.schema = schema  
         newgerald.current = False
         newgerald.complete = True      
@@ -116,97 +255,6 @@
         for (name, obj) in newgerald.schema.items():
             newgerald.descriptions[name] = ObjectDescriptor(name, obj)                    
         self.gerald = newgerald
-
-class OpenSourceDatabaseInstance(DatabaseInstance):
-    def assign_args(self, arg, opts):
-        self.assign_details(arg, opts)        
-        self.username = self.username or os.environ['USER']
-        self.db_name = self.db_name or self.username
-        self.host = opts.hostname or self.host or 'localhost'
-
-class ImportFailure(DatabaseInstance):
-    def fail(self, *arg, **kwargs):
-        raise ImportError, 'Python DB-API2 module (MySQLdb/psycopg2/cx_Oracle) was not successfully imported'
-    assign_args = fail
-    new_connection = fail
-    
-try:
-    import psycopg2
-    class PostgresDatabaseInstance(OpenSourceDatabaseInstance):
-        gerald_class = gerald.PostgresSchema
-        rdbms = 'postgres'
-        default_port = 5432
-        def assign_details(self, arg, opts):
-            self.port = opts.port or os.getenv('PGPORT') or self.default_port
-            self.host = self.host or os.getenv('PGHOST')
-            args = arg.split()
-            if len(args) > 1:
-                self.username = args[1]
-            if len(args) > 0:
-                self.db_name = args[0]   
-        def new_connection(self):
-            return psycopg2.connect(host = self.host, user = self.username, 
-                                     password = self.password, database = self.db_name,
-                                     port = self.port)                
-except ImportError:
-    PostgresDatabaseInstance = ImportFailure
-            
-try:
-    import MySQLdb
-    class MySQLDatabaseInstance(OpenSourceDatabaseInstance):
-        gerald_class = gerald.MySQLSchema
-        rdbms = 'mysql'
-        default_port = 3306        
-        def assign_details(self, arg, opts):
-            self.db_name = arg
-        def new_connection(self):
-            return MySQLdb.connect(host = self.host, user = self.username, 
-                                    passwd = self.password, db = self.db_name,
-                                    port = self.port, sql_mode = 'ANSI')
-except ImportError:
-    MySQLDatabaseInstance = ImportFailure
-try:
-    import cx_Oracle
-    
-    class OracleDatabaseInstance(DatabaseInstance):
-        gerald_class = gerald.oracle_schema.User
-        rdbms = 'oracle'
-        connection_parser = re.compile('(?P<username>[^/\s@]*)(/(?P<password>[^/\s@]*))?(@((?P<host>[^/\s:]*)(:(?P<port>\d{1,4}))?/)?(?P<db_name>[^/\s:]*))?(\s+as\s+(?P<mode>sys(dba|oper)))?',
-                                            re.IGNORECASE)
-        connection_modes = {'SYSDBA': cx_Oracle.SYSDBA, 'SYSOPER': cx_Oracle.SYSOPER}
-        oracle_connect_mode = 0
-        default_port = 1521
-        def assign_args(self, arg, opts):
-            connectargs = self.connection_parser.search(arg)
-            self.username = connectargs.group('username')
-            self.password = connectargs.group('password')
-            self.db_name = connectargs.group('db_name') or os.getenv('ORACLE_SID')
-            self.port = connectargs.group('port') or self.default_port
-            self.host = connectargs.group('host')
-            if self.host:
-                self.dsn = cx_Oracle.makedsn(self.host, self.port, self.db_name)
-            else:
-                self.dsn = self.db_name
-            if connectargs.group('mode'):
-                self.oracle_connect_mode = self.connection_modes.get(connectargs.group('mode').upper())
-        def calculated_uri(self):
-            if self.host:
-                result = DatabaseInstance.calculated_uri(self)
-            else:
-                result = '%s://%s:%s@%s' % (self.rdbms, self.username, self.password,
-                                            self.db_name)
-            if self.oracle_connect_mode:
-                result = "%s?mode=%d" % (result, self.oracle_connect_mode)
-            return result
-        def new_connection(self):
-            return cx_Oracle.connect(user = self.username, 
-                                      password = self.password,
-                                      dsn = self.dsn,
-                                      mode = self.oracle_connect_mode)
-            
-                                           
-except ImportError:
-    OracleDatabaseInstance = ImportFailure
         
 class MetadataDiscoveryThread(threading.Thread):
     def __init__(self, db_instance):
@@ -219,7 +267,7 @@
             except IOError:
                 pass
         self.db_instance.gerald.current = False
-        newgerald = self.db_instance.gerald_class(self.db_instance.username, self.db_instance.gerald_uri())
+        newgerald = gerald_classes[self.db_instance.rdbms](self.db_instance.username, self.db_instance.conn_data.gerald_uri())
         newgerald.descriptions = {}
         for (name, obj) in newgerald.schema.items():
             newgerald.descriptions[name] = ObjectDescriptor(name, obj)            
@@ -227,7 +275,6 @@
         newgerald.complete = True
         self.db_instance.gerald = newgerald
         self.db_instance.pickle()
-
-rdbms_types = {'oracle': OracleDatabaseInstance, 'mysql': MySQLDatabaseInstance, 'postgres': PostgresDatabaseInstance}
-                  
-        
\ No newline at end of file
+                 
+if __name__ == '__main__':
+    doctest.testmod()
\ No newline at end of file
--- a/sqlpython/sqlpython.py	Mon Feb 01 10:15:28 2010 -0500
+++ b/sqlpython/sqlpython.py	Mon Feb 01 15:56:15 2010 -0500
@@ -95,17 +95,6 @@
         self.no_instance()        
             
     legal_sql_word = pyparsing.Word(pyparsing.alphanums + '_$#')
-    legal_hostname = pyparsing.Word(pyparsing.alphanums + '_-.')('host') + pyparsing.Optional(
-        ':' + pyparsing.Word(pyparsing.nums)('port'))
-    oracle_connect_parser = legal_sql_word('username') + (
-                            pyparsing.Optional('/' + pyparsing.CharsNotIn('@')("password")) + 
-                            pyparsing.Optional('@' + pyparsing.Optional(legal_hostname + '/') +
-                                               legal_sql_word('db_name')) + 
-                            pyparsing.Optional(pyparsing.CaselessKeyword('as') + 
-                                               (pyparsing.CaselessKeyword('sysoper') ^ 
-                                                pyparsing.CaselessKeyword('sysdba'))('mode')))
-    postgresql_connect_parser = pyparsing.Optional(legal_sql_word('db_name') + 
-                                                   pyparsing.Optional(legal_sql_word('username')))
           
     def successfully_connect_to_number(self, arg):
         try:
@@ -131,17 +120,17 @@
                    cmd2.make_option('--oracle', action='store_true', help='Connect to an Oracle database'),
                    cmd2.make_option('--mysql', action='store_true', help='Connect to a MySQL database'),                   
                    cmd2.make_option('-H', '--hostname', type='string', 
-                                    help='Machine where database is hosted (postgresql only)'),                                  
+                                    help='Machine where database is hosted'),                                  
                    cmd2.make_option('-p', '--port', type='int', 
-                                    help='Port to connect to (postgresql only)'),                                  
+                                    help='Port to connect to'),                                  
                    cmd2.make_option('--password', type='string', 
-                                    help='Password (mysql only)'),                     
+                                    help='Password'),                     
                    cmd2.make_option('-d', '--database', type='string', 
                                     help='Database name to connect to'),
                    cmd2.make_option('-U', '--username', type='string', 
                                     help='Database user name to connect as'),
-                   cmd2.make_option('-u', '--user', type='string', 
-                                    help='Database user name to connect as')
+#                   cmd2.make_option('-u', '--user', type='string', 
+#                                    help='Database user name to connect as')
                    ])
     def do_connect(self, arg, opts):