view sqlpython/connections.py @ 427:8c1dec7fbd71

refactoring connections in progress
author catherine@dellzilla
date Mon, 25 Jan 2010 14:41:28 -0500
parents
children 6cd30b785885
line wrap: on
line source

import re
import os
import gerald
import schemagroup

class Connection(object):
    password = None
    uri = None
    connection_uri_parser = re.compile('(postgres|oracle|mysql|sqlite|mssql):/(.*$)', re.IGNORECASE)
    
    def __init__(self, arg, opts, default_rdbms = 'oracle'):
        self.default_rdbms = default_rdbms
        if not self.parse_connect_uri(arg):
            self.parse_connect_arg(arg, opts)
        self.reconnect()
        self.discover_schemas()
    
    def parse_connect_uri(self, uri):
        results = self.connection_uri_parser.search(uri)
        if results:
            (self.username, self.password, self.host, self.port, self.db_name
             ) = gerald.utilities.dburi.Connection().parse_uri(results.group(2))
            self.__class__ = rdbms_types.get(results.group(1))
            self.uri = uri
            self.port = self.port or self.default_port        
            return True
        else:
            return False
            
    def parse_connect_arg(self, arg, opts):
        self.password = opts.password    
        self.host = opts.hostname
        self.oracle_connect_mode = 0
        if opts.postgres:
            self.__class__ = PostgresConnection
        elif opts.mysql:
            self.__class__ = MySQLConnection
        elif opts.oracle:
            self.__class__ = OracleConnection
        else:
            self.__class__ = rdbms_types.get(self.default_rdbms)
        self.assign_args(arg, opts)
        self.db_name = opts.database or self.db_name
        self.port = self.port or self.default_port        
        self.uri = self.uri or '%s://%s:%s@%s:%s/%s' % (self.rdbms, self.username, self.password,
                                                         self.host, self.port, self.db_name)
    
    def gerald_uri(self):
        return self.uri.split('?mode=')[0]
        
    def reconnect(self):
        self.password = self.password or getpass.getpass('Password: ')
        self.connection = self.new_connection()

    def discover_schemas(self):
        self.schemas = schemagroup.SchemaDict(
            {}, rdbms = self.rdbms, user = self.username, 
            connection = self.connection, connection_string = self.gerald_uri())
        self.schemas.refresh_asynch()
    
    def set_connection_number(self, connection_number):
        self.connection_number = connection_number
        self.prompt = "%d:%s@%s> " % (self.connection_number, self.username, self.db_name)        

class OpenSourceConnection(Connection):
    def assign_args(self, opts, arg):
        self.assign_args(opts, arg)        
        self.username = username or os.environ['USER']
        self.db_name = self.db_name or self.username
        self.host = opts.host or self.host or 'localhost'

try:
    import psycopg2
    class PostgresConnection(OpenSourceConnection):
        rdbms = 'postgres'
        default_port = 5432
        def assign_details(self, arg, opts):
            self.port = os.getenv('PGPORT') or self.port
            self.host = self.host or os.getenv('PGHOST')
            args = arg.split()
            if len(args) > 1:
                self.username = args[1]
            if len(args) > 0:
                self.db_name = args[0]   
        def new_connection(self):
            return psycopg2.connect(host = self.host, user = self.username, 
                                     password = self.password, database = self.db_name,
                                     port = self.port)                
except ImportError:
    class PostgresConnection(OpenSourceConnection):
        pass
            
try:
    import MySQLdb
    class MySQLConnection(OpenSourceConnection):
        rdbms = 'mysql'
        default_port = 3306        
        def assign_details(self, arg, opts):
            self.db_name = arg
        def new_connection(self):
            return MySQLdb.connect(host = self.host, user = self.username, 
                                    passwd = self.password, db = self.db_name,
                                    port = self.port, sql_mode = 'ANSI')
except ImportError:
    class MySQLConnection(OpenSourceConnection):
        pass

try:
    import cx_Oracle
    
    class OracleConnection(Connection):
        rdbms = 'oracle'
        connection_parser = re.compile('(?P<username>[^/\s]*)(/(?P<password>[^/\s]*))?@((?P<host>[^/\s:]*)(:(?P<port>\d{1,4}))?/)?(?P<db_name>[^/\s:]*)(\s+as\s+(?P<mode>sys(dba|oper)))?',
                                            re.IGNORECASE)
        connection_modes = {'SYSDBA': cx_Oracle.SYSDBA, 'SYSOPER': cx_Oracle.SYSOPER}
        oracle_connect_mode = 0
        default_port = 1521
        def assign_args(self, arg, opts):
            connectargs = self.connection_parser.search(arg)
            self.username = connectargs.group('username')
            self.password = connectargs.group('password')
            self.db_name = connectargs.group('db_name')
            self.port = connectargs.group('port') or self.default_port
            self.host = connectargs.group('host')
            if self.host:
                self.dsn = cx_Oracle.makedsn(self.host, self.port, self.db_name)
            else:
                self.dsn = self.db_name
                self.uri = '%s://%s:%s@%s' % (self.rdbms, self.username, self.password, self.db_name)
            if connectargs.group('mode'):
                self.oracle_connect_mode = self.connection_modes.get(connectargs.group('mode').upper())
        def new_connection(self):
            return cx_Oracle.connect(user = self.username, 
                                      password = self.password,
                                      dsn = self.dsn,
                                      mode = self.oracle_connect_mode)
            
                                           
except ImportError:
    class OracleConnection(Connection):
        pass
                                       
rdbms_types = {'oracle': OracleConnection, 'mysql': MySQLConnection, 'postgres': PostgresConnection}